@@ -55,6 +55,20 @@ class OpStat:
5555 count : int = 0
5656
5757
58+ def resolve_native_multi_head_attention (* args , ** kwargs ):
59+ query , key , value = args [0 ], args [1 ], args [2 ]
60+ seq_len , batch_size , embed_dim = query .shape
61+ attn_output = torch .empty (
62+ (seq_len , batch_size , embed_dim ), dtype = query .dtype , device = "meta"
63+ )
64+
65+ # seq_len_k = key.shape[0]
66+ # num_heads = args[4]
67+ # attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68+ # dtype=query.dtype, device='meta')
69+ return attn_output # , attn_output_weights
70+
71+
5872def resolve_get_attr (gm : torch .fx .GraphModule , node : torch .fx .Node ):
5973 attr_itr = node .target .split ("." )
6074 val = gm
@@ -65,8 +79,8 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
6579
6680
6781def collect_op_stats (model , input_dict ):
68- # FX symbolic trace
6982 try :
83+ # FX symbolic trace
7084 traced = torch .fx .symbolic_trace (model )
7185 # print(traced.graph)
7286 except Exception :
@@ -118,7 +132,10 @@ def collect_op_stats(model, input_dict):
118132 node_args = node_args [1 :]
119133
120134 try :
121- out = op_func (* node_args , ** node_kwargs )
135+ if op_name == "_native_multi_head_attention" :
136+ out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
137+ else :
138+ out = op_func (* node_args , ** node_kwargs )
122139 node_outputs [node .name ] = out
123140 dtype = out .dtype if isinstance (out , torch .Tensor ) else None
124141 except Exception :
0 commit comments