diff --git a/src/include/planner/operator/logical_union.h b/src/include/planner/operator/logical_union.h index 914c4428d..604053320 100644 --- a/src/include/planner/operator/logical_union.h +++ b/src/include/planner/operator/logical_union.h @@ -12,6 +12,10 @@ class LogicalUnion : public LogicalOperator { : LogicalOperator{LogicalOperatorType::UNION_ALL, children}, expressionsToUnion{std::move(expressions)} {} + void setChildProjections(std::vector projections) { + childProjections = std::move(projections); + } + f_group_pos_set getGroupsPosToFlatten(uint32_t childIdx); void computeFactorizedSchema() override; @@ -21,6 +25,10 @@ class LogicalUnion : public LogicalOperator { binder::expression_vector getExpressionsToUnion() const { return expressionsToUnion; } + const std::vector& getChildProjections() const { + return childProjections; + } + Schema* getSchemaBeforeUnion(uint32_t idx) const { return children[idx]->getSchema(); } std::unique_ptr copy() override; @@ -32,6 +40,11 @@ class LogicalUnion : public LogicalOperator { private: binder::expression_vector expressionsToUnion; + // Non-deduplicated per-child projection lists, indexed by child ordinal then column. + // This preserves the positional correspondence with expressionsToUnion even when a + // child projects the same expression more than once (which the schema's + // expressionsInScope deduplicates). + std::vector childProjections; }; } // namespace planner diff --git a/src/planner/operator/logical_union.cpp b/src/planner/operator/logical_union.cpp index b9fdc1ba6..e0057d055 100644 --- a/src/planner/operator/logical_union.cpp +++ b/src/planner/operator/logical_union.cpp @@ -11,7 +11,11 @@ f_group_pos_set LogicalUnion::getGroupsPosToFlatten(uint32_t childIdx) { auto childSchema = children[childIdx]->getSchema(); for (auto i = 0u; i < expressionsToUnion.size(); ++i) { if (requireFlatExpression(i)) { - auto expression = childSchema->getExpressionsInScope()[i]; + // Use the non-deduplicated projection list rather than the child schema's + // expressionsInScope, which may be shorter when a child projects the same + // expression more than once (e.g. RETURN b.age, b.age). + DASSERT(childIdx < childProjections.size()); + auto expression = childProjections[childIdx][i]; groupsPos.insert(childSchema->getGroupPos(*expression)); } } @@ -39,13 +43,20 @@ std::unique_ptr LogicalUnion::copy() { for (auto i = 0u; i < getNumChildren(); ++i) { copiedChildren.push_back(getChild(i)->copy()); } - return make_unique(expressionsToUnion, std::move(copiedChildren)); + auto result = make_unique(expressionsToUnion, std::move(copiedChildren)); + result->setChildProjections(childProjections); + return result; } bool LogicalUnion::requireFlatExpression(uint32_t expressionIdx) { - for (auto& child : children) { - auto childSchema = child->getSchema(); - auto expression = childSchema->getExpressionsInScope()[expressionIdx]; + for (auto childIdx = 0u; childIdx < children.size(); ++childIdx) { + auto childSchema = children[childIdx]->getSchema(); + // Use the non-deduplicated projection list; indexing by unique name would not + // work because different arms may have different expressions at the same + // position (only types are validated to match). + DASSERT(childIdx < childProjections.size()); + DASSERT(expressionIdx < childProjections[childIdx].size()); + auto expression = childProjections[childIdx][expressionIdx]; if (childSchema->getGroup(expression)->isFlat()) { return true; } diff --git a/src/planner/query_planner.cpp b/src/planner/query_planner.cpp index ea72191ac..cc3574ffe 100644 --- a/src/planner/query_planner.cpp +++ b/src/planner/query_planner.cpp @@ -1,4 +1,5 @@ #include "binder/query/bound_regular_query.h" +#include "planner/operator/logical_projection.h" #include "planner/operator/logical_union.h" #include "planner/planner.h" @@ -27,11 +28,27 @@ LogicalPlan Planner::createUnionPlan(std::vector& childrenPlans, auto plan = LogicalPlan(); std::vector> children; children.reserve(childrenPlans.size()); + std::vector childProjections; + childProjections.reserve(childrenPlans.size()); for (auto& childPlan : childrenPlans) { children.push_back(childPlan.getLastOperator()); + // Record each child's non-deduplicated projection list so that + // LogicalUnion can look up expressions positionally without indexing + // into the schema's deduplicated expressionsInScope. + // Only LogicalProjection deduplicates via insertToScopeMayRepeat; + // other operator types (e.g. LogicalDelete) keep the full arity in + // getExpressionsInScope, so we fall back to that. + auto* lastOp = childPlan.getLastOperator().get(); + if (lastOp->getOperatorType() == LogicalOperatorType::PROJECTION) { + auto& projection = lastOp->constCast(); + childProjections.push_back(projection.getExpressionsToProject()); + } else { + childProjections.push_back(lastOp->getSchema()->getExpressionsInScope()); + } } // we compute the schema based on first child auto union_ = std::make_shared(expressions, std::move(children)); + union_->setChildProjections(std::move(childProjections)); for (auto i = 0u; i < childrenPlans.size(); ++i) { appendFlattens(union_->getGroupsPosToFlatten(i), childrenPlans[i]); union_->setChild(i, childrenPlans[i].getLastOperator()); diff --git a/src/processor/map/map_union.cpp b/src/processor/map/map_union.cpp index 3d38c1254..6366926ff 100644 --- a/src/processor/map/map_union.cpp +++ b/src/processor/map/map_union.cpp @@ -19,8 +19,11 @@ std::unique_ptr PlanMapper::mapUnionAll(const LogicalOperator* auto child = logicalOperator->getChild(i); auto childSchema = logicalUnionAll.getSchemaBeforeUnion(i); auto prevOperator = mapOperator(child.get()); + // Use the child's non-deduplicated projection list so the factorized table has + // the correct number of columns (matching expressionsToUnion.size()), even when + // the child projects the same expression more than once (e.g. RETURN b.age, b.age). auto resultCollector = createResultCollector(AccumulateType::REGULAR, - childSchema->getExpressionsInScope(), childSchema, std::move(prevOperator)); + logicalUnionAll.getChildProjections()[i], childSchema, std::move(prevOperator)); tables.push_back(resultCollector->getResultFTable()); prevOperators.push_back(std::move(resultCollector)); } diff --git a/test/test_files/issue/issue.test b/test/test_files/issue/issue.test index 84bc300e6..ea7367526 100644 --- a/test/test_files/issue/issue.test +++ b/test/test_files/issue/issue.test @@ -846,3 +846,38 @@ h -STATEMENT MATCH (a:A_uuid_count) RETURN COUNT(*), COUNT(DISTINCT a.id); ---- 1 1|1 + +-CASE 619 +-STATEMENT CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY (name)); +---- ok +-STATEMENT CREATE REL TABLE Knows(FROM Person TO Person); +---- ok +-STATEMENT CREATE (p1:Person {name: 'Alice', age: 30}); +---- ok +-STATEMENT CREATE (p2:Person {name: 'Bob', age: 25}); +---- ok +-STATEMENT CREATE (p3:Person {name: 'Charlie', age: 35}); +---- ok +-STATEMENT MATCH (a:Person), (b:Person), (c:Person) WHERE a.name = 'Alice' AND b.name = 'Bob' AND c.name = 'Charlie' CREATE (a)-[:Knows]->(b)-[:Knows]->(c); +---- ok +-STATEMENT MATCH p = (a)-[*1..2]-(b)-[*1..2]-(c) WHERE a.name = 'Alice' AND b.name = 'Bob' AND c.name = 'Charlie' RETURN any(x IN [1] WHERE p IS NOT NULL); +---- 1 +True + +-CASE 620 +-STATEMENT CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY (name)); +---- ok +-STATEMENT CREATE (p1:Person {name: 'Alice', age: 30}); +---- ok +-STATEMENT CREATE (p2:Person {name: 'Bob', age: 25}); +---- ok +-STATEMENT CREATE (p3:Person {name: 'Charlie', age: 35}); +---- ok +-STATEMENT MATCH (a:Person) RETURN 1, 2 UNION ALL MATCH (b:Person) RETURN b.age, b.age; +---- 6 +30|30 +25|25 +35|35 +1|2 +1|2 +1|2