diff --git a/src/include/duckdb/optimizer/filter_pushdown.hpp b/src/include/duckdb/optimizer/filter_pushdown.hpp index 8a2bd63a2216..c2fc87a5249f 100644 --- a/src/include/duckdb/optimizer/filter_pushdown.hpp +++ b/src/include/duckdb/optimizer/filter_pushdown.hpp @@ -74,6 +74,9 @@ class FilterPushdown { unique_ptr PushdownLeftJoin(unique_ptr op, unordered_set &left_bindings, unordered_set &right_bindings); + // Pushdown an outer join + unique_ptr PushdownOuterJoin(unique_ptr op, unordered_set &left_bindings, + unordered_set &right_bindings); unique_ptr PushdownSemiAntiJoin(unique_ptr op); // Pushdown a mark join unique_ptr PushdownMarkJoin(unique_ptr op, unordered_set &left_bindings, diff --git a/src/include/duckdb/planner/expression_iterator.hpp b/src/include/duckdb/planner/expression_iterator.hpp index e40ab22cc3f3..5b2e2e8a81eb 100644 --- a/src/include/duckdb/planner/expression_iterator.hpp +++ b/src/include/duckdb/planner/expression_iterator.hpp @@ -27,6 +27,8 @@ class ExpressionIterator { static void EnumerateExpression(unique_ptr &expr, const std::function &callback); + static void EnumerateExpression(unique_ptr &expr, + const std::function &child)> &callback); static void VisitExpressionClass(const Expression &expr, ExpressionClass expr_class, const std::function &callback); diff --git a/src/optimizer/filter_pushdown.cpp b/src/optimizer/filter_pushdown.cpp index 461d797675fa..dcb79fe60ade 100644 --- a/src/optimizer/filter_pushdown.cpp +++ b/src/optimizer/filter_pushdown.cpp @@ -160,6 +160,9 @@ unique_ptr FilterPushdown::PushdownJoin(unique_ptr result; switch (join.join_type) { + case JoinType::OUTER: + result = PushdownOuterJoin(std::move(op), left_bindings, right_bindings); + break; case JoinType::INNER: // AsOf joins can't push anything into the RHS, so treat it as a left join if (op->type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { diff --git a/src/optimizer/pushdown/CMakeLists.txt b/src/optimizer/pushdown/CMakeLists.txt index 7efbea01f336..424768fb3c57 100644 --- a/src/optimizer/pushdown/CMakeLists.txt +++ b/src/optimizer/pushdown/CMakeLists.txt @@ -10,6 +10,7 @@ add_library_unity( pushdown_limit.cpp pushdown_left_join.cpp pushdown_mark_join.cpp + pushdown_outer_join.cpp pushdown_projection.cpp pushdown_semi_anti_join.cpp pushdown_set_operation.cpp diff --git a/src/optimizer/pushdown/pushdown_outer_join.cpp b/src/optimizer/pushdown/pushdown_outer_join.cpp new file mode 100644 index 000000000000..3d81c68c12f3 --- /dev/null +++ b/src/optimizer/pushdown/pushdown_outer_join.cpp @@ -0,0 +1,201 @@ +#include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" + +namespace duckdb { + +using Filter = FilterPushdown::Filter; + +//! A representation of a coalesce expression that removes unnecessary usages of +//! `coalesce` in its arguments when such usages don't affect the final value of +//! the expression. +//! E.g. `coalesce(a, coalesce(b, c))` is equivalent to `coalesce(a, b, c)` so +//! that similarly, with some abuse of notation, +//! `FlattenedCoalesce::Of({a, coalesce(b, c)}) == FlattenedCoalesce.Of({a, b, c})` +struct FlattenedCoalesce { +public: + vector> args; + + bool operator==(const FlattenedCoalesce &other) const { + if (args.size() != other.args.size()) { + return false; + } + + for (idx_t i = 0; i < args.size(); i++) { + if (!args[i].get().Equals(other.args[i].get())) { + return false; + } + } + + return true; + } + + static FlattenedCoalesce Of(const vector> &expressions) { + vector> args {}; + for (auto expr : expressions) { + EnumerateFlattenedCoalesceArgs(expr, [&](Expression &arg) { args.push_back(arg); }); + } + return {args}; + } + +private: + static void EnumerateFlattenedCoalesceArgs(Expression &expr, const std::function &callback) { + if (expr.GetExpressionType() == ExpressionType::OPERATOR_COALESCE) { + ExpressionIterator::EnumerateChildren( + expr, [&](Expression &arg) { EnumerateFlattenedCoalesceArgs(arg, callback); }); + } else { + callback(expr); + } + } +}; + +struct FlattenedCoalesceHash { + hash_t operator()(const FlattenedCoalesce &coalesce) const { + hash_t hash = 0; + for (auto arg : coalesce.args) { + hash = CombineHash(hash, arg.get().Hash()); + } + return hash; + } +}; + +//! Replace all occurrences of `exprs_to_replace` in `expr` with `replacement_expr` +static unique_ptr ReplaceIn(unique_ptr expr, const expression_set_t &exprs_to_replace, + const Expression &replacement_expr) { + ExpressionIterator::EnumerateExpression(expr, [&](unique_ptr &sub_expr) { + if (exprs_to_replace.find(*sub_expr) != exprs_to_replace.end()) { + sub_expr = replacement_expr.Copy(); + } + }); + + return std::move(expr); +} + +//! True if replacing all the `args` expressions occurring in `expr` with a +//! fixed constant would make the `expr` a scalar value. +static bool ExprIsFunctionOnlyOf(const Expression &expr, const expression_set_t &args) { + auto expr_to_check = expr.Copy(); + + ExpressionIterator::EnumerateExpression(expr_to_check, [&](unique_ptr &sub_expr) { + if (args.find(*sub_expr) != args.end()) { + auto null_value = make_uniq(Value(sub_expr->return_type)); + sub_expr = std::move(null_value); + } + }); + + return expr_to_check->IsScalar(); +} + +//! Whenever a filter is of the form `P(coalesce(l, r))` or `P(coalesce(r, l))` +//! where `P` is some predicate that depends only on `coalesce(l, r)` and there +//! is a join condition of the form `l = r` where `l` and `r` are join keys for +//! the left and right table respectively, then pushdown `P(l)` to the left +//! table, `P(r)` to the right table, and remove the original filter. +static bool +PushDownFiltersOnCoalescedEqualJoinKeys(vector> &filters, + const vector &join_conditions, + const std::function filter)> &pushdown_left, + const std::function filter)> &pushdown_right) { + // Generate set of all possible coalesced join keys expressions to later + // discover filters on such expressions which are candidates for pushdown + unordered_map, FlattenedCoalesceHash> + join_cond_by_coalesced_join_keys; + + for (auto &cond : join_conditions) { + if (cond.comparison == ExpressionType::COMPARE_EQUAL) { + auto left = std::ref(*cond.left); + auto right = std::ref(*cond.right); + auto coalesce_left_right = FlattenedCoalesce::Of({left, right}); + auto coalesce_right_left = FlattenedCoalesce::Of({right, left}); + join_cond_by_coalesced_join_keys.emplace(coalesce_left_right, std::ref(cond)); + join_cond_by_coalesced_join_keys.emplace(coalesce_right_left, std::ref(cond)); + } + } + + if (join_cond_by_coalesced_join_keys.empty()) { + return false; + } + + bool has_applied_pushdown = false; + for (idx_t i = 0; i < filters.size(); i++) { + auto &filter = filters[i]->filter; + if (filter->IsVolatile() || filter->CanThrow()) { + continue; + } + + // occurrences of equivalent coalesce expressions on the same join keys + // which need to be replaced if the filter is to be pushed down + expression_set_t coalesce_exprs_to_replace; + const JoinCondition *join_cond_ptr = nullptr; + bool many_non_equivalent_coalesce_exprs = false; + + ExpressionIterator::EnumerateExpression(filter, [&](Expression &sub_expr) { + if (many_non_equivalent_coalesce_exprs || + sub_expr.GetExpressionType() != ExpressionType::OPERATOR_COALESCE) { + return; + } + + auto sub_expr_flattened_coalesce = FlattenedCoalesce::Of({sub_expr}); + auto join_cond_it = join_cond_by_coalesced_join_keys.find(sub_expr_flattened_coalesce); + if (join_cond_it == join_cond_by_coalesced_join_keys.end()) { + return; + } + + auto new_join_cond_ptr = &join_cond_it->second.get(); + if (join_cond_ptr && new_join_cond_ptr != join_cond_ptr) { + many_non_equivalent_coalesce_exprs = true; + return; + } + + join_cond_ptr = new_join_cond_ptr; + coalesce_exprs_to_replace.insert(sub_expr); + }); + + if (coalesce_exprs_to_replace.empty() || many_non_equivalent_coalesce_exprs || + !ExprIsFunctionOnlyOf(*filter, coalesce_exprs_to_replace)) { + continue; + } + + auto left_filter = ReplaceIn(filter->Copy(), coalesce_exprs_to_replace, *join_cond_ptr->left); + auto right_filter = ReplaceIn(filter->Copy(), coalesce_exprs_to_replace, *join_cond_ptr->right); + pushdown_left(std::move(left_filter)); + pushdown_right(std::move(right_filter)); + filters.erase_at(i); + has_applied_pushdown = true; + i--; + } + + return has_applied_pushdown; +} + +unique_ptr FilterPushdown::PushdownOuterJoin(unique_ptr op, + unordered_set &left_bindings, + unordered_set &right_bindings) { + + if (op->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return FinishPushdown(std::move(op)); + } + + auto &join = op->Cast(); + D_ASSERT(join.join_type == JoinType::OUTER); + + FilterPushdown left_pushdown(optimizer, convert_mark_joins), right_pushdown(optimizer, convert_mark_joins); + auto has_applied_pushdown = PushDownFiltersOnCoalescedEqualJoinKeys( + filters, join.conditions, [&](unique_ptr filter) { left_pushdown.AddFilter(std::move(filter)); }, + [&](unique_ptr filter) { right_pushdown.AddFilter(std::move(filter)); }); + + if (!has_applied_pushdown) { + return FinishPushdown(std::move(op)); + } + + left_pushdown.GenerateFilters(); + right_pushdown.GenerateFilters(); + op->children[0] = left_pushdown.Rewrite(std::move(op->children[0])); + op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); + return PushFinalFilters(std::move(op)); +} + +} // namespace duckdb diff --git a/src/planner/expression_iterator.cpp b/src/planner/expression_iterator.cpp index 2c6ef4ae7572..8f3140e176eb 100644 --- a/src/planner/expression_iterator.cpp +++ b/src/planner/expression_iterator.cpp @@ -152,6 +152,16 @@ void ExpressionIterator::EnumerateExpression(unique_ptr &expr, [&](unique_ptr &child) { EnumerateExpression(child, callback); }); } +void ExpressionIterator::EnumerateExpression(unique_ptr &expr, + const std::function &child)> &callback) { + if (!expr) { + return; + } + callback(expr); + ExpressionIterator::EnumerateChildren(*expr, + [&](unique_ptr &child) { EnumerateExpression(child, callback); }); +} + void ExpressionIterator::VisitExpressionClass(const Expression &expr, ExpressionClass expr_class, const std::function &callback) { if (expr.GetExpressionClass() == expr_class) { diff --git a/test/optimizer/pushdown/pushdown_filter_on_coalesced_equal_outer_join_keys.test b/test/optimizer/pushdown/pushdown_filter_on_coalesced_equal_outer_join_keys.test new file mode 100644 index 000000000000..329cf462056f --- /dev/null +++ b/test/optimizer/pushdown/pushdown_filter_on_coalesced_equal_outer_join_keys.test @@ -0,0 +1,181 @@ +# name: test/optimizer/pushdown/pushdown_filter_on_coalesced_equal_outer_join_keys.test +# description: Test pushdown of filters on coalesced join keys compared for equality in the join condition +# group: [pushdown] + +# enable query verification +statement ok +PRAGMA enable_verification + +statement ok +CREATE TABLE t1 AS FROM VALUES +(1), +(2), +(NULL), +(4) t(id); + +statement ok +CREATE TABLE t2 AS FROM VALUES +(1), +(3), +(NULL), +(4) t(id); + +statement ok +CREATE TABLE t3 AS FROM VALUES +(1), +(3), +(NULL), +(4) t(id); + +query I +SELECT id FROM t1 FULL OUTER JOIN t2 USING (id) WHERE id >=2 ORDER BY id +---- +2 +3 +4 + +# should find all NULL rows correctly +query I +SELECT id FROM t1 FULL OUTER JOIN t2 USING (id) WHERE id IS NULL +---- +NULL +NULL + +statement ok +set explain_output='optimized_only'; + +# optimized plan is equivalent to plan with manually pushed down filter +query II nosort single_join +EXPLAIN SELECT id +FROM (SELECT id FROM t1) FULL OUTER JOIN (SELECT id FROM t2) USING (id) +WHERE id >= 2 +---- + +query II nosort single_join +EXPLAIN SELECT id +FROM (SELECT id FROM t1 WHERE id >= 2) FULL OUTER JOIN (SELECT id FROM t2 WHERE id >= 2) USING (id) +---- + +# optimized plan is equivalent to plan with manually pushed down filter when using IS NULL as a filtering predicate +query II nosort single_join_isnull_filter +EXPLAIN SELECT id +FROM (SELECT id FROM t1) FULL OUTER JOIN (SELECT id FROM t2) USING (id) +WHERE id IS NULL +---- + +query II nosort single_join_isnull_filter +EXPLAIN SELECT id +FROM (SELECT id FROM t1 WHERE id IS NULL) FULL OUTER JOIN (SELECT id FROM t2 WHERE id IS NULL) USING (id) +---- + +# optimized plan is equivalent to plan with manually pushed down filter in the case of multiple joins +query II nosort multiple_joins +EXPLAIN SELECT id +FROM (SELECT id FROM t1) FULL OUTER JOIN (SELECT id FROM t2) USING (id) FULL OUTER JOIN (SELECT id FROM t3) USING (id) WHERE id >= 2; +---- + +query II nosort multiple_joins +EXPLAIN SELECT id +FROM (SELECT id FROM t1 WHERE id >= 2) +FULL OUTER JOIN (SELECT id FROM t2 WHERE id >= 2) USING (id) +FULL OUTER JOIN (SELECT id FROM t3 WHERE id >= 2) USING (id); +---- + +# should pushdown filter with multiple occurrences of the same coalesced join keys +query II nosort multiple_occurrences_of_the_same_coalesced_join_keys +EXPLAIN +SELECT id FROM (SELECT id FROM t1) FULL OUTER JOIN (SELECT id FROM t2) USING (id) +WHERE id >= 2 OR id IN (1, 4) OR id IS NULL; +---- + +query II nosort multiple_occurrences_of_the_same_coalesced_join_keys +EXPLAIN SELECT id +FROM (SELECT id FROM t1 WHERE id >= 2 OR id IN (1, 4) OR id IS NULL) +FULL OUTER JOIN (SELECT id FROM t2 WHERE id >= 2 OR id IN (1, 4) OR id IS NULL) +USING (id); +---- + +# should pushdown filter containing different but equivalent coalesced join keys +query II +EXPLAIN SELECT t1.id, t2.id +FROM t1 FULL OUTER JOIN t2 ON t1.id = t2.id +WHERE COALESCE(t1.id, t2.id) >= 2 OR COALESCE(t2.id, t1.id) IS NULL; +---- +logical_opt :.*SEQ_SCAN.*Filters.*\(id >= 2\) OR \(id IS NULL\).* + +# should not pushdown a volatile filter +query II +EXPLAIN SELECT id +FROM t1 FULL OUTER JOIN t2 USING (id) +WHERE trunc(random() * id) >= 2 +---- +logical_opt :.*SEQ_SCAN.*Filters.* + +# coalescing right and left is optimized like coalescing left and right +query II nosort left_right_coalesce +EXPLAIN SELECT t1.id, t2.id +FROM t1 FULL OUTER JOIN t2 ON t1.id = t2.id +WHERE COALESCE(t1.id, t2.id) >= 2; +---- + +query II nosort left_right_coalesce +EXPLAIN SELECT t1.id, t2.id +FROM t1 FULL OUTER JOIN t2 ON t1.id = t2.id +WHERE COALESCE(t2.id, t1.id) >= 2; +---- + +statement ok +CREATE TABLE t4 AS FROM VALUES + (1, 20), + (2, NULL), + (3, 16) + t(id, a); + +statement ok +CREATE TABLE t5 AS FROM VALUES + (1, NULL), + (1, 30) + t(id, a); + +# should not pushdown filter on coalesced keys that also depend on other columns +query II +EXPLAIN SELECT id FROM t1 FULL OUTER JOIN t4 USING (id) WHERE id = a; +---- +logical_opt :.*SEQ_SCAN.*Filters.* + +# should not pushdown single filter containing coalesced keys from different join conditions +query II +EXPLAIN SELECT id FROM t4 FULL OUTER JOIN t5 USING (id, a) WHERE id IN (a, 1) +---- +logical_opt :.*SEQ_SCAN.*Filters.* + +query II nosort nullif_func_join_keys +EXPLAIN SELECT * +FROM (SELECT id, a FROM t4 WHERE nullif(id, a) < 3) AS t4 +FULL OUTER JOIN (SELECT id, a FROM t5 WHERE nullif(id, a) < 3) AS t5 +ON nullif(t4.id, t4.a) = nullif(t5.id, t5.a) +---- + +# should pushdown filter containing coalesced keys which are functions of the input table +query II nosort nullif_func_join_keys +EXPLAIN SELECT * +FROM (SELECT id, a FROM t4) AS t4 +FULL OUTER JOIN (SELECT id, a FROM t5) AS t5 +ON nullif(t4.id, t4.a) = nullif(t5.id, t5.a) +WHERE coalesce(nullif(t4.id, t4.a), nullif(t5.id, t5.a)) < 3; +---- + +query II nosort list_func_join_keys +EXPLAIN SELECT * FROM +(SELECT id, a FROM t4) AS t4 +FULL OUTER JOIN (SELECT id, a FROM t5) AS t5 +ON [t4.id, t4.a] = [t5.id, t5.a] +WHERE coalesce([t4.id, t4.a], [t5.id, t5.a])[0] < 4; +---- + +query II nosort list_func_join_keys +EXPLAIN SELECT * FROM +(SELECT id, a FROM t4, WHERE [id, a][0] < 4) AS t4 +FULL OUTER JOIN (SELECT id, a FROM t5 WHERE [id, a][0] < 4) AS t5 +ON [t4.id, t4.a] = [t5.id, t5.a] +----