Skip to content

Commit 3af85cf

Browse files
authored
【Hackathon 9th Sprint No.86】feat: implement AgentUnittestGenerator for both paddle/torch run_model.py (#422)
* feat: add AgentUnittestGenerator class for run_model.py in both paddle and torch. * feat: refactor code to improve style, use a seperate method to inject AgentUnittestGenerator, adopt jinja2 as render engine * style: fix code style issues to pass pre-commit check * chore: revert run_model.py
1 parent 417240e commit 3af85cf

File tree

6 files changed

+570
-2
lines changed

6 files changed

+570
-2
lines changed

graph_net/paddle/run_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import os
2-
import sys
32
import json
43
import base64
54
import argparse
6-
from typing import Type
75

86
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
97

graph_net/paddle/sample_passes/__init__.py

Whitespace-only changes.
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
from pathlib import Path
2+
from typing import Any, Dict
3+
4+
import paddle
5+
from jinja2 import Template
6+
7+
from graph_net.sample_pass.sample_pass import SamplePass
8+
9+
10+
PADDLE_UNITTEST_TEMPLATE = r"""
11+
import importlib.util
12+
import os
13+
import unittest
14+
from typing import Any, Dict
15+
16+
import numpy as np
17+
import paddle
18+
19+
20+
def _get_classes(file_path: str):
21+
spec = importlib.util.spec_from_file_location("agent_meta", file_path)
22+
module = importlib.util.module_from_spec(spec)
23+
spec.loader.exec_module(module)
24+
return [
25+
(name, cls)
26+
for name, cls in vars(module).items()
27+
if isinstance(cls, type)
28+
]
29+
30+
31+
def _convert_meta_classes_to_wrappers(file_path: str):
32+
current_device = paddle.device.get_device()
33+
for _, cls in _get_classes(file_path):
34+
attrs = {
35+
k: v for k, v in vars(cls).items() if not k.startswith("__") and not callable(v)
36+
}
37+
dtype_attr = attrs.get("dtype", "float32")
38+
dtype = getattr(paddle, str(dtype_attr).split(".")[-1])
39+
shape = [1 if dim is None else dim for dim in attrs.get("shape", [])]
40+
info = {
41+
"shape": shape,
42+
"dtype": dtype,
43+
"device": attrs.get("device", current_device),
44+
"mean": attrs.get("mean"),
45+
"std": attrs.get("std"),
46+
"min_val": attrs.get("min_val", 0),
47+
"max_val": attrs.get("max_val", 2),
48+
}
49+
data = attrs.get("data")
50+
if data is not None and not isinstance(data, paddle.Tensor):
51+
data = paddle.to_tensor(data, dtype=dtype).reshape(info["shape"])
52+
yield {"info": info, "data": data, "name": attrs.get("name")}
53+
54+
55+
def _convert_meta_to_tensors(model_path: str):
56+
weight_meta = os.path.join(model_path, "weight_meta.py")
57+
input_meta = os.path.join(model_path, "input_meta.py")
58+
weight_info = {
59+
item["name"]: item for item in _convert_meta_classes_to_wrappers(weight_meta)
60+
}
61+
input_info = {
62+
item["name"]: item for item in _convert_meta_classes_to_wrappers(input_meta)
63+
}
64+
return {"weight_info": weight_info, "input_info": input_info}
65+
66+
67+
def _init_integer_tensor(dtype, shape, min_val, max_val, use_numpy: bool):
68+
if use_numpy:
69+
array = np.random.randint(low=min_val, high=max_val + 1, size=shape, dtype=dtype)
70+
return paddle.to_tensor(array)
71+
return paddle.randint(low=min_val, high=max_val + 1, shape=shape, dtype=dtype)
72+
73+
74+
def _init_float_tensor(shape, mean, std, min_val, max_val, use_numpy: bool):
75+
if use_numpy:
76+
if mean is not None and std is not None:
77+
array = np.random.normal(0, 1, shape) * std * 0.2 + mean
78+
array = np.clip(array, min_val, max_val)
79+
else:
80+
array = np.random.uniform(low=min_val, high=max_val, size=shape)
81+
return paddle.to_tensor(array)
82+
if mean is not None and std is not None:
83+
tensor = paddle.randn(shape, dtype="float32") * std * 0.2 + mean
84+
tensor = paddle.clip(tensor, min=min_val, max=max_val)
85+
return tensor
86+
return paddle.uniform(shape=shape, dtype="float32", min=min_val, max=max_val)
87+
88+
89+
def _replay_tensor(info: Dict[str, Any], use_numpy: bool):
90+
device = info["info"].get("device", paddle.device.get_device())
91+
dtype = info["info"].get("dtype", paddle.float32)
92+
shape = [1 if dim is None else dim for dim in info["info"].get("shape", [])]
93+
mean = info["info"].get("mean")
94+
std = info["info"].get("std")
95+
min_val = info["info"].get("min_val", 0)
96+
max_val = info["info"].get("max_val", 2)
97+
if info.get("data") is not None:
98+
return paddle.reshape(info["data"], shape).to(dtype).to(device)
99+
if dtype in [paddle.int32, paddle.int64, paddle.bool]:
100+
init_dtype = "int32" if dtype == paddle.bool else "int64"
101+
if dtype == paddle.bool:
102+
min_val, max_val = 0, 1
103+
return _init_integer_tensor(init_dtype, shape, min_val, max_val, use_numpy).to(dtype).to(device)
104+
tensor = _init_float_tensor(shape, mean, std, min_val, max_val, use_numpy)
105+
return tensor.to(dtype).to(device)
106+
107+
108+
def _get_dummy_tensor(info: Dict[str, Any]):
109+
device = info["info"].get("device", paddle.device.get_device())
110+
dtype = info["info"].get("dtype", paddle.float32)
111+
shape = [1 if dim is None else dim for dim in info["info"].get("shape", [])]
112+
if info.get("data") is not None:
113+
return paddle.reshape(info["data"], shape).to(dtype).to(device)
114+
return paddle.empty(shape=shape, dtype=dtype, device=device)
115+
116+
117+
def _load_graph_module(model_path: str):
118+
source_path = os.path.join(model_path, "model.py")
119+
spec = importlib.util.spec_from_file_location("agent_graph_module", source_path)
120+
module = importlib.util.module_from_spec(spec)
121+
spec.loader.exec_module(module)
122+
return module.GraphModule
123+
124+
125+
class AgentGraphTest(unittest.TestCase):
126+
def setUp(self):
127+
self.model_path = os.path.dirname(__file__)
128+
self.target_device = "{{ target_device }}"
129+
self.use_numpy = {{ use_numpy_flag }}
130+
paddle.set_device(self.target_device)
131+
self.GraphModule = _load_graph_module(self.model_path)
132+
self.meta = _convert_meta_to_tensors(self.model_path)
133+
134+
def _with_device(self, info: Dict[str, Any]):
135+
cloned = {"info": dict(info["info"]), "data": info.get("data")}
136+
cloned["info"]["device"] = self.target_device
137+
return cloned
138+
139+
def test_forward_runs(self):
140+
model = self.GraphModule()
141+
inputs = {k: _replay_tensor(self._with_device(v), self.use_numpy) for k, v in self.meta["input_info"].items()}
142+
params = {k: _replay_tensor(self._with_device(v), self.use_numpy) for k, v in self.meta["weight_info"].items()}
143+
model.__graph_net_file_path__ = self.model_path
144+
output = model(**params, **inputs)
145+
self.assertIsNotNone(output)
146+
147+
148+
if __name__ == "__main__":
149+
unittest.main()
150+
"""
151+
152+
153+
class AgentUnittestGenerator:
154+
"""Generate standalone unittest scripts for Paddle samples."""
155+
156+
def __init__(self, config: Dict[str, Any]):
157+
defaults = {
158+
"model_path": None,
159+
"output_path": None,
160+
"output_dir": None,
161+
"force_device": "auto", # auto / cpu / gpu
162+
"use_numpy": True,
163+
}
164+
merged = {**defaults, **(config or {})}
165+
if not merged["model_path"]:
166+
raise ValueError("AgentUnittestGenerator requires 'model_path' in config")
167+
168+
self.model_path = Path(merged["model_path"]).resolve()
169+
self.output_path = (
170+
Path(merged["output_path"]) if merged.get("output_path") else None
171+
)
172+
self.output_dir = (
173+
Path(merged["output_dir"]) if merged.get("output_dir") else None
174+
)
175+
self.force_device = merged["force_device"]
176+
self.use_numpy = merged["use_numpy"]
177+
178+
def __call__(self, model):
179+
self.generate()
180+
return model
181+
182+
def generate(self):
183+
output_path = self._resolve_output_path()
184+
target_device = self._choose_device()
185+
rendered = Template(PADDLE_UNITTEST_TEMPLATE).render(
186+
target_device=target_device, use_numpy_flag=self.use_numpy
187+
)
188+
output_path.parent.mkdir(parents=True, exist_ok=True)
189+
output_path.write_text(rendered, encoding="utf-8")
190+
print(f"[Agent] unittest generated: {output_path} (device={target_device})")
191+
192+
def _resolve_output_path(self) -> Path:
193+
if self.output_path:
194+
return self.output_path
195+
target_dir = self.output_dir or self.model_path
196+
return Path(target_dir) / f"{self.model_path.name}_test.py"
197+
198+
def _choose_device(self) -> str:
199+
if self.force_device == "cpu":
200+
return "cpu"
201+
if self.force_device == "gpu":
202+
return "gpu"
203+
return "gpu" if paddle.device.is_compiled_with_cuda() else "cpu"
204+
205+
206+
class AgentUnittestGeneratorPass(SamplePass):
207+
"""SamplePass wrapper to generate Paddle unittests via model_path_handler."""
208+
209+
def __init__(self, config=None):
210+
super().__init__(config)
211+
212+
def declare_config(
213+
self,
214+
model_path_prefix: str,
215+
output_dir: str = None,
216+
force_device: str = "auto",
217+
use_numpy: bool = True,
218+
):
219+
pass
220+
221+
def __call__(self, rel_model_path: str):
222+
model_path_prefix = Path(self.config["model_path_prefix"])
223+
target_root = Path(self.config.get("output_dir") or model_path_prefix)
224+
model_path = model_path_prefix / rel_model_path
225+
generator = AgentUnittestGenerator(
226+
{
227+
"model_path": str(model_path),
228+
"output_dir": str(target_root / rel_model_path),
229+
"force_device": self.config["force_device"],
230+
"use_numpy": self.config["use_numpy"],
231+
}
232+
)
233+
generator.generate()
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env bash
2+
# set -euo pipefail
3+
4+
# Smoke tests for AgentUnittestGenerator using model_path_handler + sample pass.
5+
6+
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
7+
GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
8+
HANDLER_PATH="$GRAPH_NET_ROOT/graph_net/torch/sample_passes/agent_unittest_generator.py"
9+
MODEL_PATH_PREFIX="$ROOT_DIR"
10+
OUTPUT_DIR="$ROOT_DIR"
11+
12+
HANDLER_CONFIG=$(base64 -w 0 <<EOF
13+
{
14+
"handler_path": "$HANDLER_PATH",
15+
"handler_class_name": "AgentUnittestGeneratorPass",
16+
"handler_config": {
17+
"model_path_prefix": "$MODEL_PATH_PREFIX",
18+
"output_dir": "$OUTPUT_DIR",
19+
"force_device": "auto",
20+
"use_dummy_inputs": false
21+
}
22+
}
23+
EOF
24+
)
25+
26+
run_case() {
27+
local rel_sample_path="$1"
28+
local name="$2"
29+
echo "[AgentTest] running $name sample at $rel_sample_path"
30+
python -m graph_net.model_path_handler \
31+
--model-path "$rel_sample_path" \
32+
--handler-config "$HANDLER_CONFIG"
33+
}
34+
35+
run_case "samples/torchvision/resnet18" "CV (torchvision/resnet18)"
36+
run_case "samples/transformers-auto-model/albert-base-v2" "NLP (transformers-auto-model/albert-base-v2)"
37+
38+
echo "[AgentTest] done. Generated *_test.py files should now exist beside the samples."

graph_net/torch/sample_passes/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)