diff --git a/morango/models/core.py b/morango/models/core.py index 73f175f..67e1b54 100644 --- a/morango/models/core.py +++ b/morango/models/core.py @@ -773,6 +773,10 @@ class RecordMaxCounter(AbstractCounter): store_model = models.ForeignKey(Store, on_delete=models.CASCADE) + @property + def unique_key(self): + return f"{self.instance_id}:{self.store_model_id}" + class Meta: unique_together = ("store_model", "instance_id") diff --git a/morango/registry.py b/morango/registry.py index 5614c81..0196d27 100644 --- a/morango/registry.py +++ b/morango/registry.py @@ -5,7 +5,9 @@ import inspect import sys from collections import OrderedDict +from typing import Generator +from django.db.models import QuerySet from django.db.models.fields.related import ForeignKey from morango.constants import transfer_stages @@ -82,6 +84,14 @@ def get_models(self, profile): self.check_models_ready(profile) return list(self.profile_models.get(profile, {}).values()) + def get_model_querysets(self, profile) -> Generator[QuerySet, None, None]: + """ + Method for future enhancement to iterate over model's and their querysets in a fashion + (particularly, an order) that is aware of FK dependencies. + """ + for model in self.get_models(profile): + yield model.syncing_objects.all() + def _insert_model_in_dependency_order(self, model, profile): # When we add models to be synced, we need to make sure # that models that depend on other models are synced AFTER diff --git a/morango/sync/controller.py b/morango/sync/controller.py index 92e38e6..fa64ede 100644 --- a/morango/sync/controller.py +++ b/morango/sync/controller.py @@ -6,47 +6,35 @@ from morango.constants import transfer_statuses from morango.registry import session_middleware from morango.sync.operations import _deserialize_from_store -from morango.sync.operations import _serialize_into_store from morango.sync.operations import OperationLogger +from morango.sync.stream.serialize import serialize_into_store from morango.sync.utils import SyncSignalGroup from morango.utils import _assert - logger = logging.getLogger(__name__) -def _self_referential_fk(klass_model): - """ - Return whether this model has a self ref FK, and the name for the field - """ - for f in klass_model._meta.concrete_fields: - if f.related_model: - if issubclass(klass_model, f.related_model): - return f.attname - return None - - class MorangoProfileController(object): def __init__(self, profile): _assert(profile, "profile needs to be defined.") self.profile = profile - def serialize_into_store(self, filter=None): + def serialize_into_store(self, sync_filter=None): """ Takes data from app layer and serializes the models into the store. """ with OperationLogger("Serializing records", "Serialization complete"): - _serialize_into_store(self.profile, filter=filter) + serialize_into_store(self.profile, sync_filter=sync_filter) - def deserialize_from_store(self, skip_erroring=False, filter=None): + def deserialize_from_store(self, skip_erroring=False, sync_filter=None): """ Takes data from the store and integrates into the application. """ with OperationLogger("Deserializing records", "Deserialization complete"): # we first serialize to avoid deserialization merge conflicts - _serialize_into_store(self.profile, filter=filter) + serialize_into_store(self.profile, sync_filter=sync_filter) _deserialize_from_store( - self.profile, filter=filter, skip_erroring=skip_erroring + self.profile, filter=sync_filter, skip_erroring=skip_erroring ) def create_network_connection(self, base_url, **kwargs): @@ -217,7 +205,7 @@ def proceed_to_and_wait_for( if tries >= max_interval_tries: sleep(max_interval) else: - sleep(0.3 * (2 ** tries - 1)) + sleep(0.3 * (2**tries - 1)) result = self.proceed_to(target_stage, context=context) tries += 1 if callable(callback): diff --git a/morango/sync/db.py b/morango/sync/db.py new file mode 100644 index 0000000..5bf56ed --- /dev/null +++ b/morango/sync/db.py @@ -0,0 +1,39 @@ +import logging +from contextlib import contextmanager + +from django.db import connection +from django.db import transaction + +from morango.sync.backends.utils import load_backend +from morango.sync.utils import lock_partitions + + +logger = logging.getLogger(__name__) + +DBBackend = load_backend(connection) + + +@contextmanager +def begin_transaction(sync_filter, isolated=False, shared_lock=False): + """ + Starts a transaction, sets the transaction isolation level to repeatable read, and locks + affected partitions + + :param sync_filter: The filter for filtering applicable records of the sync + :type sync_filter: morango.models.certificates.Filter|None + :param isolated: Whether to alter the transaction isolation to repeatable-read + :type isolated: bool + :param shared_lock: Whether the advisory lock should be exclusive or shared + :type shared_lock: bool + """ + if isolated: + # when isolation is requested, we modify the transaction isolation of the connection for the + # duration of the transaction + with DBBackend._set_transaction_repeatable_read(): + with transaction.atomic(savepoint=False): + lock_partitions(DBBackend, sync_filter=sync_filter, shared=shared_lock) + yield + else: + with transaction.atomic(): + lock_partitions(DBBackend, sync_filter=sync_filter, shared=shared_lock) + yield diff --git a/morango/sync/operations.py b/morango/sync/operations.py index a5fa0c7..7707ecc 100644 --- a/morango/sync/operations.py +++ b/morango/sync/operations.py @@ -4,12 +4,9 @@ import logging import uuid from collections import defaultdict -from contextlib import contextmanager from django.core import exceptions -from django.core.serializers.json import DjangoJSONEncoder from django.db import connection -from django.db import transaction from django.db.models import CharField from django.db.models import Q from django.db.models import signals @@ -22,7 +19,6 @@ from morango.constants import transfer_statuses from morango.constants.capabilities import ASYNC_OPERATIONS from morango.constants.capabilities import FSIC_V2_FORMAT -from morango.errors import MorangoDatabaseError from morango.errors import MorangoInvalidFSICPartition from morango.errors import MorangoLimitExceeded from morango.errors import MorangoResumeSyncError @@ -47,13 +43,14 @@ from morango.sync.backends.utils import TemporaryTable from morango.sync.context import LocalSessionContext from morango.sync.context import NetworkSessionContext -from morango.sync.utils import lock_partitions +from morango.sync.db import begin_transaction +from morango.sync.stream.serialize import serialize_into_store from morango.sync.utils import mute_signals from morango.sync.utils import validate_and_create_buffer_data from morango.utils import _assert +from morango.utils import self_referential_fk from morango.utils import SETTINGS - logger = logging.getLogger(__name__) DBBackend = load_backend(connection) @@ -81,206 +78,6 @@ def _join_with_logical_operator(lst, operator): return "(({items}))".format(items=op.join(lst)) -def _self_referential_fk(model): - """ - Return whether this model has a self ref FK, and the name for the field - """ - for f in model._meta.concrete_fields: - if f.related_model: - if issubclass(model, f.related_model): - return f.attname - return None - - -@contextmanager -def _begin_transaction(sync_filter, isolated=False, shared_lock=False): - """ - Starts a transaction, sets the transaction isolation level to repeatable read, and locks - affected partitions - - :param sync_filter: The filter for filtering applicable records of the sync - :type sync_filter: morango.models.certificates.Filter|None - :param isolated: Whether to alter the transaction isolation to repeatable-read - :type isolated: bool - :param shared_lock: Whether the advisory lock should be exclusive or shared - :type shared_lock: bool - """ - if isolated: - # when isolation is requested, we modify the transaction isolation of the connection for the - # duration of the transaction - with DBBackend._set_transaction_repeatable_read(): - with transaction.atomic(savepoint=False): - lock_partitions(DBBackend, sync_filter=sync_filter, shared=shared_lock) - yield - else: - with transaction.atomic(): - lock_partitions(DBBackend, sync_filter=sync_filter, shared=shared_lock) - yield - - -def _serialize_into_store(profile, filter=None): - """ - Takes data from app layer and serializes the models into the store. - - ALGORITHM: On a per syncable model basis, we iterate through each class model and we go through 2 possible cases: - - 1. If there is a store record pertaining to that app model, we update the serialized store record with - the latest changes from the model's fields. We also update the counter's based on this device's current Instance ID. - 2. If there is no store record for this app model, we proceed to create an in memory store model and append to a list to be - bulk created on a per class model basis. - """ - # ensure that we write and retrieve the counter in one go for consistency - current_id = InstanceIDModel.get_current_instance_and_increment_counter() - - with _begin_transaction(filter, isolated=True): - # create Q objects for filtering by prefixes - prefix_condition = None - if filter: - prefix_condition = functools.reduce( - lambda x, y: x | y, - [Q(_morango_partition__startswith=prefix) for prefix in filter], - ) - - # filter through all models with the dirty bit turned on - for model in syncable_models.get_models(profile): - new_store_records = [] - new_rmc_records = [] - klass_queryset = model.syncing_objects.filter(_morango_dirty_bit=True) - if prefix_condition: - klass_queryset = klass_queryset.filter(prefix_condition) - store_records_dict = Store.objects.in_bulk( - id_list=klass_queryset.values_list("id", flat=True) - ) - for app_model in klass_queryset: - try: - store_model = store_records_dict[app_model.id] - - # if store record dirty and app record dirty, append store serialized to conflicting data - if store_model.dirty_bit: - store_model.conflicting_serialized_data = ( - store_model.serialized - + "\n" - + store_model.conflicting_serialized_data - ) - store_model.dirty_bit = False - - # set new serialized data on this store model - ser_dict = json.loads(store_model.serialized) - ser_dict.update(app_model.serialize()) - store_model.serialized = DjangoJSONEncoder().encode(ser_dict) - - # create or update instance and counter on the record max counter for this store model - RecordMaxCounter.objects.update_or_create( - defaults={"counter": current_id.counter}, - instance_id=current_id.id, - store_model_id=store_model.id, - ) - - # update last saved bys for this store model - store_model.last_saved_instance = current_id.id - store_model.last_saved_counter = current_id.counter - # update deleted flags in case it was previously deleted - store_model.deleted = False - store_model.hard_deleted = False - # clear last_transfer_session_id - store_model.last_transfer_session_id = None - - # update this model - store_model.save() - - except KeyError: - kwargs = { - "id": app_model.id, - "serialized": DjangoJSONEncoder().encode(app_model.serialize()), - "last_saved_instance": current_id.id, - "last_saved_counter": current_id.counter, - "model_name": app_model.morango_model_name, - "profile": app_model.morango_profile, - "partition": app_model._morango_partition, - "source_id": app_model._morango_source_id, - } - # check if model has FK pointing to it and add the value to a field on the store - self_ref_fk = _self_referential_fk(model) - if self_ref_fk: - self_ref_fk_value = getattr(app_model, self_ref_fk) - kwargs.update({"_self_ref_fk": self_ref_fk_value or ""}) - # create store model and record max counter for the app model - new_store_records.append(Store(**kwargs)) - new_rmc_records.append( - RecordMaxCounter( - store_model_id=app_model.id, - instance_id=current_id.id, - counter=current_id.counter, - ) - ) - - # bulk create store and rmc records for this class - Store.objects.bulk_create(new_store_records) - RecordMaxCounter.objects.bulk_create(new_rmc_records) - - # set dirty bit to false for all instances of this model - klass_queryset.update(update_dirty_bit_to=False) - - # get list of ids of deleted models - deleted_ids = DeletedModels.objects.filter(profile=profile).values_list( - "id", flat=True - ) - # update last_saved_bys and deleted flag of all deleted store model instances - deleted_store_records = Store.objects.filter(id__in=deleted_ids) - deleted_store_records.update( - dirty_bit=False, - deleted=True, - last_saved_instance=current_id.id, - last_saved_counter=current_id.counter, - ) - # update rmcs counters for deleted models that have our instance id - RecordMaxCounter.objects.filter( - instance_id=current_id.id, store_model_id__in=deleted_ids - ).update(counter=current_id.counter) - # get a list of deleted model ids that don't have an rmc for our instance id - new_rmc_ids = deleted_store_records.exclude( - recordmaxcounter__instance_id=current_id.id - ).values_list("id", flat=True) - # bulk create these new rmcs - RecordMaxCounter.objects.bulk_create( - [ - RecordMaxCounter( - store_model_id=r_id, - instance_id=current_id.id, - counter=current_id.counter, - ) - for r_id in new_rmc_ids - ] - ) - # clear deleted models table for this profile - DeletedModels.objects.filter(profile=profile).delete() - - # handle logic for hard deletion models - hard_deleted_ids = HardDeletedModels.objects.filter( - profile=profile - ).values_list("id", flat=True) - hard_deleted_store_records = Store.objects.filter(id__in=hard_deleted_ids) - hard_deleted_store_records.update( - hard_deleted=True, serialized="{}", conflicting_serialized_data="" - ) - HardDeletedModels.objects.filter(profile=profile).delete() - - # update our own database max counters after serialization - if not filter: - DatabaseMaxCounter.objects.update_or_create( - instance_id=current_id.id, - partition="", - defaults={"counter": current_id.counter}, - ) - else: - for f in filter: - DatabaseMaxCounter.objects.update_or_create( - instance_id=current_id.id, - partition=f, - defaults={"counter": current_id.counter}, - ) - - def _validate_missing_store_foreign_keys(from_model_name, to_model_name, temp_table): """ Performs validation on a bulk set of foreign keys (FKs), given a temp table with two columns, @@ -447,7 +244,7 @@ def _deserialize_from_store(profile, skip_erroring=False, filter=None): excluded_list = [] deleted_list = [] - with _begin_transaction(filter, isolated=True): + with begin_transaction(filter, isolated=True): # iterate through classes which are in foreign key dependency order for model in syncable_models.get_models(profile): deferred_fks = defaultdict(list) @@ -473,7 +270,7 @@ def _deserialize_from_store(profile, skip_erroring=False, filter=None): store_models = store_models.filter(deserialization_error="") # handle cases where a class has a single FK reference to itself - if _self_referential_fk(model): + if self_referential_fk(model): clean_parents = store_models.filter(dirty_bit=False).char_ids_list() dirty_children = ( store_models.filter(dirty_bit=True) @@ -539,7 +336,9 @@ def _deserialize_from_store(profile, skip_erroring=False, filter=None): app_model, model_deferred_fks, ) = store_model._deserialize_store_model( - fk_cache, defer_fks=True, sync_filter=filter, + fk_cache, + defer_fks=True, + sync_filter=filter, ) if app_model: app_models.append(app_model) @@ -618,7 +417,7 @@ def _queue_into_buffer_v1(transfersession): as well as the partition for the data we are syncing. """ filter_prefixes = Filter(transfersession.filter) - with _begin_transaction(filter_prefixes, shared_lock=True): + with begin_transaction(filter_prefixes, shared_lock=True): server_fsic = json.loads(transfersession.server_fsic) client_fsic = json.loads(transfersession.client_fsic) @@ -745,7 +544,7 @@ def _queue_into_buffer_v2(transfersession, chunk_size=200): We use raw sql queries to place data in the buffer and the record max counter buffer, which matches the conditions of the FSIC. """ sync_filter = Filter(transfersession.filter) - with _begin_transaction(sync_filter, shared_lock=True): + with begin_transaction(sync_filter, shared_lock=True): server_fsic = json.loads(transfersession.server_fsic) client_fsic = json.loads(transfersession.client_fsic) @@ -889,7 +688,7 @@ def _dequeue_into_store(transfer_session, fsic, v2_format=False): are not affected by previous cases. """ - with _begin_transaction(Filter(transfer_session.filter)): + with begin_transaction(Filter(transfer_session.filter)): with connection.cursor() as cursor: DBBackend._dequeuing_delete_rmcb_records(cursor, transfer_session.id) DBBackend._dequeuing_delete_buffered_records(cursor, transfer_session.id) @@ -1012,9 +811,11 @@ def handle(self, context): if context.request: data.update( id=context.request.data.get("id"), - records_total=context.request.data.get("records_total") - if context.is_push - else None, + records_total=( + context.request.data.get("records_total") + if context.is_push + else None + ), client_fsic=context.request.data.get("client_fsic") or "{}", ) elif context.is_server: @@ -1048,7 +849,9 @@ def handle(self, context): if context.is_producer and SETTINGS.MORANGO_SERIALIZE_BEFORE_QUEUING: try: - _serialize_into_store(context.sync_session.profile, filter=context.filter) + serialize_into_store( + context.sync_session.profile, sync_filter=context.filter + ) except OperationalError as e: # if we run into a transaction isolation error, we return a pending status to force # retrying through the controller flow @@ -1276,8 +1079,12 @@ def handle(self, context): if SETTINGS.MORANGO_DESERIALIZE_AFTER_DEQUEUING and records_transferred > 0: try: # we first serialize to avoid deserialization merge conflicts - _serialize_into_store(context.sync_session.profile, filter=context.filter) - _deserialize_from_store(context.sync_session.profile, filter=context.filter) + serialize_into_store( + context.sync_session.profile, sync_filter=context.filter + ) + _deserialize_from_store( + context.sync_session.profile, filter=context.filter + ) except OperationalError as e: # if we run into a transaction isolation error, we return a pending status to force # retrying through the controller flow diff --git a/morango/sync/stream/__init__.py b/morango/sync/stream/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/morango/sync/stream/core.py b/morango/sync/stream/core.py new file mode 100644 index 0000000..0fca03b --- /dev/null +++ b/morango/sync/stream/core.py @@ -0,0 +1,196 @@ +""" +Foundational classes for streaming ETL-like pipelines. + +Provides a modular source > transform > sink pattern where models are streamed one-by-one through +a pipeline of connected modules, reducing memory overhead. +""" +import abc +from typing import Any +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import TypeVar + +T = TypeVar("T") + + +class StreamModule(abc.ABC): + """ + Abstract base class for all stream modules + """ + + pass + + +class ReaderModule(abc.ABC): + """ + Abstract base class for all stream modules that can be read, + or rather can pipe to another module. + """ + + @abc.abstractmethod + def pipe(self, other: "StreamModule") -> "StreamModule": + """Connect this module to another module, returning a Pipeline.""" + pass + + +class PipelineModule(StreamModule): + """ + Abstract base class for all pipeline transform-like modules. + + Each module receives an iterable of items and yields transformed items. + Modules are composable: they can be chained together via the pipe method. + """ + + @abc.abstractmethod + def __call__(self, items: Iterable[Any]) -> Iterator: + """Process the incoming iterable and yield output items.""" + raise NotImplementedError + + +class Source(ReaderModule, Generic[T]): + """ + A module that represents a streaming pipeline source. + """ + + def begin(self) -> None: + """Called once before the stream begins.""" + pass + + @abc.abstractmethod + def stream(self) -> Iterator[T]: + """The primary read method that returns the iterator stream of items.""" + raise NotImplementedError + + def pipe(self, other: "PipelineModule") -> "Pipeline": + """ + Start a pipeline, by attaching this as the source of a pipeline, + and add the first transform module + """ + return Pipeline(self, [other]) + + +class Sink(StreamModule, Generic[T]): + """ + A terminal module that consumes items without yielding further output. + """ + + @abc.abstractmethod + def consume(self, item: T) -> None: + """Process the incoming item from the stream.""" + raise NotImplementedError + + def finalize(self) -> None: + """Called once after all items have been consumed.""" + pass + + +class Pipeline(ReaderModule): + """ + An ordered sequence of `PipelineModule` instances that are executed in series. The output of + each module feeds into the next. + """ + + def __init__( + self, source: Source, modules: Optional[List[PipelineModule]] = None + ) -> None: + self._source = source + self._modules = list(modules) if modules else [] + + def pipe(self, other: PipelineModule) -> "Pipeline": + """Append another pipeline module to this pipeline and return self""" + if isinstance(other, PipelineModule): + self._modules.append(other) + else: + raise ValueError("Cannot pipe another module that is not a PipelineModule") + return self + + def end(self, sink: "Sink") -> int: + """Run the source through each module in order.""" + self._source.begin() + stream = self._source.stream() + + for module in self._modules: + stream = module(stream) + + count = 0 + for item in stream: + sink.consume(item) + count += 1 + sink.finalize() + return count + + +class Transform(PipelineModule, Generic[T]): + """ + A module that transforms each incoming item one-by-one. + """ + + def __call__(self, items: Iterable[T]) -> Iterator[T]: + for item in items: + result = self.transform(item) + if result is not None: + yield result + + @abc.abstractmethod + def transform(self, item: T) -> T: + """Logic for transforming an item""" + raise NotImplementedError + + +class FlatMap(PipelineModule, Generic[T]): + """ + A module that maps each incoming item to zero or more output items, flattening the result into a + single stream. + """ + + def __call__(self, items: Iterable[T]) -> Iterator[T]: + for item in items: + for result in self.flat_map(item): + yield result + + @abc.abstractmethod + def flat_map(self, item: T) -> Iterable[T]: + """Transform a single item, into multiple stream items""" + raise NotImplementedError + + +class Buffer(PipelineModule, Generic[T]): + """ + Collects incoming items into fixed-size chunks (lists). + + Inserting a buffer into the pipeline converts a stream of individual items into a stream of + lists of those items, which is useful for batching database operations such as `bulk_create`. + """ + + def __init__(self, size: int) -> None: + """ + :param size: Maximum number of items per chunk. + """ + if size < 1: + raise ValueError("Buffer size must be >= 1") + self.size = size + + def __call__(self, items: Iterable[T]) -> Iterator[List[T]]: + chunk = [] + for item in items: + chunk.append(item) + if len(chunk) >= self.size: + yield chunk + chunk = [] + if chunk: + yield chunk + + +class Unbuffer(PipelineModule, Generic[T]): + """ + Flattens a stream of iterables (e.g. like chunks from `Buffer`) back into a stream of + individual items. + """ + + def __call__(self, items: Iterable[Iterable[T]]) -> Iterator[T]: + for chunk in items: + for item in chunk: + yield item diff --git a/morango/sync/stream/serialize.py b/morango/sync/stream/serialize.py new file mode 100644 index 0000000..6544c30 --- /dev/null +++ b/morango/sync/stream/serialize.py @@ -0,0 +1,408 @@ +import json +import logging +from typing import Generator +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Type + +from django.core.serializers.json import DjangoJSONEncoder +from django.db.models import Q +from typing_extensions import Literal + +from morango.models.certificates import Filter +from morango.models.core import DatabaseMaxCounter +from morango.models.core import DeletedModels +from morango.models.core import HardDeletedModels +from morango.models.core import InstanceIDModel +from morango.models.core import RecordMaxCounter +from morango.models.core import Store +from morango.models.core import SyncableModel +from morango.registry import syncable_models +from morango.sync.stream.core import Buffer +from morango.sync.stream.core import Sink +from morango.sync.stream.core import Source +from morango.sync.stream.core import Transform +from morango.sync.stream.core import Unbuffer +from morango.utils import self_referential_fk + +logger = logging.getLogger(__name__) + + +class SerializeTask(object): + """Carrier class for providing context through the pipeline""" + + __slots__ = ("model", "obj", "store", "counter") + + def __init__(self, model: Type[SyncableModel], obj: SyncableModel): + self.model = model + self.obj = obj + self.store: Optional[Store] = None + self.counter: Optional[RecordMaxCounter] = None + + @property + def is_store_update(self): + return self.store is not None and not self.store._state.adding + + @property + def is_counter_update(self): + return self.counter is not None and not self.counter._state.adding + + def set_store(self, store_obj: Store): + self.store = store_obj + + def set_counter(self, counter: RecordMaxCounter): + self.counter = counter + + def self_referential_fk(self) -> Optional[str]: + """Return the attname of the self-referential FK on *model*, or ``None``.""" + return self_referential_fk(self.model) + + +class AppModelSource(Source[SerializeTask]): + """ + Yields ``SerializeTask`` objects for every syncable-model record that matches the + optional *sync_filter*. + """ + + def __init__( + self, + profile: str, + sync_filter: Optional[Filter] = None, + dirty_only: bool = True, + partition_order: Literal["asc", "desc"] = "asc", + ): + self.profile = profile + self.sync_filter = sync_filter + self.dirty_only = dirty_only + self.partition_order = partition_order + self._seen = set() + + def prefix_conditions(self) -> Generator[Optional[Q], None, None]: + if self.sync_filter is None: + # yield None once, so we do one query without a partition filter (everything) + yield None + else: + partitions_prefixes = [str(prefix) for prefix in self.sync_filter] + partition_iterator = sorted( + partitions_prefixes, + reverse=self.partition_order == "desc", + ) + + for prefix in partition_iterator: + yield Q(_morango_partition__startswith=prefix) + + def stream(self) -> Generator[SerializeTask, None, None]: + for partition_condition in self.prefix_conditions(): + for qs in syncable_models.get_model_querysets(self.profile): + if partition_condition is not None: + qs = qs.filter(partition_condition) + if self.dirty_only: + qs = qs.filter(_morango_dirty_bit=True) + for obj in qs.iterator(): + # partition filtering could result in overlaps, and since we're walking + # through the partitions one by one, we should avoid duplicates. Morango + # syncable models have unique IDs across the entire profile + if obj.id not in self._seen: + self._seen.add(obj.id) + yield SerializeTask(qs.model, obj) + + +class StoreLookup(Transform[List[SerializeTask]]): + """ + For each `SerializeTask`, look up the corresponding store record (if any) + and emit the tasks back. + """ + + def __init__(self, current_id: InstanceIDModel): + self.current_id = current_id + + def transform(self, tasks: List[SerializeTask]) -> List[SerializeTask]: + store_ids = [task.obj.id for task in tasks] + stores = Store.objects.in_bulk(store_ids) + counters = {} + counters_qs = RecordMaxCounter.objects.filter( + instance_id=self.current_id.id, store_model_id__in=store_ids + ) + + for counter in counters_qs: + counters[counter.store_model_id] = counter + + for task in tasks: + store_obj = stores.get(task.obj.id) + if store_obj: + task.set_store(store_obj) + counter_obj = counters.get(task.obj.id) + if counter_obj: + task.set_counter(counter_obj) + + return tasks + + +class StoreUpdate(Transform[SerializeTask]): + """Processes the updates to the Morango store and record counters.""" + + def __init__(self, current_id: InstanceIDModel): + self.current_id = current_id + + def transform(self, task: SerializeTask) -> SerializeTask: + if task.is_store_update: + self._handle_store_update(task) + else: + self._handle_store_create(task) + + if task.is_counter_update: + task.counter.counter = self.current_id.counter + else: + task.set_counter( + RecordMaxCounter( + counter=self.current_id.counter, + instance_id=self.current_id.id, + store_model_id=task.store.id, + ) + ) + + return task + + def _handle_store_update(self, task: SerializeTask): + # if store record dirty and app record dirty, append store serialized + # to conflicting data + if task.store.dirty_bit: + task.store.conflicting_serialized_data = ( + task.store.serialized + "\n" + task.store.conflicting_serialized_data + ) + task.store.dirty_bit = False + + # set new serialized data on this store model + ser_dict = json.loads(task.store.serialized) + ser_dict.update(task.obj.serialize()) + task.store.serialized = DjangoJSONEncoder().encode(ser_dict) + + # update last saved bys + task.store.last_saved_instance = self.current_id.id + task.store.last_saved_counter = self.current_id.counter + # update deleted flags in case it was previously deleted + task.store.deleted = False + task.store.hard_deleted = False + # clear last_transfer_session_id + task.store.last_transfer_session_id = None + + def _handle_store_create(self, task: SerializeTask): + kwargs = { + "id": task.obj.id, + "serialized": DjangoJSONEncoder().encode(task.obj.serialize()), + "last_saved_instance": self.current_id.id, + "last_saved_counter": self.current_id.counter, + "model_name": task.obj.morango_model_name, + "profile": task.obj.morango_profile, + "partition": task.obj._morango_partition, + "source_id": task.obj._morango_source_id, + } + + self_ref_fk = task.self_referential_fk() + if self_ref_fk: + self_ref_fk_value = getattr(task.obj, self_ref_fk) + kwargs["_self_ref_fk"] = self_ref_fk_value or "" + + task.set_store(Store(**kwargs)) + + +class ModelPartitionBuffer(Buffer[List[SerializeTask]]): + """Buffers tasks into chunks that have the same model class.""" + + def __call__(self, tasks: Iterable[SerializeTask]) -> Iterator[List[SerializeTask]]: + chunk = [] + last_model = None + + for task in tasks: + if len(chunk) >= self.size or (last_model and last_model != task.model): + yield chunk + chunk = [] + last_model = task.model + chunk.append(task) + + if chunk: + yield chunk + + +class WriteSink(Sink[List[SerializeTask]]): + """ + Consumes SerializeTask objects and writes the appropriate changes to the database. + """ + + def __init__( + self, + profile: str, + current_id: InstanceIDModel, + sync_filter: Optional[Filter] = None, + ): + self.profile = profile + self.current_id = current_id + self.sync_filter = sync_filter + + def _partition_tasks(self, tasks: List[SerializeTask]): + stores_to_create = [] + stores_to_update = [] + counters_to_create = [] + counters_to_update = [] + + for task in tasks: + if task.is_store_update: + stores_to_update.append(task.store) + else: + stores_to_create.append(task.store) + + if task.is_counter_update: + counters_to_update.append(task.counter) + else: + counters_to_create.append(task.counter) + + return ( + stores_to_create, + stores_to_update, + counters_to_create, + counters_to_update, + ) + + def consume(self, tasks: List[SerializeTask]): # noqa: C901 + stores_to_create, stores_to_update, counters_to_create, counters_to_update = ( + self._partition_tasks(tasks) + ) + + if stores_to_create: + created_stores = Store.objects.bulk_create( + stores_to_create, ignore_conflicts=True + ) + for created_store in created_stores: + # if bulk_create has not marked it as saving been added, then it must have been + # a conflict, so we'll add it to the update list + if created_store._state.adding: + stores_to_update.append(created_store) + + if stores_to_update: + # TODO: bulk_update performs poorly-- is there a better way? + for store in stores_to_update: + store.save() + + if counters_to_create: + created_counters = RecordMaxCounter.objects.bulk_create( + counters_to_create, ignore_conflicts=True + ) + update_counter_ids = [] + for created_counter in created_counters: + # if bulk_create has not marked it as saving been added, then it must have been + # a conflict, so we'll add it to the update list + if created_counter._state.adding: + update_counter_ids.append(created_counter.store_model_id) + if update_counter_ids: + counters_to_update.extend( + RecordMaxCounter.objects.filter( + instance_id=self.current_id.id, + store_model_id__in=update_counter_ids, + ) + ) + + if counters_to_update: + # TODO: bulk_update performs poorly-- is there a better way? + for counter in counters_to_update: + counter.save() + + app_model_ids = [task.obj.id for task in tasks] + app_model = tasks[0].model + app_model.syncing_objects.filter(id__in=app_model_ids).update( + update_dirty_bit_to=False + ) + + def finalize(self): + self._handle_deleted() + self._handle_hard_deleted() + self._update_counters() + + def _handle_deleted(self): + deleted_ids = DeletedModels.objects.filter(profile=self.profile).values_list( + "id", flat=True + ) + + deleted_store_records = Store.objects.filter(id__in=deleted_ids) + deleted_store_records.update( + dirty_bit=False, + deleted=True, + last_saved_instance=self.current_id.id, + last_saved_counter=self.current_id.counter, + ) + + # update rmcs counters for deleted models that have our instance id + RecordMaxCounter.objects.filter( + instance_id=self.current_id.id, store_model_id__in=deleted_ids + ).update(counter=self.current_id.counter) + + # get a list of deleted model ids that don't have an rmc for our instance id + new_rmc_ids = deleted_store_records.exclude( + recordmaxcounter__instance_id=self.current_id.id + ).values_list("id", flat=True) + + RecordMaxCounter.objects.bulk_create( + [ + RecordMaxCounter( + store_model_id=r_id, + instance_id=self.current_id.id, + counter=self.current_id.counter, + ) + for r_id in new_rmc_ids + ] + ) + # clear deleted models table for this profile + DeletedModels.objects.filter(profile=self.profile).delete() + + def _handle_hard_deleted(self): + hard_deleted_ids = HardDeletedModels.objects.filter( + profile=self.profile + ).values_list("id", flat=True) + + hard_deleted_store_records = Store.objects.filter(id__in=hard_deleted_ids) + hard_deleted_store_records.update( + hard_deleted=True, serialized="{}", conflicting_serialized_data="" + ) + HardDeletedModels.objects.filter(profile=self.profile).delete() + + def _update_counters(self): + if not self.sync_filter: + DatabaseMaxCounter.objects.update_or_create( + instance_id=self.current_id.id, + partition="", + defaults={"counter": self.current_id.counter}, + ) + else: + for f in self.sync_filter: + DatabaseMaxCounter.objects.update_or_create( + instance_id=self.current_id.id, + partition=f, + defaults={"counter": self.current_id.counter}, + ) + + +def serialize_into_store( + profile: str, sync_filter: Optional[Filter] = None, dirty_only: bool = True +): + """ + Constructs and executes the serialization pipeline, streaming dirty app models + one-by-one through a pipeline that updates the Morango store and metadata. + """ + from morango.models.core import InstanceIDModel + from morango.sync.db import begin_transaction + + current_id = InstanceIDModel.get_current_instance_and_increment_counter() + + with begin_transaction(sync_filter, isolated=True): + # Execute the main pipeline (consumes the source through to the sink). + result_count = ( + AppModelSource(profile, sync_filter=sync_filter, dirty_only=dirty_only) + .pipe(Buffer(size=500)) + .pipe(StoreLookup(current_id)) + .pipe(Unbuffer()) + .pipe(StoreUpdate(current_id)) + .pipe(ModelPartitionBuffer(size=500)) + .end(WriteSink(profile, current_id, sync_filter=sync_filter)) + ) + logger.info(f"Serialization done: {result_count} records") diff --git a/morango/utils.py b/morango/utils.py index d69038a..e4589fe 100644 --- a/morango/utils.py +++ b/morango/utils.py @@ -129,3 +129,14 @@ def _assert(condition, message, error_type=AssertionError): """ if not condition: raise error_type(message) + + +def self_referential_fk(klass_model): + """ + Return whether this model has a self ref FK, and the name for the field + """ + for f in klass_model._meta.concrete_fields: + if f.related_model: + if issubclass(klass_model, f.related_model): + return f.attname + return None diff --git a/setup.py b/setup.py index 1126f58..a71c62f 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "djangorestframework>3.10", "django-ipware==4.0.2", "requests", + "typing-extensions==4.1.1", "ifcfg", ], license="MIT", diff --git a/tests/testapp/tests/sync/stream/__init__.py b/tests/testapp/tests/sync/stream/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/testapp/tests/sync/stream/test_core.py b/tests/testapp/tests/sync/stream/test_core.py new file mode 100644 index 0000000..d858b1e --- /dev/null +++ b/tests/testapp/tests/sync/stream/test_core.py @@ -0,0 +1,111 @@ +from django.test import SimpleTestCase + +from morango.sync.stream.core import Buffer +from morango.sync.stream.core import FlatMap +from morango.sync.stream.core import Pipeline +from morango.sync.stream.core import Sink +from morango.sync.stream.core import Source +from morango.sync.stream.core import Transform +from morango.sync.stream.core import Unbuffer + + +class FakeSource(Source): + def stream(self): + yield 1 + yield 2 + yield 3 + + +class FakeTransform(Transform): + def transform(self, item): + return item * 2 + + +class FakeFlatMap(FlatMap): + def flat_map(self, item): + return [item, item + 0.5] + + +class FakeSink(Sink): + def __init__(self): + self.consumed = [] + + def consume(self, item): + self.consumed.append(item) + + +class SourceTestCase(SimpleTestCase): + def test_stream(self): + source = FakeSource() + self.assertEqual([1, 2, 3], list(source.stream())) + + def test_pipe(self): + source = FakeSource() + transform = FakeTransform() + pipeline = source.pipe(transform) + self.assertIsInstance(pipeline, Pipeline) + self.assertEqual(pipeline._source, source) + self.assertEqual(pipeline._modules, [transform]) + + +class TransformTestCase(SimpleTestCase): + def test_transform_call(self): + transform = FakeTransform() + result = list(transform([1, 2, 3])) + self.assertEqual([2, 4, 6], result) + + def test_transform_skips_none(self): + class SkipTransform(Transform): + def transform(self, item): + return item if item > 1 else None + + transform = SkipTransform() + result = list(transform([1, 2])) + self.assertEqual([2], result) + + +class FlatMapTestCase(SimpleTestCase): + def test_flat_map_call(self): + flat_map = FakeFlatMap() + result = list(flat_map([1, 2])) + self.assertEqual([1, 1.5, 2, 2.5], result) + + +class BufferTestCase(SimpleTestCase): + def test_buffer(self): + buff = Buffer(size=2) + result = list(buff([1, 2, 3])) + self.assertEqual([[1, 2], [3]], result) + + def test_buffer_invalid_size(self): + with self.assertRaises(ValueError): + Buffer(size=0) + + +class UnbufferTestCase(SimpleTestCase): + def test_unbuffer(self): + unbuff = Unbuffer() + result = list(unbuff([[1, 2], [3]])) + self.assertEqual([1, 2, 3], result) + + +class PipelineTestCase(SimpleTestCase): + def test_pipeline_execution(self): + source = FakeSource() + transform = FakeTransform() + sink = FakeSink() + + pipeline = source.pipe(transform) + count = pipeline.end(sink) + + self.assertEqual(count, 3) + self.assertEqual([2, 4, 6], sink.consumed) + + def test_pipeline_chaining(self): + source = FakeSource() + pipeline = source.pipe(FakeTransform()).pipe(FakeTransform()) + sink = FakeSink() + + count = pipeline.end(sink) + self.assertEqual(count, 3) + self.assertEqual([4, 8, 12], sink.consumed) diff --git a/tests/testapp/tests/sync/stream/test_serialize.py b/tests/testapp/tests/sync/stream/test_serialize.py new file mode 100644 index 0000000..6b1be07 --- /dev/null +++ b/tests/testapp/tests/sync/stream/test_serialize.py @@ -0,0 +1,401 @@ +import json + +import mock +from django.db.models import Q +from django.test import SimpleTestCase + +from morango.models.certificates import Filter +from morango.models.core import InstanceIDModel +from morango.models.core import RecordMaxCounter +from morango.models.core import Store +from morango.models.core import SyncableModel +from morango.sync.stream.serialize import AppModelSource +from morango.sync.stream.serialize import ModelPartitionBuffer +from morango.sync.stream.serialize import SerializeTask +from morango.sync.stream.serialize import StoreLookup +from morango.sync.stream.serialize import StoreUpdate +from morango.sync.stream.serialize import WriteSink + + +class SerializeTaskTestCase(SimpleTestCase): + def setUp(self): + self.model = mock.Mock(spec_set=SyncableModel) + self.obj = mock.Mock(spec=SyncableModel) + self.task = SerializeTask(self.model, self.obj) + + def test_is_store_update(self): + self.assertFalse(self.task.is_store_update) + store = mock.Mock(spec_set=Store)() + store._state.adding = False + self.task.set_store(store) + self.assertTrue(self.task.is_store_update) + + def test_is_counter_update(self): + self.assertFalse(self.task.is_counter_update) + counter = mock.Mock(spec_set=RecordMaxCounter)() + counter._state.adding = False + self.task.set_counter(counter) + self.assertTrue(self.task.is_counter_update) + + @mock.patch("morango.sync.stream.serialize.self_referential_fk") + def test_self_referential_fk(self, mock_self_referential_fk): + mock_self_referential_fk.return_value = "self_ref_fk" + self.assertEqual(self.task.self_referential_fk(), "self_ref_fk") + mock_self_referential_fk.assert_called_once_with(self.model) + + +class AppModelSourceTestCase(SimpleTestCase): + def setUp(self): + self.model = mock.Mock(spec_set=SyncableModel) + + def test_prefix_conditions__none(self): + source = AppModelSource(profile="test") + conditions = list(source.prefix_conditions()) + self.assertEqual(conditions, [None]) + + def test_prefix_conditions__with_filter(self): + sync_filter = Filter("a\nb") + source = AppModelSource(profile="test", sync_filter=sync_filter) + conditions = list(source.prefix_conditions()) + self.assertEqual(len(conditions), 2) + self.assertEqual( + str(conditions[0]), "(AND: ('_morango_partition__startswith', 'a'))" + ) + + @mock.patch("morango.sync.stream.serialize.syncable_models.get_model_querysets") + def test_stream__no_partition(self, mock_get_model_querysets): + qs = mock.Mock() + mock_get_model_querysets.return_value = [qs] + model = qs.model + obj = mock.Mock(id="123") + qs.filter.return_value = qs + qs.iterator.return_value = [obj] + + source = AppModelSource(profile="test") + tasks = list(source.stream()) + + self.assertEqual(len(tasks), 1) + self.assertEqual(tasks[0].model, model) + self.assertEqual(tasks[0].obj, obj) + mock_get_model_querysets.assert_called_once_with("test") + qs.filter.assert_called_once_with(_morango_dirty_bit=True) + + @mock.patch("morango.sync.stream.serialize.syncable_models.get_model_querysets") + def test_stream__seen_once(self, mock_get_model_querysets): + qs = mock.Mock() + mock_get_model_querysets.return_value = [qs] + model = qs.model + obj = mock.Mock(id="123") + qs.filter.return_value = qs + qs.iterator.return_value = [obj, obj] + + source = AppModelSource(profile="test") + tasks = list(source.stream()) + + self.assertEqual(len(tasks), 1) + self.assertEqual(tasks[0].model, model) + self.assertEqual(tasks[0].obj, obj) + mock_get_model_querysets.assert_called_once_with("test") + qs.filter.assert_called_once_with(_morango_dirty_bit=True) + + @mock.patch("morango.sync.stream.serialize.syncable_models.get_model_querysets") + def test_stream__partition(self, mock_get_model_querysets): + qs = mock.Mock() + mock_get_model_querysets.return_value = [qs] + model = qs.model + obj = mock.Mock(id="123") + qs.filter.return_value = qs + qs.iterator.return_value = [obj, obj] + + source = AppModelSource( + profile="test", sync_filter=Filter("a"), dirty_only=False + ) + tasks = list(source.stream()) + + self.assertEqual(len(tasks), 1) + self.assertEqual(tasks[0].model, model) + self.assertEqual(tasks[0].obj, obj) + mock_get_model_querysets.assert_called_once_with("test") + qs.filter.assert_called_once_with(Q(_morango_partition__startswith="a")) + + +class StoreLookupTestCase(SimpleTestCase): + @mock.patch("morango.models.core.RecordMaxCounter.objects.filter") + @mock.patch("morango.models.core.Store.objects.in_bulk") + def test_transform(self, mock_bulk, mock_qs): + current_id = mock.Mock(spec=InstanceIDModel) + current_id.id = "inst_1" + lookup = StoreLookup(current_id) + + obj1 = mock.Mock(id="obj_1") + task1 = SerializeTask(mock.Mock(), obj1) + obj2 = mock.Mock(id="obj_2") + task2 = SerializeTask(mock.Mock(), obj2) + + store_obj1 = mock.Mock(spec=Store) + store_obj2 = mock.Mock(spec=Store) + mock_bulk.return_value = {"obj_1": store_obj1, "obj_2": store_obj2} + + counter_obj = mock.Mock(spec=RecordMaxCounter, store_model_id="obj_1") + mock_qs.return_value = [counter_obj] + + results = lookup.transform([task1, task2]) + + self.assertEqual(results[0].store, store_obj1) + self.assertEqual(results[0].counter, counter_obj) + self.assertEqual(results[1].store, store_obj2) + self.assertEqual(results[1].counter, None) + + +class StoreUpdateTestCase(SimpleTestCase): + @mock.patch("morango.sync.stream.serialize.StoreUpdate._handle_store_create") + def test_transform__creates(self, mock_handle_store_create): + current_id = mock.Mock(id="inst_1", counter=10) + update = StoreUpdate(current_id) + + store = Store(id="123", serialized=json.dumps({"old": 1}), dirty_bit=False) + + task = SerializeTask(mock.Mock(), mock.Mock()) + mock_handle_store_create.side_effect = lambda _: task.set_store(store) + + update.transform(task) + mock_handle_store_create.assert_called_once_with(task) + self.assertEqual(task.counter.instance_id, "inst_1") + self.assertEqual(task.counter.counter, 10) + self.assertEqual(task.counter.store_model_id, "123") + + @mock.patch("morango.sync.stream.serialize.StoreUpdate._handle_store_update") + def test_transform__updates(self, mock_handle_store_update): + current_id = mock.Mock(id="inst_1", counter=10) + update = StoreUpdate(current_id) + + store = Store(serialized=json.dumps({"old": 1}), dirty_bit=False) + store._state.adding = False + counter = RecordMaxCounter(counter=5) + + task = SerializeTask(mock.Mock(), mock.Mock()) + task.set_store(store) + task.set_counter(counter) + + update.transform(task) + mock_handle_store_update.assert_called_once_with(task) + self.assertEqual(task.counter.counter, 10) + + def test_handle_store_update(self): + current_id = mock.Mock(id="inst_1", counter=10) + update = StoreUpdate(current_id) + + store = Store(serialized=json.dumps({"old": 1}), dirty_bit=False) + obj = mock.Mock() + obj.serialize.return_value = {"new": 2} + + task = SerializeTask(mock.Mock(), obj) + task.set_store(store) + + update._handle_store_update(task) + + ser_data = json.loads(task.store.serialized) + self.assertEqual(ser_data["old"], 1) + self.assertEqual(ser_data["new"], 2) + + +class ModelPartitionBufferTestCase(SimpleTestCase): + def test_buffer_splits_on_model_change(self): + buff = ModelPartitionBuffer(size=10) + m1, m2 = mock.Mock(), mock.Mock() + tasks = [ + SerializeTask(m1, mock.Mock()), + SerializeTask(m1, mock.Mock()), + SerializeTask(m2, mock.Mock()), + ] + + chunks = list(buff(tasks)) + self.assertEqual(len(chunks), 2) + self.assertEqual(chunks[0][0].model, m1) + self.assertEqual(chunks[0][1].model, m1) + self.assertEqual(chunks[1][0].model, m2) + + +class WriteSinkTestCase(SimpleTestCase): + def setUp(self): + self.profile = "test" + self.current_id = mock.Mock(spec=InstanceIDModel) + self.current_id.id = "inst_1" + self.current_id.counter = 10 + self.sink = WriteSink(self.profile, self.current_id) + + @mock.patch("morango.sync.stream.serialize.RecordMaxCounter.objects.bulk_create") + @mock.patch("morango.sync.stream.serialize.Store.objects.bulk_create") + def test_consume(self, mock_store_bulk, mock_counter_bulk): + model = mock.Mock() + obj1 = mock.Mock(id="obj_1") + task1 = SerializeTask(model, obj1) + store1 = mock.Mock(spec_set=Store, id="obj_1")() + store1._state.adding = True + task1.set_store(store1) + counter1 = mock.Mock(spec_set=RecordMaxCounter)() + counter1._state.adding = True + task1.set_counter(counter1) + + obj2 = mock.Mock(id="obj_2") + task2 = SerializeTask(model, obj2) + store2 = mock.Mock(spec_set=Store, id="obj_2")() + store2._state.adding = False + task2.set_store(store2) + counter2 = mock.Mock(spec_set=RecordMaxCounter)() + counter2._state.adding = False + task2.set_counter(counter2) + + def _bulk_store_create(objs, **kwargs): + for obj in objs: + obj._state.adding = False + return objs + + def _bulk_counter_create(objs, **kwargs): + for obj in objs: + obj._state.adding = False + return objs + + mock_store_bulk.side_effect = _bulk_store_create + mock_counter_bulk.side_effect = _bulk_counter_create + + self.sink.consume([task1, task2]) + + # Verify Store operations + mock_store_bulk.assert_called_once_with([store1], ignore_conflicts=True) + store1.save.assert_not_called() + store2.save.assert_called_once() + + # Verify Counter operations + mock_counter_bulk.assert_called_once_with([counter1], ignore_conflicts=True) + counter1.save.assert_not_called() + counter2.save.assert_called_once() + + # Verify dirty bit update on app models + model.syncing_objects.filter.assert_called_once_with(id__in=["obj_1", "obj_2"]) + model.syncing_objects.filter().update.assert_called_once_with( + update_dirty_bit_to=False + ) + + @mock.patch("morango.sync.stream.serialize.RecordMaxCounter.objects.filter") + @mock.patch("morango.sync.stream.serialize.RecordMaxCounter.objects.bulk_create") + @mock.patch("morango.sync.stream.serialize.Store.objects.bulk_create") + def test_consume__create_fail( + self, mock_store_bulk, mock_counter_bulk, mock_counter_filter + ): + model = mock.Mock() + obj1 = mock.Mock(id="obj_1") + task1 = SerializeTask(model, obj1) + store1 = mock.Mock(spec_set=Store, id="obj_1")() + store1._state.adding = True + task1.set_store(store1) + counter1 = mock.Mock(spec_set=RecordMaxCounter)() + counter1._state.adding = True + counter1.store_model_id = "123" + task1.set_counter(counter1) + + # Mock bulk_create returns + mock_store_bulk.return_value = [store1] + mock_counter_bulk.return_value = [counter1] + mock_counter_filter.return_value = [counter1] + + self.sink.consume([task1]) + + mock_counter_filter.assert_called_once_with( + instance_id="inst_1", store_model_id__in=["123"] + ) + + # Verify Store operations + mock_store_bulk.assert_called_once_with([store1], ignore_conflicts=True) + store1.save.assert_called_once() + + # Verify Counter operations + mock_counter_bulk.assert_called_once_with([counter1], ignore_conflicts=True) + counter1.save.assert_called_once() + + # Verify dirty bit update on app models + model.syncing_objects.filter.assert_called_once_with(id__in=["obj_1"]) + model.syncing_objects.filter().update.assert_called_once_with( + update_dirty_bit_to=False + ) + + @mock.patch("morango.sync.stream.serialize.RecordMaxCounter.objects.filter") + @mock.patch("morango.sync.stream.serialize.RecordMaxCounter.objects.bulk_create") + @mock.patch("morango.sync.stream.serialize.Store.objects.filter") + @mock.patch("morango.sync.stream.serialize.DeletedModels.objects.filter") + def test_handle_deleted( + self, mock_deleted_filter, mock_store_filter, mock_rmc_bulk, mock_counter_filter + ): + mock_deleted_filter.return_value.values_list.return_value = ["del_1"] + mock_store_records = mock.Mock() + mock_store_filter.return_value = mock_store_records + mock_store_records.exclude.return_value.values_list.return_value = ["del_1"] + + self.sink._handle_deleted() + + mock_deleted_filter.assert_called_with(profile=self.profile) + mock_store_filter.assert_called_once_with(id__in=["del_1"]) + mock_store_records.update.assert_called_once_with( + dirty_bit=False, + deleted=True, + last_saved_instance=self.current_id.id, + last_saved_counter=self.current_id.counter, + ) + mock_counter_filter.assert_called_once_with( + instance_id=self.current_id.id, store_model_id__in=["del_1"] + ) + mock_counter_filter().update.assert_called_once_with( + counter=self.current_id.counter + ) + mock_store_records.exclude.assert_called_once_with( + recordmaxcounter__instance_id=self.current_id.id + ) + mock_store_records.exclude().values_list.assert_called_once_with( + "id", flat=True + ) + mock_rmc_bulk.assert_called_once() + rmc = mock_rmc_bulk.call_args[0][0][0] + self.assertEqual(rmc.store_model_id, "del_1") + self.assertEqual(rmc.instance_id, self.current_id.id) + self.assertEqual(rmc.counter, self.current_id.counter) + mock_deleted_filter().delete.assert_called_once() + + @mock.patch("morango.sync.stream.serialize.HardDeletedModels.objects.filter") + @mock.patch("morango.sync.stream.serialize.Store.objects.filter") + def test_handle_hard_deleted(self, mock_store_filter, mock_hard_deleted_filter): + mock_hard_deleted_filter.return_value.values_list.return_value = ["hard_del_1"] + mock_store_records = mock.Mock() + mock_store_filter.return_value = mock_store_records + + self.sink._handle_hard_deleted() + + mock_hard_deleted_filter.assert_called_with(profile=self.profile) + mock_store_records.update.assert_called_once_with( + hard_deleted=True, serialized="{}", conflicting_serialized_data="" + ) + mock_hard_deleted_filter().delete.assert_called_once() + + @mock.patch( + "morango.sync.stream.serialize.DatabaseMaxCounter.objects.update_or_create" + ) + def test_update_counters(self, mock_update_or_create): + self.sink.sync_filter = Filter("a") + self.sink._update_counters() + self.assertEqual(mock_update_or_create.call_count, 1) + mock_update_or_create.assert_called_with( + instance_id=self.current_id.id, + partition="a", + defaults={"counter": self.current_id.counter}, + ) + + @mock.patch( + "morango.sync.stream.serialize.DatabaseMaxCounter.objects.update_or_create" + ) + def test_update_counters__no_filter(self, mock_update_or_create): + self.sink._update_counters() + self.assertEqual(mock_update_or_create.call_count, 1) + mock_update_or_create.assert_called_with( + instance_id=self.current_id.id, + partition="", + defaults={"counter": self.current_id.counter}, + ) diff --git a/tests/testapp/tests/sync/test_controller.py b/tests/testapp/tests/sync/test_controller.py index f009e21..efccae7 100644 --- a/tests/testapp/tests/sync/test_controller.py +++ b/tests/testapp/tests/sync/test_controller.py @@ -21,7 +21,6 @@ from morango.models.core import InstanceIDModel from morango.models.core import RecordMaxCounter from morango.models.core import Store -from morango.sync.controller import _self_referential_fk from morango.sync.controller import MorangoProfileController from morango.sync.controller import SessionController @@ -108,8 +107,8 @@ def test_last_saved_instance_updates(self): old_instance_id = Store.objects.first().last_saved_instance with EnvironmentVarGuard() as env: - env['MORANGO_SYSTEM_ID'] = 'new_sys_id' - (new_id, _) = InstanceIDModel.get_or_create_current_instance(clear_cache=True) + env["MORANGO_SYSTEM_ID"] = "new_sys_id" + new_id, _ = InstanceIDModel.get_or_create_current_instance(clear_cache=True) Facility.objects.all().update(name=self.new_name) self.mc.serialize_into_store() @@ -291,7 +290,7 @@ def test_store_hard_delete_propagates(self): class RecordMaxCounterUpdatesDuringSerialization(TestCase): def setUp(self): - (self.current_id, _) = InstanceIDModel.get_or_create_current_instance() + self.current_id, _ = InstanceIDModel.get_or_create_current_instance() self.mc = MorangoProfileController("facilitydata") self.fac1 = FacilityModelFactory(name="school") self.mc.serialize_into_store() @@ -299,8 +298,8 @@ def setUp(self): def test_new_rmc_for_existing_model(self): with EnvironmentVarGuard() as env: - env['MORANGO_SYSTEM_ID'] = 'new_sys_id' - (new_id, _) = InstanceIDModel.get_or_create_current_instance(clear_cache=True) + env["MORANGO_SYSTEM_ID"] = "new_sys_id" + new_id, _ = InstanceIDModel.get_or_create_current_instance(clear_cache=True) Facility.objects.update(name="facility") self.mc.serialize_into_store() @@ -336,8 +335,8 @@ def test_update_rmc_for_existing_model(self): def test_new_rmc_for_non_existent_model(self): with EnvironmentVarGuard() as env: - env['MORANGO_SYSTEM_ID'] = 'new_sys_id' - (new_id, _) = InstanceIDModel.get_or_create_current_instance(clear_cache=True) + env["MORANGO_SYSTEM_ID"] = "new_sys_id" + new_id, _ = InstanceIDModel.get_or_create_current_instance(clear_cache=True) new_fac = FacilityModelFactory(name="college") self.mc.serialize_into_store() @@ -354,7 +353,7 @@ def test_new_rmc_for_non_existent_model(self): class DeserializationFromStoreIntoAppTestCase(TestCase): def setUp(self): - (self.current_id, _) = InstanceIDModel.get_or_create_current_instance() + self.current_id, _ = InstanceIDModel.get_or_create_current_instance() self.range = 10 self.mc = MorangoProfileController("facilitydata") for i in range(self.range): @@ -429,7 +428,9 @@ def test_record_with_dirty_bit_off_doesnt_deserialize(self): def test_broken_fk_leaves_store_dirty_bit(self): log_id = uuid.uuid4().hex - serialized = json.dumps({"user_id": "40de9a3fded95d7198f200c78e559353", "id": log_id}) + serialized = json.dumps( + {"user_id": "40de9a3fded95d7198f200c78e559353", "id": log_id} + ) st = StoreModelFacilityFactory( id=log_id, serialized=serialized, model_name="contentsummarylog" ) @@ -518,8 +519,12 @@ def _create_two_users_to_deserialize(self): self.mc.serialize_into_store() user.username = "changed" user2.username = "changed2" - Store.objects.filter(id=user.id).update(serialized=json.dumps(user.serialize()), dirty_bit=True) - Store.objects.filter(id=user2.id).update(serialized=json.dumps(user2.serialize()), dirty_bit=True) + Store.objects.filter(id=user.id).update( + serialized=json.dumps(user.serialize()), dirty_bit=True + ) + Store.objects.filter(id=user2.id).update( + serialized=json.dumps(user2.serialize()), dirty_bit=True + ) return user, user2 def test_regular_model_deserialization(self): @@ -543,13 +548,9 @@ def test_filtered_deserialization(self): class SelfReferentialFKDeserializationTestCase(TestCase): def setUp(self): - (self.current_id, _) = InstanceIDModel.get_or_create_current_instance() + self.current_id, _ = InstanceIDModel.get_or_create_current_instance() self.mc = MorangoProfileController("facilitydata") - def test_self_ref_fk(self): - self.assertEqual(_self_referential_fk(Facility), "parent_id") - self.assertEqual(_self_referential_fk(MyUser), None) - def test_delete_model_in_store_deletes_models_in_app(self): root = FacilityModelFactory() child1 = FacilityModelFactory(parent=root) @@ -590,10 +591,14 @@ def test_models_created_successfully(self): self.assertEqual(child2[0].parent_id, root.id) def test_deserialization_of_model_with_missing_parent(self): - self._test_deserialization_of_model_with_missing_parent(correct_self_ref_fk=True) + self._test_deserialization_of_model_with_missing_parent( + correct_self_ref_fk=True + ) def test_deserialization_of_model_with_mismatched_self_ref_fk(self): - self._test_deserialization_of_model_with_missing_parent(correct_self_ref_fk=False) + self._test_deserialization_of_model_with_missing_parent( + correct_self_ref_fk=False + ) def _test_deserialization_of_model_with_missing_parent(self, correct_self_ref_fk): root = FacilityModelFactory() @@ -619,7 +624,7 @@ def _test_deserialization_of_model_with_missing_parent(self, correct_self_ref_fk class ForeignKeyDeserializationTestCase(TestCase): def setUp(self): - (self.current_id, _) = InstanceIDModel.get_or_create_current_instance() + self.current_id, _ = InstanceIDModel.get_or_create_current_instance() self.mc = MorangoProfileController("facilitydata") def test_deserialization_of_model_with_missing_foreignkey_referent(self): @@ -640,7 +645,10 @@ def test_deserialization_of_model_with_missing_foreignkey_referent(self): new_log.refresh_from_db() self.assertTrue(new_log.dirty_bit) - self.assertIn("my user instance with id '{}'".format(data["user_id"]), new_log.deserialization_error) + self.assertIn( + "my user instance with id '{}'".format(data["user_id"]), + new_log.deserialization_error, + ) def test_deserialization_of_model_with_disallowed_null_foreignkey(self): @@ -708,15 +716,18 @@ class SessionControllerTestCase(SimpleTestCase): def setUp(self): super(SessionControllerTestCase, self).setUp() self.middleware = [ - mock.Mock(related_stage=stage) - for stage, _ in transfer_stages.CHOICES + mock.Mock(related_stage=stage) for stage, _ in transfer_stages.CHOICES ] self.context = TestSessionContext() - self.controller = SessionController.build(middleware=self.middleware, context=self.context) + self.controller = SessionController.build( + middleware=self.middleware, context=self.context + ) @contextlib.contextmanager def _mock_method(self, method): - with mock.patch('morango.sync.controller.SessionController.{}'.format(method)) as invoke: + with mock.patch( + "morango.sync.controller.SessionController.{}".format(method) + ) as invoke: yield invoke invoke.reset_mock() @@ -726,18 +737,24 @@ def test_proceed_to__passed_stage(self): self.assertEqual(transfer_statuses.COMPLETED, result) def test_proceed_to__in_progress(self): - self.context.update(stage=transfer_stages.TRANSFERRING, stage_status=transfer_statuses.STARTED) + self.context.update( + stage=transfer_stages.TRANSFERRING, stage_status=transfer_statuses.STARTED + ) result = self.controller.proceed_to(transfer_stages.TRANSFERRING) self.assertEqual(transfer_statuses.STARTED, result) def test_proceed_to__errored(self): - self.context.update(stage=transfer_stages.TRANSFERRING, stage_status=transfer_statuses.ERRORED) + self.context.update( + stage=transfer_stages.TRANSFERRING, stage_status=transfer_statuses.ERRORED + ) result = self.controller.proceed_to(transfer_stages.TRANSFERRING) self.assertEqual(transfer_statuses.ERRORED, result) def test_proceed_to__executes_middleware__incrementally(self): - self.context.update(stage=transfer_stages.SERIALIZING, stage_status=transfer_statuses.COMPLETED) - with self._mock_method('_invoke_middleware') as mock_invoke: + self.context.update( + stage=transfer_stages.SERIALIZING, stage_status=transfer_statuses.COMPLETED + ) + with self._mock_method("_invoke_middleware") as mock_invoke: mock_invoke.return_value = transfer_statuses.STARTED result = self.controller.proceed_to(transfer_stages.QUEUING) self.assertEqual(transfer_statuses.STARTED, result) @@ -748,15 +765,19 @@ def test_proceed_to__executes_middleware__incrementally(self): mock_invoke.reset_mock() def test_proceed_to__executes_middleware__all(self): - self.context.update(stage=transfer_stages.SERIALIZING, stage_status=transfer_statuses.COMPLETED) - with self._mock_method('_invoke_middleware') as invoke: + self.context.update( + stage=transfer_stages.SERIALIZING, stage_status=transfer_statuses.COMPLETED + ) + with self._mock_method("_invoke_middleware") as invoke: invoke.return_value = transfer_statuses.COMPLETED result = self.controller.proceed_to(transfer_stages.CLEANUP) self.assertEqual(transfer_statuses.COMPLETED, result) self.assertEqual(5, len(invoke.call_args_list)) def test_proceed_to__resuming_fast_forward(self): - self.context.update(stage=transfer_stages.INITIALIZING, stage_status=transfer_statuses.PENDING) + self.context.update( + stage=transfer_stages.INITIALIZING, stage_status=transfer_statuses.PENDING + ) expected_stages = ( transfer_stages.INITIALIZING, transfer_stages.DESERIALIZING, @@ -766,32 +787,39 @@ def test_proceed_to__resuming_fast_forward(self): def invoke(context, middleware): self.assertIn(context.stage, expected_stages) if context.stage == transfer_stages.INITIALIZING: - context.update(stage=transfer_stages.DESERIALIZING, stage_status=transfer_statuses.PENDING) + context.update( + stage=transfer_stages.DESERIALIZING, + stage_status=transfer_statuses.PENDING, + ) return transfer_statuses.COMPLETED - with self._mock_method('_invoke_middleware') as mock_invoke: + with self._mock_method("_invoke_middleware") as mock_invoke: mock_invoke.side_effect = invoke result = self.controller.proceed_to(transfer_stages.CLEANUP) self.assertEqual(transfer_statuses.COMPLETED, result) self.assertEqual(3, len(mock_invoke.call_args_list)) def test_proceed_to_and_wait_for(self): - with self._mock_method('proceed_to') as proceed_to: + with self._mock_method("proceed_to") as proceed_to: proceed_to.side_effect = [ transfer_statuses.PENDING, transfer_statuses.PENDING, - transfer_statuses.COMPLETED + transfer_statuses.COMPLETED, ] - result = self.controller.proceed_to_and_wait_for(transfer_stages.CLEANUP, max_interval=0.1) + result = self.controller.proceed_to_and_wait_for( + transfer_stages.CLEANUP, max_interval=0.1 + ) self.assertEqual(result, transfer_statuses.COMPLETED) def test_proceed_to_and_wait_for__errored(self): - with self._mock_method('proceed_to') as proceed_to: + with self._mock_method("proceed_to") as proceed_to: proceed_to.side_effect = [ transfer_statuses.PENDING, - transfer_statuses.ERRORED + transfer_statuses.ERRORED, ] - result = self.controller.proceed_to_and_wait_for(transfer_stages.CLEANUP, max_interval=0.1) + result = self.controller.proceed_to_and_wait_for( + transfer_stages.CLEANUP, max_interval=0.1 + ) self.assertEqual(result, transfer_statuses.ERRORED) @mock.patch("morango.sync.controller.sleep") @@ -807,9 +835,11 @@ def mock_proceed_to(*args, **kwargs): return transfer_statuses.PENDING try: - with self._mock_method('proceed_to') as proceed_to: + with self._mock_method("proceed_to") as proceed_to: proceed_to.side_effect = mock_proceed_to - result = self.controller.proceed_to_and_wait_for(transfer_stages.CLEANUP, max_interval=0.1) + result = self.controller.proceed_to_and_wait_for( + transfer_stages.CLEANUP, max_interval=0.1 + ) self.assertEqual(result, transfer_statuses.COMPLETED) except OverflowError: self.fail("Overflow error raised!") @@ -823,14 +853,25 @@ def test_invoke_middleware(self): middleware = self.middleware[0] middleware.return_value = transfer_statuses.STARTED - with mock.patch.object(TestSessionContext, "update_state", wraps=context.update_state) as m: + with mock.patch.object( + TestSessionContext, "update_state", wraps=context.update_state + ) as m: result = self.controller._invoke_middleware(context, middleware) self.assertEqual(result, transfer_statuses.STARTED) context_update_calls = m.call_args_list self.assertEqual(2, len(context_update_calls)) - self.assertEqual(mock.call(stage=middleware.related_stage, stage_status=transfer_statuses.PENDING), context_update_calls[0]) - self.assertEqual(mock.call(stage=None, stage_status=transfer_statuses.STARTED), context_update_calls[1]) + self.assertEqual( + mock.call( + stage=middleware.related_stage, + stage_status=transfer_statuses.PENDING, + ), + context_update_calls[0], + ) + self.assertEqual( + mock.call(stage=None, stage_status=transfer_statuses.STARTED), + context_update_calls[1], + ) self.assertEqual(2, len(handler.call_args_list)) self.assertEqual(mock.call(context=context), handler.call_args_list[0]) diff --git a/tests/testapp/tests/sync/test_db.py b/tests/testapp/tests/sync/test_db.py new file mode 100644 index 0000000..0bee040 --- /dev/null +++ b/tests/testapp/tests/sync/test_db.py @@ -0,0 +1,108 @@ +import threading +import uuid +from time import sleep + +import pytest +from django.conf import settings +from django.db import connection +from django.test import override_settings +from django.test import TransactionTestCase +from django.utils import timezone + +from ..helpers import create_buffer_and_store_dummy_data +from morango.models.certificates import Filter +from morango.models.core import Store +from morango.models.core import SyncSession +from morango.models.core import TransferSession +from morango.sync.backends.utils import load_backend +from morango.sync.db import begin_transaction + + +DBBackend = load_backend(connection) + + +def _concurrent_store_write(thread_event, store_id): + while not thread_event.is_set(): + sleep(.1) + Store.objects.filter(id=store_id).delete() + connection.close() + + +class TransactionIsolationTestCase(TransactionTestCase): + serialized_rollback = True + + def _fixture_setup(self): + """Don't setup fixtures for this test case""" + pass + + @override_settings(MORANGO_TEST_POSTGRESQL=False) + def test_begin_transaction(self): + """ + Assert that we can start a transaction using our util and make some writes without + raising errors, specifically + """ + # the utility we're testing here avoids setting the isolation level when this setting is True + # because tests usually run within their own transaction. By the time the isolation level + # is attempted to be set within a test, there have been reads and writes and the isolation + # cannot be changed + self.assertFalse(connection.in_atomic_block) + with begin_transaction(None, isolated=True): + session = SyncSession.objects.create( + id=uuid.uuid4().hex, + profile="facilitydata", + last_activity_timestamp=timezone.now(), + ) + transfer_session = TransferSession.objects.create( + id=uuid.uuid4().hex, + sync_session=session, + push=True, + last_activity_timestamp=timezone.now(), + ) + create_buffer_and_store_dummy_data(transfer_session.id) + + # manual cleanup + self.assertNotEqual(0, Store.objects.all().count()) + # will cascade delete + SyncSession.objects.all().delete() + Store.objects.all().delete() + + @pytest.mark.skipif( + not getattr(settings, "MORANGO_TEST_POSTGRESQL", False), reason="Not supported" + ) + def test_transaction_isolation_handling(self): + from psycopg2.extensions import ISOLATION_LEVEL_REPEATABLE_READ + + store = Store.objects.create( + id=uuid.uuid4().hex, + last_saved_instance=uuid.uuid4().hex, + last_saved_counter=1, + partition=uuid.uuid4().hex, + profile="facilitydata", + source_id="qqq", + model_name="qqq", + ) + + concurrent_event = threading.Event() + concurrent_thread = threading.Thread( + target=_concurrent_store_write, + args=(concurrent_event, store.id), + ) + concurrent_thread.start() + + # this test is only for postgres, but we don't want the code to know it's a test + with override_settings(MORANGO_TEST_POSTGRESQL=False): + try: + self.assertNotEqual(connection.connection.isolation_level, ISOLATION_LEVEL_REPEATABLE_READ) + with begin_transaction(Filter(store.partition), isolated=True): + self.assertEqual(connection.connection.isolation_level, ISOLATION_LEVEL_REPEATABLE_READ) + s = Store.objects.get(id=store.id) + concurrent_event.set() + sleep(.2) + s.last_saved_counter += 1 + s.save() + raise AssertionError("Didn't raise transactional error") + except Exception as e: + self.assertTrue(DBBackend._is_transaction_isolation_error(e)) + self.assertNotEqual(connection.connection.isolation_level, ISOLATION_LEVEL_REPEATABLE_READ) + finally: + concurrent_thread.join(5) diff --git a/tests/testapp/tests/sync/test_operations.py b/tests/testapp/tests/sync/test_operations.py index 6aae39e..b0a85cb 100644 --- a/tests/testapp/tests/sync/test_operations.py +++ b/tests/testapp/tests/sync/test_operations.py @@ -1,16 +1,11 @@ import json -import threading import uuid -from time import sleep import factory import mock -import pytest -from django.conf import settings from django.db import connection from django.test import override_settings from django.test import TestCase -from django.test import TransactionTestCase from django.utils import timezone from facility_profile.models import ConditionalLog from facility_profile.models import Facility @@ -36,7 +31,6 @@ from morango.sync.context import LocalSessionContext from morango.sync.controller import MorangoProfileController from morango.sync.controller import SessionController -from morango.sync.operations import _begin_transaction from morango.sync.operations import _dequeue_into_store from morango.sync.operations import _deserialize_from_store from morango.sync.operations import _queue_into_buffer_v1 @@ -78,93 +72,6 @@ def assertRecordsNotBuffered(records): assert i.id not in rmcb_ids -def _concurrent_store_write(thread_event, store_id): - while not thread_event.is_set(): - sleep(.1) - Store.objects.filter(id=store_id).delete() - connection.close() - - -class TransactionIsolationTestCase(TransactionTestCase): - serialized_rollback = True - - def _fixture_setup(self): - """Don't setup fixtures for this test case""" - pass - - @override_settings(MORANGO_TEST_POSTGRESQL=False) - def test_begin_transaction(self): - """ - Assert that we can start a transaction using our util and make some writes without - raising errors, specifically - """ - # the utility we're testing here avoids setting the isolation level when this setting is True - # because tests usually run within their own transaction. By the time the isolation level - # is attempted to be set within a test, there have been reads and writes and the isolation - # cannot be changed - self.assertFalse(connection.in_atomic_block) - with _begin_transaction(None, isolated=True): - session = SyncSession.objects.create( - id=uuid.uuid4().hex, - profile="facilitydata", - last_activity_timestamp=timezone.now(), - ) - transfer_session = TransferSession.objects.create( - id=uuid.uuid4().hex, - sync_session=session, - push=True, - last_activity_timestamp=timezone.now(), - ) - create_buffer_and_store_dummy_data(transfer_session.id) - - # manual cleanup - self.assertNotEqual(0, Store.objects.all().count()) - # will cascade delete - SyncSession.objects.all().delete() - Store.objects.all().delete() - - @pytest.mark.skipif( - not getattr(settings, "MORANGO_TEST_POSTGRESQL", False), reason="Not supported" - ) - def test_transaction_isolation_handling(self): - from psycopg2.extensions import ISOLATION_LEVEL_REPEATABLE_READ - - store = Store.objects.create( - id=uuid.uuid4().hex, - last_saved_instance=uuid.uuid4().hex, - last_saved_counter=1, - partition=uuid.uuid4().hex, - profile="facilitydata", - source_id="qqq", - model_name="qqq", - ) - - concurrent_event = threading.Event() - concurrent_thread = threading.Thread( - target=_concurrent_store_write, - args=(concurrent_event, store.id), - ) - concurrent_thread.start() - - # this test is only for postgres, but we don't want the code to know it's a test - with override_settings(MORANGO_TEST_POSTGRESQL=False): - try: - self.assertNotEqual(connection.connection.isolation_level, ISOLATION_LEVEL_REPEATABLE_READ) - with _begin_transaction(Filter(store.partition), isolated=True): - self.assertEqual(connection.connection.isolation_level, ISOLATION_LEVEL_REPEATABLE_READ) - s = Store.objects.get(id=store.id) - concurrent_event.set() - sleep(.2) - s.last_saved_counter += 1 - s.save() - raise AssertionError("Didn't raise transactional error") - except Exception as e: - self.assertTrue(DBBackend._is_transaction_isolation_error(e)) - self.assertNotEqual(connection.connection.isolation_level, ISOLATION_LEVEL_REPEATABLE_READ) - finally: - concurrent_thread.join(5) - - @override_settings(MORANGO_SERIALIZE_BEFORE_QUEUING=False, MORANGO_DISABLE_FSIC_V2_FORMAT=True) class QueueStoreIntoBufferV1TestCase(TestCase): def setUp(self): diff --git a/tests/testapp/tests/test_utils.py b/tests/testapp/tests/test_utils.py index 39449bd..3e9ad43 100644 --- a/tests/testapp/tests/test_utils.py +++ b/tests/testapp/tests/test_utils.py @@ -1,23 +1,26 @@ import os -from requests import Request -from django.http.request import HttpRequest -from django.test.testcases import SimpleTestCase + import mock import pytest +from django.http.request import HttpRequest +from django.test.testcases import SimpleTestCase +from facility_profile.models import Facility +from facility_profile.models import MyUser +from requests import Request +from morango.constants import transfer_stages from morango.constants.capabilities import ALLOW_CERTIFICATE_PUSHING from morango.constants.capabilities import ASYNC_OPERATIONS from morango.constants.capabilities import FSIC_V2_FORMAT -from morango.constants import transfer_stages -from morango.utils import SETTINGS +from morango.utils import _posix_pid_exists +from morango.utils import _windows_pid_exists from morango.utils import CAPABILITIES_CLIENT_HEADER -from morango.utils import CAPABILITIES_SERVER_HEADER from morango.utils import get_capabilities -from morango.utils import serialize_capabilities_to_client_request from morango.utils import parse_capabilities_from_server_request from morango.utils import pid_exists -from morango.utils import _posix_pid_exists -from morango.utils import _windows_pid_exists +from morango.utils import self_referential_fk +from morango.utils import serialize_capabilities_to_client_request +from morango.utils import SETTINGS class SettingsTestCase(SimpleTestCase): @@ -110,3 +113,9 @@ def test_windows(self): pid = os.getpid() self.assertTrue(pid_exists(pid)) self.assertFalse(pid_exists(123456789)) + + +class SelfReferentialFKTestCase(SimpleTestCase): + def test_self_ref_fk(self): + self.assertEqual(self_referential_fk(Facility), "parent_id") + self.assertEqual(self_referential_fk(MyUser), None)