8201194: Handle local variable declarations in lambda deduplication

Reviewed-by: vromero
This commit is contained in:
Liam Miller-Cushon 2018-04-05 14:39:04 -07:00
parent f9e5a41e1a
commit 5acbe5ff92
5 changed files with 95 additions and 89 deletions

View file

@ -183,15 +183,7 @@ public class LambdaToMethod extends TreeTranslator {
public int hashCode() { public int hashCode() {
int hashCode = this.hashCode; int hashCode = this.hashCode;
if (hashCode == 0) { if (hashCode == 0) {
this.hashCode = hashCode = TreeHasher.hash(tree, sym -> { this.hashCode = hashCode = TreeHasher.hash(tree, symbol.params());
if (sym.owner == symbol) {
int idx = symbol.params().indexOf(sym);
if (idx != -1) {
return idx;
}
}
return null;
});
} }
return hashCode; return hashCode;
} }
@ -203,17 +195,7 @@ public class LambdaToMethod extends TreeTranslator {
} }
DedupedLambda that = (DedupedLambda) o; DedupedLambda that = (DedupedLambda) o;
return types.isSameType(symbol.asType(), that.symbol.asType()) return types.isSameType(symbol.asType(), that.symbol.asType())
&& new TreeDiffer((lhs, rhs) -> { && new TreeDiffer(symbol.params(), that.symbol.params()).scan(tree, that.tree);
if (lhs.owner == symbol) {
int idx = symbol.params().indexOf(lhs);
if (idx != -1) {
if (Objects.equals(idx, that.symbol.params().indexOf(rhs))) {
return true;
}
}
}
return null;
}).scan(tree, that.tree);
} }
} }

View file

