From 7a91e19a4c998240eddeefe628f089589a5ccc5d Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 16 Sep 2025 09:29:14 +0800 Subject: [PATCH 01/19] Implement a function to collect the model's execution stats. --- graph_net/torch/collect_stats.py | 227 +++++++++++++++++++++++++++++++ graph_net/torch/test_compiler.py | 3 +- 2 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 graph_net/torch/collect_stats.py diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py new file mode 100644 index 000000000..a4cfe43d4 --- /dev/null +++ b/graph_net/torch/collect_stats.py @@ -0,0 +1,227 @@ +import argparse +import os +import sys +import math +import importlib +import inspect +from typing import Type +from dataclasses import dataclass, field +from collections import defaultdict + +import torch +from torch.fx.passes.shape_prop import ShapeProp +from graph_net.torch import utils + + +def is_single_model_dir(model_dir): + return os.path.isfile(f"{model_dir}/graph_net.json") + + +def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: + spec = importlib.util.spec_from_file_location("unnamed", file_path) + unnamed = importlib.util.module_from_spec(spec) + spec.loader.exec_module(unnamed) + model_class = getattr(unnamed, class_name, None) + return model_class + + +def get_argument_types(model_class, func_name): + arg_types = {} + for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction): + if name == func_name: + for arg_name, arg in inspect.signature(func).parameters.items(): + if arg_name != "self": + arg_types[arg_name] = ( + None if arg.annotation is inspect._empty else arg.annotation + ) + return arg_types + + +def get_input_dict(model_path, device): + inputs_params = utils.load_converted_from_text(f"{model_path}") + params = inputs_params["weight_info"] + for tensor_meta in params.values(): + if hasattr(tensor_meta, "device"): + tensor_meta.device = device + return { + k: utils.replay_tensor(v).to(torch.device(device)) for k, v in params.items() + } + + +@dataclass +class OpStat: + op_name: str + dtype: set[str] = field(default_factory=set) + count: int = 0 + + +def collect_op_stats(model, input_dict): + # Use meta tensors as input to avoid actually running the model + meta_input_dict = {} + for name, x in input_dict.items(): + meta_input_dict[name] = ( + torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x + ) + + # FX symbolic trace + traced = torch.fx.symbolic_trace(model) + # print(traced.graph) + + node_outputs = {} + op_stats = {} + for node in traced.graph.nodes: + op_name = None + dtype = None + if node.op == "placeholder": + node_outputs[node.name] = meta_input_dict[node.target] + op_name = node.op + dtype = node_outputs[node.name].dtype + elif node.op in ["call_function", "call_method", "call_module"]: + node_args = [] + for arg in node.args: + node_args.append( + node_outputs[arg.name] if hasattr(arg, "name") else arg + ) + node_kwargs = {} + for k, v in node.kwargs.items(): + node_kwargs[k] = node_outputs[v.name] if hasattr(v, "name") else v + + if node.op == "call_module": + # classname of module + submod = dict(traced.named_modules())[node.target] + op_name = submod.__class__.__name__ + try: + out = submod(*node_args, **node_kwargs) + node_outputs[node.name] = out + dtype = out.dtype if isinstance(out, torch.Tensor) else None + except Exception: + node_outputs[node.name] = None + elif node.op in ["call_function", "call_method"]: + op_name = ( + node.target.__name__ if node.op == "call_function" else node.target + ) + try: + out = node.target(*node_args, **node_kwargs) + node_outputs[node.name] = out + dtype = out.dtype if isinstance(out, torch.Tensor) else None + except Exception: + print(f"dtype inference failed: op_name={op_name}") + node_outputs[node.name] = None + elif node.op == "output": + op_name = node.op + node_args = [] + for arg in node.args: + node_args.append( + node_outputs[arg.name] if hasattr(arg, "name") else arg + ) + node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args + dtype = ( + node_args[0].dtype if isinstance(node_args[0], torch.Tensor) else None + ) + else: + assert False, f"node.op: {node.op}" + + if op_name is not None: + dtype_str = str(dtype).replace("torch.", "") if dtype is not None else None + if op_stats.get(op_name, None) is None: + op_stats[op_name] = OpStat(op_name, {dtype_str}, 1) + else: + op_stats[op_name].dtype.add(dtype_str) + op_stats[op_name].count = op_stats[op_name].count + 1 + return op_stats + + +def collect_model_stats(model_path, device, log_prompt): + print(f"Collect information for {model_path}") + model_class = load_class_from_file( + os.path.join(model_path, "model.py"), "GraphModule" + ) + model = model_class() + input_dict = get_input_dict(model_path, device) + + num_ops = 0 + num_inputs = 0 + num_outputs = 0 + dtypes = set() + op_stats = collect_op_stats(model, input_dict) + for op_name, stat in op_stats.items(): + if op_name == "placeholder": + num_inputs += stat.count + elif op_name == "output": + num_outputs += stat.count + else: + num_ops += stat.count + for v in stat.dtype: + if v is not None: + dtypes.add(v) + + arg_types = get_argument_types(model_class, "forward") + num_params = 0 + param_dtypes = set() + for name, arg_type in arg_types.items(): + if arg_type == torch.nn.parameter.Parameter: + count = math.prod(input_dict[name].shape) + # print(f"Parameter {name}: {count}") + num_params += count + param_dtypes.add(str(input_dict[name].dtype).replace("torch.", "")) + num_params_in_billion = num_params / 1e9 + + dtypes_str = "[" + ",".join(dtypes) + "]" + param_dtypes_str = "[" + ",".join(param_dtypes) + "]" + print( + f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str}", + file=sys.stderr, + flush=True, + ) + + +def main(args): + if args.model_path is not None: + assert os.path.isdir(args.model_path) + assert is_single_model_dir(args.model_path) + collect_model_stats(args.model_path, args.device, args.log_prompt) + else: + graph_net_samples_path = ( + (graph_net.torch.samples_util.get_default_samples_directory()) + if args.graph_net_samples_path is None + else args.graph_net_samples_path + ) + for root, dirs, files in os.walk(graph_net_samples_path): + if is_single_model_dir(root): + collect_model_stats(root, args.device, args.log_prompt) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Validate a computation graph sample. return 0 if success" + ) + parser.add_argument( + "--device", + type=str, + required=False, + default="cuda", + help="Device for testing the compiler (e.g., 'cpu' or 'cuda')", + ) + parser.add_argument( + "--model-path", + type=str, + required=False, + default=None, + help="Computation graph sample directory. e.g '../../samples/torch/resnet18'", + ) + parser.add_argument( + "--graph-net-samples-path", + type=str, + required=False, + default=None, + help="GraphNet samples directory. e.g '../../samples'", + ) + parser.add_argument( + "--log-prompt", + type=str, + required=False, + default="graph-net-collect-stats-log", + help="Log prompt for stats log filtering.", + ) + args = parser.parse_args() + main(args=args) diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py index 5922991c5..1348c4ddc 100644 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -1,4 +1,3 @@ -from . import utils import argparse import importlib.util import inspect @@ -14,6 +13,8 @@ import json import numpy as np import platform + +from graph_net.torch import utils from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend from graph_net.torch.backend.tvm_backend import TvmBackend from graph_net.torch.backend.xla_backend import XlaBackend From bcf9d5a1a4c9d3fd67eb755fc0695253304e6725 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 16 Sep 2025 11:13:12 +0800 Subject: [PATCH 02/19] Add support of get_attr and simplify some codes. --- graph_net/torch/collect_stats.py | 38 +++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index a4cfe43d4..d85a6edba 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -55,6 +55,14 @@ class OpStat: count: int = 0 +def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node): + attr_itr = node.target.split(".") + val = gm + for a in attr_itr: + val = getattr(val, a) + return val + + def collect_op_stats(model, input_dict): # Use meta tensors as input to avoid actually running the model meta_input_dict = {} @@ -77,14 +85,14 @@ def collect_op_stats(model, input_dict): op_name = node.op dtype = node_outputs[node.name].dtype elif node.op in ["call_function", "call_method", "call_module"]: - node_args = [] - for arg in node.args: - node_args.append( - node_outputs[arg.name] if hasattr(arg, "name") else arg - ) - node_kwargs = {} - for k, v in node.kwargs.items(): - node_kwargs[k] = node_outputs[v.name] if hasattr(v, "name") else v + node_args = torch.fx.map_arg( + node.args, + lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, + ) + node_kwargs = torch.fx.map_arg( + node.kwargs, + lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, + ) if node.op == "call_module": # classname of module @@ -107,13 +115,17 @@ def collect_op_stats(model, input_dict): except Exception: print(f"dtype inference failed: op_name={op_name}") node_outputs[node.name] = None + elif node.op == "get_attr": + val = resolve_get_attr(traced, node) + out = val.to(device="meta") if isinstance(val, torch.Tensor) else val + node_outputs[node.name] = out + dtype = out.dtype if isinstance(out, torch.Tensor) else None elif node.op == "output": op_name = node.op - node_args = [] - for arg in node.args: - node_args.append( - node_outputs[arg.name] if hasattr(arg, "name") else arg - ) + node_args = torch.fx.map_arg( + node.args, + lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, + ) node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args dtype = ( node_args[0].dtype if isinstance(node_args[0], torch.Tensor) else None From 1415926ff842253225db942236fa2ac9009e64d5 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 16 Sep 2025 14:24:15 +0800 Subject: [PATCH 03/19] Fix support of call_method. --- graph_net/torch/collect_stats.py | 81 ++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 35 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index d85a6edba..addb6c18a 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -60,10 +60,19 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node): val = gm for a in attr_itr: val = getattr(val, a) - return val + out = val.to(device="meta") if isinstance(val, torch.Tensor) else val + return out def collect_op_stats(model, input_dict): + # FX symbolic trace + try: + traced = torch.fx.symbolic_trace(model) + # print(traced.graph) + except Exception: + print("Failed to FX symbolic trace") + return None + # Use meta tensors as input to avoid actually running the model meta_input_dict = {} for name, x in input_dict.items(): @@ -71,10 +80,6 @@ def collect_op_stats(model, input_dict): torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x ) - # FX symbolic trace - traced = torch.fx.symbolic_trace(model) - # print(traced.graph) - node_outputs = {} op_stats = {} for node in traced.graph.nodes: @@ -84,7 +89,7 @@ def collect_op_stats(model, input_dict): node_outputs[node.name] = meta_input_dict[node.target] op_name = node.op dtype = node_outputs[node.name].dtype - elif node.op in ["call_function", "call_method", "call_module"]: + elif node.op in ["call_function", "call_module", "call_method"]: node_args = torch.fx.map_arg( node.args, lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, @@ -96,28 +101,32 @@ def collect_op_stats(model, input_dict): if node.op == "call_module": # classname of module - submod = dict(traced.named_modules())[node.target] + submod = traced.get_submodule(node.target) op_name = submod.__class__.__name__ - try: - out = submod(*node_args, **node_kwargs) - node_outputs[node.name] = out - dtype = out.dtype if isinstance(out, torch.Tensor) else None - except Exception: - node_outputs[node.name] = None - elif node.op in ["call_function", "call_method"]: - op_name = ( - node.target.__name__ if node.op == "call_function" else node.target + op_func = submod + elif node.op == "call_function": + op_name = node.target.__name__ + op_func = node.target + elif node.op == "call_method": + op_name = node.target + self_obj = ( + node_outputs[node.args[0].name] + if isinstance(node.args[0], torch.fx.Node) + else node.args[0] ) - try: - out = node.target(*node_args, **node_kwargs) - node_outputs[node.name] = out - dtype = out.dtype if isinstance(out, torch.Tensor) else None - except Exception: - print(f"dtype inference failed: op_name={op_name}") - node_outputs[node.name] = None + op_func = getattr(self_obj, node.target) + node_args = node_args[1:] + + try: + out = op_func(*node_args, **node_kwargs) + node_outputs[node.name] = out + dtype = out.dtype if isinstance(out, torch.Tensor) else None + except Exception: + print(f"dtype inference failed: node.op={node.op}, op_name={op_name}") + node_outputs[node.name] = None elif node.op == "get_attr": - val = resolve_get_attr(traced, node) - out = val.to(device="meta") if isinstance(val, torch.Tensor) else val + op_name = node.op + out = resolve_get_attr(traced, node) node_outputs[node.name] = out dtype = out.dtype if isinstance(out, torch.Tensor) else None elif node.op == "output": @@ -156,18 +165,20 @@ def collect_model_stats(model_path, device, log_prompt): num_outputs = 0 dtypes = set() op_stats = collect_op_stats(model, input_dict) - for op_name, stat in op_stats.items(): - if op_name == "placeholder": - num_inputs += stat.count - elif op_name == "output": - num_outputs += stat.count - else: - num_ops += stat.count - for v in stat.dtype: - if v is not None: - dtypes.add(v) + if op_stats is not None: + for op_name, stat in op_stats.items(): + if op_name == "placeholder": + num_inputs += stat.count + elif op_name == "output": + num_outputs += stat.count + else: + num_ops += stat.count + for v in stat.dtype: + if v is not None: + dtypes.add(v) arg_types = get_argument_types(model_class, "forward") + num_inputs = len(arg_types) if op_stats is None else num_inputs num_params = 0 param_dtypes = set() for name, arg_type in arg_types.items(): From dbcadfac61fb577323ba02ba7ae9430a08787dca Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 16 Sep 2025 16:37:08 +0800 Subject: [PATCH 04/19] Support _native_multi_head_attention. --- graph_net/torch/collect_stats.py | 76 +++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 26 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index addb6c18a..52bdbd0b7 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -55,6 +55,20 @@ class OpStat: count: int = 0 +def resolve_native_multi_head_attention(*args, **kwargs): + query, key, value = args[0], args[1], args[2] + seq_len, batch_size, embed_dim = query.shape + attn_output = torch.empty( + (seq_len, batch_size, embed_dim), dtype=query.dtype, device="meta" + ) + + # seq_len_k = key.shape[0] + # num_heads = args[4] + # attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k), + # dtype=query.dtype, device='meta') + return attn_output # , attn_output_weights + + def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node): attr_itr = node.target.split(".") val = gm @@ -65,13 +79,13 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node): def collect_op_stats(model, input_dict): - # FX symbolic trace try: + # FX symbolic trace traced = torch.fx.symbolic_trace(model) # print(traced.graph) except Exception: print("Failed to FX symbolic trace") - return None + return False, None # Use meta tensors as input to avoid actually running the model meta_input_dict = {} @@ -80,8 +94,9 @@ def collect_op_stats(model, input_dict): torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x ) - node_outputs = {} + is_complete = True op_stats = {} + node_outputs = {} for node in traced.graph.nodes: op_name = None dtype = None @@ -99,31 +114,35 @@ def collect_op_stats(model, input_dict): lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, ) - if node.op == "call_module": - # classname of module - submod = traced.get_submodule(node.target) - op_name = submod.__class__.__name__ - op_func = submod - elif node.op == "call_function": - op_name = node.target.__name__ - op_func = node.target - elif node.op == "call_method": - op_name = node.target - self_obj = ( - node_outputs[node.args[0].name] - if isinstance(node.args[0], torch.fx.Node) - else node.args[0] - ) - op_func = getattr(self_obj, node.target) - node_args = node_args[1:] - try: - out = op_func(*node_args, **node_kwargs) + if node.op == "call_module": + # classname of module + submod = traced.get_submodule(node.target) + op_name = submod.__class__.__name__ + op_func = submod + elif node.op == "call_function": + op_name = node.target.__name__ + op_func = node.target + elif node.op == "call_method": + op_name = node.target + self_obj = ( + node_outputs[node.args[0].name] + if isinstance(node.args[0], torch.fx.Node) + else node.args[0] + ) + op_func = getattr(self_obj, node.target) + node_args = node_args[1:] + + if op_name == "_native_multi_head_attention": + out = resolve_native_multi_head_attention(*node_args, **node_kwargs) + else: + out = op_func(*node_args, **node_kwargs) node_outputs[node.name] = out dtype = out.dtype if isinstance(out, torch.Tensor) else None except Exception: print(f"dtype inference failed: node.op={node.op}, op_name={op_name}") node_outputs[node.name] = None + is_complete = False elif node.op == "get_attr": op_name = node.op out = resolve_get_attr(traced, node) @@ -149,11 +168,16 @@ def collect_op_stats(model, input_dict): else: op_stats[op_name].dtype.add(dtype_str) op_stats[op_name].count = op_stats[op_name].count + 1 - return op_stats + return is_complete, op_stats def collect_model_stats(model_path, device, log_prompt): - print(f"Collect information for {model_path}") + if not hasattr(collect_model_stats, "_counter"): + collect_model_stats._counter = 0 + else: + collect_model_stats._counter += 1 + print(f"[{collect_model_stats._counter}] Collect information for {model_path}") + model_class = load_class_from_file( os.path.join(model_path, "model.py"), "GraphModule" ) @@ -164,7 +188,7 @@ def collect_model_stats(model_path, device, log_prompt): num_inputs = 0 num_outputs = 0 dtypes = set() - op_stats = collect_op_stats(model, input_dict) + is_complete, op_stats = collect_op_stats(model, input_dict) if op_stats is not None: for op_name, stat in op_stats.items(): if op_name == "placeholder": @@ -192,7 +216,7 @@ def collect_model_stats(model_path, device, log_prompt): dtypes_str = "[" + ",".join(dtypes) + "]" param_dtypes_str = "[" + ",".join(param_dtypes) + "]" print( - f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str}", + f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} is_complete:{is_complete}", file=sys.stderr, flush=True, ) From 8161df726e546799e4e02ccf11e7ffdeda70d121 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 16 Sep 2025 22:29:32 +0800 Subject: [PATCH 05/19] Fix several ops and change to use subprocess for multiple tests. --- graph_net/torch/collect_stats.py | 75 +++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 15 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index 52bdbd0b7..59ea9489d 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -4,6 +4,7 @@ import math import importlib import inspect +import subprocess from typing import Type from dataclasses import dataclass, field from collections import defaultdict @@ -62,11 +63,31 @@ def resolve_native_multi_head_attention(*args, **kwargs): (seq_len, batch_size, embed_dim), dtype=query.dtype, device="meta" ) - # seq_len_k = key.shape[0] - # num_heads = args[4] - # attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k), - # dtype=query.dtype, device='meta') - return attn_output # , attn_output_weights + # TODO(Xreki): get value from args + need_weights = False + if need_weights: + seq_len_k = key.shape[0] + num_heads = args[4] + attn_output_weights = torch.empty( + (batch_size, num_heads, seq_len, seq_len_k), + dtype=query.dtype, + device="meta", + ) + return attn_output, attn_output_weights + else: + return attn_output + + +def resolve_tensor_to(tensor, *args, **kwargs): + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + dtype = tensor.dtype + return torch.empty(tensor.shape, dtype=dtype, device="meta") + + +def resolve_tensor_item(tensor): + return torch.empty((), dtype=tensor.dtype, device="meta") def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node): @@ -115,6 +136,7 @@ def collect_op_stats(model, input_dict): ) try: + # if True: if node.op == "call_module": # classname of module submod = traced.get_submodule(node.target) @@ -133,8 +155,15 @@ def collect_op_stats(model, input_dict): op_func = getattr(self_obj, node.target) node_args = node_args[1:] + # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}") if op_name == "_native_multi_head_attention": out = resolve_native_multi_head_attention(*node_args, **node_kwargs) + elif op_name == "to": + out = resolve_tensor_to( + node_outputs[node.args[0].name], *node_args, **node_kwargs + ) + elif op_name == "item": + out = resolve_tensor_item(node_outputs[node.args[0].name]) else: out = op_func(*node_args, **node_kwargs) node_outputs[node.name] = out @@ -172,12 +201,6 @@ def collect_op_stats(model, input_dict): def collect_model_stats(model_path, device, log_prompt): - if not hasattr(collect_model_stats, "_counter"): - collect_model_stats._counter = 0 - else: - collect_model_stats._counter += 1 - print(f"[{collect_model_stats._counter}] Collect information for {model_path}") - model_class = load_class_from_file( os.path.join(model_path, "model.py"), "GraphModule" ) @@ -187,16 +210,18 @@ def collect_model_stats(model_path, device, log_prompt): num_ops = 0 num_inputs = 0 num_outputs = 0 + ops_count_info = [] dtypes = set() is_complete, op_stats = collect_op_stats(model, input_dict) if op_stats is not None: - for op_name, stat in op_stats.items(): + for op_name, stat in sorted(op_stats.items()): if op_name == "placeholder": num_inputs += stat.count elif op_name == "output": num_outputs += stat.count else: num_ops += stat.count + ops_count_info.append(f"{op_name}={stat.count}") for v in stat.dtype: if v is not None: dtypes.add(v) @@ -213,11 +238,11 @@ def collect_model_stats(model_path, device, log_prompt): param_dtypes.add(str(input_dict[name].dtype).replace("torch.", "")) num_params_in_billion = num_params / 1e9 + ops_str = "[" + ",".join(ops_count_info) + "]" dtypes_str = "[" + ",".join(dtypes) + "]" param_dtypes_str = "[" + ",".join(param_dtypes) + "]" print( - f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} is_complete:{is_complete}", - file=sys.stderr, + f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} is_complete:{is_complete} ops:{ops_str}", flush=True, ) @@ -226,6 +251,7 @@ def main(args): if args.model_path is not None: assert os.path.isdir(args.model_path) assert is_single_model_dir(args.model_path) + print(f"Collect information for {args.model_path}") collect_model_stats(args.model_path, args.device, args.log_prompt) else: graph_net_samples_path = ( @@ -233,9 +259,28 @@ def main(args): if args.graph_net_samples_path is None else args.graph_net_samples_path ) + i = 0 for root, dirs, files in os.walk(graph_net_samples_path): if is_single_model_dir(root): - collect_model_stats(root, args.device, args.log_prompt) + print(f"[{i}] Collect information for {root}") + cmd = [ + "python", + "-m", + "graph_net.torch.collect_stats", + f"--device={args.device}", + f"--model-path={root}", + f"--log-prompt={args.log_prompt}", + ] + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=600, + ) + if result.returncode == 0: + print(result.stdout) + i += 1 if __name__ == "__main__": From 07558f28173233f6290c98104989a39bdbe5fdc8 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 17 Sep 2025 15:24:42 +0800 Subject: [PATCH 06/19] Support another method with make_fx. --- graph_net/torch/collect_stats.py | 89 ++++++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 10 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index 59ea9489d..065082635 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -10,7 +10,7 @@ from collections import defaultdict import torch -from torch.fx.passes.shape_prop import ShapeProp +from functorch import make_fx from graph_net.torch import utils @@ -99,13 +99,13 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node): return out -def collect_op_stats(model, input_dict): +def collect_op_stats_manual(model, input_dict): try: # FX symbolic trace traced = torch.fx.symbolic_trace(model) # print(traced.graph) except Exception: - print("Failed to FX symbolic trace") + print("Failed to FX symbolic_trace") return False, None # Use meta tensors as input to avoid actually running the model @@ -136,7 +136,6 @@ def collect_op_stats(model, input_dict): ) try: - # if True: if node.op == "call_module": # classname of module submod = traced.get_submodule(node.target) @@ -200,23 +199,94 @@ def collect_op_stats(model, input_dict): return is_complete, op_stats +def collect_op_stats_with_make_fx(model, input_dict, arg_types): + # Use meta tensors as input to avoid actually running the model + meta_input_list = [] + for arg_name in arg_types.keys(): + x = input_dict[arg_name] + meta_x = ( + torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x + ) + meta_input_list.append(meta_x) + + try: + # Generate FX Graph, and automatically fill in meta information + fx_model = make_fx(model)(*meta_input_list) + except Exception: + print("Failed to execute make_fx") + return False, None + + is_complete = True + op_stats = {} + for node in fx_model.graph.nodes: + op_name = None + if node.op == "call_module": + # classname of module + submod = traced.get_submodule(node.target) + op_name = submod.__class__.__name__ + elif node.op == "call_function": + op_name = node.target.__name__ + elif node.op == "call_method": + op_name = node.target + elif node.op in ["placeholder", "output", "get_attr"]: + op_name = node.op + else: + assert False, f"node.op: {node.op}" + + dtype = None + if node.op != "output": + if "tensor_meta" in node.meta: + tensor_meta = node.meta["tensor_meta"] + dtype = tensor_meta.dtype + # print(f"node.op={node.op}, node.target={node.target}, dtype={tensor_meta.dtype}") + else: + print( + f"node.op={node.op}, node.target={node.target} has no tensor_meta!" + ) + is_complete = False + + op_name = ( + op_name.replace(".default", "") + .replace(".Tensor", "") + .replace(".Scalar", "") + ) + dtype_str = str(dtype).replace("torch.", "") + if op_stats.get(op_name, None) is None: + op_stats[op_name] = OpStat(op_name, {dtype_str}, 1) + else: + op_stats[op_name].dtype.add(dtype_str) + op_stats[op_name].count = op_stats[op_name].count + 1 + return is_complete, op_stats + + +def collect_op_stats(model, input_dict, arg_types): + is_complete_manual, op_stats_manual = collect_op_stats_manual(model, input_dict) + if not is_complete_manual: + is_complete_make_fx, op_stats_make_fx = collect_op_stats_with_make_fx( + model, input_dict, arg_types + ) + if is_complete_make_fx or op_stats_manual is None: + return "make_fx", is_complete_make_fx, op_stats_make_fx + return "manual", is_complete_manual, op_stats_manual + + def collect_model_stats(model_path, device, log_prompt): model_class = load_class_from_file( os.path.join(model_path, "model.py"), "GraphModule" ) model = model_class() + arg_types = get_argument_types(model_class, "forward") input_dict = get_input_dict(model_path, device) num_ops = 0 - num_inputs = 0 num_outputs = 0 ops_count_info = [] dtypes = set() - is_complete, op_stats = collect_op_stats(model, input_dict) + method, is_complete, op_stats = collect_op_stats(model, input_dict, arg_types) if op_stats is not None: for op_name, stat in sorted(op_stats.items()): if op_name == "placeholder": - num_inputs += stat.count + pass elif op_name == "output": num_outputs += stat.count else: @@ -226,8 +296,7 @@ def collect_model_stats(model_path, device, log_prompt): if v is not None: dtypes.add(v) - arg_types = get_argument_types(model_class, "forward") - num_inputs = len(arg_types) if op_stats is None else num_inputs + num_inputs = len(arg_types) num_params = 0 param_dtypes = set() for name, arg_type in arg_types.items(): @@ -242,7 +311,7 @@ def collect_model_stats(model_path, device, log_prompt): dtypes_str = "[" + ",".join(dtypes) + "]" param_dtypes_str = "[" + ",".join(param_dtypes) + "]" print( - f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} is_complete:{is_complete} ops:{ops_str}", + f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} method:{method} is_complete:{is_complete} ops:{ops_str}", flush=True, ) From 256c75fba235aca136eb7c7ed9c570a08d94f787 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 17 Sep 2025 15:58:47 +0800 Subject: [PATCH 07/19] Optimize the dtypes stats. --- graph_net/torch/collect_stats.py | 78 ++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index 065082635..d9f0d9058 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -52,7 +52,7 @@ def get_input_dict(model_path, device): @dataclass class OpStat: op_name: str - dtype: set[str] = field(default_factory=set) + op_dtypes: dict[str, int] = field(default_factory=dict) count: int = 0 @@ -124,7 +124,6 @@ def collect_op_stats_manual(model, input_dict): if node.op == "placeholder": node_outputs[node.name] = meta_input_dict[node.target] op_name = node.op - dtype = node_outputs[node.name].dtype elif node.op in ["call_function", "call_module", "call_method"]: node_args = torch.fx.map_arg( node.args, @@ -190,11 +189,13 @@ def collect_op_stats_manual(model, input_dict): assert False, f"node.op: {node.op}" if op_name is not None: - dtype_str = str(dtype).replace("torch.", "") if dtype is not None else None + dtype_str = str(dtype).replace("torch.", "") if op_stats.get(op_name, None) is None: - op_stats[op_name] = OpStat(op_name, {dtype_str}, 1) + op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) else: - op_stats[op_name].dtype.add(dtype_str) + op_stats[op_name].op_dtypes[dtype_str] = ( + op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 + ) op_stats[op_name].count = op_stats[op_name].count + 1 return is_complete, op_stats @@ -234,7 +235,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types): assert False, f"node.op: {node.op}" dtype = None - if node.op != "output": + if node.op not in ["placeholder", "output"]: if "tensor_meta" in node.meta: tensor_meta = node.meta["tensor_meta"] dtype = tensor_meta.dtype @@ -252,9 +253,11 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types): ) dtype_str = str(dtype).replace("torch.", "") if op_stats.get(op_name, None) is None: - op_stats[op_name] = OpStat(op_name, {dtype_str}, 1) + op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) else: - op_stats[op_name].dtype.add(dtype_str) + op_stats[op_name].op_dtypes[dtype_str] = ( + op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 + ) op_stats[op_name].count = op_stats[op_name].count + 1 return is_complete, op_stats @@ -280,8 +283,8 @@ def collect_model_stats(model_path, device, log_prompt): num_ops = 0 num_outputs = 0 - ops_count_info = [] - dtypes = set() + ops_count_dict = {} + op_dtypes = {} method, is_complete, op_stats = collect_op_stats(model, input_dict, arg_types) if op_stats is not None: for op_name, stat in sorted(op_stats.items()): @@ -291,29 +294,48 @@ def collect_model_stats(model_path, device, log_prompt): num_outputs += stat.count else: num_ops += stat.count - ops_count_info.append(f"{op_name}={stat.count}") - for v in stat.dtype: - if v is not None: - dtypes.add(v) + ops_count_dict[op_name] = stat.count + for dtype_str, num in stat.op_dtypes.items(): + if dtype_str is not None and dtype_str != "None": + op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num - num_inputs = len(arg_types) num_params = 0 - param_dtypes = set() + model_size = 0 + input_dtypes = {} + param_dtypes = {} for name, arg_type in arg_types.items(): if arg_type == torch.nn.parameter.Parameter: - count = math.prod(input_dict[name].shape) + param_numel = math.prod(input_dict[name].shape) # print(f"Parameter {name}: {count}") - num_params += count - param_dtypes.add(str(input_dict[name].dtype).replace("torch.", "")) - num_params_in_billion = num_params / 1e9 - - ops_str = "[" + ",".join(ops_count_info) + "]" - dtypes_str = "[" + ",".join(dtypes) + "]" - param_dtypes_str = "[" + ",".join(param_dtypes) + "]" - print( - f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} method:{method} is_complete:{is_complete} ops:{ops_str}", - flush=True, - ) + num_params += 1 + model_size += param_numel + dtype_str = str(input_dict[name].dtype).replace("torch.", "") + param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1 + else: + dtype_str = str(input_dict[name].dtype).replace("torch.", "") + input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 + model_size_in_billion = model_size / 1e9 + num_inputs = len(arg_types) - num_params + + def dict_to_string(d): + kv_list = [f"{k}={v}" for k, v in d.items()] + return "{" + ",".join(kv_list) + "}" + + log_fields = [log_prompt, "[ModelStats]"] + log_fields.append(f"model_path:{model_path}") + log_fields.append(f"num_inputs:{num_inputs}") + log_fields.append(f"num_params:{num_params}") + log_fields.append(f"num_outputs:{num_outputs}") + log_fields.append(f"num_ops:{num_ops}") + log_fields.append(f"model_size:{model_size_in_billion}B") + log_fields.append(f"input_dtypes:{dict_to_string(input_dtypes)}") + log_fields.append(f"param_dtypes:{dict_to_string(param_dtypes)}") + log_fields.append(f"op_dtypes:{dict_to_string(op_dtypes)}") + log_fields.append(f"ops:{dict_to_string(ops_count_dict)}") + log_fields.append(f"method:{method}") + log_fields.append(f"is_complete:{is_complete}") + + print(" ".join(log_fields), flush=True) def main(args): From a3fb5ae7822de54530f27f0953ff7606c6bb9cbf Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 17 Sep 2025 16:00:50 +0800 Subject: [PATCH 08/19] Enable to print error messages. --- graph_net/torch/collect_stats.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index d9f0d9058..c679417c4 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -369,8 +369,9 @@ def main(args): text=True, timeout=600, ) - if result.returncode == 0: - print(result.stdout) + print(result.stdout) + if result.returncode != 0: + print(result.stderr) i += 1 From d10dcc370d08170cf23939e98116876395ecb427 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 17 Sep 2025 21:56:49 +0800 Subject: [PATCH 09/19] Fix several problems. --- graph_net/torch/collect_stats.py | 91 ++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 22 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index c679417c4..b560917a1 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -79,7 +79,7 @@ def resolve_native_multi_head_attention(*args, **kwargs): def resolve_tensor_to(tensor, *args, **kwargs): - if isinstance(args[0], torch.dtype): + if len(args) > 0 and isinstance(args[0], torch.dtype): dtype = args[0] else: dtype = tensor.dtype @@ -99,7 +99,40 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node): return out -def collect_op_stats_manual(model, input_dict): +def convert_real_to_meta(x): + if isinstance(x, torch.Tensor) and not x.is_meta: + return torch.empty_like(x, device="meta") + elif isinstance(x, (list, tuple)): + return type(x)(convert_real_to_meta(v) for v in x) + elif isinstance(x, dict): + return {k: convert_real_to_meta(v) for k, v in x.items()} + else: + return x + + +def convert_meta_to_real(x, device): + if isinstance(x, torch.Tensor) and x.is_meta: + return torch.empty_like(x, device=device) + elif isinstance(x, (list, tuple)): + return type(x)(convert_meta_to_real(v, device) for v in x) + elif isinstance(x, dict): + return {k: convert_meta_to_real(v, device) for k, v in x.items()} + else: + return x + + +def resolve_with_real_tensor(op_func, device, meta_args, meta_kwargs): + try: + real_args = convert_meta_to_real(meta_args, device) + real_kwargs = convert_meta_to_real(meta_kwargs, device) + + real_out = op_func(*real_args, **real_kwargs) + return convert_real_to_meta(real_out) + except Exception: + return None + + +def collect_op_stats_manual(model, input_dict, device): try: # FX symbolic trace traced = torch.fx.symbolic_trace(model) @@ -109,11 +142,19 @@ def collect_op_stats_manual(model, input_dict): return False, None # Use meta tensors as input to avoid actually running the model - meta_input_dict = {} - for name, x in input_dict.items(): - meta_input_dict[name] = ( - torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x - ) + meta_input_dict = convert_real_to_meta(input_dict) + + def get_output_dtype(out): + if isinstance(out, torch.Tensor): + return out.dtype + if ( + isinstance(out, (list, tuple)) + and len(out) > 0 + and isinstance(out[0], torch.Tensor) + ): + return out[0].dtype + else: + return None is_complete = True op_stats = {} @@ -157,6 +198,7 @@ def collect_op_stats_manual(model, input_dict): if op_name == "_native_multi_head_attention": out = resolve_native_multi_head_attention(*node_args, **node_kwargs) elif op_name == "to": + # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}") out = resolve_tensor_to( node_outputs[node.args[0].name], *node_args, **node_kwargs ) @@ -165,16 +207,22 @@ def collect_op_stats_manual(model, input_dict): else: out = op_func(*node_args, **node_kwargs) node_outputs[node.name] = out - dtype = out.dtype if isinstance(out, torch.Tensor) else None + dtype = get_output_dtype(out) except Exception: - print(f"dtype inference failed: node.op={node.op}, op_name={op_name}") - node_outputs[node.name] = None - is_complete = False + out = resolve_with_real_tensor(op_func, device, node_args, node_kwargs) + node_outputs[node.name] = out + if out is not None: + dtype = get_output_dtype(out) + else: + print( + f"dtype inference failed: node.op={node.op}, op_name={op_name}" + ) + is_complete = False elif node.op == "get_attr": op_name = node.op out = resolve_get_attr(traced, node) node_outputs[node.name] = out - dtype = out.dtype if isinstance(out, torch.Tensor) else None + dtype = get_output_dtype(out) elif node.op == "output": op_name = node.op node_args = torch.fx.map_arg( @@ -182,9 +230,7 @@ def collect_op_stats_manual(model, input_dict): lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, ) node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args - dtype = ( - node_args[0].dtype if isinstance(node_args[0], torch.Tensor) else None - ) + dtype = get_output_dtype(node_args[0]) else: assert False, f"node.op: {node.op}" @@ -205,10 +251,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types): meta_input_list = [] for arg_name in arg_types.keys(): x = input_dict[arg_name] - meta_x = ( - torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x - ) - meta_input_list.append(meta_x) + meta_input_list.append(convert_real_to_meta(x)) try: # Generate FX Graph, and automatically fill in meta information @@ -262,8 +305,10 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types): return is_complete, op_stats -def collect_op_stats(model, input_dict, arg_types): - is_complete_manual, op_stats_manual = collect_op_stats_manual(model, input_dict) +def collect_op_stats(model, input_dict, arg_types, device): + is_complete_manual, op_stats_manual = collect_op_stats_manual( + model, input_dict, device + ) if not is_complete_manual: is_complete_make_fx, op_stats_make_fx = collect_op_stats_with_make_fx( model, input_dict, arg_types @@ -285,7 +330,9 @@ def collect_model_stats(model_path, device, log_prompt): num_outputs = 0 ops_count_dict = {} op_dtypes = {} - method, is_complete, op_stats = collect_op_stats(model, input_dict, arg_types) + method, is_complete, op_stats = collect_op_stats( + model, input_dict, arg_types, device + ) if op_stats is not None: for op_name, stat in sorted(op_stats.items()): if op_name == "placeholder": From f159a3d7d06b6c4d81fbc8549dffc6ead71e520c Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 18 Sep 2025 11:12:20 +0800 Subject: [PATCH 10/19] Support to rerun the failed cases only. --- graph_net/torch/collect_stats.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index b560917a1..5647e6043 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -397,9 +397,21 @@ def main(args): if args.graph_net_samples_path is None else args.graph_net_samples_path ) + + previous_failed_model_pathes = [] + if args.previous_collect_result_path is not None: + with open(args.previous_collect_result_path, "r") as f: + for line in f.readlines(): + if "[ModelStats]" in line: + fields = line.strip().split() + model_path = fields[2].split(":")[-1] + is_complete = fields[-1].split(":")[-1] + if is_complete == "False": + previous_failed_model_pathes.append(model_path) + i = 0 for root, dirs, files in os.walk(graph_net_samples_path): - if is_single_model_dir(root): + if is_single_model_dir(root) and root in previous_failed_model_pathes: print(f"[{i}] Collect information for {root}") cmd = [ "python", @@ -447,6 +459,13 @@ def main(args): default=None, help="GraphNet samples directory. e.g '../../samples'", ) + parser.add_argument( + "--previous-collect-result-path", + type=str, + required=False, + default=None, + help="Previous collect result path, use to recollect the failed cases", + ) parser.add_argument( "--log-prompt", type=str, @@ -455,4 +474,5 @@ def main(args): help="Log prompt for stats log filtering.", ) args = parser.parse_args() + print(f"[CollectStats Arguments] {args}") main(args=args) From ee5fd22f09d886c2073da71faf9758c0833463a6 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 18 Sep 2025 15:40:17 +0800 Subject: [PATCH 11/19] Implement method using torch.compile with customized backend. --- graph_net/paddle/validate.py | 2 - graph_net/torch/collect_stats.py | 343 +++++++++++++++++++------------ 2 files changed, 217 insertions(+), 128 deletions(-) diff --git a/graph_net/paddle/validate.py b/graph_net/paddle/validate.py index 9570d6cc2..bd9c9e377 100644 --- a/graph_net/paddle/validate.py +++ b/graph_net/paddle/validate.py @@ -36,8 +36,6 @@ def _extract_forward_source(model_path, class_name): source = f.read() tree = ast.parse(source) - forward_code = None - for node in tree.body: if isinstance(node, ast.ClassDef) and node.name == class_name: for fn in node.body: diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index 5647e6043..8d6be7a05 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -1,6 +1,7 @@ import argparse import os import sys +import ast import math import importlib import inspect @@ -26,16 +27,37 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul return model_class -def get_argument_types(model_class, func_name): - arg_types = {} +def get_argument_name_and_types(model_class, func_name): + argument_name2types = {} for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction): if name == func_name: for arg_name, arg in inspect.signature(func).parameters.items(): if arg_name != "self": - arg_types[arg_name] = ( + argument_name2types[arg_name] = ( None if arg.annotation is inspect._empty else arg.annotation ) - return arg_types + return argument_name2types + + +def get_number_of_returns(file_path, class_name, func_name): + source = None + with open(f"{file_path}", "r") as f: + source = f.read() + + tree = ast.parse(source) + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_name: + for f in node.body: + if isinstance(f, ast.FunctionDef) and f.name == func_name: + for stmt in ast.walk(f): + if isinstance(stmt, ast.Return): + if stmt.value is None: + return 0 + elif isinstance(stmt.value, ast.Tuple): + return len(stmt.value.elts) + else: + return 1 + return 0 def get_input_dict(model_path, device): @@ -55,6 +77,12 @@ class OpStat: op_dtypes: dict[str, int] = field(default_factory=dict) count: int = 0 + def update(self, other): + if isinstance(other, OpStat) and self.op_name == other.op_name: + self.count += other.count + for name, count in other.op_dtypes.items(): + self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count + def resolve_native_multi_head_attention(*args, **kwargs): query, key, value = args[0], args[1], args[2] @@ -132,19 +160,23 @@ def resolve_with_real_tensor(op_func, device, meta_args, meta_kwargs): return None -def collect_op_stats_manual(model, input_dict, device): - try: - # FX symbolic trace - traced = torch.fx.symbolic_trace(model) - # print(traced.graph) - except Exception: - print("Failed to FX symbolic_trace") - return False, None +torch._dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.capture_dynamic_output_shape_ops = True +torch._dynamo.config.capture_sparse_compute = True +torch._dynamo.config.raise_on_ctx_manager_usage = False +torch._dynamo.config.allow_rnn = True - # Use meta tensors as input to avoid actually running the model - meta_input_dict = convert_real_to_meta(input_dict) - def get_output_dtype(out): +class GraphMetaExecutor: + def __init__(self, device): + self.device = device + self.op_stats = {} + self.is_complete = True + self.num_ops = 0 + self.num_ops_misses_dtypes = 0 + self.subgraph_counter = 0 + + def get_output_dtype(self, out): if isinstance(out, torch.Tensor): return out.dtype if ( @@ -156,102 +188,165 @@ def get_output_dtype(out): else: return None - is_complete = True - op_stats = {} - node_outputs = {} - for node in traced.graph.nodes: + def get_op_name_and_func(self, gm, node, node_outputs): op_name = None - dtype = None - if node.op == "placeholder": - node_outputs[node.name] = meta_input_dict[node.target] - op_name = node.op - elif node.op in ["call_function", "call_module", "call_method"]: - node_args = torch.fx.map_arg( - node.args, - lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, - ) - node_kwargs = torch.fx.map_arg( - node.kwargs, - lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, - ) - - try: - if node.op == "call_module": - # classname of module - submod = traced.get_submodule(node.target) - op_name = submod.__class__.__name__ - op_func = submod - elif node.op == "call_function": - op_name = node.target.__name__ - op_func = node.target - elif node.op == "call_method": - op_name = node.target - self_obj = ( - node_outputs[node.args[0].name] - if isinstance(node.args[0], torch.fx.Node) - else node.args[0] - ) - op_func = getattr(self_obj, node.target) - node_args = node_args[1:] - - # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}") - if op_name == "_native_multi_head_attention": - out = resolve_native_multi_head_attention(*node_args, **node_kwargs) - elif op_name == "to": - # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}") - out = resolve_tensor_to( - node_outputs[node.args[0].name], *node_args, **node_kwargs - ) - elif op_name == "item": - out = resolve_tensor_item(node_outputs[node.args[0].name]) - else: - out = op_func(*node_args, **node_kwargs) - node_outputs[node.name] = out - dtype = get_output_dtype(out) - except Exception: - out = resolve_with_real_tensor(op_func, device, node_args, node_kwargs) - node_outputs[node.name] = out - if out is not None: - dtype = get_output_dtype(out) - else: - print( - f"dtype inference failed: node.op={node.op}, op_name={op_name}" - ) - is_complete = False - elif node.op == "get_attr": - op_name = node.op - out = resolve_get_attr(traced, node) - node_outputs[node.name] = out - dtype = get_output_dtype(out) - elif node.op == "output": - op_name = node.op - node_args = torch.fx.map_arg( - node.args, - lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n, - ) - node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args - dtype = get_output_dtype(node_args[0]) - else: - assert False, f"node.op: {node.op}" - + op_func = None + try: + if node.op == "call_module": + # classname of module + submod = gm.get_submodule(node.target) + op_name = submod.__class__.__name__ + op_func = submod + elif node.op == "call_function": + op_name = node.target.__name__ + op_func = node.target + elif node.op == "call_method": + op_name = node.target + self_obj = ( + node_outputs[node.args[0].name] + if isinstance(node.args[0], torch.fx.Node) + else node.args[0] + ) + op_func = getattr(self_obj, node.target) + elif node.op in ["get_attr", "placeholder", "output"]: + op_name = node.op + except Exception: + pass + return op_name, op_func + + def update_op_stats(self, op_stats, op_name, op_dtype): if op_name is not None: - dtype_str = str(dtype).replace("torch.", "") + dtype_str = str(op_dtype).replace("torch.", "") if op_stats.get(op_name, None) is None: op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) else: op_stats[op_name].op_dtypes[dtype_str] = ( op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 ) - op_stats[op_name].count = op_stats[op_name].count + 1 - return is_complete, op_stats + op_stats[op_name].count += 1 + + def __call__(self, gm: torch.fx.GraphModule, sample_inputs): + # Use meta tensors as input to avoid actually running the model + meta_sample_inputs = convert_real_to_meta(sample_inputs) + + op_stats = {} + num_ops_misses_dtypes = 0 + + input_idx = 0 + node_outputs = {} + for node in gm.graph.nodes: + out = None + op_dtype = None + op_name, op_func = self.get_op_name_and_func(gm, node, node_outputs) + if node.op == "placeholder": + out = meta_sample_inputs[input_idx] + input_idx += 1 + elif node.op in ["call_function", "call_module", "call_method"]: + try: + node_args = torch.fx.map_arg( + node.args, + lambda n: node_outputs[n.name] + if isinstance(n, torch.fx.Node) + else n, + ) + node_kwargs = torch.fx.map_arg( + node.kwargs, + lambda n: node_outputs[n.name] + if isinstance(n, torch.fx.Node) + else n, + ) + if node.op == "call_method": + node_args = node_args[1:] + + if op_name == "_native_multi_head_attention": + out = resolve_native_multi_head_attention( + *node_args, **node_kwargs + ) + elif op_name == "to": + out = resolve_tensor_to( + node_outputs[node.args[0].name], *node_args, **node_kwargs + ) + elif op_name == "item": + out = resolve_tensor_item(node_outputs[node.args[0].name]) + else: + assert op_func is not None, f"op_func of {node} is None." + out = op_func(*node_args, **node_kwargs) + except Exception: + out = resolve_with_real_tensor( + op_func, self.device, node_args, node_kwargs + ) + if out is None: + if num_ops_misses_dtypes == 0: + print( + f"dtype inference failed: node.op={node.op}, op_name={op_name}" + ) + num_ops_misses_dtypes += 1 + elif node.op == "get_attr": + out = resolve_get_attr(gm, node) + elif node.op == "output": + pass + else: + assert False, f"node.op: {node.op}" + + if out is not None: + node_outputs[node.name] = out + op_dtype = self.get_output_dtype(out) + + if node.op not in ["placeholder", "output"]: + self.update_op_stats(op_stats, op_name, op_dtype) + + if num_ops_misses_dtypes > 0: + self.is_complete = False + self.num_ops_misses_dtypes += num_ops_misses_dtypes + num_ops = 0 + for name, stat in op_stats.items(): + num_ops += stat.count + if name in self.op_stats.keys(): + self.op_stats[name].update(stat) + else: + self.op_stats[name] = stat + self.num_ops += num_ops + self.subgraph_counter += 1 + return gm.forward + + def summary(self): + print( + f"Totally {self.subgraph_counter} subgraphs, {self.num_ops} operators, and {self.num_ops_misses_dtypes} operators failed to inference dtypes." + ) + + +def collect_op_stats_with_compile(model, sample_inputs, device): + assert isinstance(model, torch.nn.Module), f"{type(model)=}" + try: + meta_executor = GraphMetaExecutor(device) + compiled_model = torch.compile(model, backend=meta_executor) + compiled_model(*sample_inputs) + meta_executor.summary() + return meta_executor.is_complete, meta_executor.op_stats + except Exception: + print("Failed with torch.compile") + return False, None -def collect_op_stats_with_make_fx(model, input_dict, arg_types): +def collect_op_stats_with_symbolic_trace(model, sample_inputs, device): + assert isinstance(model, torch.nn.Module), f"{type(model)=}" + try: + # FX symbolic trace + traced = torch.fx.symbolic_trace(model) + # print(traced.graph) + except Exception: + print("Failed with symbolic_trace") + return False, None + + meta_executor = GraphMetaExecutor(device) + meta_executor(traced, sample_inputs) + meta_executor.summary() + return meta_executor.is_complete, meta_executor.op_stats + + +def collect_op_stats_with_make_fx(model, sample_inputs): # Use meta tensors as input to avoid actually running the model - meta_input_list = [] - for arg_name in arg_types.keys(): - x = input_dict[arg_name] - meta_input_list.append(convert_real_to_meta(x)) + meta_input_list = convert_real_to_meta(sample_inputs) try: # Generate FX Graph, and automatically fill in meta information @@ -266,7 +361,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types): op_name = None if node.op == "call_module": # classname of module - submod = traced.get_submodule(node.target) + submod = fx_model.get_submodule(node.target) op_name = submod.__class__.__name__ elif node.op == "call_function": op_name = node.target.__name__ @@ -305,41 +400,38 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types): return is_complete, op_stats -def collect_op_stats(model, input_dict, arg_types, device): - is_complete_manual, op_stats_manual = collect_op_stats_manual( - model, input_dict, device +def collect_op_stats(model, sample_inputs, device): + is_complete_symbolic, op_stats_symbolic = collect_op_stats_with_symbolic_trace( + model, sample_inputs, device ) - if not is_complete_manual: - is_complete_make_fx, op_stats_make_fx = collect_op_stats_with_make_fx( - model, input_dict, arg_types + if not is_complete_symbolic: + is_complete_compile, op_stats_compile = collect_op_stats_with_compile( + model, sample_inputs, device ) - if is_complete_make_fx or op_stats_manual is None: - return "make_fx", is_complete_make_fx, op_stats_make_fx - return "manual", is_complete_manual, op_stats_manual + if is_complete_compile or op_stats_symbolic is None: + return "torch.compile", is_complete_compile, op_stats_compile + return "symbolic_trace", is_complete_symbolic, op_stats_symbolic def collect_model_stats(model_path, device, log_prompt): - model_class = load_class_from_file( - os.path.join(model_path, "model.py"), "GraphModule" - ) + file_path = os.path.join(model_path, "model.py") + model_class = load_class_from_file(file_path, "GraphModule") model = model_class() - arg_types = get_argument_types(model_class, "forward") + argument_name2types = get_argument_name_and_types(model_class, "forward") + num_outputs = get_number_of_returns(file_path, "GraphModule", "forward") + input_dict = get_input_dict(model_path, device) + ordered_input_list = [ + input_dict[arg_name] for arg_name in argument_name2types.keys() + ] num_ops = 0 - num_outputs = 0 ops_count_dict = {} op_dtypes = {} - method, is_complete, op_stats = collect_op_stats( - model, input_dict, arg_types, device - ) + method, is_complete, op_stats = collect_op_stats(model, ordered_input_list, device) if op_stats is not None: for op_name, stat in sorted(op_stats.items()): - if op_name == "placeholder": - pass - elif op_name == "output": - num_outputs += stat.count - else: + if op_name not in ["placeholder", "output"]: num_ops += stat.count ops_count_dict[op_name] = stat.count for dtype_str, num in stat.op_dtypes.items(): @@ -350,7 +442,7 @@ def collect_model_stats(model_path, device, log_prompt): model_size = 0 input_dtypes = {} param_dtypes = {} - for name, arg_type in arg_types.items(): + for name, arg_type in argument_name2types.items(): if arg_type == torch.nn.parameter.Parameter: param_numel = math.prod(input_dict[name].shape) # print(f"Parameter {name}: {count}") @@ -362,7 +454,7 @@ def collect_model_stats(model_path, device, log_prompt): dtype_str = str(input_dict[name].dtype).replace("torch.", "") input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 model_size_in_billion = model_size / 1e9 - num_inputs = len(arg_types) - num_params + num_inputs = len(argument_name2types) - num_params def dict_to_string(d): kv_list = [f"{k}={v}" for k, v in d.items()] @@ -474,5 +566,4 @@ def main(args): help="Log prompt for stats log filtering.", ) args = parser.parse_args() - print(f"[CollectStats Arguments] {args}") main(args=args) From 777f8dd52ce6dc8e7afdf55eed9e4ddd74607758 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 18 Sep 2025 16:29:27 +0800 Subject: [PATCH 12/19] Update the log format. --- graph_net/torch/collect_stats.py | 38 +++++++++++++++++--------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index 8d6be7a05..3a17b3655 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -457,24 +457,26 @@ def collect_model_stats(model_path, device, log_prompt): num_inputs = len(argument_name2types) - num_params def dict_to_string(d): - kv_list = [f"{k}={v}" for k, v in d.items()] - return "{" + ",".join(kv_list) + "}" - - log_fields = [log_prompt, "[ModelStats]"] - log_fields.append(f"model_path:{model_path}") - log_fields.append(f"num_inputs:{num_inputs}") - log_fields.append(f"num_params:{num_params}") - log_fields.append(f"num_outputs:{num_outputs}") - log_fields.append(f"num_ops:{num_ops}") - log_fields.append(f"model_size:{model_size_in_billion}B") - log_fields.append(f"input_dtypes:{dict_to_string(input_dtypes)}") - log_fields.append(f"param_dtypes:{dict_to_string(param_dtypes)}") - log_fields.append(f"op_dtypes:{dict_to_string(op_dtypes)}") - log_fields.append(f"ops:{dict_to_string(ops_count_dict)}") - log_fields.append(f"method:{method}") - log_fields.append(f"is_complete:{is_complete}") - - print(" ".join(log_fields), flush=True) + kv_list = [f"{k}:{v}" for k, v in d.items()] + return " ".join(kv_list) + + def print_with_log_prompt(key, value): + print( + f"{log_prompt} [ModelStats.{key}] model_path:{model_path} {value}", + flush=True, + ) + + print_with_log_prompt("num_inputs", num_inputs) + print_with_log_prompt("num_params", num_params) + print_with_log_prompt("num_outputs", num_outputs) + print_with_log_prompt("num_ops", num_ops) + print_with_log_prompt("model_size", f"{model_size_in_billion}B") + print_with_log_prompt("input_dtypes", dict_to_string(input_dtypes)) + print_with_log_prompt("param_dtypes", dict_to_string(param_dtypes)) + print_with_log_prompt("op_dtypes", dict_to_string(op_dtypes)) + print_with_log_prompt("ops", dict_to_string(ops_count_dict)) + print_with_log_prompt("method", method) + print_with_log_prompt("is_complete", is_complete) def main(args): From 9cf8a86f2531374fa0b05ab0a65c9d3097116f36 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 18 Sep 2025 16:50:15 +0800 Subject: [PATCH 13/19] Add source and heuristic_tag. --- graph_net/torch/collect_stats.py | 40 +++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index 3a17b3655..6ad271912 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -3,6 +3,7 @@ import sys import ast import math +import json import importlib import inspect import subprocess @@ -60,6 +61,36 @@ def get_number_of_returns(file_path, class_name, func_name): return 0 +def read_graph_source_and_tag(model_path): + try: + with open(os.path.join(model_path, "graph_net.json"), "r") as f: + data = json.load(f) + return data["source"], data["heuristic_tag"] + except Exception: + if "cosyvoice" in model_path: + return "cosyvoice", "audio" + elif "torchaudio" in model_path: + return "torchaudio", "audio" + elif "ultralytics" in model_path: + return "ultralytics", "computer_vision" + elif "torchvision" in model_path: + return "torchvision", "computer_vision" + elif "timm" in model_path: + return "timm", "computer_vision" + elif "mmseg" in model_path: + return "mmseg", "computer_vision" + elif "mmpose" in model_path: + return "mmpose", "computer_vision" + elif "torchgeometric" in model_path: + return "torchgeometric", "other" + elif "transformers-auto-model" in model_path: + return "huggingface_hub", "unknown" + elif "nemo" in model_path: + return "nemo", "unknown" + else: + return "unknown", "unknown" + + def get_input_dict(model_path, device): inputs_params = utils.load_converted_from_text(f"{model_path}") params = inputs_params["weight_info"] @@ -456,6 +487,8 @@ def collect_model_stats(model_path, device, log_prompt): model_size_in_billion = model_size / 1e9 num_inputs = len(argument_name2types) - num_params + source, heuristic_tag = read_graph_source_and_tag(model_path) + def dict_to_string(d): kv_list = [f"{k}:{v}" for k, v in d.items()] return " ".join(kv_list) @@ -475,6 +508,8 @@ def print_with_log_prompt(key, value): print_with_log_prompt("param_dtypes", dict_to_string(param_dtypes)) print_with_log_prompt("op_dtypes", dict_to_string(op_dtypes)) print_with_log_prompt("ops", dict_to_string(ops_count_dict)) + print_with_log_prompt("source", source) + print_with_log_prompt("heuristic_tag", heuristic_tag) print_with_log_prompt("method", method) print_with_log_prompt("is_complete", is_complete) @@ -505,7 +540,10 @@ def main(args): i = 0 for root, dirs, files in os.walk(graph_net_samples_path): - if is_single_model_dir(root) and root in previous_failed_model_pathes: + if is_single_model_dir(root) and ( + args.previous_collect_result_path is None + or root in previous_failed_model_pathes + ): print(f"[{i}] Collect information for {root}") cmd = [ "python", From b1eb293a259b5d74681c0461fa2af8534e814295 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 18 Sep 2025 19:08:07 +0800 Subject: [PATCH 14/19] Add timestamp in log. --- graph_net/torch/collect_stats.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index 6ad271912..2ea85793a 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -7,6 +7,7 @@ import importlib import inspect import subprocess +from datetime import datetime from typing import Type from dataclasses import dataclass, field from collections import defaultdict @@ -518,7 +519,10 @@ def main(args): if args.model_path is not None: assert os.path.isdir(args.model_path) assert is_single_model_dir(args.model_path) - print(f"Collect information for {args.model_path}") + timestamp_sec = datetime.now().timestamp() + dt = datetime.fromtimestamp(timestamp_sec) + formatted_dt = dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + print(f"[{formatted_dt}] Collect information for {args.model_path}") collect_model_stats(args.model_path, args.device, args.log_prompt) else: graph_net_samples_path = ( From 7575957925a795d379ceaf707a0f0685b78a418d Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 19 Sep 2025 09:06:39 +0800 Subject: [PATCH 15/19] Remove the make_fx implementation. --- graph_net/torch/collect_stats.py | 64 ++------------------------------ 1 file changed, 3 insertions(+), 61 deletions(-) diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index 2ea85793a..a46965380 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -13,7 +13,6 @@ from collections import defaultdict import torch -from functorch import make_fx from graph_net.torch import utils @@ -376,62 +375,6 @@ def collect_op_stats_with_symbolic_trace(model, sample_inputs, device): return meta_executor.is_complete, meta_executor.op_stats -def collect_op_stats_with_make_fx(model, sample_inputs): - # Use meta tensors as input to avoid actually running the model - meta_input_list = convert_real_to_meta(sample_inputs) - - try: - # Generate FX Graph, and automatically fill in meta information - fx_model = make_fx(model)(*meta_input_list) - except Exception: - print("Failed to execute make_fx") - return False, None - - is_complete = True - op_stats = {} - for node in fx_model.graph.nodes: - op_name = None - if node.op == "call_module": - # classname of module - submod = fx_model.get_submodule(node.target) - op_name = submod.__class__.__name__ - elif node.op == "call_function": - op_name = node.target.__name__ - elif node.op == "call_method": - op_name = node.target - elif node.op in ["placeholder", "output", "get_attr"]: - op_name = node.op - else: - assert False, f"node.op: {node.op}" - - dtype = None - if node.op not in ["placeholder", "output"]: - if "tensor_meta" in node.meta: - tensor_meta = node.meta["tensor_meta"] - dtype = tensor_meta.dtype - # print(f"node.op={node.op}, node.target={node.target}, dtype={tensor_meta.dtype}") - else: - print( - f"node.op={node.op}, node.target={node.target} has no tensor_meta!" - ) - is_complete = False - - op_name = ( - op_name.replace(".default", "") - .replace(".Tensor", "") - .replace(".Scalar", "") - ) - dtype_str = str(dtype).replace("torch.", "") - if op_stats.get(op_name, None) is None: - op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) - else: - op_stats[op_name].op_dtypes[dtype_str] = ( - op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 - ) - op_stats[op_name].count = op_stats[op_name].count + 1 - return is_complete, op_stats - - def collect_op_stats(model, sample_inputs, device): is_complete_symbolic, op_stats_symbolic = collect_op_stats_with_symbolic_trace( model, sample_inputs, device @@ -470,23 +413,22 @@ def collect_model_stats(model_path, device, log_prompt): if dtype_str is not None and dtype_str != "None": op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num - num_params = 0 model_size = 0 input_dtypes = {} param_dtypes = {} for name, arg_type in argument_name2types.items(): if arg_type == torch.nn.parameter.Parameter: param_numel = math.prod(input_dict[name].shape) - # print(f"Parameter {name}: {count}") - num_params += 1 model_size += param_numel dtype_str = str(input_dict[name].dtype).replace("torch.", "") param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1 else: dtype_str = str(input_dict[name].dtype).replace("torch.", "") input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 + model_size_in_billion = model_size / 1e9 - num_inputs = len(argument_name2types) - num_params + num_params = len(param_dtypes) + num_inputs = len(input_dtypes) source, heuristic_tag = read_graph_source_and_tag(model_path) From 0ce778f4caa0c7941798b4d48befbab6921717dc Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 22 Sep 2025 16:43:49 +0800 Subject: [PATCH 16/19] Add paddle implementation. --- graph_net/paddle/collect_stats.py | 377 ++++++++++++++++++++++++++++++ 1 file changed, 377 insertions(+) create mode 100644 graph_net/paddle/collect_stats.py diff --git a/graph_net/paddle/collect_stats.py b/graph_net/paddle/collect_stats.py new file mode 100644 index 000000000..c9e12f0b0 --- /dev/null +++ b/graph_net/paddle/collect_stats.py @@ -0,0 +1,377 @@ +import argparse +import os +import re +import sys +import ast +import math +import importlib +import inspect +import subprocess +from datetime import datetime +from typing import Type +from dataclasses import dataclass, field +from collections import defaultdict + +import paddle +from graph_net.paddle import utils + + +def is_single_model_dir(model_dir): + return os.path.isfile(f"{model_dir}/graph_net.json") + + +def load_class_from_file(file_path: str, class_name: str) -> Type[paddle.nn.Layer]: + spec = importlib.util.spec_from_file_location("unnamed", file_path) + unnamed = importlib.util.module_from_spec(spec) + spec.loader.exec_module(unnamed) + model_class = getattr(unnamed, class_name, None) + return model_class + + +def get_argument_name_and_types(model_class, func_name): + argument_name2types = {} + for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction): + if name == func_name: + for arg_name, arg in inspect.signature(func).parameters.items(): + if arg_name != "self": + argument_name2types[arg_name] = ( + None if arg.annotation is inspect._empty else arg.annotation + ) + return argument_name2types + + +def get_number_of_returns(file_path, class_name, func_name): + source = None + with open(f"{file_path}", "r") as f: + source = f.read() + + tree = ast.parse(source) + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_name: + for f in node.body: + if isinstance(f, ast.FunctionDef) and f.name == func_name: + for stmt in ast.walk(f): + if isinstance(stmt, ast.Return): + if stmt.value is None: + return 0 + elif isinstance(stmt.value, ast.Tuple): + return len(stmt.value.elts) + else: + return 1 + return 0 + + +def read_graph_source_and_tag(model_path): + try: + with open(os.path.join(model_path, "graph_net.json"), "r") as f: + data = json.load(f) + return data["source"], data["heuristic_tag"] + except Exception: + if "PaddleX" in model_path: + return "PaddleX", "computer_vision" + elif "PaddleNLP" in model_path: + return "PaddleNLP", "nlp" + elif "PaddleScience" in model_path: + return "PaddleScience", "scientific_computing" + else: + return "unknown", "unknown" + + +def get_input_spec(model_path): + inputs_params_list = utils.load_converted_list_from_text(f"{model_path}") + input_spec = [None] * len(inputs_params_list) + for i, v in enumerate(inputs_params_list): + dtype = v["info"]["dtype"] + shape = v["info"]["shape"] + input_spec[i] = paddle.static.InputSpec(shape, dtype) + return input_spec + + +@dataclass +class OpStat: + op_name: str + op_dtypes: dict[str, int] = field(default_factory=dict) + count: int = 0 + + def update(self, other): + if isinstance(other, OpStat) and self.op_name == other.op_name: + self.count += other.count + for name, count in other.op_dtypes.items(): + self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count + + +class ProgramAnalyzer: + def __init__(self): + self.op_stats = {} + self.input_dict = {} + self.num_ops = 0 + self.num_ops_misses_dtypes = 0 + self.is_complete = True + + def update_op_stats(self, op_name, op_dtype): + if op_name is not None: + dtype_str = str(op_dtype).replace("paddle.", "") + if self.op_stats.get(op_name, None) is None: + self.op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) + else: + self.op_stats[op_name].op_dtypes[dtype_str] = ( + self.op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 + ) + self.op_stats[op_name].count += 1 + + def parse_pir_value_dtypes(self, type_str): + short_form2dtype = { + "f32": "float32", + "f16": "float16", + "bf16": "bfloat16", + "i64": "int64", + } + # type_str: "vec[tensor<1x18x13x9xf32>,tensor<1x9x13x9xf32>]" + matches = re.findall(r"tensor<([^>]+)>", type_str) + dtype_strs = [] + for s in matches: + parts = s.split("x") + assert len(parts) > 0 + + dtype = parts[-1].lower() + dtype_strs.append(short_form2dtype[dtype]) + return dtype_strs + + def __call__(self, program): + assert isinstance(program, paddle.base.libpaddle.pir.Program) + + self.op_stats = {} + self.num_ops_misses_dtypes = 0 + self.num_ops = 0 + for block in program.blocks: + for op in block.ops: + op_name = None + op_dtype = None + if op.name() == "pd_op.data": + op_name = "data" + op_attrs = op.attrs() + op_dtype = op_attrs["dtype"] + self.input_dict[op_attrs["name"]] = { + "dtype": str(op_dtype).replace("paddle.", ""), + "shape": op_attrs["shape"], + } + elif op.name().startswith("pd_op."): + self.num_ops += 1 + op_name = op.name().replace("pd_op.", "") + try: + if len(op.results()) > 0: + out = op.results()[0] + if out.is_dense_tensor_type(): + op_dtype = out.dtype + else: + # for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined + if op_name in ["split", "split_with_num", "meshgrid"]: + op_dtype = self.parse_pir_value_dtypes( + str(out.type()) + )[0] + else: + assert False, f"Unsupport op: {op}" + except Exception: + if self.num_ops_misses_dtypes == 0: + print(f"dtype inference failed for {op_name}") + if op_dtype is not None: + self.update_op_stats(op_name, op_dtype) + else: + self.num_ops_misses_dtypes += 1 + elif not op.name().startswith("builtin."): + assert False, f"Unrecognized op: {op}" + + if self.num_ops_misses_dtypes > 0: + self.is_complete = False + + def summary(self): + print( + f"Totally {self.num_ops} operators, and {self.num_ops_misses_dtypes} operators failed to inference dtypes." + ) + + +def collect_op_stats(model, model_path): + assert isinstance(model, paddle.nn.Layer), f"{type(model)=}" + try: + static_model = paddle.jit.to_static( + model, + input_spec=get_input_spec(model_path), + full_graph=True, + backend=None, + ) + static_model.eval() + program = static_model.forward.concrete_program.main_program + + program_analyzer = ProgramAnalyzer() + program_analyzer(program) + program_analyzer.summary() + return program_analyzer + except Exception: + print("Failed with to_static") + return None + + +def collect_model_stats(model_path, log_prompt): + file_path = os.path.join(model_path, "model.py") + model_class = load_class_from_file(file_path, "GraphModule") + model = model_class() + num_outputs = get_number_of_returns(file_path, "GraphModule", "forward") + + model_size = 0 + input_dtypes = {} + param_dtypes = {} + ops_count_dict = {} + op_dtypes = {} + + program_analyzer = collect_op_stats(model, model_path) + if program_analyzer is not None: + for op_name, stat in sorted(program_analyzer.op_stats.items()): + ops_count_dict[op_name] = stat.count + for dtype_str, num in stat.op_dtypes.items(): + if dtype_str is not None and dtype_str != "None": + op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num + + inputs_params = utils.load_converted_from_text(f"{model_path}") + params = inputs_params["weight_info"] + inputs = inputs_params["input_info"] + + for name, value in program_analyzer.input_dict.items(): + dtype_str = value["dtype"] + if name in params.keys(): + param_numel = math.prod(value["shape"]) + model_size += param_numel + param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1 + elif name in inputs.keys(): + input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 + + model_size_in_billion = model_size / 1e9 + num_params = sum(param_dtypes.values()) + num_inputs = sum(input_dtypes.values()) + num_ops = sum(ops_count_dict.values()) + source, heuristic_tag = read_graph_source_and_tag(model_path) + method = "to_static" + is_complete = ( + program_analyzer.is_complete if program_analyzer is not None else False + ) + + def dict_to_string(d): + kv_list = [f"{k}:{v}" for k, v in d.items()] + return " ".join(kv_list) + + def print_with_log_prompt(key, value): + print( + f"{log_prompt} [ModelStats.{key}] model_path:{model_path} {value}", + flush=True, + ) + + print_with_log_prompt("num_inputs", num_inputs) + print_with_log_prompt("num_params", num_params) + print_with_log_prompt("num_outputs", num_outputs) + print_with_log_prompt("num_ops", num_ops) + print_with_log_prompt("model_size", f"{model_size_in_billion}B") + print_with_log_prompt("input_dtypes", dict_to_string(input_dtypes)) + print_with_log_prompt("param_dtypes", dict_to_string(param_dtypes)) + print_with_log_prompt("op_dtypes", dict_to_string(op_dtypes)) + print_with_log_prompt("ops", dict_to_string(ops_count_dict)) + print_with_log_prompt("source", source) + print_with_log_prompt("heuristic_tag", heuristic_tag) + print_with_log_prompt("method", method) + print_with_log_prompt("is_complete", is_complete) + + +def main(args): + if args.model_path is not None: + assert os.path.isdir(args.model_path) + assert is_single_model_dir(args.model_path) + timestamp_sec = datetime.now().timestamp() + dt = datetime.fromtimestamp(timestamp_sec) + formatted_dt = dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + print(f"[{formatted_dt}] Collect information for {args.model_path}") + collect_model_stats(args.model_path, args.log_prompt) + else: + graph_net_samples_path = ( + (graph_net.paddle.samples_util.get_default_samples_directory()) + if args.graph_net_samples_path is None + else args.graph_net_samples_path + ) + + previous_failed_model_pathes = [] + if args.previous_collect_result_path is not None: + with open(args.previous_collect_result_path, "r") as f: + for line in f.readlines(): + if "[ModelStats]" in line: + fields = line.strip().split() + model_path = fields[2].split(":")[-1] + is_complete = fields[-1].split(":")[-1] + if is_complete == "False": + previous_failed_model_pathes.append(model_path) + + i = 0 + for root, dirs, files in os.walk(graph_net_samples_path): + if is_single_model_dir(root) and ( + args.previous_collect_result_path is None + or root in previous_failed_model_pathes + ): + print(f"[{i}] Collect information for {root}") + cmd = [ + "python", + "-m", + "graph_net.paddle.collect_stats", + f"--device={args.device}", + f"--model-path={root}", + f"--log-prompt={args.log_prompt}", + ] + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=600, + ) + print(result.stdout) + if result.returncode != 0: + print(result.stderr) + i += 1 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Collect stats for computation graph samples. return 0 if success" + ) + parser.add_argument( + "--device", + type=str, + required=False, + default="cuda", + help="Device for testing the compiler (e.g., 'cpu' or 'cuda')", + ) + parser.add_argument( + "--model-path", + type=str, + required=False, + default=None, + help="Computation graph sample directory. e.g '../../paddle_samples/PaddleX/ResNet18'", + ) + parser.add_argument( + "--graph-net-samples-path", + type=str, + required=False, + default=None, + help="GraphNet samples directory. e.g '../../paddle_samples'", + ) + parser.add_argument( + "--previous-collect-result-path", + type=str, + required=False, + default=None, + help="Previous collect result path, use to recollect the failed cases", + ) + parser.add_argument( + "--log-prompt", + type=str, + required=False, + default="graph-net-collect-stats-log", + help="Log prompt for stats log filtering.", + ) + args = parser.parse_args() + main(args=args) From 1a3dc301d3cb59dd28292fc9f35089872ec47b7c Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 22 Sep 2025 17:46:28 +0800 Subject: [PATCH 17/19] Reorganize some codes. --- graph_net/collect_stats_util.py | 101 +++++++++++++++++++++ graph_net/paddle/collect_stats.py | 141 ++++++------------------------ 2 files changed, 129 insertions(+), 113 deletions(-) create mode 100644 graph_net/collect_stats_util.py diff --git a/graph_net/collect_stats_util.py b/graph_net/collect_stats_util.py new file mode 100644 index 000000000..72ffd666a --- /dev/null +++ b/graph_net/collect_stats_util.py @@ -0,0 +1,101 @@ +import ast +import importlib +import inspect +from dataclasses import dataclass, field +from typing import Dict + + +@dataclass +class OpStat: + op_name: str + op_dtypes: dict[str, int] = field(default_factory=dict) + count: int = 0 + + def update(self, other): + if isinstance(other, OpStat) and self.op_name == other.op_name: + self.count += other.count + for name, count in other.op_dtypes.items(): + self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count + + +@dataclass +class ModelStats: + model_path: str + num_inputs: int = None + num_params: int = None + num_outputs: int = None + num_ops: int = None + model_size_in_billion: float = None + input_dtypes: Dict[str, int] = field(default_factory=dict) + param_dtypes: Dict[str, int] = field(default_factory=dict) + op_dtypes: Dict[str, int] = field(default_factory=dict) + ops: Dict[str, int] = field(default_factory=dict) + source: str = None + heuristic_tag: str = None + + +def print_model_stats(stats, log_prompt): + assert isinstance(stats, ModelStats), f"{type(stats)=}" + + def dict_to_string(d): + kv_list = [f"{k}:{v}" for k, v in d.items()] + return " ".join(kv_list) + + def print_with_log_prompt(key, value): + print( + f"{log_prompt} [ModelStats.{key}] model_path:{stats.model_path} {value}", + flush=True, + ) + + print_with_log_prompt("num_inputs", stats.num_inputs) + print_with_log_prompt("num_params", stats.num_params) + print_with_log_prompt("num_outputs", stats.num_outputs) + print_with_log_prompt("num_ops", stats.num_ops) + print_with_log_prompt("model_size", f"{stats.model_size_in_billion}B") + print_with_log_prompt("input_dtypes", dict_to_string(stats.input_dtypes)) + print_with_log_prompt("param_dtypes", dict_to_string(stats.param_dtypes)) + print_with_log_prompt("op_dtypes", dict_to_string(stats.op_dtypes)) + print_with_log_prompt("ops", dict_to_string(stats.ops)) + print_with_log_prompt("source", stats.source) + print_with_log_prompt("heuristic_tag", stats.heuristic_tag) + + +def load_class_from_file(file_path, class_name): + spec = importlib.util.spec_from_file_location("unnamed", file_path) + unnamed = importlib.util.module_from_spec(spec) + spec.loader.exec_module(unnamed) + model_class = getattr(unnamed, class_name, None) + return model_class + + +def get_argument_name_and_types(model_class, func_name): + argument_name2types = {} + for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction): + if name == func_name: + for arg_name, arg in inspect.signature(func).parameters.items(): + if arg_name != "self": + argument_name2types[arg_name] = ( + None if arg.annotation is inspect._empty else arg.annotation + ) + return argument_name2types + + +def get_number_of_returns(file_path, class_name, func_name): + source = None + with open(f"{file_path}", "r") as f: + source = f.read() + + tree = ast.parse(source) + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_name: + for f in node.body: + if isinstance(f, ast.FunctionDef) and f.name == func_name: + for stmt in ast.walk(f): + if isinstance(stmt, ast.Return): + if stmt.value is None: + return 0 + elif isinstance(stmt.value, ast.Tuple): + return len(stmt.value.elts) + else: + return 1 + return 0 diff --git a/graph_net/paddle/collect_stats.py b/graph_net/paddle/collect_stats.py index c9e12f0b0..bfee92a8c 100644 --- a/graph_net/paddle/collect_stats.py +++ b/graph_net/paddle/collect_stats.py @@ -2,17 +2,12 @@ import os import re import sys -import ast import math -import importlib -import inspect import subprocess from datetime import datetime -from typing import Type -from dataclasses import dataclass, field -from collections import defaultdict import paddle +from graph_net import collect_stats_util from graph_net.paddle import utils @@ -20,47 +15,6 @@ def is_single_model_dir(model_dir): return os.path.isfile(f"{model_dir}/graph_net.json") -def load_class_from_file(file_path: str, class_name: str) -> Type[paddle.nn.Layer]: - spec = importlib.util.spec_from_file_location("unnamed", file_path) - unnamed = importlib.util.module_from_spec(spec) - spec.loader.exec_module(unnamed) - model_class = getattr(unnamed, class_name, None) - return model_class - - -def get_argument_name_and_types(model_class, func_name): - argument_name2types = {} - for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction): - if name == func_name: - for arg_name, arg in inspect.signature(func).parameters.items(): - if arg_name != "self": - argument_name2types[arg_name] = ( - None if arg.annotation is inspect._empty else arg.annotation - ) - return argument_name2types - - -def get_number_of_returns(file_path, class_name, func_name): - source = None - with open(f"{file_path}", "r") as f: - source = f.read() - - tree = ast.parse(source) - for node in tree.body: - if isinstance(node, ast.ClassDef) and node.name == class_name: - for f in node.body: - if isinstance(f, ast.FunctionDef) and f.name == func_name: - for stmt in ast.walk(f): - if isinstance(stmt, ast.Return): - if stmt.value is None: - return 0 - elif isinstance(stmt.value, ast.Tuple): - return len(stmt.value.elts) - else: - return 1 - return 0 - - def read_graph_source_and_tag(model_path): try: with open(os.path.join(model_path, "graph_net.json"), "r") as f: @@ -87,19 +41,6 @@ def get_input_spec(model_path): return input_spec -@dataclass -class OpStat: - op_name: str - op_dtypes: dict[str, int] = field(default_factory=dict) - count: int = 0 - - def update(self, other): - if isinstance(other, OpStat) and self.op_name == other.op_name: - self.count += other.count - for name, count in other.op_dtypes.items(): - self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count - - class ProgramAnalyzer: def __init__(self): self.op_stats = {} @@ -112,7 +53,9 @@ def update_op_stats(self, op_name, op_dtype): if op_name is not None: dtype_str = str(op_dtype).replace("paddle.", "") if self.op_stats.get(op_name, None) is None: - self.op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) + self.op_stats[op_name] = collect_stats_util.OpStat( + op_name, {dtype_str: 1}, 1 + ) else: self.op_stats[op_name].op_dtypes[dtype_str] = ( self.op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 @@ -213,9 +156,8 @@ def collect_op_stats(model, model_path): def collect_model_stats(model_path, log_prompt): file_path = os.path.join(model_path, "model.py") - model_class = load_class_from_file(file_path, "GraphModule") + model_class = collect_stats_util.load_class_from_file(file_path, "GraphModule") model = model_class() - num_outputs = get_number_of_returns(file_path, "GraphModule", "forward") model_size = 0 input_dtypes = {} @@ -244,39 +186,33 @@ def collect_model_stats(model_path, log_prompt): elif name in inputs.keys(): input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 - model_size_in_billion = model_size / 1e9 - num_params = sum(param_dtypes.values()) - num_inputs = sum(input_dtypes.values()) - num_ops = sum(ops_count_dict.values()) + num_outputs = collect_stats_util.get_number_of_returns( + file_path, "GraphModule", "forward" + ) + num_ops = program_analyzer.num_ops if program_analyzer is not None else 0 source, heuristic_tag = read_graph_source_and_tag(model_path) - method = "to_static" is_complete = ( program_analyzer.is_complete if program_analyzer is not None else False ) + print( + f"model_stats collection information: model_path={model_path}, method=to_static, is_ops_complete={is_complete}" + ) - def dict_to_string(d): - kv_list = [f"{k}:{v}" for k, v in d.items()] - return " ".join(kv_list) - - def print_with_log_prompt(key, value): - print( - f"{log_prompt} [ModelStats.{key}] model_path:{model_path} {value}", - flush=True, - ) - - print_with_log_prompt("num_inputs", num_inputs) - print_with_log_prompt("num_params", num_params) - print_with_log_prompt("num_outputs", num_outputs) - print_with_log_prompt("num_ops", num_ops) - print_with_log_prompt("model_size", f"{model_size_in_billion}B") - print_with_log_prompt("input_dtypes", dict_to_string(input_dtypes)) - print_with_log_prompt("param_dtypes", dict_to_string(param_dtypes)) - print_with_log_prompt("op_dtypes", dict_to_string(op_dtypes)) - print_with_log_prompt("ops", dict_to_string(ops_count_dict)) - print_with_log_prompt("source", source) - print_with_log_prompt("heuristic_tag", heuristic_tag) - print_with_log_prompt("method", method) - print_with_log_prompt("is_complete", is_complete) + stats = collect_stats_util.ModelStats( + model_path=model_path, + num_inputs=sum(input_dtypes.values()), + num_params=sum(param_dtypes.values()), + num_outputs=num_outputs, + num_ops=num_ops, + model_size_in_billion=model_size / 1e9, + input_dtypes=input_dtypes, + param_dtypes=param_dtypes, + op_dtypes=op_dtypes, + ops=ops_count_dict, + source=source, + heuristic_tag=heuristic_tag, + ) + collect_stats_util.print_model_stats(stats, log_prompt) def main(args): @@ -295,23 +231,9 @@ def main(args): else args.graph_net_samples_path ) - previous_failed_model_pathes = [] - if args.previous_collect_result_path is not None: - with open(args.previous_collect_result_path, "r") as f: - for line in f.readlines(): - if "[ModelStats]" in line: - fields = line.strip().split() - model_path = fields[2].split(":")[-1] - is_complete = fields[-1].split(":")[-1] - if is_complete == "False": - previous_failed_model_pathes.append(model_path) - i = 0 for root, dirs, files in os.walk(graph_net_samples_path): - if is_single_model_dir(root) and ( - args.previous_collect_result_path is None - or root in previous_failed_model_pathes - ): + if is_single_model_dir(root): print(f"[{i}] Collect information for {root}") cmd = [ "python", @@ -359,13 +281,6 @@ def main(args): default=None, help="GraphNet samples directory. e.g '../../paddle_samples'", ) - parser.add_argument( - "--previous-collect-result-path", - type=str, - required=False, - default=None, - help="Previous collect result path, use to recollect the failed cases", - ) parser.add_argument( "--log-prompt", type=str, From 5db711765fee4861511e3757c9093864f2692f55 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 22 Sep 2025 21:05:44 +0800 Subject: [PATCH 18/19] Reorgnanize codes. --- graph_net/paddle/collect_stats.py | 7 +- graph_net/torch/collect_stats.py | 198 ++++++++---------------------- 2 files changed, 54 insertions(+), 151 deletions(-) diff --git a/graph_net/paddle/collect_stats.py b/graph_net/paddle/collect_stats.py index bfee92a8c..845cab101 100644 --- a/graph_net/paddle/collect_stats.py +++ b/graph_net/paddle/collect_stats.py @@ -108,7 +108,12 @@ def __call__(self, program): op_dtype = out.dtype else: # for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined - if op_name in ["split", "split_with_num", "meshgrid"]: + if op_name in [ + "split", + "split_with_num", + "meshgrid", + "distribute_fpn_proposals", + ]: op_dtype = self.parse_pir_value_dtypes( str(out.type()) )[0] diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index a46965380..b31a48ce8 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -1,18 +1,13 @@ import argparse import os import sys -import ast import math import json -import importlib -import inspect import subprocess from datetime import datetime -from typing import Type -from dataclasses import dataclass, field -from collections import defaultdict import torch +from graph_net import collect_stats_util from graph_net.torch import utils @@ -20,75 +15,13 @@ def is_single_model_dir(model_dir): return os.path.isfile(f"{model_dir}/graph_net.json") -def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: - spec = importlib.util.spec_from_file_location("unnamed", file_path) - unnamed = importlib.util.module_from_spec(spec) - spec.loader.exec_module(unnamed) - model_class = getattr(unnamed, class_name, None) - return model_class - - -def get_argument_name_and_types(model_class, func_name): - argument_name2types = {} - for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction): - if name == func_name: - for arg_name, arg in inspect.signature(func).parameters.items(): - if arg_name != "self": - argument_name2types[arg_name] = ( - None if arg.annotation is inspect._empty else arg.annotation - ) - return argument_name2types - - -def get_number_of_returns(file_path, class_name, func_name): - source = None - with open(f"{file_path}", "r") as f: - source = f.read() - - tree = ast.parse(source) - for node in tree.body: - if isinstance(node, ast.ClassDef) and node.name == class_name: - for f in node.body: - if isinstance(f, ast.FunctionDef) and f.name == func_name: - for stmt in ast.walk(f): - if isinstance(stmt, ast.Return): - if stmt.value is None: - return 0 - elif isinstance(stmt.value, ast.Tuple): - return len(stmt.value.elts) - else: - return 1 - return 0 - - def read_graph_source_and_tag(model_path): try: with open(os.path.join(model_path, "graph_net.json"), "r") as f: data = json.load(f) return data["source"], data["heuristic_tag"] except Exception: - if "cosyvoice" in model_path: - return "cosyvoice", "audio" - elif "torchaudio" in model_path: - return "torchaudio", "audio" - elif "ultralytics" in model_path: - return "ultralytics", "computer_vision" - elif "torchvision" in model_path: - return "torchvision", "computer_vision" - elif "timm" in model_path: - return "timm", "computer_vision" - elif "mmseg" in model_path: - return "mmseg", "computer_vision" - elif "mmpose" in model_path: - return "mmpose", "computer_vision" - elif "torchgeometric" in model_path: - return "torchgeometric", "other" - elif "transformers-auto-model" in model_path: - return "huggingface_hub", "unknown" - elif "nemo" in model_path: - return "nemo", "unknown" - else: - return "unknown", "unknown" + return "unknown", "unknown" def get_input_dict(model_path, device): @@ -102,19 +35,6 @@ def get_input_dict(model_path, device): } -@dataclass -class OpStat: - op_name: str - op_dtypes: dict[str, int] = field(default_factory=dict) - count: int = 0 - - def update(self, other): - if isinstance(other, OpStat) and self.op_name == other.op_name: - self.count += other.count - for name, count in other.op_dtypes.items(): - self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count - - def resolve_native_multi_head_attention(*args, **kwargs): query, key, value = args[0], args[1], args[2] seq_len, batch_size, embed_dim = query.shape @@ -249,7 +169,9 @@ def update_op_stats(self, op_stats, op_name, op_dtype): if op_name is not None: dtype_str = str(op_dtype).replace("torch.", "") if op_stats.get(op_name, None) is None: - op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) + op_stats[op_name] = collect_stats_util.OpStat( + op_name, {dtype_str: 1}, 1 + ) else: op_stats[op_name].op_dtypes[dtype_str] = ( op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 @@ -353,10 +275,10 @@ def collect_op_stats_with_compile(model, sample_inputs, device): compiled_model = torch.compile(model, backend=meta_executor) compiled_model(*sample_inputs) meta_executor.summary() - return meta_executor.is_complete, meta_executor.op_stats + return meta_executor except Exception: print("Failed with torch.compile") - return False, None + return None def collect_op_stats_with_symbolic_trace(model, sample_inputs, device): @@ -364,50 +286,50 @@ def collect_op_stats_with_symbolic_trace(model, sample_inputs, device): try: # FX symbolic trace traced = torch.fx.symbolic_trace(model) - # print(traced.graph) except Exception: print("Failed with symbolic_trace") - return False, None + return None meta_executor = GraphMetaExecutor(device) meta_executor(traced, sample_inputs) meta_executor.summary() - return meta_executor.is_complete, meta_executor.op_stats + return meta_executor def collect_op_stats(model, sample_inputs, device): - is_complete_symbolic, op_stats_symbolic = collect_op_stats_with_symbolic_trace( + meta_executor_symbolic = collect_op_stats_with_symbolic_trace( model, sample_inputs, device ) - if not is_complete_symbolic: - is_complete_compile, op_stats_compile = collect_op_stats_with_compile( + if meta_executor_symbolic is None or not meta_executor_symbolic.is_complete: + meta_executor_compile = collect_op_stats_with_compile( model, sample_inputs, device ) - if is_complete_compile or op_stats_symbolic is None: - return "torch.compile", is_complete_compile, op_stats_compile - return "symbolic_trace", is_complete_symbolic, op_stats_symbolic + if meta_executor_symbolic is None or ( + meta_executor_compile is not None and meta_executor_compile.is_complete + ): + return "torch.compile", meta_executor_compile + return "symbolic_trace", meta_executor_symbolic def collect_model_stats(model_path, device, log_prompt): file_path = os.path.join(model_path, "model.py") - model_class = load_class_from_file(file_path, "GraphModule") + model_class = collect_stats_util.load_class_from_file(file_path, "GraphModule") model = model_class() - argument_name2types = get_argument_name_and_types(model_class, "forward") - num_outputs = get_number_of_returns(file_path, "GraphModule", "forward") + argument_name2types = collect_stats_util.get_argument_name_and_types( + model_class, "forward" + ) input_dict = get_input_dict(model_path, device) ordered_input_list = [ input_dict[arg_name] for arg_name in argument_name2types.keys() ] - num_ops = 0 ops_count_dict = {} op_dtypes = {} - method, is_complete, op_stats = collect_op_stats(model, ordered_input_list, device) - if op_stats is not None: - for op_name, stat in sorted(op_stats.items()): + method, meta_executor = collect_op_stats(model, ordered_input_list, device) + if meta_executor is not None: + for op_name, stat in sorted(meta_executor.op_stats.items()): if op_name not in ["placeholder", "output"]: - num_ops += stat.count ops_count_dict[op_name] = stat.count for dtype_str, num in stat.op_dtypes.items(): if dtype_str is not None and dtype_str != "None": @@ -426,35 +348,32 @@ def collect_model_stats(model_path, device, log_prompt): dtype_str = str(input_dict[name].dtype).replace("torch.", "") input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 - model_size_in_billion = model_size / 1e9 - num_params = len(param_dtypes) - num_inputs = len(input_dtypes) - + num_outputs = collect_stats_util.get_number_of_returns( + file_path, "GraphModule", "forward" + ) + num_ops = meta_executor.num_ops if meta_executor is not None else 0 source, heuristic_tag = read_graph_source_and_tag(model_path) - def dict_to_string(d): - kv_list = [f"{k}:{v}" for k, v in d.items()] - return " ".join(kv_list) - - def print_with_log_prompt(key, value): - print( - f"{log_prompt} [ModelStats.{key}] model_path:{model_path} {value}", - flush=True, - ) + is_complete = meta_executor.is_complete if meta_executor is not None else False + print( + f"model_stats collection information: model_path={model_path}, method={method}, is_ops_complete={is_complete}" + ) - print_with_log_prompt("num_inputs", num_inputs) - print_with_log_prompt("num_params", num_params) - print_with_log_prompt("num_outputs", num_outputs) - print_with_log_prompt("num_ops", num_ops) - print_with_log_prompt("model_size", f"{model_size_in_billion}B") - print_with_log_prompt("input_dtypes", dict_to_string(input_dtypes)) - print_with_log_prompt("param_dtypes", dict_to_string(param_dtypes)) - print_with_log_prompt("op_dtypes", dict_to_string(op_dtypes)) - print_with_log_prompt("ops", dict_to_string(ops_count_dict)) - print_with_log_prompt("source", source) - print_with_log_prompt("heuristic_tag", heuristic_tag) - print_with_log_prompt("method", method) - print_with_log_prompt("is_complete", is_complete) + stats = collect_stats_util.ModelStats( + model_path=model_path, + num_inputs=sum(input_dtypes.values()), + num_params=sum(param_dtypes.values()), + num_outputs=num_outputs, + num_ops=num_ops, + model_size_in_billion=model_size / 1e9, + input_dtypes=input_dtypes, + param_dtypes=param_dtypes, + op_dtypes=op_dtypes, + ops=ops_count_dict, + source=source, + heuristic_tag=heuristic_tag, + ) + collect_stats_util.print_model_stats(stats, log_prompt) def main(args): @@ -473,23 +392,9 @@ def main(args): else args.graph_net_samples_path ) - previous_failed_model_pathes = [] - if args.previous_collect_result_path is not None: - with open(args.previous_collect_result_path, "r") as f: - for line in f.readlines(): - if "[ModelStats]" in line: - fields = line.strip().split() - model_path = fields[2].split(":")[-1] - is_complete = fields[-1].split(":")[-1] - if is_complete == "False": - previous_failed_model_pathes.append(model_path) - i = 0 for root, dirs, files in os.walk(graph_net_samples_path): - if is_single_model_dir(root) and ( - args.previous_collect_result_path is None - or root in previous_failed_model_pathes - ): + if is_single_model_dir(root): print(f"[{i}] Collect information for {root}") cmd = [ "python", @@ -537,13 +442,6 @@ def main(args): default=None, help="GraphNet samples directory. e.g '../../samples'", ) - parser.add_argument( - "--previous-collect-result-path", - type=str, - required=False, - default=None, - help="Previous collect result path, use to recollect the failed cases", - ) parser.add_argument( "--log-prompt", type=str, From 6a5b8db218d3e489a1607e0004fb972f1446be24 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 23 Sep 2025 15:21:12 +0800 Subject: [PATCH 19/19] Support to collect input shapes and use json to dump list and dict. --- graph_net/collect_stats_util.py | 15 +++++++-------- graph_net/paddle/collect_stats.py | 10 +++++++--- graph_net/torch/collect_stats.py | 13 ++++++++++--- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/graph_net/collect_stats_util.py b/graph_net/collect_stats_util.py index 72ffd666a..fe98579d8 100644 --- a/graph_net/collect_stats_util.py +++ b/graph_net/collect_stats_util.py @@ -1,4 +1,5 @@ import ast +import json import importlib import inspect from dataclasses import dataclass, field @@ -28,6 +29,7 @@ class ModelStats: model_size_in_billion: float = None input_dtypes: Dict[str, int] = field(default_factory=dict) param_dtypes: Dict[str, int] = field(default_factory=dict) + input_shapes: Dict[str, list] = field(default_factory=dict) op_dtypes: Dict[str, int] = field(default_factory=dict) ops: Dict[str, int] = field(default_factory=dict) source: str = None @@ -37,10 +39,6 @@ class ModelStats: def print_model_stats(stats, log_prompt): assert isinstance(stats, ModelStats), f"{type(stats)=}" - def dict_to_string(d): - kv_list = [f"{k}:{v}" for k, v in d.items()] - return " ".join(kv_list) - def print_with_log_prompt(key, value): print( f"{log_prompt} [ModelStats.{key}] model_path:{stats.model_path} {value}", @@ -52,10 +50,11 @@ def print_with_log_prompt(key, value): print_with_log_prompt("num_outputs", stats.num_outputs) print_with_log_prompt("num_ops", stats.num_ops) print_with_log_prompt("model_size", f"{stats.model_size_in_billion}B") - print_with_log_prompt("input_dtypes", dict_to_string(stats.input_dtypes)) - print_with_log_prompt("param_dtypes", dict_to_string(stats.param_dtypes)) - print_with_log_prompt("op_dtypes", dict_to_string(stats.op_dtypes)) - print_with_log_prompt("ops", dict_to_string(stats.ops)) + print_with_log_prompt("input_dtypes", json.dumps(stats.input_dtypes)) + print_with_log_prompt("param_dtypes", json.dumps(stats.param_dtypes)) + print_with_log_prompt("input_shapes", json.dumps(stats.input_shapes)) + print_with_log_prompt("op_dtypes", json.dumps(stats.op_dtypes)) + print_with_log_prompt("ops", json.dumps(stats.ops)) print_with_log_prompt("source", stats.source) print_with_log_prompt("heuristic_tag", stats.heuristic_tag) diff --git a/graph_net/paddle/collect_stats.py b/graph_net/paddle/collect_stats.py index 845cab101..34f9c366d 100644 --- a/graph_net/paddle/collect_stats.py +++ b/graph_net/paddle/collect_stats.py @@ -109,10 +109,11 @@ def __call__(self, program): else: # for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined if op_name in [ + "broadcast_tensors", + "distribute_fpn_proposals", + "meshgrid", "split", "split_with_num", - "meshgrid", - "distribute_fpn_proposals", ]: op_dtype = self.parse_pir_value_dtypes( str(out.type()) @@ -165,6 +166,7 @@ def collect_model_stats(model_path, log_prompt): model = model_class() model_size = 0 + input_shapes = set() input_dtypes = {} param_dtypes = {} ops_count_dict = {} @@ -190,6 +192,7 @@ def collect_model_stats(model_path, log_prompt): param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1 elif name in inputs.keys(): input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 + input_shapes.add(str(value["shape"])) num_outputs = collect_stats_util.get_number_of_returns( file_path, "GraphModule", "forward" @@ -200,7 +203,7 @@ def collect_model_stats(model_path, log_prompt): program_analyzer.is_complete if program_analyzer is not None else False ) print( - f"model_stats collection information: model_path={model_path}, method=to_static, is_ops_complete={is_complete}" + f"model_stats collection information: model_path={model_path} method=to_static is_ops_complete={is_complete}" ) stats = collect_stats_util.ModelStats( @@ -212,6 +215,7 @@ def collect_model_stats(model_path, log_prompt): model_size_in_billion=model_size / 1e9, input_dtypes=input_dtypes, param_dtypes=param_dtypes, + input_shapes=list(input_shapes), op_dtypes=op_dtypes, ops=ops_count_dict, source=source, diff --git a/graph_net/torch/collect_stats.py b/graph_net/torch/collect_stats.py index b31a48ce8..78a6502a0 100644 --- a/graph_net/torch/collect_stats.py +++ b/graph_net/torch/collect_stats.py @@ -336,17 +336,23 @@ def collect_model_stats(model_path, device, log_prompt): op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num model_size = 0 + input_shapes = set() input_dtypes = {} param_dtypes = {} for name, arg_type in argument_name2types.items(): - if arg_type == torch.nn.parameter.Parameter: + if ( + name.startswith("L_self_modules_") + or arg_type == torch.nn.parameter.Parameter + ): + # Some parameters like L_self_modules_bn1_buffers_running_mean_ are torch.Tensor. param_numel = math.prod(input_dict[name].shape) model_size += param_numel dtype_str = str(input_dict[name].dtype).replace("torch.", "") param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1 - else: + elif arg_type == torch.Tensor: dtype_str = str(input_dict[name].dtype).replace("torch.", "") input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1 + input_shapes.add(str(list(input_dict[name].shape))) num_outputs = collect_stats_util.get_number_of_returns( file_path, "GraphModule", "forward" @@ -356,7 +362,7 @@ def collect_model_stats(model_path, device, log_prompt): is_complete = meta_executor.is_complete if meta_executor is not None else False print( - f"model_stats collection information: model_path={model_path}, method={method}, is_ops_complete={is_complete}" + f"model_stats collection information: model_path={model_path} method={method} is_ops_complete={is_complete}" ) stats = collect_stats_util.ModelStats( @@ -368,6 +374,7 @@ def collect_model_stats(model_path, device, log_prompt): model_size_in_billion=model_size / 1e9, input_dtypes=input_dtypes, param_dtypes=param_dtypes, + input_shapes=list(input_shapes), op_dtypes=op_dtypes, ops=ops_count_dict, source=source,