Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b8e105f

Browse files
yuanzhanhkuhighker
authored andcommitted
Disable AGG(IF()) rewrite for nondeterministic functions
Not sure if non-deterministic functions could cause issues. But to be safe, disable the rewrite when the IF contains any non-deterministic functions to avoid some corner cases.
1 parent 72644f6 commit b8e105f

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.facebook.presto.spi.relation.VariableReferenceExpression;
3030
import com.facebook.presto.sql.planner.iterative.Rule;
3131
import com.facebook.presto.sql.relational.Expressions;
32+
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
3233
import com.google.common.collect.ImmutableList;
3334
import com.google.common.collect.ImmutableMap;
3435
import com.google.common.collect.ImmutableSortedSet;
@@ -71,10 +72,12 @@ public class RewriteAggregationIfToFilter
7172
.with(source().matching(project().capturedAs(CHILD)));
7273

7374
private final FunctionAndTypeManager functionAndTypeManager;
75+
private final RowExpressionDeterminismEvaluator rowExpressionDeterminismEvaluator;
7476

7577
public RewriteAggregationIfToFilter(FunctionAndTypeManager functionAndTypeManager)
7678
{
7779
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
80+
rowExpressionDeterminismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
7881
}
7982

8083
@Override
@@ -200,7 +203,7 @@ private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode so
200203
return false;
201204
}
202205
RowExpression sourceExpression = sourceProject.getAssignments().get((VariableReferenceExpression) aggregation.getArguments().get(0));
203-
if (!(sourceExpression instanceof SpecialFormExpression)) {
206+
if (!(sourceExpression instanceof SpecialFormExpression) || !rowExpressionDeterminismEvaluator.isDeterministic(sourceExpression)) {
204207
return false;
205208
}
206209
SpecialFormExpression expression = (SpecialFormExpression) sourceExpression;

presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteAggregationIfToFilter.java

+26
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.Optional;
2828

2929
import static com.facebook.presto.common.type.BigintType.BIGINT;
30+
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
3031
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
3132
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
3233
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
@@ -72,6 +73,31 @@ public void testDoesNotFireForIfWithElse()
7273
}).doesNotFire();
7374
}
7475

76+
@Test
77+
public void testDoesNotFireForNonDeterministicFunction()
78+
{
79+
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
80+
.on(p -> {
81+
VariableReferenceExpression a = p.variable("a", DOUBLE);
82+
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
83+
return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
84+
.addAggregation(p.variable("expr"), p.rowExpression("sum(a)"))
85+
.source(p.project(
86+
assignment(a, p.rowExpression("IF(ds > '2021-07-01', random())")),
87+
p.values(ds))));
88+
}).doesNotFire();
89+
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
90+
.on(p -> {
91+
VariableReferenceExpression a = p.variable("a", BIGINT);
92+
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
93+
return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
94+
.addAggregation(p.variable("expr"), p.rowExpression("sum(a)"))
95+
.source(p.project(
96+
assignment(a, p.rowExpression("IF(random() > DOUBLE '0.1', 1)")),
97+
p.values(ds))));
98+
}).doesNotFire();
99+
}
100+
75101
@Test
76102
public void testFireOneAggregation()
77103
{

0 commit comments

Comments
 (0)