Skip to content

Commit 0cc0f18

Browse files
authored
[jax-inference-offloading] Print CI metric (#1819)
Example metrics: ``` 0: JIO_METRIC_TRANSFER|mean=0.061956|std=0.001554|min=0.059907|max=0.064638|unit=s 0: JIO_METRIC_HANDSHAKE|mean=28.636484|std=0.000000|min=28.636484|max=28.636484|unit=s 0: JIO_METRIC_LOADMODEL|mean=5.969427|std=0.000000|min=5.969427|max=5.969427|unit=s ```
1 parent 7055490 commit 0cc0f18

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

jax-inference-offloading/examples/trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,16 @@
153153

154154
if jax.process_index() == 0:
155155
timer.summary(sort_by='name', precision=3)
156+
for metric_name, metric_key in [
157+
('JIO_METRIC_TRANSFER', r'transport\.run\d+$'),
158+
('JIO_METRIC_HANDSHAKE', r'create_bridge\.handshake$'),
159+
('JIO_METRIC_LOADMODEL', r'load_model$'),
160+
]:
161+
print(timer.ci_metric(
162+
metric_name,
163+
timer.node_stat(metric_key, ('mean', 'std', 'min', 'max')),
164+
unit='s'
165+
))
156166

157167
if jax.process_index() == 0:
158168
bridge.gateway.shutdown()

jax-inference-offloading/jax_inference_offloading/timer.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
import re
1718
import time
19+
import numpy as np
1820
from collections import defaultdict, deque
1921
from contextlib import contextmanager
22+
from typing import Callable
2023

2124

2225
class Timer:
@@ -39,7 +42,6 @@ class Timer:
3942
parse : 0.3017 s
4043
json : 0.3017 s
4144
train : 2.8344 s
42-
4345
"""
4446

4547
def __init__(self) -> None:
@@ -113,6 +115,38 @@ def _print_by_name(self, precision: int, col_sep: str) -> None:
113115
indent = " " * depth
114116
print(f"{indent + segment:<{col_width_name}}{col_sep}{t:{col_width_time}.{precision}f}")
115117

118+
def node_stat(
119+
self,
120+
pattern: str,
121+
stats: str | Callable[[list[float]], float] | list[str | Callable[[list[float]], float]] = 'mean'
122+
) -> float | tuple[float, ...] | None:
123+
matcher = re.compile(pattern)
124+
vals = [t for name, t in self._times.items() if matcher.fullmatch(name)]
125+
if not vals:
126+
return None
127+
128+
def get_op(name):
129+
try:
130+
return getattr(np, name)
131+
except AttributeError:
132+
raise ValueError(f"Unsupported statistic function name: {name}")
133+
134+
if isinstance(stats, (list, tuple)):
135+
return {s: get_op(s)(vals) for s in stats}
136+
else:
137+
return {stats: get_op(stats)(vals)}
138+
139+
def ci_metric(
140+
self,
141+
tag,
142+
stats: dict[str, float],
143+
unit: str = 's',
144+
sep: str = '|',
145+
):
146+
return sep.join(
147+
[tag] + [f'{k}={v:.6f}' for k, v in stats.items()] + [f'unit={unit}']
148+
)
149+
116150
def reset(self) -> None:
117151
"""Clear all recorded data."""
118152
self._times.clear()

0 commit comments

Comments
 (0)