From 537d035d4d9a5c516f754adfdd0d96fb0a3fd9c7 Mon Sep 17 00:00:00 2001 From: caramdache Date: Wed, 14 Sep 2022 11:41:19 +0200 Subject: [PATCH 1/9] Add support for filter expression in GroupConcat GroupConcat did not support fitler expression. This PR adds support based on Django `Count` aggregate. https://github.com/adamchainz/django-mysql/blob/main/src/django_mysql/models/aggregates.py --- src/django_mysql/models/aggregates.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index 417c8ab3..f2957560 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -28,6 +28,7 @@ class GroupConcat(Aggregate): def __init__( self, expression: Expression, + filter = None, distinct: bool = False, separator: str | None = None, ordering: str | None = None, @@ -38,7 +39,7 @@ def __init__( # This can/will be improved to SetTextField or ListTextField extra["output_field"] = CharField() - super().__init__(expression, **extra) + super().__init__(expression, filter=filter, **extra) self.distinct = distinct self.separator = separator @@ -53,6 +54,17 @@ def as_sql( connection: BaseDatabaseWrapper, **extra_context: Any, ) -> tuple[str, tuple[Any, ...]]: + if self.filter: + extra_context["distinct"] = "DISTINCT " if self.distinct else "" + copy = self.copy() + copy.filter = None + source_expressions = copy.get_source_expressions() + condition = When(self.filter, then=source_expressions[0]) + copy.set_source_expressions([Case(condition)] + source_expressions[1:]) + return super(Aggregate, copy).as_sql( + compiler, connection, **extra_context + ) + connection.ops.check_expression_support(self) sql = ["GROUP_CONCAT("] if self.distinct: From 452d3c8251bb1b8d076b47d5ec777a90d8a7bbc7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Sep 2022 09:42:08 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/django_mysql/models/aggregates.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index f2957560..0adc9a3c 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -28,7 +28,7 @@ class GroupConcat(Aggregate): def __init__( self, expression: Expression, - filter = None, + filter=None, distinct: bool = False, separator: str | None = None, ordering: str | None = None, @@ -61,9 +61,7 @@ def as_sql( source_expressions = copy.get_source_expressions() condition = When(self.filter, then=source_expressions[0]) copy.set_source_expressions([Case(condition)] + source_expressions[1:]) - return super(Aggregate, copy).as_sql( - compiler, connection, **extra_context - ) + return super(Aggregate, copy).as_sql(compiler, connection, **extra_context) connection.ops.check_expression_support(self) sql = ["GROUP_CONCAT("] From 24c0e35768643aa9044678eb2cbfdd34f6adfa52 Mon Sep 17 00:00:00 2001 From: caramdache Date: Wed, 14 Sep 2022 14:07:16 +0200 Subject: [PATCH 3/9] Update src/django_mysql/models/aggregates.py Co-authored-by: Adam Johnson --- src/django_mysql/models/aggregates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index 0adc9a3c..e1001396 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -28,7 +28,7 @@ class GroupConcat(Aggregate): def __init__( self, expression: Expression, - filter=None, + filter: Any | None=None, distinct: bool = False, separator: str | None = None, ordering: str | None = None, From 97ea273832c862424546a0918c53bf747f8b95aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Sep 2022 12:07:33 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/django_mysql/models/aggregates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index e1001396..a309181a 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -28,7 +28,7 @@ class GroupConcat(Aggregate): def __init__( self, expression: Expression, - filter: Any | None=None, + filter: Any | None = None, distinct: bool = False, separator: str | None = None, ordering: str | None = None, From 65b301c4339f7b5e06756368023eaf731c9f0c7a Mon Sep 17 00:00:00 2001 From: caramdache Date: Wed, 14 Sep 2022 14:10:16 +0200 Subject: [PATCH 5/9] Add missing name and allow_distinct variables I missed this in the first commit --- src/django_mysql/models/aggregates.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index a309181a..1562a85f 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -24,6 +24,8 @@ class BitXor(Aggregate): class GroupConcat(Aggregate): function = "GROUP_CONCAT" + name = "GroupConcat" + allow_distinct = True def __init__( self, From b3e4a2c7901b7be93722efb0a6845de281336363 Mon Sep 17 00:00:00 2001 From: caramdache Date: Wed, 14 Sep 2022 15:12:33 +0200 Subject: [PATCH 6/9] Add support for ordering and separator when filter is present --- src/django_mysql/models/aggregates.py | 32 +++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index 1562a85f..1efaa3d3 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -23,8 +23,10 @@ class BitXor(Aggregate): class GroupConcat(Aggregate): + template = "%(function)s(%(distinct)s%(expressions)s%(order_by)s%(separator)s)" function = "GROUP_CONCAT" name = "GroupConcat" + output_field = CharField() allow_distinct = True def __init__( @@ -56,6 +58,15 @@ def as_sql( connection: BaseDatabaseWrapper, **extra_context: Any, ) -> tuple[str, tuple[Any, ...]]: + def expr_sql(): + expr_parts = [] + params = [] + for arg in self.source_expressions: + arg_sql, arg_params = compiler.compile(arg) + expr_parts.append(arg_sql) + params.extend(arg_params) + return self.arg_joiner.join(expr_parts) + if self.filter: extra_context["distinct"] = "DISTINCT " if self.distinct else "" copy = self.copy() @@ -63,6 +74,19 @@ def as_sql( source_expressions = copy.get_source_expressions() condition = When(self.filter, then=source_expressions[0]) copy.set_source_expressions([Case(condition)] + source_expressions[1:]) + + extra_context["order_by"] = ( + f" ORDER BY {expr_sql()} {self.ordering}" + if self.ordering else + "" + ) + + extra_context["separator"] = ( + f" SEPARATOR '{self.separator}' " + if self.separator else + "" + ) + return super(Aggregate, copy).as_sql(compiler, connection, **extra_context) connection.ops.check_expression_support(self) @@ -70,13 +94,7 @@ def as_sql( if self.distinct: sql.append("DISTINCT ") - expr_parts = [] - params = [] - for arg in self.source_expressions: - arg_sql, arg_params = compiler.compile(arg) - expr_parts.append(arg_sql) - params.extend(arg_params) - expr_sql = self.arg_joiner.join(expr_parts) + expr_sql = expr_sql() sql.append(expr_sql) From 451597bbe1410b3055deda7ec3dead3ce3329d4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Sep 2022 13:12:54 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/django_mysql/models/aggregates.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index 1efaa3d3..a1cdab9a 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -76,15 +76,11 @@ def expr_sql(): copy.set_source_expressions([Case(condition)] + source_expressions[1:]) extra_context["order_by"] = ( - f" ORDER BY {expr_sql()} {self.ordering}" - if self.ordering else - "" + f" ORDER BY {expr_sql()} {self.ordering}" if self.ordering else "" ) extra_context["separator"] = ( - f" SEPARATOR '{self.separator}' " - if self.separator else - "" + f" SEPARATOR '{self.separator}' " if self.separator else "" ) return super(Aggregate, copy).as_sql(compiler, connection, **extra_context) From 96e7c9a349d4f356985b378c098206b18d60915f Mon Sep 17 00:00:00 2001 From: caramdache Date: Thu, 15 Sep 2022 09:30:52 +0200 Subject: [PATCH 8/9] Fix bug introduced in legacy path due to the introduction of an auxiliary function. --- src/django_mysql/models/aggregates.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index a1cdab9a..38a0efe9 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -65,7 +65,7 @@ def expr_sql(): arg_sql, arg_params = compiler.compile(arg) expr_parts.append(arg_sql) params.extend(arg_params) - return self.arg_joiner.join(expr_parts) + return self.arg_joiner.join(expr_parts), params if self.filter: extra_context["distinct"] = "DISTINCT " if self.distinct else "" @@ -75,12 +75,18 @@ def expr_sql(): condition = When(self.filter, then=source_expressions[0]) copy.set_source_expressions([Case(condition)] + source_expressions[1:]) + expr_sql, _ = expr_sql() + extra_context["order_by"] = ( - f" ORDER BY {expr_sql()} {self.ordering}" if self.ordering else "" + f" ORDER BY {expr_sql} {self.ordering}" + if self.ordering else + "" ) extra_context["separator"] = ( - f" SEPARATOR '{self.separator}' " if self.separator else "" + f" SEPARATOR '{self.separator}' " + if self.separator else + "" ) return super(Aggregate, copy).as_sql(compiler, connection, **extra_context) @@ -90,7 +96,7 @@ def expr_sql(): if self.distinct: sql.append("DISTINCT ") - expr_sql = expr_sql() + expr_sql, params = expr_sql() sql.append(expr_sql) From 8c5c5e13dda14224ad653cb65502b6d3a81ade42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Sep 2022 07:31:09 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/django_mysql/models/aggregates.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index 38a0efe9..a3b23307 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -76,17 +76,13 @@ def expr_sql(): copy.set_source_expressions([Case(condition)] + source_expressions[1:]) expr_sql, _ = expr_sql() - + extra_context["order_by"] = ( - f" ORDER BY {expr_sql} {self.ordering}" - if self.ordering else - "" + f" ORDER BY {expr_sql} {self.ordering}" if self.ordering else "" ) extra_context["separator"] = ( - f" SEPARATOR '{self.separator}' " - if self.separator else - "" + f" SEPARATOR '{self.separator}' " if self.separator else "" ) return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)