Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 113 additions & 23 deletions src/planner/binder/tableref/bind_pivot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,60 @@ static void ExtractPivotExpressions(ParsedExpression &expr, case_insensitive_set
expr, [&](ParsedExpression &child) { ExtractPivotExpressions(child, handled_columns); });
}

void ExtractPivotAggregateExpression(ClientContext &context, ParsedExpression &expr,
vector<reference<FunctionExpression>> &aggregates) {
if (expr.GetExpressionType() == ExpressionType::FUNCTION) {
auto &aggr_function = expr.Cast<FunctionExpression>();
struct ExtractPivotAggregateOperator {
using TYPE = vector<reference<FunctionExpression>>;

static void HandleAggregate(unique_ptr<ParsedExpression> &expr, FunctionExpression &aggr_function,
TYPE &aggregates) {
aggregates.push_back(aggr_function);
}
};

struct ReplacePivotAggregateOperator {
using TYPE = unique_ptr<ParsedExpression>;

static void HandleAggregate(unique_ptr<ParsedExpression> &expr, FunctionExpression &aggr_function,
TYPE &replacement_expr) {
if (replacement_expr->type != ExpressionType::COLUMN_REF) {
throw BinderException(*expr, "Pivot expression can only have one aggregate");
}
auto aggr = std::move(expr);
expr = std::move(replacement_expr);
replacement_expr = std::move(aggr);
}
};

template <class OP>
void TemplatedHandlePivotAggregate(ClientContext &context, unique_ptr<ParsedExpression> &expr,
typename OP::TYPE &aggregates) {
if (expr->GetExpressionType() == ExpressionType::FUNCTION) {
auto &aggr_function = expr->Cast<FunctionExpression>();

// check if this is an aggregate to ensure it is an aggregate and not a scalar function
EntryLookupInfo lookup_info(CatalogType::AGGREGATE_FUNCTION_ENTRY, aggr_function.function_name, expr);
EntryLookupInfo lookup_info(CatalogType::AGGREGATE_FUNCTION_ENTRY, aggr_function.function_name, *expr);
auto &entry = Catalog::GetEntry(context, aggr_function.catalog, aggr_function.schema, lookup_info);
if (entry.type == CatalogType::AGGREGATE_FUNCTION_ENTRY) {
// aggregate
aggregates.push_back(aggr_function);
OP::HandleAggregate(expr, aggr_function, aggregates);
return;
}
}
if (expr.GetExpressionType() == ExpressionType::COLUMN_REF) {
throw BinderException(expr, "Columns can only be referenced within the aggregate of a PIVOT expression");
if (expr->GetExpressionType() == ExpressionType::COLUMN_REF) {
throw BinderException(*expr, "Columns can only be referenced within the aggregate of a PIVOT expression");
}
ParsedExpressionIterator::EnumerateChildren(
expr, [&](ParsedExpression &child) { ExtractPivotAggregateExpression(context, child, aggregates); });
ParsedExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr<ParsedExpression> &child) {
TemplatedHandlePivotAggregate<OP>(context, child, aggregates);
});
}

void ExtractPivotAggregateExpression(ClientContext &context, unique_ptr<ParsedExpression> &expr,
vector<reference<FunctionExpression>> &aggregates) {
TemplatedHandlePivotAggregate<ExtractPivotAggregateOperator>(context, expr, aggregates);
}

void ReplacePivotAggregateExpression(ClientContext &context, unique_ptr<ParsedExpression> &expr,
unique_ptr<ParsedExpression> &replacement_expr) {
TemplatedHandlePivotAggregate<ReplacePivotAggregateOperator>(context, expr, replacement_expr);
}

static unique_ptr<SelectNode> ConstructInitialGrouping(PivotRef &ref, vector<unique_ptr<ParsedExpression>> all_columns,
Expand Down Expand Up @@ -149,7 +184,7 @@ static unique_ptr<SelectNode> PivotFilteredAggregate(ClientContext &context, Piv
auto copied_aggr = aggregate->Copy();

vector<reference<FunctionExpression>> aggregates;
ExtractPivotAggregateExpression(context, *copied_aggr, aggregates);
ExtractPivotAggregateExpression(context, copied_aggr, aggregates);
D_ASSERT(aggregates.size() == 1);

auto &aggr = aggregates[0].get().Cast<FunctionExpression>();
Expand All @@ -174,7 +209,7 @@ struct PivotBindState {
vector<string> internal_aggregate_names;
};

static unique_ptr<SelectNode> PivotInitialAggregate(PivotBindState &bind_state, PivotRef &ref,
static unique_ptr<SelectNode> PivotInitialAggregate(ClientContext &context, PivotBindState &bind_state, PivotRef &ref,
vector<unique_ptr<ParsedExpression>> all_columns,
const case_insensitive_set_t &handled_columns) {
auto subquery_stage1 = ConstructInitialGrouping(ref, std::move(all_columns), handled_columns);
Expand Down Expand Up @@ -205,10 +240,22 @@ static unique_ptr<SelectNode> PivotInitialAggregate(PivotBindState &bind_state,
// finally add the aggregates
for (auto &aggregate : ref.aggregates) {
auto aggregate_alias = "__internal_pivot_aggregate" + std::to_string(++aggregate_count);
bind_state.aggregate_names.push_back(aggregate->GetAlias());
auto aggr_name = aggregate->GetAlias();
if (aggr_name.empty() && ref.aggregates.size() > 1) {
aggr_name = aggregate->ToString();
}
if (!aggr_name.empty()) {
aggr_name = "_" + aggr_name;
}

unique_ptr<ParsedExpression> aggregate_ref;
aggregate_ref = make_uniq<ColumnRefExpression>(aggregate_alias);
ReplacePivotAggregateExpression(context, aggregate, aggregate_ref);

bind_state.aggregate_names.push_back(std::move(aggr_name));
bind_state.internal_aggregate_names.push_back(aggregate_alias);
aggregate->SetAlias(std::move(aggregate_alias));
subquery_stage1->select_list.push_back(std::move(aggregate));
aggregate_ref->SetAlias(std::move(aggregate_alias));
subquery_stage1->select_list.push_back(std::move(aggregate_ref));
}
return subquery_stage1;
}
Expand Down Expand Up @@ -280,6 +327,16 @@ static unique_ptr<SelectNode> PivotListAggregate(PivotBindState &bind_state, Piv
return subquery_stage2;
}

void ReplacePivotColumnRef(ParsedExpression &expr, const string &name) {
if (expr.type == ExpressionType::COLUMN_REF) {
auto &colref = expr.Cast<ColumnRefExpression>();
colref.column_names[0] = name;
return;
}
ParsedExpressionIterator::EnumerateChildren(expr,
[&](ParsedExpression &child) { ReplacePivotColumnRef(child, name); });
}

static unique_ptr<SelectNode> PivotFinalOperator(PivotBindState &bind_state, PivotRef &ref,
unique_ptr<SelectNode> subquery,
vector<PivotValueElement> pivot_values) {
Expand All @@ -295,7 +352,34 @@ static unique_ptr<SelectNode> PivotFinalOperator(PivotBindState &bind_state, Piv
bound_pivot->bound_aggregate_names = std::move(bind_state.aggregate_names);
bound_pivot->source = std::move(subquery_ref);

final_pivot_operator->select_list.push_back(make_uniq<StarExpression>());
for (auto &group_name : bound_pivot->bound_group_names) {
final_pivot_operator->select_list.push_back(make_uniq<ColumnRefExpression>(group_name));
}
// gather aggregate names
vector<string> aggregate_names;
for (auto &pivot_value : bound_pivot->bound_pivot_values) {
for (idx_t aggr_idx = 0; aggr_idx < ref.aggregates.size(); aggr_idx++) {
auto aggr = ref.aggregates[aggr_idx]->Copy();
auto &aggr_name = bound_pivot->bound_aggregate_names[aggr_idx];
auto pivot_aggr_name = pivot_value.name + aggr_name;
aggregate_names.push_back(std::move(pivot_aggr_name));
}
}
QueryResult::DeduplicateColumns(aggregate_names);

idx_t aggr_name_idx = 0;
for (idx_t pivot_idx = 0; pivot_idx < bound_pivot->bound_pivot_values.size(); pivot_idx++) {
for (idx_t aggr_idx = 0; aggr_idx < ref.aggregates.size(); aggr_idx++) {
auto aggr = ref.aggregates[aggr_idx]->Copy();
auto &pivot_aggr_name = aggregate_names[aggr_name_idx++];
// replace column ref with name
ReplacePivotColumnRef(*aggr, pivot_aggr_name);
aggr->alias = pivot_aggr_name;

final_pivot_operator->select_list.push_back(std::move(aggr));
}
}

final_pivot_operator->from_table = std::move(bound_pivot);
return final_pivot_operator;
}
Expand Down Expand Up @@ -325,6 +409,15 @@ void ExtractPivotAggregates(BoundTableRef &node, vector<unique_ptr<Expression>>
}
}

string GetPivotAggregateName(const PivotValueElement &pivot_value, const string &aggr_name, idx_t aggregate_count) {
auto name = pivot_value.name;
if (aggregate_count > 1 || !aggr_name.empty()) {
// if there are multiple aggregates specified we add the name of the aggregate as well
name += "_" + aggr_name;
}
return name;
}

unique_ptr<BoundTableRef> Binder::BindBoundPivot(PivotRef &ref) {
// bind the child table in a child binder
auto result = make_uniq<BoundPivotRef>();
Expand Down Expand Up @@ -355,11 +448,8 @@ unique_ptr<BoundTableRef> Binder::BindBoundPivot(PivotRef &ref) {
for (idx_t aggr_idx = 0; aggr_idx < ref.bound_aggregate_names.size(); aggr_idx++) {
auto &aggr = aggregates[aggr_idx];
auto &aggr_name = ref.bound_aggregate_names[aggr_idx];
auto name = pivot_value.name;
if (aggregates.size() > 1 || !aggr_name.empty()) {
// if there are multiple aggregates specified we add the name of the aggregate as well
name += "_" + (aggr_name.empty() ? aggr->GetName() : aggr_name);
}
auto name = pivot_value.name + aggr_name;

string pivot_str;
for (auto &value : pivot_value.values) {
auto str = value.ToString();
Expand Down Expand Up @@ -399,7 +489,7 @@ unique_ptr<SelectNode> Binder::BindPivot(PivotRef &ref, vector<unique_ptr<Parsed
throw BinderException(*aggr, "Pivot expression cannot contain window functions");
}
idx_t aggregate_count = pivot_aggregates.size();
ExtractPivotAggregateExpression(context, *aggr, pivot_aggregates);
ExtractPivotAggregateExpression(context, aggr, pivot_aggregates);
if (pivot_aggregates.size() != aggregate_count + 1) {
string error_str = pivot_aggregates.size() == aggregate_count
? "but no aggregates were found"
Expand Down Expand Up @@ -497,7 +587,7 @@ unique_ptr<SelectNode> Binder::BindPivot(PivotRef &ref, vector<unique_ptr<Parsed
PivotBindState bind_state;
// Pivot Stage 1
// SELECT {groups}, {pivots}, {aggregates} FROM {from_clause} GROUP BY {groups}, {pivots}
auto subquery_stage1 = PivotInitialAggregate(bind_state, ref, std::move(all_columns), handled_columns);
auto subquery_stage1 = PivotInitialAggregate(context, bind_state, ref, std::move(all_columns), handled_columns);

// Pivot stage 2
// SELECT {groups}, LIST({pivots}), LIST({aggregates}) FROM [Q1] GROUP BY {groups}
Expand Down
38 changes: 38 additions & 0 deletions test/sql/pivot/pivot_operator_expression.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# name: test/sql/pivot/pivot_operator_expression.test
# description: Test expressions in pivot syntax
# group: [pivot]

statement ok
PRAGMA enable_verification

# pivot using an enum
statement ok
CREATE OR REPLACE TABLE monthly_sales(empid INT, amount INT, month TEXT);

statement ok
INSERT INTO monthly_sales VALUES
(1, 10000, '1-JAN'),
(1, 400, '1-JAN'),
(2, 4500, '1-JAN'),
(2, 35000, '1-JAN'),
(1, 5000, '2-FEB'),
(1, 3000, '2-FEB'),
(2, 200, '2-FEB'),
(2, 90500, '2-FEB'),
(2, 2500, '3-MAR'),
(2, 9500, '3-MAR'),
(1, 8000, '4-APR'),
(1, 10000, '4-APR'),
(2, 800, '4-APR'),
(2, 4500, '4-APR');

query IIIII rowsort
PIVOT monthly_sales ON MONTH USING COALESCE(SUM(AMOUNT), 0)
----
1 10400 8000 0 18000
2 39500 90700 12000 5300

query I
SELECT mode(column_type) FROM (DESCRIBE PIVOT monthly_sales ON MONTH USING SUM(AMOUNT)::INTEGER)
----
INTEGER
Loading