diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 70b9108..c43425f 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -179,6 +179,7 @@ def schema( namespace: Optional[str] = None, names: Optional[NamesType] = None, options: Option = Option(0), + processing: set[type] | None = None, ) -> JSONType: """ Generate and return an Avro schema for a given Python type @@ -193,12 +194,17 @@ def schema( """ if names is None: names = [] - schema_obj = _schema_obj(py_type, namespace=namespace, options=options) + schema_obj = _schema_obj(py_type, namespace=namespace, options=options, processing=processing) schema_data = schema_obj.data(names=names) return schema_data -def _schema_obj(py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)) -> "Schema": +def _schema_obj( + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, +) -> "Schema": """ Dispatch to relevant schema classes @@ -206,10 +212,11 @@ def _schema_obj(py_type: Type, namespace: Optional[str] = None, options: Option :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ + processing = processing or set() # Find concrete Schema subclasses defined in the current module for schema_class in sorted(_SCHEMA_CLASSES, key=lambda c: getattr(c, "__py_avro_priority", 0)): # Find the first schema class that handles py_type - schema_obj = schema_class(py_type, namespace=namespace, options=options) # type: ignore + schema_obj = schema_class(py_type, namespace=namespace, options=options, processing=processing) # type: ignore if schema_obj: return schema_obj raise TypeNotSupportedError(f"Cannot generate Avro schema for Python type {py_type}") @@ -229,7 +236,13 @@ def validate_name(value: str) -> str: class Schema(abc.ABC): """Schema base""" - def __new__(cls, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __new__( + cls, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ Create an instance of this schema class if it handles py_type @@ -242,17 +255,25 @@ def __new__(cls, py_type: Type, namespace: Optional[str] = None, options: Option else: return None - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ A schema base :param py_type: The Python class to generate a schema for. :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. + :param processing: Internal parameter to track types currently being processed (for circular references). """ self.py_type = py_type self.options = options self._namespace = namespace # Namespace override + self.processing = processing or set() @property def namespace_override(self) -> Optional[str]: @@ -361,7 +382,13 @@ def data(self, names: NamesType) -> JSONObj: class LiteralSchema(Schema): """An Avro schema of any type for a Python Literal type, e.g. ``Literal[""]``""" - def __init__(self, py_type: Type[Any], namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type[Any], + namespace: Optional[str] = None, + options: Option = Option(0), + **kwargs, + ): """ An Avro schema of any type for a Python Literal type, e.g. ``Literal[""]`` @@ -395,7 +422,13 @@ def data(self, names: NamesType) -> JSONType: class FinalSchema(Schema): """An Avro schema for Python ``typing.Final``""" - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + **kwargs, + ): """An Avro schema for Python ``typing.Final``""" super().__init__(py_type, namespace, options) py_type = _type_from_annotated(py_type) @@ -684,6 +717,7 @@ def __init__( py_type: Type[collections.abc.MutableSequence], namespace: Optional[str] = None, options: Option = Option(0), + **kwargs, ): """ An Avro array schema for a given Python sequence @@ -728,6 +762,7 @@ def __init__( py_type: type[collections.abc.MutableSet], namespace: str | None = None, options: Option = Option(0), + **kwargs, ): """ An Avro array schema for a given Python sequence @@ -757,6 +792,7 @@ def __init__( py_type: Type[collections.abc.MutableMapping], namespace: Optional[str] = None, options: Option = Option(0), + processing: set[type] | None = None, ): """ An Avro map schema for a given Python mapping @@ -764,8 +800,9 @@ def __init__( :param py_type: The Python class to generate a schema for. :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. + :param processing: Internal parameter to track types currently being processed (for circular references). """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) args = get_args(py_type) if args[0] != str and not issubclass(args[0], StrEnum): @@ -797,7 +834,13 @@ def handles_type(cls, py_type: Type) -> bool: return origin == Union or origin == union_type return origin == Union - def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type[Union[Any]], + namespace: Optional[str] = None, + options: Option = Option(0), + **kwargs, + ): """ An Avro union schema for a given Python union type @@ -859,7 +902,13 @@ def make_default(self, py_default: Any) -> JSONType: class NamedSchema(Schema): """A named Avro schema base class""" - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ A named Avro schema base class @@ -867,7 +916,7 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) self.name = py_type.__name__ @@ -915,7 +964,13 @@ def handles_type(cls, py_type: Type) -> bool: """Whether this schema class can represent a given Python class""" return _is_class(py_type, enum.Enum) - def __init__(self, py_type: Type[enum.Enum], namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type[enum.Enum], + namespace: Optional[str] = None, + options: Option = Option(0), + **kwargs, + ): """ An Avro enum schema for a Python enum with string values @@ -981,15 +1036,25 @@ def data_before_deduplication(self, names: NamesType) -> JSONObj: class RecordSchema(NamedSchema): """An Avro record schema base class""" - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema base class :param py_type: The Python class to generate a schema for. :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. + :param processing: Internal parameter to track types currently being processed (for circular references). """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) + self.processing = processing or set() + # Add this type to the processing set to detect circular references + self.processing.add(py_type) self.record_fields: collections.abc.Sequence[RecordField] = [] def data_before_deduplication(self, names: NamesType) -> JSONObj: @@ -1008,6 +1073,7 @@ def data_before_deduplication(self, names: NamesType) -> JSONObj: doc = _doc_for_class(self.py_type) if doc: record_schema["doc"] = doc + self.processing.discard(self.py_type) return record_schema @@ -1023,6 +1089,7 @@ def __init__( default: Any = dataclasses.MISSING, docs: str = "", options: Option = Option(0), + processing: set[type] | None = None, ): """ An Avro record field @@ -1035,6 +1102,8 @@ def __init__( :param docs: Field documentation or description :param options: Schema generation options """ + if processing is None: + processing = set() if aliases is None: aliases = [] self.py_type = py_type @@ -1044,7 +1113,14 @@ def __init__( self.default = default self.docs = docs self.options = options - self.schema = _schema_obj(self.py_type, namespace=self._namespace, options=options) + + _type = self.py_type + # Check for circular dependency + if self.py_type in processing and hasattr(self.py_type, "__name__"): + # This is a circular reference - use a ForwardRef to break the cycle + _type = ForwardRef(py_type.__name__) # type: ignore + + self.schema = _schema_obj(_type, namespace=self._namespace, options=options, processing=processing) if self.default != dataclasses.MISSING: if isinstance(self.schema, UnionSchema): @@ -1083,7 +1159,13 @@ def handles_type(cls, py_type: Type) -> bool: py_type = _type_from_annotated(py_type) return dataclasses.is_dataclass(py_type) - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema for a given Python dataclass @@ -1091,7 +1173,7 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) self.py_fields = dataclasses.fields(py_type) self.record_fields = [self._record_field(field) for field in self.py_fields] @@ -1109,6 +1191,7 @@ def _record_field(self, py_field: dataclasses.Field) -> RecordField: default=default, aliases=aliases, options=self.options, + processing=self.processing, ) return field_obj @@ -1131,7 +1214,13 @@ def handles_type(cls, py_type: Type) -> bool: py_type = _type_from_annotated(py_type) return hasattr(py_type, "__pydantic_private__") - def __init__(self, py_type: Type[pydantic.BaseModel], namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type[pydantic.BaseModel], + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema for a given Pydantic model class @@ -1139,7 +1228,7 @@ def __init__(self, py_type: Type[pydantic.BaseModel], namespace: Optional[str] = :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) if Option.USE_CLASS_ALIAS in self.options: self.name = py_type.model_config.get("title") or self.name self.py_fields = py_type.model_fields @@ -1159,6 +1248,7 @@ def _record_field(self, name: str, py_field: pydantic.fields.FieldInfo) -> Recor aliases=aliases, docs=py_field.description or "", options=self.options, + processing=self.processing, ) return field_obj @@ -1205,10 +1295,16 @@ def handles_type(cls, py_type: Type) -> bool: # If we are subclassing a string, used the "named string" approach and (inspect.isclass(py_type) and not issubclass(py_type, str)) # and any other class with typed annotations - and bool(get_type_hints(py_type)) + and has_annotations(py_type) ) - def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: Optional[str] = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema for a plain Python class with type hints @@ -1216,7 +1312,7 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) # Try to get resolved type hints, but fall back to raw annotations if there are unresolved forward refs @@ -1251,6 +1347,7 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField: default=default, aliases=aliases, options=self.options, + processing=self.processing, ) return field_obj @@ -1271,7 +1368,13 @@ def handles_type(cls, py_type: Type) -> bool: """Whether this schema can represent a TypedDict""" return is_typeddict(py_type) - def __init__(self, py_type: Type, namespace: str | None = None, options: Option = Option(0)): + def __init__( + self, + py_type: Type, + namespace: str | None = None, + options: Option = Option(0), + processing: set[type] | None = None, + ): """ An Avro record schema for a given Python TypedDict @@ -1279,7 +1382,7 @@ def __init__(self, py_type: Type, namespace: str | None = None, options: Option :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options. """ - super().__init__(py_type, namespace=namespace, options=options) + super().__init__(py_type, namespace=namespace, options=options, processing=processing) py_type = _type_from_annotated(py_type) self.is_total = py_type.__dict__.get("__total__", True) self.py_fields: dict[str, Type] = get_type_hints(py_type, include_extras=True) @@ -1306,6 +1409,7 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField: namespace=self.namespace_override, aliases=aliases, options=self.options, + processing=self.processing, ) return field_obj @@ -1355,10 +1459,17 @@ def _is_list_any(py_type: Type) -> bool: def is_logically_json(py_type: Type) -> bool: """Returns whether a given type is logically a JSON and can be serialized as such""" - return _is_list_any(py_type) or _is_list_dict_str_any(py_type) or _is_dict_str_any(py_type) + try: + return _is_list_any(py_type) or _is_list_dict_str_any(py_type) or _is_dict_str_any(py_type) + except RecursionError: + return False -def _is_class(py_type: Any, of_types: Union[Type, Tuple[Type, ...]], include_subclasses: bool = True) -> bool: +def _is_class( + py_type: Any, + of_types: Union[Type, Tuple[Type, ...]], + include_subclasses: bool = True, +) -> bool: """Return whether the given type is a (sub) class of a type or types""" py_type = _type_from_annotated(py_type) if include_subclasses: @@ -1381,3 +1492,13 @@ def _type_from_annotated(py_type: Type) -> Type: return args[0] else: return py_type + + +def has_annotations(py_type: Type) -> bool: + """Checks if a type has annotations""" + py_type = _type_from_annotated(py_type) + try: + return bool(get_type_hints(py_type)) + except Exception: + pass + return hasattr(py_type, "__annotations__") diff --git a/tests/test_plain_class.py b/tests/test_plain_class.py index a90da2c..613b4dd 100644 --- a/tests/test_plain_class.py +++ b/tests/test_plain_class.py @@ -10,7 +10,7 @@ # specific language governing permissions and limitations under the License. import re -from typing import Annotated, Final +from typing import Annotated, Final, ForwardRef import pytest @@ -179,3 +179,30 @@ def __init__(self): } assert_schema(PyType, expected) + + +class PyType: + backend: ForwardRef("Backend") + value: str + + +class Backend: + py_type: PyType + + +def test_circular_references(): + expected = { + "fields": [ + { + "name": "py_type", + "type": { + "fields": [{"name": "backend", "type": "Backend"}, {"name": "value", "type": "string"}], + "name": "PyType", + "type": "record", + }, + } + ], + "name": "Backend", + "type": "record", + } + assert_schema(Backend, expected)