@ -89,24 +89,34 @@ import com.sun.tools.javac.tree.JCTree.TypeBoundKind;
import com.sun.tools.javac.tree.TreeInfo; import com.sun.tools.javac.tree.TreeInfo;
import com.sun.tools.javac.tree.TreeScanner; import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.List; import com.sun.tools.javac.util.List;
import java.util.Collection;
import javax.lang.model.element.ElementKind; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Consumer;
/** A visitor that compares two lambda bodies for structural equality. */ /** A visitor that compares two lambda bodies for structural equality. */
public class TreeDiffer extends TreeScanner { public class TreeDiffer extends TreeScanner {
private BiFunction<Symbol, Symbol, Boolean> symbolDiffer; public TreeDiffer(
Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
this.equiv = equiv(symbols, otherSymbols);
}
public TreeDiffer(BiFunction<Symbol, Symbol, Boolean> symbolDiffer) { private static Map<Symbol, Symbol> equiv(
this.symbolDiffer = Objects.requireNonNull(symbolDiffer); Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
Map<Symbol, Symbol> result = new HashMap<>();
Iterator<? extends Symbol> it = otherSymbols.iterator();
for (Symbol symbol : symbols) {
if (!it.hasNext()) break;
result.put(symbol, it.next());
}
return result;
} }
private JCTree parameter; private JCTree parameter;
private boolean result; private boolean result;
private Map<Symbol, Symbol> equiv = new HashMap<>();
public boolean scan(JCTree tree, JCTree parameter) { public boolean scan(JCTree tree, JCTree parameter) {
if (tree == null || parameter == null) { if (tree == null || parameter == null) {
@ -172,9 +182,8 @@ public class TreeDiffer extends TreeScanner {
Symbol symbol = tree.sym; Symbol symbol = tree.sym;
Symbol otherSymbol = that.sym; Symbol otherSymbol = that.sym;
if (symbol != null && otherSymbol != null) { if (symbol != null && otherSymbol != null) {
Boolean tmp = symbolDiffer.apply(symbol, otherSymbol); if (Objects.equals(equiv.get(symbol), otherSymbol)) {
if (tmp != null) { result = true;
result = tmp;
return; return;
} }
} }
@ -598,6 +607,10 @@ public class TreeDiffer extends TreeScanner {
&& scan(tree.nameexpr, that.nameexpr) && scan(tree.nameexpr, that.nameexpr)
&& scan(tree.vartype, that.vartype) && scan(tree.vartype, that.vartype)
&& scan(tree.init, that.init); && scan(tree.init, that.init);
if (!result) {
return;
}
equiv.put(tree.sym, that.sym);
} }
@Override @Override

View file

@ -30,26 +30,31 @@ import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCFieldAccess; import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
import com.sun.tools.javac.tree.JCTree.JCIdent; import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCLiteral; import com.sun.tools.javac.tree.JCTree.JCLiteral;
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
import com.sun.tools.javac.tree.TreeInfo; import com.sun.tools.javac.tree.TreeInfo;
import com.sun.tools.javac.tree.TreeScanner; import com.sun.tools.javac.tree.TreeScanner;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.function.Function;
/** A tree visitor that computes a hash code. */ /** A tree visitor that computes a hash code. */
public class TreeHasher extends TreeScanner { public class TreeHasher extends TreeScanner {
private final Function<Symbol, Integer> symbolHasher; private final Map<Symbol, Integer> symbolHashes;
private int result = 17; private int result = 17;
public TreeHasher(Function<Symbol, Integer> symbolHasher) { public TreeHasher(Map<Symbol, Integer> symbolHashes) {
this.symbolHasher = Objects.requireNonNull(symbolHasher); this.symbolHashes = Objects.requireNonNull(symbolHashes);
} }
public static int hash(JCTree tree, Function<Symbol, Integer> symbolHasher) { public static int hash(JCTree tree, Collection<? extends Symbol> symbols) {
if (tree == null) { if (tree == null) {
return 0; return 0;
} }
TreeHasher hasher = new TreeHasher(symbolHasher); Map<Symbol, Integer> symbolHashes = new HashMap<>();
symbols.forEach(s -> symbolHashes.put(s, symbolHashes.size()));
TreeHasher hasher = new TreeHasher(symbolHashes);
tree.accept(hasher); tree.accept(hasher);
return hasher.result; return hasher.result;
} }
@ -85,7 +90,7 @@ public class TreeHasher extends TreeScanner {
public void visitIdent(JCIdent tree) { public void visitIdent(JCIdent tree) {
Symbol sym = tree.sym; Symbol sym = tree.sym;
if (sym != null) { if (sym != null) {
Integer hash = symbolHasher.apply(sym); Integer hash = symbolHashes.get(sym);
if (hash != null) { if (hash != null) {
hash(hash); hash(hash);
return; return;
@ -99,4 +104,10 @@ public class TreeHasher extends TreeScanner {
hash(tree.sym); hash(tree.sym);
super.visitSelect(tree); super.visitSelect(tree);
} }
@Override
public void visitVarDef(JCVariableDecl tree) {
symbolHashes.computeIfAbsent(tree.sym, k -> symbolHashes.size());
super.visitVarDef(tree);
}
} }

View file

@ -77,18 +77,45 @@ public class Deduplication {
group((Function<Integer, Integer>) y -> j); group((Function<Integer, Integer>) y -> j);
group( group(
(Function<Integer, Integer>) y -> { (Function<Integer, Integer>)
while (true) { y -> {
break; while (true) {
} break;
return 42; }
}, return 42;
(Function<Integer, Integer>) y -> { },
while (true) { (Function<Integer, Integer>)
break; y -> {
} while (true) {
return 42; break;
}); }
return 42;
});
group(
(Function<Integer, Integer>)
x -> {
int y = x;
return y;
},
(Function<Integer, Integer>)
x -> {
int y = x;
return y;
});
group(
(Function<Integer, Integer>)
x -> {
int y = 0, z = x;
return y;
});
group(
(Function<Integer, Integer>)
x -> {
int y = 0, z = x;
return z;
});
class Local { class Local {
int i; int i;

View file

@ -22,7 +22,7 @@
*/ */
/** /**
* @test 8200301 * @test 8200301 8201194
* @summary deduplicate lambda methods with the same body, target type, and captured state * @summary deduplicate lambda methods with the same body, target type, and captured state
* @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api * @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api
* jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp * jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp
@ -32,6 +32,7 @@
*/ */
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet; import static java.util.stream.Collectors.toSet;
@ -57,7 +58,6 @@ import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCLambda; import com.sun.tools.javac.tree.JCTree.JCLambda;
import com.sun.tools.javac.tree.JCTree.JCMethodInvocation; import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
import com.sun.tools.javac.tree.JCTree.JCTypeCast; import com.sun.tools.javac.tree.JCTree.JCTypeCast;
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
import com.sun.tools.javac.tree.JCTree.Tag; import com.sun.tools.javac.tree.JCTree.Tag;
import com.sun.tools.javac.tree.TreeScanner; import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.Context; import com.sun.tools.javac.util.Context;
@ -70,10 +70,8 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.function.BiFunction;
import javax.tools.Diagnostic; import javax.tools.Diagnostic;
import javax.tools.DiagnosticListener; import javax.tools.DiagnosticListener;
import javax.tools.JavaFileObject; import javax.tools.JavaFileObject;
@ -160,36 +158,9 @@ public class DeduplicationTest {
} }
} }
/** /** Returns the parameter symbols of the given lambda. */
* Returns a symbol comparator that treats symbols that correspond to the same parameter of each private static List<Symbol> paramSymbols(JCLambda lambda) {
* of the given lambdas as equal. return lambda.params.stream().map(x -> x.sym).collect(toList());
*/
private static BiFunction<Symbol, Symbol, Boolean> paramsEqual(JCLambda lhs, JCLambda rhs) {
return (x, y) -> {
Integer idx = paramIndex(lhs, x);
if (idx != null && idx != -1) {
if (Objects.equals(idx, paramIndex(rhs, y))) {
return true;
}
}
return null;
};
}
/**
* Returns the index of the given symbol as a parameter of the given lambda, or else {@code -1}
* if is not a parameter.
*/
private static Integer paramIndex(JCLambda lambda, Symbol sym) {
if (sym != null) {
int idx = 0;
for (JCVariableDecl param : lambda.params) {
if (sym == param.sym) {
return idx;
}
}
}
return null;
} }
/** A diagnostic listener that records debug messages related to lambda desugaring. */ /** A diagnostic listener that records debug messages related to lambda desugaring. */
@ -310,13 +281,14 @@ public class DeduplicationTest {
dedupedLambdas.put(lhs, first); dedupedLambdas.put(lhs, first);
} }
for (JCLambda rhs : curr) { for (JCLambda rhs : curr) {
if (!new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) { if (!new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
.scan(lhs.body, rhs.body)) {
throw new AssertionError( throw new AssertionError(
String.format( String.format(
"expected lambdas to be equal\n%s\n%s", lhs, rhs)); "expected lambdas to be equal\n%s\n%s", lhs, rhs));
} }
if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym)) if (TreeHasher.hash(lhs, paramSymbols(lhs))
!= TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) { != TreeHasher.hash(rhs, paramSymbols(rhs))) {
throw new AssertionError( throw new AssertionError(
String.format( String.format(
"expected lambdas to hash to the same value\n%s\n%s", "expected lambdas to hash to the same value\n%s\n%s",
@ -334,14 +306,15 @@ public class DeduplicationTest {
} }
for (JCLambda lhs : curr) { for (JCLambda lhs : curr) {
for (JCLambda rhs : lambdaGroups.get(j)) { for (JCLambda rhs : lambdaGroups.get(j)) {
if (new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) { if (new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
.scan(lhs.body, rhs.body)) {
throw new AssertionError( throw new AssertionError(
String.format( String.format(
"expected lambdas to not be equal\n%s\n%s", "expected lambdas to not be equal\n%s\n%s",
lhs, rhs)); lhs, rhs));
} }
if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym)) if (TreeHasher.hash(lhs, paramSymbols(lhs))
== TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) { == TreeHasher.hash(rhs, paramSymbols(rhs))) {
throw new AssertionError( throw new AssertionError(
String.format( String.format(
"expected lambdas to hash to different values\n%s\n%s", "expected lambdas to hash to different values\n%s\n%s",