mirror of
https://github.com/openjdk/jdk.git
synced 2025-09-22 03:54:33 +02:00
8201194: Handle local variable declarations in lambda deduplication
Reviewed-by: vromero
This commit is contained in:
parent
f9e5a41e1a
commit
5acbe5ff92
5 changed files with 95 additions and 89 deletions
|
@ -183,15 +183,7 @@ public class LambdaToMethod extends TreeTranslator {
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
int hashCode = this.hashCode;
|
int hashCode = this.hashCode;
|
||||||
if (hashCode == 0) {
|
if (hashCode == 0) {
|
||||||
this.hashCode = hashCode = TreeHasher.hash(tree, sym -> {
|
this.hashCode = hashCode = TreeHasher.hash(tree, symbol.params());
|
||||||
if (sym.owner == symbol) {
|
|
||||||
int idx = symbol.params().indexOf(sym);
|
|
||||||
if (idx != -1) {
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
return hashCode;
|
return hashCode;
|
||||||
}
|
}
|
||||||
|
@ -203,17 +195,7 @@ public class LambdaToMethod extends TreeTranslator {
|
||||||
}
|
}
|
||||||
DedupedLambda that = (DedupedLambda) o;
|
DedupedLambda that = (DedupedLambda) o;
|
||||||
return types.isSameType(symbol.asType(), that.symbol.asType())
|
return types.isSameType(symbol.asType(), that.symbol.asType())
|
||||||
&& new TreeDiffer((lhs, rhs) -> {
|
&& new TreeDiffer(symbol.params(), that.symbol.params()).scan(tree, that.tree);
|
||||||
if (lhs.owner == symbol) {
|
|
||||||
int idx = symbol.params().indexOf(lhs);
|
|
||||||
if (idx != -1) {
|
|
||||||
if (Objects.equals(idx, that.symbol.params().indexOf(rhs))) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}).scan(tree, that.tree);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -89,24 +89,34 @@ import com.sun.tools.javac.tree.JCTree.TypeBoundKind;
|
||||||
import com.sun.tools.javac.tree.TreeInfo;
|
import com.sun.tools.javac.tree.TreeInfo;
|
||||||
import com.sun.tools.javac.tree.TreeScanner;
|
import com.sun.tools.javac.tree.TreeScanner;
|
||||||
import com.sun.tools.javac.util.List;
|
import com.sun.tools.javac.util.List;
|
||||||
|
import java.util.Collection;
|
||||||
import javax.lang.model.element.ElementKind;
|
import java.util.HashMap;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.function.BiFunction;
|
|
||||||
import java.util.function.Consumer;
|
|
||||||
|
|
||||||
/** A visitor that compares two lambda bodies for structural equality. */
|
/** A visitor that compares two lambda bodies for structural equality. */
|
||||||
public class TreeDiffer extends TreeScanner {
|
public class TreeDiffer extends TreeScanner {
|
||||||
|
|
||||||
private BiFunction<Symbol, Symbol, Boolean> symbolDiffer;
|
public TreeDiffer(
|
||||||
|
Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
|
||||||
|
this.equiv = equiv(symbols, otherSymbols);
|
||||||
|
}
|
||||||
|
|
||||||
public TreeDiffer(BiFunction<Symbol, Symbol, Boolean> symbolDiffer) {
|
private static Map<Symbol, Symbol> equiv(
|
||||||
this.symbolDiffer = Objects.requireNonNull(symbolDiffer);
|
Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
|
||||||
|
Map<Symbol, Symbol> result = new HashMap<>();
|
||||||
|
Iterator<? extends Symbol> it = otherSymbols.iterator();
|
||||||
|
for (Symbol symbol : symbols) {
|
||||||
|
if (!it.hasNext()) break;
|
||||||
|
result.put(symbol, it.next());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private JCTree parameter;
|
private JCTree parameter;
|
||||||
private boolean result;
|
private boolean result;
|
||||||
|
private Map<Symbol, Symbol> equiv = new HashMap<>();
|
||||||
|
|
||||||
public boolean scan(JCTree tree, JCTree parameter) {
|
public boolean scan(JCTree tree, JCTree parameter) {
|
||||||
if (tree == null || parameter == null) {
|
if (tree == null || parameter == null) {
|
||||||
|
@ -172,9 +182,8 @@ public class TreeDiffer extends TreeScanner {
|
||||||
Symbol symbol = tree.sym;
|
Symbol symbol = tree.sym;
|
||||||
Symbol otherSymbol = that.sym;
|
Symbol otherSymbol = that.sym;
|
||||||
if (symbol != null && otherSymbol != null) {
|
if (symbol != null && otherSymbol != null) {
|
||||||
Boolean tmp = symbolDiffer.apply(symbol, otherSymbol);
|
if (Objects.equals(equiv.get(symbol), otherSymbol)) {
|
||||||
if (tmp != null) {
|
result = true;
|
||||||
result = tmp;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -598,6 +607,10 @@ public class TreeDiffer extends TreeScanner {
|
||||||
&& scan(tree.nameexpr, that.nameexpr)
|
&& scan(tree.nameexpr, that.nameexpr)
|
||||||
&& scan(tree.vartype, that.vartype)
|
&& scan(tree.vartype, that.vartype)
|
||||||
&& scan(tree.init, that.init);
|
&& scan(tree.init, that.init);
|
||||||
|
if (!result) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
equiv.put(tree.sym, that.sym);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -30,26 +30,31 @@ import com.sun.tools.javac.tree.JCTree;
|
||||||
import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
|
import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
|
||||||
import com.sun.tools.javac.tree.JCTree.JCIdent;
|
import com.sun.tools.javac.tree.JCTree.JCIdent;
|
||||||
import com.sun.tools.javac.tree.JCTree.JCLiteral;
|
import com.sun.tools.javac.tree.JCTree.JCLiteral;
|
||||||
|
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
|
||||||
import com.sun.tools.javac.tree.TreeInfo;
|
import com.sun.tools.javac.tree.TreeInfo;
|
||||||
import com.sun.tools.javac.tree.TreeScanner;
|
import com.sun.tools.javac.tree.TreeScanner;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.function.Function;
|
|
||||||
|
|
||||||
/** A tree visitor that computes a hash code. */
|
/** A tree visitor that computes a hash code. */
|
||||||
public class TreeHasher extends TreeScanner {
|
public class TreeHasher extends TreeScanner {
|
||||||
|
|
||||||
private final Function<Symbol, Integer> symbolHasher;
|
private final Map<Symbol, Integer> symbolHashes;
|
||||||
private int result = 17;
|
private int result = 17;
|
||||||
|
|
||||||
public TreeHasher(Function<Symbol, Integer> symbolHasher) {
|
public TreeHasher(Map<Symbol, Integer> symbolHashes) {
|
||||||
this.symbolHasher = Objects.requireNonNull(symbolHasher);
|
this.symbolHashes = Objects.requireNonNull(symbolHashes);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int hash(JCTree tree, Function<Symbol, Integer> symbolHasher) {
|
public static int hash(JCTree tree, Collection<? extends Symbol> symbols) {
|
||||||
if (tree == null) {
|
if (tree == null) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
TreeHasher hasher = new TreeHasher(symbolHasher);
|
Map<Symbol, Integer> symbolHashes = new HashMap<>();
|
||||||
|
symbols.forEach(s -> symbolHashes.put(s, symbolHashes.size()));
|
||||||
|
TreeHasher hasher = new TreeHasher(symbolHashes);
|
||||||
tree.accept(hasher);
|
tree.accept(hasher);
|
||||||
return hasher.result;
|
return hasher.result;
|
||||||
}
|
}
|
||||||
|
@ -85,7 +90,7 @@ public class TreeHasher extends TreeScanner {
|
||||||
public void visitIdent(JCIdent tree) {
|
public void visitIdent(JCIdent tree) {
|
||||||
Symbol sym = tree.sym;
|
Symbol sym = tree.sym;
|
||||||
if (sym != null) {
|
if (sym != null) {
|
||||||
Integer hash = symbolHasher.apply(sym);
|
Integer hash = symbolHashes.get(sym);
|
||||||
if (hash != null) {
|
if (hash != null) {
|
||||||
hash(hash);
|
hash(hash);
|
||||||
return;
|
return;
|
||||||
|
@ -99,4 +104,10 @@ public class TreeHasher extends TreeScanner {
|
||||||
hash(tree.sym);
|
hash(tree.sym);
|
||||||
super.visitSelect(tree);
|
super.visitSelect(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void visitVarDef(JCVariableDecl tree) {
|
||||||
|
symbolHashes.computeIfAbsent(tree.sym, k -> symbolHashes.size());
|
||||||
|
super.visitVarDef(tree);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,18 +77,45 @@ public class Deduplication {
|
||||||
group((Function<Integer, Integer>) y -> j);
|
group((Function<Integer, Integer>) y -> j);
|
||||||
|
|
||||||
group(
|
group(
|
||||||
(Function<Integer, Integer>) y -> {
|
(Function<Integer, Integer>)
|
||||||
while (true) {
|
y -> {
|
||||||
break;
|
while (true) {
|
||||||
}
|
break;
|
||||||
return 42;
|
}
|
||||||
},
|
return 42;
|
||||||
(Function<Integer, Integer>) y -> {
|
},
|
||||||
while (true) {
|
(Function<Integer, Integer>)
|
||||||
break;
|
y -> {
|
||||||
}
|
while (true) {
|
||||||
return 42;
|
break;
|
||||||
});
|
}
|
||||||
|
return 42;
|
||||||
|
});
|
||||||
|
|
||||||
|
group(
|
||||||
|
(Function<Integer, Integer>)
|
||||||
|
x -> {
|
||||||
|
int y = x;
|
||||||
|
return y;
|
||||||
|
},
|
||||||
|
(Function<Integer, Integer>)
|
||||||
|
x -> {
|
||||||
|
int y = x;
|
||||||
|
return y;
|
||||||
|
});
|
||||||
|
|
||||||
|
group(
|
||||||
|
(Function<Integer, Integer>)
|
||||||
|
x -> {
|
||||||
|
int y = 0, z = x;
|
||||||
|
return y;
|
||||||
|
});
|
||||||
|
group(
|
||||||
|
(Function<Integer, Integer>)
|
||||||
|
x -> {
|
||||||
|
int y = 0, z = x;
|
||||||
|
return z;
|
||||||
|
});
|
||||||
|
|
||||||
class Local {
|
class Local {
|
||||||
int i;
|
int i;
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @test 8200301
|
* @test 8200301 8201194
|
||||||
* @summary deduplicate lambda methods with the same body, target type, and captured state
|
* @summary deduplicate lambda methods with the same body, target type, and captured state
|
||||||
* @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api
|
* @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api
|
||||||
* jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp
|
* jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp
|
||||||
|
@ -32,6 +32,7 @@
|
||||||
*/
|
*/
|
||||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||||
import static java.util.stream.Collectors.joining;
|
import static java.util.stream.Collectors.joining;
|
||||||
|
import static java.util.stream.Collectors.toList;
|
||||||
import static java.util.stream.Collectors.toMap;
|
import static java.util.stream.Collectors.toMap;
|
||||||
import static java.util.stream.Collectors.toSet;
|
import static java.util.stream.Collectors.toSet;
|
||||||
|
|
||||||
|
@ -57,7 +58,6 @@ import com.sun.tools.javac.tree.JCTree.JCIdent;
|
||||||
import com.sun.tools.javac.tree.JCTree.JCLambda;
|
import com.sun.tools.javac.tree.JCTree.JCLambda;
|
||||||
import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
|
import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
|
||||||
import com.sun.tools.javac.tree.JCTree.JCTypeCast;
|
import com.sun.tools.javac.tree.JCTree.JCTypeCast;
|
||||||
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
|
|
||||||
import com.sun.tools.javac.tree.JCTree.Tag;
|
import com.sun.tools.javac.tree.JCTree.Tag;
|
||||||
import com.sun.tools.javac.tree.TreeScanner;
|
import com.sun.tools.javac.tree.TreeScanner;
|
||||||
import com.sun.tools.javac.util.Context;
|
import com.sun.tools.javac.util.Context;
|
||||||
|
@ -70,10 +70,8 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.TreeSet;
|
import java.util.TreeSet;
|
||||||
import java.util.function.BiFunction;
|
|
||||||
import javax.tools.Diagnostic;
|
import javax.tools.Diagnostic;
|
||||||
import javax.tools.DiagnosticListener;
|
import javax.tools.DiagnosticListener;
|
||||||
import javax.tools.JavaFileObject;
|
import javax.tools.JavaFileObject;
|
||||||
|
@ -160,36 +158,9 @@ public class DeduplicationTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** Returns the parameter symbols of the given lambda. */
|
||||||
* Returns a symbol comparator that treats symbols that correspond to the same parameter of each
|
private static List<Symbol> paramSymbols(JCLambda lambda) {
|
||||||
* of the given lambdas as equal.
|
return lambda.params.stream().map(x -> x.sym).collect(toList());
|
||||||
*/
|
|
||||||
private static BiFunction<Symbol, Symbol, Boolean> paramsEqual(JCLambda lhs, JCLambda rhs) {
|
|
||||||
return (x, y) -> {
|
|
||||||
Integer idx = paramIndex(lhs, x);
|
|
||||||
if (idx != null && idx != -1) {
|
|
||||||
if (Objects.equals(idx, paramIndex(rhs, y))) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the index of the given symbol as a parameter of the given lambda, or else {@code -1}
|
|
||||||
* if is not a parameter.
|
|
||||||
*/
|
|
||||||
private static Integer paramIndex(JCLambda lambda, Symbol sym) {
|
|
||||||
if (sym != null) {
|
|
||||||
int idx = 0;
|
|
||||||
for (JCVariableDecl param : lambda.params) {
|
|
||||||
if (sym == param.sym) {
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** A diagnostic listener that records debug messages related to lambda desugaring. */
|
/** A diagnostic listener that records debug messages related to lambda desugaring. */
|
||||||
|
@ -310,13 +281,14 @@ public class DeduplicationTest {
|
||||||
dedupedLambdas.put(lhs, first);
|
dedupedLambdas.put(lhs, first);
|
||||||
}
|
}
|
||||||
for (JCLambda rhs : curr) {
|
for (JCLambda rhs : curr) {
|
||||||
if (!new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) {
|
if (!new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
|
||||||
|
.scan(lhs.body, rhs.body)) {
|
||||||
throw new AssertionError(
|
throw new AssertionError(
|
||||||
String.format(
|
String.format(
|
||||||
"expected lambdas to be equal\n%s\n%s", lhs, rhs));
|
"expected lambdas to be equal\n%s\n%s", lhs, rhs));
|
||||||
}
|
}
|
||||||
if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym))
|
if (TreeHasher.hash(lhs, paramSymbols(lhs))
|
||||||
!= TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) {
|
!= TreeHasher.hash(rhs, paramSymbols(rhs))) {
|
||||||
throw new AssertionError(
|
throw new AssertionError(
|
||||||
String.format(
|
String.format(
|
||||||
"expected lambdas to hash to the same value\n%s\n%s",
|
"expected lambdas to hash to the same value\n%s\n%s",
|
||||||
|
@ -334,14 +306,15 @@ public class DeduplicationTest {
|
||||||
}
|
}
|
||||||
for (JCLambda lhs : curr) {
|
for (JCLambda lhs : curr) {
|
||||||
for (JCLambda rhs : lambdaGroups.get(j)) {
|
for (JCLambda rhs : lambdaGroups.get(j)) {
|
||||||
if (new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) {
|
if (new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
|
||||||
|
.scan(lhs.body, rhs.body)) {
|
||||||
throw new AssertionError(
|
throw new AssertionError(
|
||||||
String.format(
|
String.format(
|
||||||
"expected lambdas to not be equal\n%s\n%s",
|
"expected lambdas to not be equal\n%s\n%s",
|
||||||
lhs, rhs));
|
lhs, rhs));
|
||||||
}
|
}
|
||||||
if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym))
|
if (TreeHasher.hash(lhs, paramSymbols(lhs))
|
||||||
== TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) {
|
== TreeHasher.hash(rhs, paramSymbols(rhs))) {
|
||||||
throw new AssertionError(
|
throw new AssertionError(
|
||||||
String.format(
|
String.format(
|
||||||
"expected lambdas to hash to different values\n%s\n%s",
|
"expected lambdas to hash to different values\n%s\n%s",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue