diff --git a/docs/src/piccolo/query_clauses/as_of.rst b/docs/src/piccolo/query_clauses/as_of.rst index 2d8407e42..97527033b 100644 --- a/docs/src/piccolo/query_clauses/as_of.rst +++ b/docs/src/piccolo/query_clauses/as_of.rst @@ -3,6 +3,8 @@ as_of ===== +.. note:: Cockroach only. + You can use ``as_of`` clause with the following queries: * :ref:`Select` @@ -21,5 +23,3 @@ This generates an ``AS OF SYSTEM TIME`` clause. See `documentation `_. This is very useful for performance, as it will reduce transaction contention across a cluster. - -Currently only supported on Cockroach Engine. diff --git a/docs/src/piccolo/query_clauses/index.rst b/docs/src/piccolo/query_clauses/index.rst index abb3fa18b..f9167ff63 100644 --- a/docs/src/piccolo/query_clauses/index.rst +++ b/docs/src/piccolo/query_clauses/index.rst @@ -26,6 +26,7 @@ by modifying the return values. ./freeze ./group_by ./offset + ./on_conflict ./output ./returning diff --git a/docs/src/piccolo/query_clauses/on_conflict.rst b/docs/src/piccolo/query_clauses/on_conflict.rst new file mode 100644 index 000000000..9ae8444e3 --- /dev/null +++ b/docs/src/piccolo/query_clauses/on_conflict.rst @@ -0,0 +1,229 @@ +.. _on_conflict: + +on_conflict +=========== + +.. hint:: This is an advanced topic, and first time learners of Piccolo + can skip if they want. + +You can use the ``on_conflict`` clause with the following queries: + +* :ref:`Insert` + +Introduction +------------ + +When inserting rows into a table, if a unique constraint fails on one or more +of the rows, then the insertion fails. + +Using the ``on_conflict`` clause, we can instead tell the database to ignore +the error (using ``DO NOTHING``), or to update the row (using ``DO UPDATE``). + +This is sometimes called an **upsert** (update if it already exists else insert). + +Example data +------------ + +If we have the following table: + +.. code-block:: python + + class Band(Table): + name = Varchar(unique=True) + popularity = Integer() + +With this data: + +.. csv-table:: + :file: ./on_conflict/bands.csv + +Let's try inserting another row with the same ``name``, and we'll get an error: + +.. code-block:: python + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ) + Unique constraint error! + +``DO NOTHING`` +-------------- + +To ignore the error: + +.. code-block:: python + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO NOTHING" + ... ) + +If we fetch the data from the database, we'll see that it hasn't changed: + +.. code-block:: python + + >>> await Band.select().where(Band.name == "Pythonistas").first() + {'id': 1, 'name': 'Pythonistas', 'popularity': 1000} + + +``DO UPDATE`` +------------- + +Instead, if we want to update the ``popularity``: + +.. code-block:: python + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=[Band.popularity] + ... ) + +If we fetch the data from the database, we'll see that it was updated: + +.. code-block:: python + + >>> await Band.select().where(Band.name == "Pythonistas").first() + {'id': 1, 'name': 'Pythonistas', 'popularity': 1200} + +``target`` +---------- + +Using the ``target`` argument, we can specify which constraint we're concerned +with. By specifying ``target=Band.name`` we're only concerned with the unique +constraint for the ``band`` column. If you omit the ``target`` argument, then +it works for all constraints on the table. + +.. code-block:: python + :emphasize-lines: 5 + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO NOTHING", + ... target=Band.name + ... ) + +If you want to target a composite unique constraint, you can do so by passing +in a tuple of columns: + +.. code-block:: python + :emphasize-lines: 5 + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO NOTHING", + ... target=(Band.name, Band.popularity) + ... ) + +You can also specify the name of a constraint using a string: + +.. code-block:: python + :emphasize-lines: 5 + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO NOTHING", + ... target='some_constraint' + ... ) + +``values`` +---------- + +This lets us specify which values to update when a conflict occurs. + +By specifying a :class:`Column `, this means that +the new value for that column will be used: + +.. code-block:: python + :emphasize-lines: 6 + + # The new popularity will be 1200. + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=[Band.popularity] + ... ) + +Instead, we can specify a custom value using a tuple: + +.. code-block:: python + :emphasize-lines: 6 + + # The new popularity will be 1111. + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=[(Band.popularity, 1111)] + ... ) + +If we want to update all of the values, we can use :meth:`all_columns`. + +.. code-block:: python + :emphasize-lines: 5 + + >>> await Band.insert( + ... Band(id=1, name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=Band.all_columns() + ... ) + +``where`` +--------- + +This can be used with ``DO UPDATE``. It gives us more control over whether the +update should be made: + +.. code-block:: python + :emphasize-lines: 6 + + >>> await Band.insert( + ... Band(id=1, name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... values=[Band.popularity], + ... where=Band.popularity < 1000 + ... ) + +Multiple ``on_conflict`` clauses +-------------------------------- + +SQLite allows you to specify multiple ``ON CONFLICT`` clauses, but Postgres and +Cockroach don't. + +.. code-block:: python + + >>> await Band.insert( + ... Band(name="Pythonistas", popularity=1200) + ... ).on_conflict( + ... action="DO UPDATE", + ... ... + ... ).on_conflict( + ... action="DO NOTHING", + ... ... + ... ) + +Learn more +---------- + +* `Postgres docs `_ +* `Cockroach docs `_ +* `SQLite docs `_ + +Source +------ + +.. currentmodule:: piccolo.query.methods.insert + +.. automethod:: Insert.on_conflict + +.. autoclass:: OnConflictAction + :members: + :undoc-members: diff --git a/docs/src/piccolo/query_clauses/on_conflict/bands.csv b/docs/src/piccolo/query_clauses/on_conflict/bands.csv new file mode 100644 index 000000000..d796928a1 --- /dev/null +++ b/docs/src/piccolo/query_clauses/on_conflict/bands.csv @@ -0,0 +1,2 @@ +id,name,popularity +1,Pythonistas,1000 diff --git a/docs/src/piccolo/query_types/insert.rst b/docs/src/piccolo/query_types/insert.rst index eda460de1..f8d1de007 100644 --- a/docs/src/piccolo/query_types/insert.rst +++ b/docs/src/piccolo/query_types/insert.rst @@ -3,33 +3,47 @@ Insert ====== -This is used to insert rows into the table. - -.. code-block:: python - - >>> await Band.insert(Band(name="Pythonistas")) - [{'id': 3}] - -We can insert multiple rows in one go: +This is used to bulk insert rows into the table: .. code-block:: python await Band.insert( + Band(name="Pythonistas") Band(name="Darts"), Band(name="Gophers") ) ------------------------------------------------------------------------------- -add ---- +``add`` +------- -You can also compose it as follows: +If we later decide to insert additional rows, we can use the ``add`` method: .. code-block:: python - await Band.insert().add( - Band(name="Darts") - ).add( - Band(name="Gophers") - ) + query = Band.insert(Band(name="Pythonistas")) + + if other_bands: + query = query.add( + Band(name="Darts"), + Band(name="Gophers") + ) + + await query + +------------------------------------------------------------------------------- + +Query clauses +------------- + +on_conflict +~~~~~~~~~~~ + +See :ref:`On_Conflict`. + + +returning +~~~~~~~~~ + +See :ref:`Returning`. diff --git a/piccolo/apps/asgi/commands/new.py b/piccolo/apps/asgi/commands/new.py index aedabdf93..f4b0e16e3 100644 --- a/piccolo/apps/asgi/commands/new.py +++ b/piccolo/apps/asgi/commands/new.py @@ -12,7 +12,7 @@ SERVERS = ["uvicorn", "Hypercorn"] ROUTERS = ["starlette", "fastapi", "blacksheep", "litestar"] ROUTER_DEPENDENCIES = { - "litestar": ["litestar>=2.0.0a3"], + "litestar": ["litestar==2.0.0a3"], } diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index 9f31f445a..283bdde0b 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -2,9 +2,16 @@ import typing as t -from piccolo.custom_types import TableInstance +from typing_extensions import Literal + +from piccolo.custom_types import Combinable, TableInstance from piccolo.query.base import Query -from piccolo.query.mixins import AddDelegate, ReturningDelegate +from piccolo.query.mixins import ( + AddDelegate, + OnConflictAction, + OnConflictDelegate, + ReturningDelegate, +) from piccolo.querystring import QueryString if t.TYPE_CHECKING: # pragma: no cover @@ -15,7 +22,7 @@ class Insert( t.Generic[TableInstance], Query[TableInstance, t.List[t.Dict[str, t.Any]]] ): - __slots__ = ("add_delegate", "returning_delegate") + __slots__ = ("add_delegate", "on_conflict_delegate", "returning_delegate") def __init__( self, table: t.Type[TableInstance], *instances: TableInstance, **kwargs @@ -23,6 +30,7 @@ def __init__( super().__init__(table, **kwargs) self.add_delegate = AddDelegate() self.returning_delegate = ReturningDelegate() + self.on_conflict_delegate = OnConflictDelegate() self.add(*instances) ########################################################################### @@ -36,6 +44,43 @@ def returning(self: Self, *columns: Column) -> Self: self.returning_delegate.returning(columns) return self + def on_conflict( + self: Self, + target: t.Optional[t.Union[str, Column, t.Tuple[Column, ...]]] = None, + action: t.Union[ + OnConflictAction, Literal["DO NOTHING", "DO UPDATE"] + ] = OnConflictAction.do_nothing, + values: t.Optional[ + t.Sequence[t.Union[Column, t.Tuple[Column, t.Any]]] + ] = None, + where: t.Optional[Combinable] = None, + ) -> Self: + if ( + self.engine_type == "sqlite" + and self.table._meta.db.get_version_sync() < 3.24 + ): + raise NotImplementedError( + "SQLite versions lower than 3.24 don't support ON CONFLICT" + ) + + if ( + self.engine_type in ("postgres", "cockroach") + and len(self.on_conflict_delegate._on_conflict.on_conflict_items) + == 1 + ): + raise NotImplementedError( + "Postgres and Cockroach only support a single ON CONFLICT " + "clause." + ) + + self.on_conflict_delegate.on_conflict( + target=target, + action=action, + values=values, + where=where, + ) + return self + ########################################################################### def _raw_response_callback(self, results): @@ -70,16 +115,27 @@ def default_querystrings(self) -> t.Sequence[QueryString]: engine_type = self.engine_type + on_conflict = self.on_conflict_delegate._on_conflict + if on_conflict.on_conflict_items: + querystring = QueryString( + "{}{}", + querystring, + on_conflict.querystring, + query_type="insert", + table=self.table, + ) + if engine_type in ("postgres", "cockroach") or ( engine_type == "sqlite" and self.table._meta.db.get_version_sync() >= 3.35 ): - if self.returning_delegate._returning: + returning = self.returning_delegate._returning + if returning: return [ QueryString( "{}{}", querystring, - self.returning_delegate._returning.querystring, + returning.querystring, query_type="insert", table=self.table, ) diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index db9e43ccc..6892fc95c 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -230,6 +230,8 @@ def callback( return self def as_of(self, interval: str = "-1s") -> Objects: + if self.engine_type != "cockroach": + raise NotImplementedError("Only CockroachDB supports AS OF") self.as_of_delegate.as_of(interval) return self diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index f94e173dc..4e1949f55 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -356,13 +356,8 @@ def columns(self: Self, *columns: t.Union[Selectable, str]) -> Self: def distinct( self: Self, *, on: t.Optional[t.Sequence[Column]] = None ) -> Self: - if on is not None and self.engine_type not in ( - "postgres", - "cockroach", - ): - raise ValueError( - "Only Postgres and Cockroach supports DISTINCT ON" - ) + if on is not None and self.engine_type == "sqlite": + raise NotImplementedError("SQLite doesn't support DISTINCT ON") self.distinct_delegate.distinct(enabled=True, on=on) return self @@ -377,6 +372,9 @@ def group_by(self: Self, *columns: t.Union[Column, str]) -> Self: return self def as_of(self: Self, interval: str = "-1s") -> Self: + if self.engine_type != "cockroach": + raise NotImplementedError("Only CockroachDB supports AS OF") + self.as_of_delegate.as_of(interval) return self diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 109b48629..43d126436 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -7,6 +7,8 @@ from dataclasses import dataclass, field from enum import Enum, auto +from typing_extensions import Literal + from piccolo.columns import And, Column, Or, Where from piccolo.columns.column_types import ForeignKey from piccolo.custom_types import Combinable @@ -581,9 +583,10 @@ class OffsetDelegate: Typically used in conjunction with order_by and limit. - Example usage: + Example usage:: + + .offset(100) - .offset(100) """ _offset: t.Optional[Offset] = None @@ -613,12 +616,173 @@ def __str__(self): @dataclass class GroupByDelegate: """ - Used to group results - needed when doing aggregation. + Used to group results - needed when doing aggregation:: + + .group_by(Band.name) - .group_by(Band.name) """ _group_by: t.Optional[GroupBy] = None def group_by(self, *columns: Column): self._group_by = GroupBy(columns=columns) + + +class OnConflictAction(str, Enum): + """ + Specify which action to take on conflict. + """ + + do_nothing = "DO NOTHING" + do_update = "DO UPDATE" + + +@dataclass +class OnConflictItem: + target: t.Optional[t.Union[str, Column, t.Tuple[Column, ...]]] = None + action: t.Optional[OnConflictAction] = None + values: t.Optional[ + t.Sequence[t.Union[Column, t.Tuple[Column, t.Any]]] + ] = None + where: t.Optional[Combinable] = None + + @property + def target_string(self) -> str: + target = self.target + assert target + + def to_string(value) -> str: + if isinstance(value, Column): + return f'"{value._meta.db_column_name}"' + else: + raise ValueError("OnConflict.target isn't a valid type") + + if isinstance(target, str): + return f'ON CONSTRAINT "{target}"' + elif isinstance(target, Column): + return f"({to_string(target)})" + elif isinstance(target, tuple): + columns_str = ", ".join([to_string(i) for i in target]) + return f"({columns_str})" + else: + raise ValueError("OnConflict.target isn't a valid type") + + @property + def action_string(self) -> QueryString: + action = self.action + if isinstance(action, OnConflictAction): + if action == OnConflictAction.do_nothing: + return QueryString(OnConflictAction.do_nothing.value) + elif action == OnConflictAction.do_update: + values = [] + query = f"{OnConflictAction.do_update.value} SET" + + if not self.values: + raise ValueError("No values specified for `on conflict`") + + for value in self.values: + if isinstance(value, Column): + column_name = value._meta.db_column_name + query += f' "{column_name}"=EXCLUDED."{column_name}",' + elif isinstance(value, tuple): + column = value[0] + value_ = value[1] + if isinstance(column, Column): + column_name = column._meta.db_column_name + else: + raise ValueError("Unsupported column type") + + query += f' "{column_name}"={{}},' + values.append(value_) + + return QueryString(query.rstrip(","), *values) + + raise ValueError("OnConflict.action isn't a valid type") + + @property + def querystring(self) -> QueryString: + query = " ON CONFLICT" + values = [] + + if self.target: + query += f" {self.target_string}" + + if self.action: + query += " {}" + values.append(self.action_string) + + if self.where: + query += " WHERE {}" + values.append(self.where.querystring) + + return QueryString(query, *values) + + def __str__(self) -> str: + return self.querystring.__str__() + + +@dataclass +class OnConflict: + """ + Multiple `ON CONFLICT` statements are allowed - which is why we have this + parent class. + """ + + on_conflict_items: t.List[OnConflictItem] = field(default_factory=list) + + @property + def querystring(self) -> QueryString: + query = "".join("{}" for i in self.on_conflict_items) + return QueryString( + query, *[i.querystring for i in self.on_conflict_items] + ) + + def __str__(self) -> str: + return self.querystring.__str__() + + +@dataclass +class OnConflictDelegate: + """ + Used with insert queries to specify what to do when a query fails due to + a constraint:: + + .on_conflict(action='DO NOTHING') + + .on_conflict(action='DO UPDATE', values=[Band.popularity]) + + .on_conflict(action='DO UPDATE', values=[(Band.popularity, 1)]) + + """ + + _on_conflict: OnConflict = field(default_factory=OnConflict) + + def on_conflict( + self, + target: t.Optional[t.Union[str, Column, t.Tuple[Column, ...]]] = None, + action: t.Union[ + OnConflictAction, Literal["DO NOTHING", "DO UPDATE"] + ] = OnConflictAction.do_nothing, + values: t.Optional[ + t.Sequence[t.Union[Column, t.Tuple[Column, t.Any]]] + ] = None, + where: t.Optional[Combinable] = None, + ): + action_: OnConflictAction + if isinstance(action, OnConflictAction): + action_ = action + elif isinstance(action, str): + action_ = OnConflictAction(action.upper()) + else: + raise ValueError("Unrecognised `on conflict` action.") + + if where and action_ == OnConflictAction.do_nothing: + raise ValueError( + "The `where` option can only be used with DO NOTHING." + ) + + self._on_conflict.on_conflict_items.append( + OnConflictItem( + target=target, action=action_, values=values, where=where + ) + ) diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index 474497f25..1c5fab732 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -1,8 +1,22 @@ +import sqlite3 +from unittest import TestCase + import pytest -from tests.base import DBTestCase, engine_version_lt, is_running_sqlite +from piccolo.columns import Integer, Varchar +from piccolo.query.methods.insert import OnConflictAction +from piccolo.table import Table +from piccolo.utils.lazy_loader import LazyLoader +from tests.base import ( + DBTestCase, + engine_version_lt, + engines_only, + is_running_sqlite, +) from tests.example_apps.music.tables import Band, Manager +asyncpg = LazyLoader("asyncpg", globals(), "asyncpg") + class TestInsert(DBTestCase): def test_insert(self): @@ -76,3 +90,385 @@ def test_insert_returning_alias(self): ) self.assertListEqual(response, [{"manager_name": "Maz"}]) + + +@pytest.mark.skipif( + is_running_sqlite() and engine_version_lt(3.24), + reason="SQLite version not supported", +) +class TestOnConflict(TestCase): + class Band(Table): + name = Varchar(unique=True) + popularity = Integer() + + def setUp(self) -> None: + Band = self.Band + Band.create_table().run_sync() + self.band = Band({Band.name: "Pythonistas", Band.popularity: 1000}) + self.band.save().run_sync() + + def tearDown(self) -> None: + Band = self.Band + Band.alter().drop_table().run_sync() + + def test_do_update(self): + """ + Make sure that `DO UPDATE` works. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + target=Band.name, + action="DO UPDATE", + values=[Band.popularity], + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": new_popularity, # changed + } + ], + ) + + def test_do_update_tuple_values(self): + """ + Make sure we can use tuples in ``values``. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + new_name = "Rustaceans" + + Band.insert( + Band( + id=self.band.id, + name=new_name, + popularity=new_popularity, + ) + ).on_conflict( + action="DO UPDATE", + target=Band.id, + values=[ + (Band.name, new_name), + (Band.popularity, new_popularity + 2000), + ], + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": new_name, + "popularity": new_popularity + 2000, + } + ], + ) + + def test_do_update_no_values(self): + """ + Make sure that `DO UPDATE` with no `values` raises an exception. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + with self.assertRaises(ValueError) as manager: + Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + target=Band.name, + action="DO UPDATE", + ).run_sync() + + self.assertEqual( + manager.exception.__str__(), + "No values specified for `on conflict`", + ) + + @engines_only("postgres", "cockroach") + def test_target_tuple(self): + """ + Make sure that a composite unique constraint can be used as a target. + + We only run it on Postgres and Cockroach because we use ALTER TABLE + to add a contraint, which SQLite doesn't support. + """ + Band = self.Band + + # Add a composite unique constraint: + Band.raw( + "ALTER TABLE band ADD CONSTRAINT id_name_unique UNIQUE (id, name)" + ).run_sync() + + Band.insert( + Band( + id=self.band.id, + name=self.band.name, + popularity=self.band.popularity, + ) + ).on_conflict( + target=(Band.id, Band.name), + action="DO NOTHING", + ).run_sync() + + @engines_only("postgres", "cockroach") + def test_target_string(self): + """ + Make sure we can explicitly specify the name of target constraint using + a string. + + We just test this on Postgres for now, as we have to get the constraint + name from the database. + """ + Band = self.Band + + constraint_name = Band.raw( + """ + SELECT constraint_name + FROM information_schema.constraint_column_usage + WHERE column_name = 'name' + AND table_name = 'band'; + """ + ).run_sync()[0]["constraint_name"] + + query = Band.insert(Band(name=self.band.name)).on_conflict( + target=constraint_name, + action="DO NOTHING", + ) + self.assertIn(f'ON CONSTRAINT "{constraint_name}"', query.__str__()) + query.run_sync() + + def test_violate_non_target(self): + """ + Make sure that if we specify a target constraint, but violate a + different constraint, then we still get the error. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + with self.assertRaises(Exception) as manager: + Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + target=Band.id, # Target the primary key instead. + action="DO UPDATE", + values=[Band.popularity], + ).run_sync() + + if self.Band._meta.db.engine_type in ("postgres", "cockroach"): + self.assertIsInstance( + manager.exception, asyncpg.exceptions.UniqueViolationError + ) + elif self.Band._meta.db.engine_type == "sqlite": + self.assertIsInstance(manager.exception, sqlite3.IntegrityError) + + def test_where(self): + """ + Make sure we can pass in a `where` argument. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + query = Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + target=Band.name, + action="DO UPDATE", + values=[Band.popularity], + where=Band.popularity < self.band.popularity, + ) + + self.assertIn( + f'WHERE "band"."popularity" < {self.band.popularity}', + query.__str__(), + ) + + query.run_sync() + + def test_do_nothing_where(self): + """ + Make sure an error is raised if `where` is used with `DO NOTHING`. + """ + Band = self.Band + + with self.assertRaises(ValueError) as manager: + Band.insert(Band()).on_conflict( + action="DO NOTHING", + where=Band.popularity < self.band.popularity, + ) + + self.assertEqual( + manager.exception.__str__(), + "The `where` option can only be used with DO NOTHING.", + ) + + def test_do_nothing(self): + """ + Make sure that `DO NOTHING` works. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + Band.insert( + Band(name="Pythonistas", popularity=new_popularity) + ).on_conflict(action="DO NOTHING").run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": self.band.popularity, + } + ], + ) + + @engines_only("sqlite") + def test_multiple_do_update(self): + """ + Make sure multiple `ON CONFLICT` clauses work for SQLite. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + # Conflicting with name - should update. + Band.insert( + Band(name="Pythonistas", popularity=new_popularity) + ).on_conflict(action="DO NOTHING", target=Band.id).on_conflict( + action="DO UPDATE", target=Band.name, values=[Band.popularity] + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": new_popularity, # changed + } + ], + ) + + @engines_only("sqlite") + def test_multiple_do_nothing(self): + """ + Make sure multiple `ON CONFLICT` clauses work for SQLite. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + # Conflicting with ID - should be ignored. + Band.insert( + Band( + id=self.band.id, + name="Pythonistas", + popularity=new_popularity, + ) + ).on_conflict(action="DO NOTHING", target=Band.id).on_conflict( + action="DO UPDATE", + target=Band.name, + values=[Band.popularity], + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": self.band.popularity, + } + ], + ) + + @engines_only("postgres", "cockroach") + def test_mutiple_error(self): + """ + Postgres and Cockroach don't support multiple `ON CONFLICT` clauses. + """ + with self.assertRaises(NotImplementedError) as manager: + Band = self.Band + + Band.insert(Band()).on_conflict(action="DO NOTHING").on_conflict( + action="DO UPDATE", + ).run_sync() + + assert manager.exception.__str__() == ( + "Postgres and Cockroach only support a single ON CONFLICT clause." + ) + + def test_all_columns(self): + """ + We can use ``all_columns`` instead of specifying the ``values`` + manually. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + new_name = "Rustaceans" + + # Conflicting with ID - should be ignored. + q = Band.insert( + Band( + id=self.band.id, + name=new_name, + popularity=new_popularity, + ) + ).on_conflict( + action="DO UPDATE", + target=Band.id, + values=Band.all_columns(), + ) + q.run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": new_name, + "popularity": new_popularity, + } + ], + ) + + def test_enum(self): + """ + A string literal can be passed in, or an enum, to determine the action. + Make sure that the enum works. + """ + Band = self.Band + + Band.insert( + Band( + id=self.band.id, + name=self.band.name, + popularity=self.band.popularity, + ) + ).on_conflict(action=OnConflictAction.do_nothing).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": self.band.popularity, + } + ], + ) diff --git a/tests/table/test_select.py b/tests/table/test_select.py index 2ce429c2d..892972aec 100644 --- a/tests/table/test_select.py +++ b/tests/table/test_select.py @@ -1364,12 +1364,12 @@ def test_distinct_on_sqlite(self): SQLite doesn't support ``DISTINCT ON``, so a ``ValueError`` should be raised. """ - with self.assertRaises(ValueError) as manager: + with self.assertRaises(NotImplementedError) as manager: Album.select().distinct(on=[Album.band]) self.assertEqual( manager.exception.__str__(), - "Only Postgres and Cockroach supports DISTINCT ON", + "SQLite doesn't support DISTINCT ON", ) @engines_only("postgres", "cockroach")