Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions penzai/core/shapecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class DimVar(Mapping):
and the inner name is the named shape of a single axis in that collection.
"""

name: str | tuple[str, str | int]
name: str | tuple[str, named_axes.AxisName]

def __len__(self) -> int:
return 1
Expand Down Expand Up @@ -148,7 +148,7 @@ class KnownDim:
from_keypath: Optional keypath that indicates where this size was bound.
"""

name: str | tuple[str, str | int]
name: str | tuple[str, named_axes.AxisName]
size: int
from_keypath: str | None = None

Expand Down Expand Up @@ -618,7 +618,6 @@ def _named_inline_multidimvars(
binding = solutions[key.name]
assert isinstance(binding.value, dict)
for subkey, subval in binding.value.items():
assert isinstance(subkey, str)
if subkey in new_pattern:
return (
_UnsatisfiedConstraint(
Expand Down Expand Up @@ -972,9 +971,13 @@ def add_constraints(keypath, pattern: Any, value: Any):
)
elif remaining_value_keys:
# No unpack pattern, but we ran out of keys.
# Use list() instead of sorted() because axis names may be
# non-string hashables (e.g. TmpPosAxisMarker) that are not
# mutually orderable. OrderedSet preserves insertion order, which
# gives a deterministic message without requiring comparison.
subfailures.append(
" Unexpected names in value's named shape:"
f" {sorted(remaining_value_keys)}"
f" {list(remaining_value_keys)}"
)
if subfailures:
failures.append((
Expand Down Expand Up @@ -1038,7 +1041,6 @@ def add_constraints(keypath, pattern: Any, value: Any):
):
found = solutions[name[0]].value[name[1]]
else:
assert isinstance(name[1], str)
if isinstance(solutions[name[0]].value, dict):
found = solutions[name[0]].value.get(name[1])
if found != binding.value:
Expand Down Expand Up @@ -1090,7 +1092,7 @@ class DimensionVariableSubstitution:
axis name to sizes.
"""

size_variables: dict[str | tuple[str, str | int], int | DimVar]
size_variables: dict[str | tuple[str, named_axes.AxisName], int | DimVar]
sequence_variables: dict[str, tuple[int | DimVar | MultiDimVar, ...]]
mapping_variables: dict[
str, dict[Any | MultiDimVar, int | DimVar | RemainingAxisPlaceholder]
Expand Down
98 changes: 98 additions & 0 deletions tests/core/shapecheck_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,104 @@ def test_get_and_substitute_dimension_variables(self):
}
self.assertEqual(res, expected)

def test_nonstring_axis_names_in_check_structure(self):
"""Regression test for issue #132.

pz.chk.check_structure should work when named array axes use non-string
AxisNames (e.g. TmpPosAxisMarker), since AxisName = Hashable allows any
hashable as an axis name.

The crash occurs when a solved **var() MultiDimVar binding contains
non-string keys and is passed back via known_vars for inlining.
"""
nonstr_axis = pz.nx.TmpPosAxisMarker()

# Step 1: solve the batch var against an array with a non-string axis name.
input_arr = pz.nx.zeros({"in_axis": 2, nonstr_axis: 3})
input_pattern = pz.chk.ArraySpec(
named_shape={"in_axis": 2, **pz.chk.var("batch")}
)
dimvars = pz.chk.check_structure(input_arr, input_pattern)
self.assertEqual(dict(dimvars), {"batch": {nonstr_axis: 3}})

# Step 2: use those dimvars as known_vars for a second check (this
# previously crashed at shapecheck.py:621 with AssertionError).
output_arr = pz.nx.zeros({nonstr_axis: 3, "out_axis": 5})
output_pattern = pz.chk.ArraySpec(
named_shape={**pz.chk.var("batch"), "out_axis": 5}
)
result = pz.chk.check_structure(
output_arr, output_pattern, known_vars=dimvars
)
# The result includes the known_vars variables re-confirmed as consistent.
self.assertEqual(dict(result), {"batch": {nonstr_axis: 3}})

def test_linear_layer_with_tmpposaxismarker_batch(self):
"""Regression test for issue #132: Linear layer with TmpPosAxisMarker batch axis.

This is the exact repro from the issue report: calling a Linear layer on a
NamedArray that has a TmpPosAxisMarker as one of its named axes.
"""
my_layer = pz.nn.Linear.from_config(
"my_layer",
jax.random.key(0),
input_axes={"in_axis": 2},
output_axes={"out_axis": 3},
)
nonstr_batch_axis = pz.nx.TmpPosAxisMarker()
# This previously crashed with AssertionError at shapecheck.py:621.
result = my_layer(pz.nx.zeros({"in_axis": 2, nonstr_batch_axis: 3}))
self.assertIn(nonstr_batch_axis, result.named_axes)
self.assertIn("out_axis", result.named_axes)

def test_nonstring_axis_name_mismatch_error_message(self):
"""Non-string axis names in error paths should produce StructureMismatchError.

When a named shape mismatch occurs and the value contains a single
non-string AxisName not present in the pattern (the 'unexpected names'
error path), check_structure must raise StructureMismatchError rather than
crashing with TypeError inside sorted(). This exercises the error-message
path in the named constraint solver (shapecheck.py line ~976).
"""
nonstr_axis = pz.nx.TmpPosAxisMarker()

# Array has an extra non-string axis that the pattern does not mention.
arr = pz.nx.zeros({nonstr_axis: 3, "out_axis": 5})
pattern = pz.chk.ArraySpec(named_shape={"out_axis": 5})
with self.assertRaises(pz.chk.StructureMismatchError):
pz.chk.check_structure(arr, pattern)

# Non-string axis present with the wrong size.
arr2 = pz.nx.zeros({nonstr_axis: 3, "out_axis": 5})
pattern2 = pz.chk.ArraySpec(named_shape={nonstr_axis: 99, "out_axis": 5})
with self.assertRaises(pz.chk.StructureMismatchError):
pz.chk.check_structure(arr2, pattern2)

def test_multiple_nonstring_axis_names_unexpected_error_path(self):
"""Two non-orderable non-string axis names in unexpected-names error path.

When the value has extra non-string axis names not present in the pattern
and no unpack var(**) to absorb them, the error-formatting path at
shapecheck.py line ~976 previously called sorted(remaining_value_keys),
which crashes with TypeError for non-orderable types (e.g. two distinct
TmpPosAxisMarker instances). The fix replaces sorted() with list() since
OrderedSet already has deterministic insertion order.

This is in-scope for issue #132: the crash was latent on main but becomes
reachable in realistic use once non-string axis names can flow through
check_structure (which is what #132 enables).
"""
axis_a = pz.nx.TmpPosAxisMarker()
axis_b = pz.nx.TmpPosAxisMarker()

# Two non-orderable non-string axis names that are unexpected by the pattern.
# Previously crashed with TypeError inside sorted(); now must raise
# StructureMismatchError with a readable message.
arr = pz.nx.zeros({axis_a: 3, axis_b: 5, "out": 7})
pattern = pz.chk.ArraySpec(named_shape={"out": 7})
with self.assertRaises(pz.chk.StructureMismatchError):
pz.chk.check_structure(arr, pattern)


if __name__ == "__main__":
absltest.main()