Skip to content

Commit 8df8792

Browse files
authored
Merge pull request #13 from ProKil/feature-forget-datatype-decorator-warning
Node Intialization with Friendly Warning
2 parents 08710c2 + 8e2f997 commit 8df8792

6 files changed

Lines changed: 178 additions & 3 deletions

File tree

docs/applications/cogym.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
[Co-Gym](https://github.com/SALT-NLP/collaborative-gym) is a framework for enabling and evaluating **Human-Agent Collaboration** created by Yijia Shao, Vinay Samuel, Yucheng Jiang, John Yang, and Diyi Yang.
44

5-
Co-Gym uses aact as the communciation infrastructure between humans, agents, and environments to enable asynchronous action taking and observation receiving.
5+
Co-Gym uses aact as the communciation infrastructure between humans, agents, and environments to enable asynchronous action taking and observation receiving.
66

7-
Please visit their GitHub repo for details.
7+
Please visit their GitHub repo for details.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
redis_url = "redis://localhost:6379/0" # required
2+
extra_modules = ["exception_node"]
3+
4+
[[nodes]]
5+
node_name = "exception_node"
6+
node_class = "exception_node"
7+
8+
[nodes.node_args.print_channel_types]
9+
"tick/secs/1" = "tick"
10+
11+
[[nodes]]
12+
node_name = "tick"
13+
node_class = "tick"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import AsyncIterator
2+
from aact.nodes import Node, NodeFactory
3+
from aact.messages import Text
4+
from aact import Message
5+
6+
7+
@NodeFactory.register("exception_node")
8+
class ExceptionNode(Node[Text, Text]):
9+
def event_handler(
10+
self, input_channel: str, input_message: Message[Text]
11+
) -> AsyncIterator[tuple[str, Message[Text]]]:
12+
raise Exception("This is an exception from the node.")

src/aact/nodes/base.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from ..utils import Self
6-
from typing import Any, AsyncIterator, Generic, Type, TypeVar
6+
from typing import Any, AsyncIterator, Generic, Literal, Type, TypeVar
77
from pydantic import BaseModel, ConfigDict, ValidationError
88

99
from abc import abstractmethod
@@ -176,6 +176,30 @@ def __init__(
176176
redis_url: str = "redis://localhost:6379/0",
177177
):
178178
try:
179+
for _, input_channel_type in input_channel_types:
180+
if not issubclass(input_channel_type, DataModel):
181+
raise TypeError(
182+
f"Input channel type {input_channel_type} is not a subclass of DataModel"
183+
)
184+
elif (
185+
input_channel_type.model_fields["data_type"].annotation # type: ignore[comparison-overlap]
186+
== Literal[""]
187+
):
188+
raise TypeError(
189+
f"The input channel type {input_channel_type} needs to be registered with `@DataModelFactory.register`"
190+
)
191+
for _, output_channel_type in output_channel_types:
192+
if not issubclass(output_channel_type, DataModel):
193+
raise TypeError(
194+
f"Output channel type {output_channel_type} is not a subclass of DataModel"
195+
)
196+
elif (
197+
output_channel_type.model_fields["data_type"].annotation # type: ignore[comparison-overlap]
198+
== Literal[""]
199+
):
200+
raise TypeError(
201+
f"The output channel type {output_channel_type} needs to be registered with `@DataModelFactory.register`"
202+
)
179203
BaseModel.__init__(
180204
self,
181205
input_channel_types=dict(input_channel_types),

tests/messages/test_message.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from aact.messages import DataModel, DataModelFactory
2+
3+
4+
def test_create_data_model() -> None:
5+
# Create a new data model
6+
@DataModelFactory.register("MyDataModel")
7+
class MyDataModel(DataModel):
8+
name: str
9+
age: int
10+
11+
# Create an instance of the data model
12+
instance = MyDataModel(name="John", age=30)
13+
14+
# Validate the instance
15+
assert instance.name == "John"
16+
assert instance.age == 30
17+
18+
# Check if the instance is of type DataModel
19+
assert isinstance(instance, DataModel)

tests/nodes/test_node_creation.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pytest
2+
from aact.messages.base import DataModel
3+
from aact.messages.registry import DataModelFactory
4+
from aact.nodes import Node
5+
from aact.nodes import NodeFactory
6+
7+
8+
def test_node_with_non_datamodel_types() -> None:
9+
"""Test that creating a node with non-DataModel types raises a TypeError."""
10+
11+
class AWrong:
12+
pass
13+
14+
class BWrong:
15+
pass
16+
17+
# Create a node with non-DataModel types
18+
@NodeFactory.register("MyNode")
19+
class MyNode(Node[AWrong, BWrong]): # type: ignore[type-var]
20+
def __init__(self) -> None:
21+
super().__init__(
22+
node_name="MyNode",
23+
input_channel_types=[("input", AWrong)],
24+
output_channel_types=[("output", BWrong)],
25+
)
26+
27+
async def event_handler( # type: ignore[override]
28+
self, input_channel: str, input_message: AWrong
29+
) -> None:
30+
# Handle the event
31+
pass
32+
33+
# Verify the correct error is raised
34+
with pytest.raises(TypeError) as excinfo:
35+
_ = MyNode()
36+
37+
assert (
38+
"Input channel type <class 'test_node_creation.test_node_with_non_datamodel_types.<locals>.AWrong'> "
39+
"is not a subclass of DataModel" in str(excinfo.value)
40+
)
41+
42+
43+
def test_node_with_unregistered_datamodel_types() -> None:
44+
"""Test that creating a node with unregistered DataModel types raises a TypeError."""
45+
46+
class ACorrect(DataModel):
47+
pass
48+
49+
class BCorrect(DataModel):
50+
pass
51+
52+
# Create a node with unregistered DataModel types
53+
@NodeFactory.register("MyNodeCorrect")
54+
class MyNodeCorrect(Node[ACorrect, BCorrect]):
55+
def __init__(self) -> None:
56+
super().__init__(
57+
node_name="MyNode",
58+
input_channel_types=[("input", ACorrect)],
59+
output_channel_types=[("output", BCorrect)],
60+
)
61+
62+
async def event_handler( # type: ignore[override]
63+
self, input_channel: str, input_message: ACorrect
64+
) -> None:
65+
# Handle the event
66+
pass
67+
68+
# Verify the correct error is raised
69+
with pytest.raises(TypeError) as excinfo:
70+
_ = MyNodeCorrect()
71+
72+
assert (
73+
"The input channel type <class 'test_node_creation.test_node_with_unregistered_datamodel_types.<locals>.ACorrect'> "
74+
"needs to be registered with `@DataModelFactory.register`" in str(excinfo.value)
75+
)
76+
77+
78+
def test_node_with_registered_datamodel_types() -> None:
79+
"""Test that creating a node with properly registered DataModel types succeeds."""
80+
81+
@DataModelFactory.register("ACorrectNew")
82+
class ACorrectNew(DataModel):
83+
pass
84+
85+
@DataModelFactory.register("BCorrectNew")
86+
class BCorrectNew(DataModel):
87+
pass
88+
89+
# Create a node with registered DataModel types
90+
@NodeFactory.register("MyNodeCorrect")
91+
class MyNodeCorrect(Node[ACorrectNew, BCorrectNew]):
92+
def __init__(self) -> None:
93+
super().__init__(
94+
node_name="MyNode",
95+
input_channel_types=[("input", ACorrectNew)],
96+
output_channel_types=[("output", BCorrectNew)],
97+
)
98+
99+
async def event_handler( # type: ignore[override]
100+
self, input_channel: str, input_message: ACorrectNew
101+
) -> None:
102+
# Handle the event
103+
pass
104+
105+
# Verify that node creation succeeds
106+
node = MyNodeCorrect()
107+
assert isinstance(node, MyNodeCorrect)

0 commit comments

Comments
 (0)