diff --git a/src/umm/cli/mathvista_eval.py b/src/umm/cli/mathvista_eval.py
index e72f35a..d1ef1f3 100644
--- a/src/umm/cli/mathvista_eval.py
+++ b/src/umm/cli/mathvista_eval.py
@@ -12,6 +12,17 @@
from tqdm import tqdm
from umm.core.config import load_config
+from umm.eval.distributed import (
+ barrier,
+ cleanup_distributed,
+ cleanup_shards,
+ load_shard_items,
+ maybe_init_distributed,
+ merge_shards,
+ rank_shard_path,
+ sum_across_ranks,
+)
+from umm.eval.runner import run_sharded_inference
DS_COLLECTIONS = {
@@ -44,7 +55,6 @@
def _quick_extract(response: str, problem: dict) -> str | None:
- """Try rule-based extraction before falling back to LLM."""
if not response:
return ""
question_type = problem.get("question_type", "")
@@ -64,18 +74,16 @@ def _quick_extract(response: str, problem: dict) -> str | None:
except (ValueError, TypeError):
pass
- # Try regex for "Final answer: ..." or "Answer: ..."
match = re.search(r"(?:Final answer:|Answer:)\s*(.*)", response, re.IGNORECASE)
if match:
ans = match.group(1).strip()
if ans:
return ans
- return None # need LLM extraction
+ return None
def _build_extract_prompt(query: str, response: str) -> str:
- """Build the full prompt for LLM-based answer extraction."""
test_prompt = f"{query}\n\n{response}"
return f"{_EXTRACT_DEMO_PROMPT.strip()}\n\n{test_prompt}\n\nExtracted answer: "
@@ -87,13 +95,11 @@ def _run_llm_extraction(
use_quick_extract: bool = True,
) -> dict[str, Any]:
"""Extract answers from model responses using a local LLM (e.g. Qwen3-32B)."""
- # First pass: check which items already have extraction or can be rule-extracted
already_done = 0
need_llm = []
for pid, problem in results.items():
if "choices" not in problem:
continue
- # Skip if extraction was already done in a previous run
if "extraction" in problem and problem["extraction"]:
already_done += 1
continue
@@ -117,7 +123,6 @@ def _run_llm_extraction(
flush=True,
)
- # Only load the heavy LLM if there are items that actually need it
if not need_llm:
print("[mathvista] all items already extracted, skipping LLM load", flush=True)
return results
@@ -156,11 +161,9 @@ def _run_llm_extraction(
)
generated = outputs[0][inputs["input_ids"].shape[1]:]
extraction = tokenizer.decode(generated, skip_special_tokens=True).strip()
- # Handle Qwen3 thinking format: strip ... block
think_match = re.search(r"\s*(.*)", extraction, re.DOTALL)
if think_match:
extraction = think_match.group(1).strip()
- # Take only the first line as the extracted answer
extraction = extraction.split("\n")[0].strip()
problem["extraction"] = extraction
@@ -209,7 +212,6 @@ def _extract_text(output: Any) -> str:
text = _extract_text(item)
if text:
return text
- # Handle adapters that return {"understandings": [{"response": "..."}]}
for list_key in ("understandings",):
container = output.get(list_key)
if isinstance(container, list):
@@ -236,7 +238,6 @@ def _load_eval_cfg(config_path: str) -> tuple[dict[str, Any], dict[str, Any], di
def _find_latest_results(out_dir: Path, ds_name: str) -> Path | None:
- """Find the most recent results JSON for a dataset in out_dir."""
candidates = sorted(out_dir.glob(f"{ds_name}_*.json"))
candidates = [c for c in candidates if "_checkpoint" not in c.name and "_score" not in c.name]
if candidates:
@@ -288,7 +289,6 @@ def run_mathvista_eval_command(args: Any) -> int:
gt_file = mathvista_cfg.get("gt_file")
resume = bool(mathvista_cfg.get("resume", False))
- # Mode support (like wise): generate / score / full
mode = str(mathvista_cfg.get("mode", "full")).strip().lower()
if mode not in ("full", "generate", "score"):
print(f"[mathvista] unknown mode '{mode}', defaulting to 'full'", flush=True)
@@ -297,14 +297,12 @@ def run_mathvista_eval_command(args: Any) -> int:
run_gen = mode in ("full", "generate")
run_score = mode in ("full", "score")
- # LLM extraction config (replaces OpenAI)
llm_extract_cfg = mathvista_cfg.get("llm_extract", {})
if not isinstance(llm_extract_cfg, dict):
llm_extract_cfg = {}
llm_model_path = str(llm_extract_cfg.get("model_path", "")).strip()
llm_max_new_tokens = int(llm_extract_cfg.get("max_new_tokens", 2048))
- # Legacy OpenAI config (fallback when llm_extract is not configured)
run_extract_legacy = bool(mathvista_cfg.get("run_extract", False))
run_calculation_legacy = bool(mathvista_cfg.get("run_calculation", False))
openai_api_key = mathvista_cfg.get("openai_api_key")
@@ -320,195 +318,249 @@ def run_mathvista_eval_command(args: Any) -> int:
"mode": mode,
}
- # ── Phase 1: Generation ──
- if run_gen:
- from datasets import load_dataset
- from PIL import Image
+ dist_info = maybe_init_distributed()
+ try:
+ summary["world_size"] = dist_info.world_size
+ if dist_info.world_size > 1 and run_gen:
+ print(
+ f"[mathvista] distributed inference enabled: rank={dist_info.rank}, "
+ f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}",
+ flush=True,
+ )
+
+ local_total_written = 0
- from umm.inference import InferencePipeline
+ # ── Phase 1: Generation ──
+ if run_gen:
+ from datasets import load_dataset
+ from PIL import Image
- pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
+ from umm.inference import InferencePipeline
- for ds_name in datasets:
- entry = DS_COLLECTIONS.get(ds_name)
- if not entry:
- raise ValueError(f"Unknown MathVista dataset: {ds_name}")
+ pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
+
+ for ds_name in datasets:
+ entry = DS_COLLECTIONS.get(ds_name)
+ if not entry:
+ raise ValueError(f"Unknown MathVista dataset: {ds_name}")
+
+ dataset_root = str(mathvista_cfg.get("root", entry["root"]))
+ split = str(mathvista_cfg.get("split", entry["split"]))
+ dataset = load_dataset(
+ dataset_root,
+ cache_dir=str(_resolve_path(cache_dir, repo_root)) if cache_dir else None,
+ )
+ data = dataset[split]
+
+ checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl"
+ shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size)
+
+ done_pids: set[str] = set()
+ results_file: Path | None = None
+
+ if resume:
+ shard_items = load_shard_items(shard_path)
+ if shard_items:
+ done_pids = {str(it.get("pid", "")) for it in shard_items}
+ print(
+ f"[mathvista] {ds_name}: rank {dist_info.rank} resume "
+ f"from shard: {len(done_pids)} done",
+ flush=True,
+ )
+ elif dist_info.world_size <= 1:
+ # Single-card fallback: a previous completed run can be loaded
+ # whole. Multi-card runs rely on the rank shard only.
+ prior = _find_latest_results(out_dir, ds_name)
+ if prior:
+ print(
+ f"[mathvista] resume: using completed file {prior}",
+ flush=True,
+ )
+ results_file = prior
+
+ if results_file is None:
+ print(
+ f"[mathvista] {ds_name}: total={len(data)}, rank={dist_info.rank}, "
+ f"done={len(done_pids)}",
+ flush=True,
+ )
+
+ def payload_fn(data_item: Any) -> dict[str, Any]:
+ pid = data_item.get("pid")
+ if pid is None:
+ raise ValueError("MathVista sample missing `pid`.")
+ image = data_item.get("decoded_image")
+ if image is None:
+ raise ValueError("MathVista sample missing `decoded_image`.")
+ if not isinstance(image, Image.Image):
+ raise ValueError("Expected `decoded_image` to be a PIL image.")
+ image_dir.mkdir(parents=True, exist_ok=True)
+ image_path = image_dir / f"{pid}.png"
+ image.save(image_path, format="PNG")
+ question = data_item.get("query")
+ if question is None:
+ raise ValueError("MathVista sample missing `query`.")
+ prompt = COT_INSTRUCTION.format(question=question) if use_cot else question
+ return {
+ "backbone": backbone,
+ "task": "understanding",
+ "prompt": prompt,
+ "images": [str(image_path)],
+ "params": request_params,
+ "metadata": {"pid": pid, "dataset": ds_name},
+ }
+
+ def record_fn(data_item: Any, raw: Any, _idx: int) -> dict[str, Any]:
+ response = _extract_text(raw)
+ item = dict(data_item)
+ item.pop("decoded_image", None)
+ item["response"] = response
+ # Ensure pid is JSON-friendly so resume works.
+ item["pid"] = str(item.get("pid", ""))
+ return item
+
+ n_written = run_sharded_inference(
+ infer_fn=pipeline.run,
+ dist_info=dist_info,
+ shard_path=shard_path,
+ samples=data,
+ total=len(data),
+ payload_fn=payload_fn,
+ record_fn=record_fn,
+ sample_id_fn=lambda d: str(d.get("pid", "")),
+ done_ids=done_pids,
+ max_samples=max_samples,
+ log_prefix=f"mathvista/{ds_name}/rank{dist_info.rank}",
+ )
+ local_total_written += n_written
+
+ barrier(dist_info)
+
+ time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime())
+ results_file = out_dir / f"{ds_name}_{time_prefix}.json"
+
+ if dist_info.rank == 0:
+ merged = merge_shards(checkpoint_jsonl)
+ results: dict[str, Any] = {item["pid"]: item for item in merged}
+ results_file.write_text(json.dumps(results, indent=2), encoding="utf-8")
+ cleanup_shards(checkpoint_jsonl)
+ if dist_info.world_size <= 1 and checkpoint_jsonl.exists():
+ checkpoint_jsonl.unlink()
+
+ barrier(dist_info)
+
+ summary[f"{ds_name}_output_path"] = str(results_file)
+
+ if mode == "generate":
+ print(f"[mathvista] generation phase done, outputs={out_dir}", flush=True)
+
+ # Free GPU memory from generation pipeline before scoring (rank 0 only
+ # touches scoring; other ranks return after the barriers above).
+ del pipeline
+ import gc
+ gc.collect()
+ import torch as _torch
+ if _torch.cuda.is_available():
+ _torch.cuda.empty_cache()
+ print(
+ f"[mathvista] rank {dist_info.rank}: released generation pipeline GPU memory",
+ flush=True,
+ )
- dataset_root = str(mathvista_cfg.get("root", entry["root"]))
- split = str(mathvista_cfg.get("split", entry["split"]))
- dataset = load_dataset(
- dataset_root,
- cache_dir=str(_resolve_path(cache_dir, repo_root)) if cache_dir else None,
+ # Multi-card: ranks > 0 finish here; only rank 0 proceeds with scoring +
+ # summary writing (scoring loads its own LLM with device_map="auto").
+ local_written_all = sum_across_ranks(local_total_written, dist_info)
+ if dist_info.rank != 0:
+ print(
+ f"[umm eval] rank {dist_info.rank} finished MathVista shard: "
+ f"samples_written={local_total_written}",
+ flush=True,
)
- data = dataset[split]
+ return 0
- checkpoint_json = out_dir / f"{ds_name}_checkpoint.json"
- results: dict[str, Any] = {}
- results_file: Path | None = None
+ summary["samples_written"] = local_written_all
- if resume:
- if checkpoint_json.exists():
- results = json.loads(checkpoint_json.read_text("utf-8"))
- print(f"[mathvista] resume from checkpoint: {len(results)} done", flush=True)
+ # ── Phase 2: Scoring (extract + calculate) — rank 0 only ──
+ if run_score:
+ for ds_name in datasets:
+ results_file = None
+ if f"{ds_name}_output_path" in summary:
+ results_file = Path(summary[f"{ds_name}_output_path"])
else:
results_file = _find_latest_results(out_dir, ds_name)
- if results_file:
- print(f"[mathvista] resume: using completed file {results_file}", flush=True)
-
- if results_file is None:
- print(
- f"[mathvista] {ds_name}: {len(data)} total, {len(results)} done, "
- f"{len(data) - len(results)} remaining",
- flush=True,
- )
- for idx, data_item in enumerate(tqdm(data, desc=f"mathvista/{ds_name}", file=sys.stdout), start=1):
- pid = data_item.get("pid")
- if pid is None:
- raise ValueError("MathVista sample missing `pid`.")
- if str(pid) in results:
- continue
-
- image = data_item.get("decoded_image")
- if image is None:
- raise ValueError("MathVista sample missing `decoded_image`.")
- if not isinstance(image, Image.Image):
- raise ValueError("Expected `decoded_image` to be a PIL image.")
-
- image_dir.mkdir(parents=True, exist_ok=True)
- image_path = image_dir / f"{pid}.png"
- image.save(image_path, format="PNG")
-
- question = data_item.get("query")
- if question is None:
- raise ValueError("MathVista sample missing `query`.")
+ if results_file is None or not results_file.exists():
+ raise FileNotFoundError(
+ f"No results file found for {ds_name} in {out_dir}. "
+ f"Run generation phase first (mode: generate)."
+ )
+
+ print(f"[mathvista] scoring {ds_name} from {results_file}", flush=True)
+ results = json.loads(results_file.read_text("utf-8"))
+
+ if llm_model_path:
+ results = _run_llm_extraction(
+ results,
+ model_path=llm_model_path,
+ max_new_tokens=llm_max_new_tokens,
+ use_quick_extract=use_cot,
+ )
+ results_file.write_text(json.dumps(results, indent=2), encoding="utf-8")
+ print(f"[mathvista] saved extractions to {results_file}", flush=True)
+ elif run_extract_legacy:
+ cmd = [
+ sys.executable,
+ "src/umm/eval/internvl_chat/eval/mathvista/extract_answer.py",
+ "--output_file",
+ results_file.name,
+ "--output_dir",
+ str(out_dir),
+ ]
if use_cot:
- prompt = COT_INSTRUCTION.format(question=question)
- else:
- prompt = question
-
- payload = {
- "backbone": backbone,
- "task": "understanding",
- "prompt": prompt,
- "images": [str(image_path)],
- "params": request_params,
- "metadata": {"pid": pid, "dataset": ds_name},
- }
- response = _extract_text(pipeline.run(payload))
-
- item = dict(data_item)
- item.pop("decoded_image", None)
- item["response"] = response
- results[str(pid)] = item
-
- checkpoint_json.write_text(json.dumps(results, indent=2), encoding="utf-8")
-
- if max_samples > 0 and len(results) >= max_samples:
- break
-
- time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime())
- results_file = out_dir / f"{ds_name}_{time_prefix}.json"
- results_file.write_text(json.dumps(results, indent=2), encoding="utf-8")
-
- if checkpoint_json.exists():
- checkpoint_json.unlink()
-
- summary[f"{ds_name}_output_path"] = str(results_file)
-
- if mode == "generate":
- print(f"[mathvista] generation phase done, outputs={out_dir}", flush=True)
-
- # Free GPU memory from generation pipeline before scoring
- del pipeline
- import gc
- gc.collect()
- import torch as _torch
- if _torch.cuda.is_available():
- _torch.cuda.empty_cache()
- print("[mathvista] released generation pipeline GPU memory", flush=True)
-
- # ── Phase 2: Scoring (extract + calculate) ──
- if run_score:
- for ds_name in datasets:
- # Find results file from generation phase (or previous run)
- results_file = None
- if f"{ds_name}_output_path" in summary:
- results_file = Path(summary[f"{ds_name}_output_path"])
- else:
- results_file = _find_latest_results(out_dir, ds_name)
- if results_file is None or not results_file.exists():
- raise FileNotFoundError(
- f"No results file found for {ds_name} in {out_dir}. "
- f"Run generation phase first (mode: generate)."
- )
-
- print(f"[mathvista] scoring {ds_name} from {results_file}", flush=True)
- results = json.loads(results_file.read_text("utf-8"))
-
- # ── Extract answers ──
- if llm_model_path:
- # Use local Qwen model for extraction
- results = _run_llm_extraction(
- results,
- model_path=llm_model_path,
- max_new_tokens=llm_max_new_tokens,
- use_quick_extract=use_cot,
- )
- # Save results with extraction field
- results_file.write_text(json.dumps(results, indent=2), encoding="utf-8")
- print(f"[mathvista] saved extractions to {results_file}", flush=True)
- elif run_extract_legacy:
- # Fallback: use legacy OpenAI-based extract_answer.py
+ cmd.append("--quick_extract")
+ env = None
+ if isinstance(openai_api_key, str) and openai_api_key.strip():
+ env = dict(os.environ)
+ env["OPENAI_API_KEY"] = openai_api_key.strip()
+ proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True, env=env)
+ print(proc.stdout)
+ if proc.returncode != 0:
+ if proc.stderr:
+ print(proc.stderr, file=sys.stderr)
+ raise RuntimeError(f"MathVista extract_answer failed with return code {proc.returncode}")
+ summary[f"{ds_name}_extract_stdout"] = proc.stdout
+
+ score_file = results_file.with_name(f"{results_file.stem}_score.json")
cmd = [
sys.executable,
- "src/umm/eval/internvl_chat/eval/mathvista/extract_answer.py",
+ "src/umm/eval/internvl_chat/eval/mathvista/calculate_score.py",
"--output_file",
results_file.name,
"--output_dir",
str(out_dir),
+ "--score_file",
+ score_file.name,
]
- if use_cot:
- cmd.append("--quick_extract")
- env = None
- if isinstance(openai_api_key, str) and openai_api_key.strip():
- env = dict(os.environ)
- env["OPENAI_API_KEY"] = openai_api_key.strip()
- proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True, env=env)
+ if isinstance(gt_file, str) and gt_file.strip():
+ cmd.extend(["--gt_file", gt_file.strip()])
+ proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True)
print(proc.stdout)
if proc.returncode != 0:
if proc.stderr:
print(proc.stderr, file=sys.stderr)
- raise RuntimeError(f"MathVista extract_answer failed with return code {proc.returncode}")
- summary[f"{ds_name}_extract_stdout"] = proc.stdout
-
- # ── Calculate scores ──
- score_file = results_file.with_name(f"{results_file.stem}_score.json")
- cmd = [
- sys.executable,
- "src/umm/eval/internvl_chat/eval/mathvista/calculate_score.py",
- "--output_file",
- results_file.name,
- "--output_dir",
- str(out_dir),
- "--score_file",
- score_file.name,
- ]
- if isinstance(gt_file, str) and gt_file.strip():
- cmd.extend(["--gt_file", gt_file.strip()])
- proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True)
- print(proc.stdout)
- if proc.returncode != 0:
- if proc.stderr:
- print(proc.stderr, file=sys.stderr)
- raise RuntimeError(f"MathVista calculate_score failed with return code {proc.returncode}")
- summary[f"{ds_name}_score_file"] = str(score_file)
- summary[f"{ds_name}_score_stdout"] = proc.stdout
-
- if isinstance(score_output_path, str) and score_output_path:
- score_path = _resolve_path(score_output_path, repo_root)
- score_path.parent.mkdir(parents=True, exist_ok=True)
- score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
- print(f"[umm eval] wrote MathVista summary to {score_path}")
-
- print(f"[umm eval] completed MathVista (mode={mode}) for backbone={backbone}, outputs={out_dir}")
- return 0
+ raise RuntimeError(f"MathVista calculate_score failed with return code {proc.returncode}")
+ summary[f"{ds_name}_score_file"] = str(score_file)
+ summary[f"{ds_name}_score_stdout"] = proc.stdout
+
+ if isinstance(score_output_path, str) and score_output_path:
+ score_path = _resolve_path(score_output_path, repo_root)
+ score_path.parent.mkdir(parents=True, exist_ok=True)
+ score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
+ print(f"[umm eval] wrote MathVista summary to {score_path}")
+
+ print(
+ f"[umm eval] completed MathVista (mode={mode}) for backbone={backbone}, "
+ f"outputs={out_dir}, world_size={dist_info.world_size}"
+ )
+ return 0
+ finally:
+ cleanup_distributed(dist_info)
diff --git a/src/umm/cli/mmbench_eval.py b/src/umm/cli/mmbench_eval.py
index 7b14fa1..97d5f64 100644
--- a/src/umm/cli/mmbench_eval.py
+++ b/src/umm/cli/mmbench_eval.py
@@ -1,12 +1,8 @@
from __future__ import annotations
import base64
-import copy
import json
import os
-import random
-import string
-import sys
import time
from io import BytesIO
from pathlib import Path
@@ -14,9 +10,20 @@
import pandas as pd
from PIL import Image
-from tqdm import tqdm
from umm.core.config import load_config
+from umm.eval.distributed import (
+ barrier,
+ cleanup_distributed,
+ cleanup_shards,
+ load_shard_items,
+ maybe_init_distributed,
+ merge_shards,
+ rank_shard_path,
+ sum_across_ranks,
+)
+from umm.eval.runner import run_sharded_inference
+from umm.inference import InferencePipeline
DS_COLLECTIONS = {
@@ -50,387 +57,9 @@
"type": "dev",
"language": "cn",
},
- # V11 test splits — released with GT labels after the online eval service closed
- # on 2026-03-31 (see open-compass/MMBench#61). Same TSV schema as dev (four
- # rotations per question, index = base + k * 1_000_000 for k = 0..3).
- "MMBench_TEST_EN_V11": {
- "root": "/datasets/mmbench/MMBench_TEST_EN_V11.tsv",
- "type": "test",
- "language": "en",
- },
- "MMBench_TEST_CN_V11": {
- "root": "/datasets/mmbench/MMBench_TEST_CN_V11.tsv",
- "type": "test",
- "language": "cn",
- },
}
-# Port of VLMEvalKit MMB_abbrs (vlmeval/dataset/utils/multiple_choice.py:17-24)
-MMB_ABBRS = {
- "coarse_perception": "CP",
- "finegrained_perception (instance-level)": "FP-S",
- "finegrained_perception (cross-instance)": "FP-C",
- "logic_reasoning": "LR",
- "relation_reasoning": "RR",
- "attribute_reasoning": "AR",
-}
-
-
-# VLMEvalKit uses the same English scaffold for both EN and CN splits — only the
-# question / option / hint fields themselves are localised inside the TSV.
-# (vlmeval/dataset/image_mcq.py:210-245)
-_ANSWER_INSTRUCTION = "Please select the correct answer from the options above. \n"
-
-
-# ---------------------------------------------------------------------------
-# Generation-time prompt (ports vlmeval/dataset/image_mcq.py:210-245)
-# ---------------------------------------------------------------------------
-
-
-def _build_prompt(question: str, options: "dict[str, str]", hint: str | None) -> str:
- options_prompt = "Options:\n"
- for key, item in options.items():
- options_prompt += f"{key}. {item}\n"
- prompt = ""
- if hint:
- prompt += f"Hint: {hint}\n"
- prompt += f"Question: {question}\n"
- if options:
- prompt += options_prompt
- prompt += _ANSWER_INSTRUCTION
- return prompt
-
-
-# ---------------------------------------------------------------------------
-# Exact-matching extraction — ports vlmeval/utils/matching_util.py:12-116
-# ---------------------------------------------------------------------------
-
-
-_REJECT_TO_ANSWER = [
- "Sorry, I can't help with images of people yet.",
- "I can't process this file.",
- "I'm sorry, but without the image provided",
- "Cannot determine the answer",
-]
-_EXACT_CHARS_TO_SPACE = ".()[],:;!*#{}"
-
-
-def _can_infer_option(answer: str, choices: "dict[str, str]") -> "str | bool":
- if "Failed to obtain answer via API" in answer:
- return False
- for err in _REJECT_TO_ANSWER:
- if err in answer:
- return "Z"
-
- answer_mod = copy.copy(answer)
- for c in _EXACT_CHARS_TO_SPACE:
- answer_mod = answer_mod.replace(c, " ")
- splits = [x.strip() for x in answer_mod.split()]
-
- def count_choice(tokens, cands, prefix: str = "", suffix: str = "") -> int:
- return sum(1 for c in cands if (prefix + c + suffix) in tokens)
-
- count = count_choice(splits, choices)
- if count == 1:
- for ch in choices:
- if ch in splits and splits.index(ch) > len(splits) - 5:
- return ch
- elif count == 0 and count_choice(splits, {"Z", ""}) == 1:
- return "Z"
- return False
-
-
-def _can_infer_text(answer: str, choices: "dict[str, str]") -> "str | bool":
- answer_low = answer.lower()
- total_len = sum(len(str(v)) for v in choices.values())
- if len(answer_low) > 2 * total_len:
- return False
- lowered = {k: str(v).lower() for k, v in choices.items()}
- cands = [k for k, v in lowered.items() if v in answer_low]
- if len(cands) == 1:
- return cands[0]
- return False
-
-
-def _can_infer(answer: Any, choices: "dict[str, str]") -> "str | bool":
- answer = str(answer)
- copt = _can_infer_option(answer, choices)
- return copt if copt else _can_infer_text(answer, choices)
-
-
-# ---------------------------------------------------------------------------
-# LLM judge — ports vlmeval/dataset/utils/multiple_choice.py build_prompt(_cn)
-# and build_option_str / cn_string.
-# ---------------------------------------------------------------------------
-
-
-_JUDGE_PROMPT_EN = (
- "You are an AI assistant who will help me to match "
- "an answer with several options of a single-choice question. "
- "You are provided with a question, several options, and an answer, "
- "and you need to find which option is most similar to the answer. "
- "If the meaning of all options are significantly different from the answer, output Z. "
- "Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n"
- "Example 1: \n"
- "Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n"
- "Answer: a cute teddy bear\nYour output: A\n"
- "Example 2: \n"
- "Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n"
- "Answer: Spider\nYour output: Z\n"
- "Example 3: \n"
- "Question: {}?\nOptions: {}\nAnswer: {}\nYour output: "
-)
-
-
-_JUDGE_PROMPT_CN = (
- "你是一个帮助我匹配答案与单选题中多个选项的 AI 助手。"
- "你会被提供:一个问题,多个选项,一个答案。你的任务是找到与答案意义最相近的选项。"
- "如果所有选项的意义都与答案显著不同,则输出 Z。"
- "你应该输出一个单个的大写字母,例如 A, B, C, D(如果它们是有效选项),或 Z。"
- "例 1:"
- "问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 一只可爱的泰迪熊\n输出: A\n"
- "例 2: \n"
- "问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 蜘蛛\n输出: Z\n"
- "例 3: \n"
- "问题: {}?\n选项: {}\n答案: {}\n输出: "
-)
-
-
-def _build_option_str(options: "dict[str, str]") -> str:
- s = "There are several options: \n"
- for c, content in options.items():
- if content is None:
- continue
- if isinstance(content, float) and pd.isna(content):
- continue
- s += f"{c}. {content}\n"
- return s
-
-
-def _cn_string(text: str) -> bool:
- return any("\u4e00" <= ch <= "\u9fff" for ch in str(text))
-
-
-def _build_judge_prompt(question: str, options: "dict[str, str]", prediction: str, language: str) -> str:
- option_str = _build_option_str(options)
- if language == "cn" or _cn_string(question):
- tmpl = _JUDGE_PROMPT_CN
- else:
- tmpl = _JUDGE_PROMPT_EN
- return tmpl.format(question, option_str, prediction)
-
-
-# ---------------------------------------------------------------------------
-# Judge model bundle — lazy loaded, shared across all extractions in a run.
-# ---------------------------------------------------------------------------
-
-
-class _JudgeBundle:
- def __init__(self, model_path: str, max_new_tokens: int = 32) -> None:
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer
-
- print(f"[mmbench] loading judge LLM: {model_path}", flush=True)
- self._torch = torch
- self.max_new_tokens = max_new_tokens
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
- self.model = AutoModelForCausalLM.from_pretrained(
- model_path,
- torch_dtype="auto",
- device_map="auto",
- trust_remote_code=True,
- )
- self.model.eval()
-
- def generate(self, prompt: str) -> str:
- messages = [{"role": "user", "content": prompt}]
- text = self.tokenizer.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=True,
- enable_thinking=False,
- )
- inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
- with self._torch.no_grad():
- outputs = self.model.generate(
- **inputs,
- max_new_tokens=self.max_new_tokens,
- do_sample=False,
- )
- generated = outputs[0][inputs["input_ids"].shape[1]:]
- raw = self.tokenizer.decode(generated, skip_special_tokens=True).strip()
- # Strip Qwen3-style ... if present so can_infer sees the letter
- if "" in raw:
- raw = raw.split("", 1)[1].strip()
- return raw
-
- def close(self) -> None:
- del self.model
- del self.tokenizer
- if self._torch.cuda.is_available():
- self._torch.cuda.empty_cache()
-
-
-# ---------------------------------------------------------------------------
-# extract_answer_from_item — ports multiple_choice.py:359-406
-# ---------------------------------------------------------------------------
-
-
-def _extract_answer_from_item(item: "dict[str, Any]", judge: "_JudgeBundle | None") -> "dict[str, str]":
- """Return {'opt': letter, 'log': ...}, matching VLMEvalKit's extract_answer_from_item."""
- choices = item["choices"]
- prediction = str(item.get("prediction", ""))
-
- ret = _can_infer(prediction, choices)
- if ret:
- return {"opt": str(ret), "log": prediction}
- if judge is None:
- return {
- "opt": "Z",
- "log": "Failed in Prefetch, no LLM-based answer matching under `exact_matching` policy.",
- }
-
- prompt = _build_judge_prompt(
- question=str(item.get("question", "")),
- options=choices,
- prediction=prediction,
- language=item.get("language", "en"),
- )
- for _ in range(3):
- ans = judge.generate(prompt)
- if "Failed to obtain answer via API" in ans:
- continue
- ret = _can_infer(ans, choices)
- if ret:
- return {"opt": str(ret), "log": ans}
- # Random fallback (matches VLMEvalKit — the random choice ensures we still
- # produce a concrete letter rather than silently skipping the item).
- options_pool = list(choices) + (["Z"] if "Z" not in choices else [])
- return {"opt": random.choice(options_pool), "log": "Failed to predict, thus randomly generate one."}
-
-
-# ---------------------------------------------------------------------------
-# Circular evaluation — ports multiple_choice.py:409-471
-# ---------------------------------------------------------------------------
-
-
-def _prefetch_answer(item: "dict[str, Any]") -> "str | bool":
- return _can_infer(str(item.get("prediction", "")), item["choices"])
-
-
-def _prefetch_circular_group(
- sub_items: "list[dict[str, Any]]",
-) -> "tuple[dict | None, list, list]":
- """Returns (result_dict_or_None, gts, preds).
-
- If a definitive result can be decided via can_infer alone (all matched, or
- any rotation's prefetched answer conflicts with its GT), returns a dict with
- hit/log. Otherwise returns None (LLM fallback is needed).
- """
- gts: list = []
- preds: list = []
- for i, item in enumerate(sub_items):
- gts.append(item["gt_answer"])
- preds.append(_prefetch_answer(item))
- if preds[-1] and gts[-1] != preds[-1]:
- return (
- {
- "hit": 0,
- "log": (
- f"Failed in Prefetching Rolling {i}: Answer is {gts[-1]}, "
- f"Prediction is {item.get('prediction', '')}, "
- f"Pre-fetched is {preds[-1]}."
- ),
- },
- gts,
- preds,
- )
- if all(g == p for g, p in zip(gts, preds)):
- return {"hit": 1, "log": "Succeed During Pre-fetching"}, gts, preds
- return None, gts, preds
-
-
-def _eval_circular_group(
- sub_items: "list[dict[str, Any]]",
- judge: "_JudgeBundle | None",
-) -> "dict[str, Any]":
- """Evaluate one circular group (1-4 rotations of the same question).
-
- Returns {'hit': 0 or 1, 'log': ...}. hit=1 requires every rotation's
- extracted letter to match its own rotated GT.
- """
- result, gts, preds = _prefetch_circular_group(sub_items)
- if result is not None:
- return result
-
- log = ""
- for i, item in enumerate(sub_items):
- if preds[i]:
- log += f"Rolling {i} Matched.\n"
- continue
- res = _extract_answer_from_item(item, judge)
- opt, match_log = res["opt"], res["log"]
- preds[i] = opt
- if preds[i] != gts[i]:
- log += (
- f"Failed in Rolling {i}: Answer is {gts[i]}; "
- f"Prediction is {item.get('prediction', '')}; "
- f"Pre-fetched is {preds[i]}; Match Log is {match_log}.\n"
- )
- return {"hit": 0, "log": log}
- log += (
- f"Rolling {i}: Answer is {gts[i]}, "
- f"Prediction is {item.get('prediction', '')}, Pre-fetched is {preds[i]}.\n"
- )
- return {"hit": 1, "log": log}
-
-
-# ---------------------------------------------------------------------------
-# report_acc — ports multiple_choice.py:77-100
-# ---------------------------------------------------------------------------
-
-
-def _report_acc(df: pd.DataFrame) -> "dict[str, float]":
- """Return a flat dict keyed by 'split=|' → accuracy in [0,1].
-
- Empty (split, category) cells return NaN to match VLMEvalKit's
- np.mean([]) behaviour rather than 0.0 (which would distort
- average-of-averages aggregations downstream).
- """
- res: "dict[str, float]" = {}
- if "split" in df.columns:
- splits = sorted(df["split"].dropna().unique().tolist())
- else:
- df = df.copy()
- df["split"] = "none"
- splits = ["none"]
-
- for group in (None, "l2-category", "category"):
- if group is None:
- for sp in splits:
- sub = df[df["split"] == sp]["hit"]
- val = float(sub.mean()) if len(sub) else float("nan")
- res[f"split={sp}|Overall"] = val
- elif group not in df.columns:
- continue
- else:
- abilities = sorted(df[group].dropna().unique().tolist())
- for ab in abilities:
- ab_name = MMB_ABBRS.get(ab, ab)
- sub_df = df[df[group] == ab]
- for sp in splits:
- cell = sub_df[sub_df["split"] == sp]["hit"]
- val = float(cell.mean()) if len(cell) else float("nan")
- res[f"split={sp}|{ab_name}"] = val
- return res
-
-
-# ---------------------------------------------------------------------------
-# Helpers (unchanged)
-# ---------------------------------------------------------------------------
-
-
def _resolve_path(path_str: str, repo_root: Path) -> Path:
path = Path(path_str).expanduser()
if not path.is_absolute():
@@ -467,7 +96,6 @@ def _extract_text(output: Any) -> str:
text = _extract_text(item)
if text:
return text
- # Handle adapters that return {"understandings": [{"response": "..."}]}
for list_key in ("understandings",):
container = output.get(list_key)
if isinstance(container, list):
@@ -483,7 +111,21 @@ def _extract_text(output: Any) -> str:
return ""
-def _load_eval_cfg(config_path: str) -> "tuple[dict, dict, dict]":
+def _post_process(pred: str, option: dict[str, str]) -> str:
+ pred = pred.strip()
+ option_candidate = list(option.keys())
+ if len(pred) == 1:
+ return pred
+ if len(pred) > 1 and pred[0] in option_candidate:
+ return pred[0]
+ if len(pred) > 1 and pred[0] not in option_candidate:
+ for k, v in option.items():
+ if v in pred:
+ return k
+ return pred
+
+
+def _load_eval_cfg(config_path: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
raw_cfg = load_config(config_path)
eval_cfg = raw_cfg.get("eval", {}) if isinstance(raw_cfg.get("eval"), dict) else {}
mmbench_cfg = raw_cfg.get("mmbench", {}) if isinstance(raw_cfg.get("mmbench"), dict) else {}
@@ -493,6 +135,18 @@ def _load_eval_cfg(config_path: str) -> "tuple[dict, dict, dict]":
return eval_cfg, mmbench_cfg, inference_cfg
+def _build_prompt(question: str, options: dict[str, str], hint: str | None, language: str) -> str:
+ if hint:
+ question = f"{hint}\n{question}"
+ for key, item in options.items():
+ question += f"\n{key}. {item}"
+ if language == "cn":
+ suffix = "请直接回答选项字母。"
+ else:
+ suffix = "Answer with the option's letter from the given choices directly."
+ return f"{question}\n{suffix}".strip()
+
+
def _decode_image(image_b64: str, image_dir: Path, row_index: int) -> str:
image_dir.mkdir(parents=True, exist_ok=True)
image = Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB")
@@ -502,10 +156,10 @@ def _decode_image(image_b64: str, image_dir: Path, row_index: int) -> str:
def _get_dataset_paths(
- datasets: "list[str]",
+ datasets: list[str],
repo_root: Path,
- override_paths: "dict[str, Any]",
-) -> "dict[str, Path]":
+ override_paths: dict[str, Any],
+) -> dict[str, Path]:
resolved: dict[str, Path] = {}
for name in datasets:
if name in override_paths:
@@ -518,46 +172,14 @@ def _get_dataset_paths(
return resolved
-def _find_latest_jsonl(out_dir: Path, ds_name: str) -> "Path | None":
- candidates = sorted(out_dir.glob(f"{ds_name}_*.jsonl"))
- candidates = [c for c in candidates if "_checkpoint" not in c.name]
- if candidates:
- return max(candidates, key=lambda p: p.stat().st_mtime)
- return None
-
-
-def _parse_options(row: pd.Series) -> "dict[str, str]":
- """Collect A/B/C/D/E/... option letters whose cell is non-null, preserving TSV order."""
+def _options_for(row: pd.Series) -> dict[str, str]:
options: dict[str, str] = {}
- for cand in string.ascii_uppercase:
+ for cand in ["A", "B", "C", "D", "E"]:
if cand in row and not pd.isna(row[cand]):
options[cand] = row[cand]
return options
-def _resolve_image_blob(image_map: "dict[int, str]", idx: int) -> str:
- """Resolve MMBench's short-circuit image storage (a ≤64-char string that
- points to another row's index). Ports image_base.py:154-164."""
- value = image_map.get(idx, "")
- if not isinstance(value, str):
- return ""
- if len(value) <= 64:
- # It's a redirect to another row's full base64 blob
- try:
- target = int(value)
- except (TypeError, ValueError):
- return value
- target_val = image_map.get(target, "")
- if isinstance(target_val, str) and len(target_val) > 64:
- return target_val
- return value
-
-
-# ---------------------------------------------------------------------------
-# Main entry point
-# ---------------------------------------------------------------------------
-
-
def run_mmbench_eval_command(args: Any) -> int:
config_path = str(args.config)
eval_cfg, mmbench_cfg, inference_cfg = _load_eval_cfg(config_path)
@@ -577,7 +199,7 @@ def run_mmbench_eval_command(args: Any) -> int:
raise ValueError("`inference.backbone_cfg` must be a dict when provided.")
request_cfg = inference_cfg.get("request", {})
- request_params: "dict[str, Any]" = {}
+ request_params: dict[str, Any] = {}
if isinstance(request_cfg, dict):
params = request_cfg.get("params", {})
if isinstance(params, dict):
@@ -600,41 +222,33 @@ def run_mmbench_eval_command(args: Any) -> int:
resume = bool(mmbench_cfg.get("resume", False))
resume_jsonl = mmbench_cfg.get("resume_jsonl")
- mode = str(mmbench_cfg.get("mode", "generate")).strip().lower()
- if mode not in ("full", "generate", "score"):
- print(f"[mmbench] unknown mode '{mode}', defaulting to 'generate'", flush=True)
- mode = "generate"
- run_gen = mode in ("full", "generate")
- run_score = mode in ("full", "score")
-
- llm_extract_cfg = mmbench_cfg.get("llm_extract", {})
- if not isinstance(llm_extract_cfg, dict):
- llm_extract_cfg = {}
- llm_model_path = str(llm_extract_cfg.get("model_path", "")).strip()
- llm_max_new_tokens = int(llm_extract_cfg.get("max_new_tokens", 32))
-
dataset_paths = _get_dataset_paths(
datasets=datasets,
repo_root=repo_root,
override_paths=mmbench_cfg.get("dataset_paths", {}) if isinstance(mmbench_cfg.get("dataset_paths"), dict) else {},
)
- out_dir.mkdir(parents=True, exist_ok=True)
-
- summary: "dict[str, Any]" = {
- "benchmark": "mmbench",
- "backbone": backbone,
- "out_dir": str(out_dir),
- "datasets": datasets,
- "mode": mode,
- }
+ dist_info = maybe_init_distributed()
+ try:
+ out_dir.mkdir(parents=True, exist_ok=True)
+ pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
- # ── Phase 1: Generation ──
- if run_gen:
- from umm.inference import InferencePipeline
+ summary: dict[str, Any] = {
+ "benchmark": "mmbench",
+ "backbone": backbone,
+ "out_dir": str(out_dir),
+ "datasets": datasets,
+ "world_size": dist_info.world_size,
+ }
- pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
+ if dist_info.world_size > 1:
+ print(
+ f"[mmbench] distributed inference enabled: rank={dist_info.rank}, "
+ f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}",
+ flush=True,
+ )
+ local_total_written = 0
for ds_name in datasets:
dataset_path = dataset_paths[ds_name]
if not dataset_path.exists():
@@ -643,316 +257,157 @@ def run_mmbench_eval_command(args: Any) -> int:
entry = DS_COLLECTIONS.get(ds_name, {})
language = str(entry.get("language", "en"))
df = pd.read_csv(dataset_path, sep="\t")
- df["index"] = df["index"].astype(int)
-
- # Resolve MMBench's image short-circuits so every row has its own blob.
- image_map: "dict[int, str]" = {}
- if "image" in df.columns:
- for _, row in df.iterrows():
- image_map[int(row["index"])] = str(row["image"]) if not pd.isna(row["image"]) else ""
checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl"
- outputs: "list[dict[str, Any]]" = []
- done_indices: "set[int]" = set()
+ shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size)
+ done_indices: set[Any] = set()
if resume:
- if checkpoint_jsonl.exists():
- with checkpoint_jsonl.open("r", encoding="utf-8") as reader:
- for line in reader:
- line = line.strip()
- if not line:
- continue
- item = json.loads(line)
- outputs.append(item)
- done_indices.add(int(item["index"]))
- print(f"[mmbench] resume from checkpoint: {len(outputs)} done", flush=True)
- else:
- jsonl_path: "Path | None" = None
+ shard_items = load_shard_items(shard_path)
+ if shard_items:
+ done_indices = {int(item["index"]) for item in shard_items}
+ print(
+ f"[mmbench] {ds_name}: rank {dist_info.rank} resume from "
+ f"checkpoint: {len(done_indices)} done",
+ flush=True,
+ )
+ elif dist_info.world_size <= 1:
+ # Single-card fallback: recover from a previously-completed
+ # output JSONL. Multi-card runs rely on the rank shard only.
+ jsonl_path: Path | None = None
if isinstance(resume_jsonl, str) and resume_jsonl:
jsonl_path = _resolve_path(resume_jsonl, repo_root)
else:
- jsonl_path = _find_latest_jsonl(out_dir, ds_name)
+ candidates = sorted(out_dir.glob(f"{ds_name}_*.jsonl"))
+ candidates = [c for c in candidates if "_checkpoint" not in c.name]
+ if candidates:
+ jsonl_path = max(candidates, key=lambda p: p.stat().st_mtime)
if jsonl_path and jsonl_path.exists():
- with jsonl_path.open("r", encoding="utf-8") as reader:
- for line in reader:
- line = line.strip()
- if not line:
- continue
- item = json.loads(line)
- outputs.append(item)
- done_indices.add(int(item["index"]))
- print(f"[mmbench] resume from {jsonl_path}: {len(outputs)} done", flush=True)
+ prior = load_shard_items(jsonl_path)
+ done_indices = {int(item["index"]) for item in prior}
+ print(
+ f"[mmbench] resume from {jsonl_path}: "
+ f"{len(done_indices)} done",
+ flush=True,
+ )
+ assigned_total = (
+ (len(df) + dist_info.world_size - 1 - dist_info.rank) // dist_info.world_size
+ if dist_info.world_size > 1
+ else len(df)
+ )
print(
- f"[mmbench] {ds_name}: {len(df)} total, {len(done_indices)} done, "
- f"{len(df) - len(done_indices)} remaining",
+ f"[mmbench] {ds_name}: total={len(df)}, rank={dist_info.rank}, "
+ f"assigned={assigned_total}, done={len(done_indices)}, "
+ f"remaining={max(0, assigned_total - len(done_indices))}",
flush=True,
)
- with checkpoint_jsonl.open("a", encoding="utf-8") as ckpt_writer:
- for _, row in tqdm(df.iterrows(), total=len(df), desc=f"mmbench/{ds_name}", file=sys.stdout):
- row_index = int(row["index"])
- if row_index in done_indices:
- continue
-
- image_b64 = _resolve_image_blob(image_map, row_index)
- if not image_b64:
- # Skip rows without a usable image (rare).
- continue
- image_path = _decode_image(image_b64, image_dir, row_index=row_index)
- options = _parse_options(row)
- hint = None
- if "hint" in row and not pd.isna(row["hint"]):
- hint = str(row["hint"])
-
- question_text = str(row["question"])
- prompt = _build_prompt(question_text, options, hint)
- payload = {
- "backbone": backbone,
- "task": "understanding",
- "prompt": prompt,
- "images": [image_path],
- "params": request_params,
- "metadata": {"index": row_index, "dataset": ds_name},
- }
- prediction = _extract_text(pipeline.run(payload))
- gt = None
- if "answer" in row and not pd.isna(row["answer"]):
- gt = str(row["answer"]).strip().upper()
- item = {
- "index": row_index,
- "question": question_text,
- "options": options,
- "prediction": prediction,
- "gt_answer": gt,
- "language": language,
- }
- outputs.append(item)
- done_indices.add(row_index)
- ckpt_writer.write(json.dumps(item) + "\n")
- ckpt_writer.flush()
- os.fsync(ckpt_writer.fileno())
-
- if max_samples > 0 and len(outputs) >= max_samples:
- break
+ def payload_fn(pair: tuple[Any, pd.Series]) -> dict[str, Any]:
+ _, row = pair
+ row_index = int(row["index"])
+ image_path = _decode_image(str(row["image"]), image_dir, row_index=row_index)
+ options = _options_for(row)
+ hint = None
+ if "hint" in row and not pd.isna(row["hint"]):
+ hint = row["hint"]
+ question = _build_prompt(str(row["question"]), options, hint, language=language)
+ return {
+ "backbone": backbone,
+ "task": "understanding",
+ "prompt": question,
+ "images": [image_path],
+ "params": request_params,
+ "metadata": {"index": row_index, "dataset": ds_name},
+ }
+
+ def record_fn(pair: tuple[Any, pd.Series], raw: Any, _idx: int) -> dict[str, Any]:
+ _, row = pair
+ row_index = int(row["index"])
+ options = _options_for(row)
+ hint = None
+ if "hint" in row and not pd.isna(row["hint"]):
+ hint = row["hint"]
+ question = _build_prompt(str(row["question"]), options, hint, language=language)
+ response = _extract_text(raw)
+ pred = _post_process(response, options)
+ return {
+ "question": question,
+ "answer": pred,
+ "gt_answers": row["answer"] if "answer" in row else None,
+ "index": row_index,
+ }
+
+ n_written = run_sharded_inference(
+ infer_fn=pipeline.run,
+ dist_info=dist_info,
+ shard_path=shard_path,
+ samples=df.iterrows(),
+ total=len(df),
+ payload_fn=payload_fn,
+ record_fn=record_fn,
+ sample_id_fn=lambda pair: int(pair[1]["index"]),
+ done_ids=done_indices,
+ max_samples=max_samples,
+ log_prefix=f"mmbench/{ds_name}/rank{dist_info.rank}",
+ )
+ local_total_written += n_written
+
+ barrier(dist_info)
time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime())
+ results_file = f"{ds_name}_{time_prefix}.xlsx"
+ output_path = out_dir / results_file
jsonl_path_out = out_dir / f"{ds_name}_{time_prefix}.jsonl"
- with jsonl_path_out.open("w", encoding="utf-8") as writer:
- for item in outputs:
- writer.write(json.dumps(item) + "\n")
-
- if checkpoint_jsonl.exists():
- checkpoint_jsonl.unlink()
-
- summary[f"{ds_name}_output_jsonl"] = str(jsonl_path_out)
-
- if mode == "generate":
- print(f"[mmbench] generation phase done, outputs={out_dir}", flush=True)
-
- # Release generation GPU memory before loading the judge LLM.
- del pipeline
- import gc
- gc.collect()
- try:
- import torch as _torch
- if _torch.cuda.is_available():
- _torch.cuda.empty_cache()
- except ImportError:
- pass
-
- # ── Phase 2: Circular scoring (VLMEvalKit-compatible) ──
- if run_score:
- judge: "_JudgeBundle | None" = None
- try:
- for ds_name in datasets:
- jsonl_path = None
- if f"{ds_name}_output_jsonl" in summary:
- jsonl_path = Path(summary[f"{ds_name}_output_jsonl"])
- else:
- jsonl_path = _find_latest_jsonl(out_dir, ds_name)
- if jsonl_path is None or not jsonl_path.exists():
- raise FileNotFoundError(
- f"No generation output found for {ds_name} in {out_dir}. "
- f"Run generation phase first (mode: generate)."
- )
-
- print(f"[mmbench] scoring {ds_name} from {jsonl_path}", flush=True)
- items: "list[dict[str, Any]]" = []
- with jsonl_path.open("r", encoding="utf-8") as reader:
- for line in reader:
- line = line.strip()
- if not line:
- continue
- items.append(json.loads(line))
-
- # Re-hydrate missing keys from the source TSV when needed
- # (back-compat with JSONLs produced by the pre-rewrite code).
- dataset_path = dataset_paths[ds_name]
- df = pd.read_csv(dataset_path, sep="\t") if dataset_path.exists() else pd.DataFrame()
- if not df.empty:
- df["index"] = df["index"].astype(int)
- df_by_idx = df.set_index("index", drop=False)
- language_default = str(DS_COLLECTIONS.get(ds_name, {}).get("language", "en"))
- for item in items:
- idx = int(item["index"])
- row = df_by_idx.loc[idx] if idx in df_by_idx.index else None
- if "options" not in item and row is not None:
- item["options"] = _parse_options(row)
- if "prediction" not in item:
- item["prediction"] = item.get("response", item.get("answer", ""))
- if "gt_answer" not in item and row is not None:
- if "answer" in row and not pd.isna(row["answer"]):
- item["gt_answer"] = str(row["answer"]).strip().upper()
- if "language" not in item:
- item["language"] = language_default
- if "question" not in item and row is not None and "question" in row:
- item["question"] = str(row["question"])
-
- # Prepare per-item `choices` (copy of options, used by can_infer).
- for item in items:
- item["choices"] = dict(item.get("options", {}))
-
- # Group by g_index = index % 1e6 → circular rotation group
- groups: "dict[int, list[dict[str, Any]]]" = {}
- for item in items:
- g = int(item["index"]) % 1_000_000
- groups.setdefault(g, []).append(item)
- # Sort rotations inside each group by their original index so
- # rotation k=0 (smallest) comes first, matching VLMEvalKit's order.
- for g in groups:
- groups[g].sort(key=lambda it: int(it["index"]))
-
- # Fast-path: resolve as many groups as possible with can_infer only.
- group_results: "dict[int, dict[str, Any]]" = {}
- pending_groups: "list[tuple[int, list[dict[str, Any]]]]" = []
- for g, rows in groups.items():
- pre_res, _gts, _preds = _prefetch_circular_group(rows)
- if pre_res is not None:
- group_results[g] = pre_res
- else:
- pending_groups.append((g, rows))
- # Load the judge once if any groups still need LLM extraction.
- if pending_groups and llm_model_path and judge is None:
- judge = _JudgeBundle(llm_model_path, max_new_tokens=llm_max_new_tokens)
+ if dist_info.rank == 0:
+ merged_outputs = merge_shards(checkpoint_jsonl)
- if pending_groups:
- print(
- f"[mmbench] {ds_name}: {len(group_results)} groups decided by can_infer, "
- f"{len(pending_groups)} need LLM judging",
- flush=True,
- )
- for g, rows in tqdm(pending_groups, desc=f"mmbench/{ds_name}/judge", file=sys.stdout):
- group_results[g] = _eval_circular_group(rows, judge=judge)
+ cur_df = df.copy()
+ if "mmbench" in ds_name:
+ cur_df = cur_df.drop(columns=["hint", "category", "source", "image", "comment", "l2-category"])
+ cur_df.insert(6, "prediction", None)
else:
- print(
- f"[mmbench] {ds_name}: all {len(group_results)} groups decided by can_infer",
- flush=True,
- )
+ cur_df = cur_df.drop(columns=["category", "image"])
+ cur_df.insert(8, "prediction", None)
- df_scored: pd.DataFrame
- if not df.empty:
- df_work = df.copy()
- df_work["g_index"] = df_work["index"].astype(int) % 1_000_000
- df_scored = df_work[df_work["index"] == df_work["g_index"]].copy()
- expected_g = set(int(g) for g in df_scored["g_index"].tolist())
- missing = expected_g - set(group_results.keys())
- if missing:
- raise RuntimeError(
- f"[mmbench] {ds_name}: {len(missing)} groups in TSV missing from "
- f"generation outputs (e.g. {sorted(missing)[:5]}). "
- f"Generation phase skipped some rows; rerun with mode: full or "
- f"mode: generate to refill the JSONL before scoring."
- )
- df_scored["hit"] = df_scored["g_index"].map(lambda g: group_results[int(g)]["hit"])
- df_scored["log"] = df_scored["g_index"].map(lambda g: group_results[int(g)]["log"])
- else:
- # No TSV available (unlikely): fall back to a minimal frame.
- rows = []
- for g, rows_list in groups.items():
- head = rows_list[0]
- rows.append(
- {
- "index": head["index"],
- "g_index": g,
- "hit": group_results[g]["hit"],
- }
- )
- df_scored = pd.DataFrame(rows)
-
- metrics = _report_acc(df_scored)
- overall_vals = [v for k, v in metrics.items() if k.endswith("|Overall")]
- overall_acc_pct = round(100.0 * (overall_vals[0] if overall_vals else 0.0), 2)
- question_count = int(len(df_scored))
- hit_count = int(df_scored["hit"].sum()) if "hit" in df_scored.columns else 0
- # Write the annotated items (now with extraction+hit) back to JSONL.
- # Each item gets its group's hit attached for easy inspection.
- for item in items:
- g = int(item["index"]) % 1_000_000
- res = group_results.get(g, {"hit": 0, "log": ""})
- item["group_hit"] = int(res.get("hit", 0))
- with jsonl_path.open("w", encoding="utf-8") as writer:
- for item in items:
+ for item in merged_outputs:
+ cur_df.loc[df["index"] == item["index"], "prediction"] = item["answer"]
+
+ cur_df.to_excel(output_path, index=False, engine="openpyxl") # pip install openpyxl
+ with jsonl_path_out.open("w", encoding="utf-8") as writer:
+ for item in merged_outputs:
writer.write(json.dumps(item) + "\n")
- score_path = jsonl_path.with_name(f"{jsonl_path.stem}_score.json")
- score_payload = {
- "overall": {
- "accuracy": overall_acc_pct,
- "hit_count": hit_count,
- "question_count": question_count,
- "mode": "circular",
- },
- "metrics": {k: round(100.0 * v, 2) for k, v in metrics.items()},
- }
- score_path.write_text(json.dumps(score_payload, indent=2), encoding="utf-8")
- print(
- f"[mmbench] {ds_name} Overall Acc = {overall_acc_pct}% "
- f"({hit_count}/{question_count} groups)",
- flush=True,
- )
-
- # Persist the flat acc CSV too for parity with VLMEvalKit.
- if not df_scored.empty:
- csv_path = jsonl_path.with_name(f"{jsonl_path.stem}_acc.csv")
- acc_rows = [{"metric": k, "accuracy": round(100.0 * v, 2)} for k, v in metrics.items()]
- pd.DataFrame(acc_rows).to_csv(csv_path, index=False)
- summary[f"{ds_name}_acc_csv"] = str(csv_path)
-
- # Optional xlsx with predictions filled in (requires openpyxl).
- # Wise image ships without openpyxl; if missing, silently skip.
- is_mmbench_schema = ds_name.startswith("mmbench_") or ds_name.startswith("MMBench_")
- if not df.empty and is_mmbench_schema:
- try:
- import openpyxl # noqa: F401
- xlsx_path = jsonl_path.with_suffix(".xlsx")
- cur_df = df.copy()
- drop_cols = [c for c in ("hint", "category", "source", "image", "comment", "l2-category") if c in cur_df.columns]
- if drop_cols:
- cur_df = cur_df.drop(columns=drop_cols)
- # For circular V11 the xlsx still lists all 4 rotations; each row
- # gets the extracted letter from the corresponding item.
- pred_by_idx: "dict[int, str]" = {}
- for item in items:
- pred_by_idx[int(item["index"])] = str(item.get("extraction", item.get("prediction", "")))
- cur_df["prediction"] = cur_df["index"].map(lambda k: pred_by_idx.get(int(k), ""))
- cur_df.to_excel(xlsx_path, index=False, engine="openpyxl")
- summary[f"{ds_name}_xlsx"] = str(xlsx_path)
- except ImportError:
- print("[mmbench] openpyxl unavailable, skipping xlsx export", flush=True)
-
- summary[f"{ds_name}_score_file"] = str(score_path)
- summary[f"{ds_name}_accuracy"] = overall_acc_pct
- finally:
- if judge is not None:
- judge.close()
-
- if isinstance(score_output_path, str) and score_output_path:
- score_path = _resolve_path(score_output_path, repo_root)
- score_path.parent.mkdir(parents=True, exist_ok=True)
- score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
- print(f"[umm eval] wrote MMBench summary to {score_path}")
-
- print(f"[umm eval] completed MMBench (mode={mode}) for backbone={backbone}, outputs={out_dir}")
- return 0
+ cleanup_shards(checkpoint_jsonl)
+ if dist_info.world_size <= 1 and checkpoint_jsonl.exists():
+ checkpoint_jsonl.unlink()
+
+ summary[f"{ds_name}_output_path"] = str(output_path)
+ summary[f"{ds_name}_output_jsonl"] = str(jsonl_path_out)
+
+ barrier(dist_info)
+
+ total_written_all = sum_across_ranks(local_total_written, dist_info)
+ if dist_info.rank != 0:
+ print(
+ f"[umm eval] rank {dist_info.rank} finished MMBench shard: "
+ f"samples_written={local_total_written}",
+ flush=True,
+ )
+ return 0
+
+ summary["samples_written"] = total_written_all
+ if isinstance(score_output_path, str) and score_output_path:
+ score_path = _resolve_path(score_output_path, repo_root)
+ score_path.parent.mkdir(parents=True, exist_ok=True)
+ score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
+ print(f"[umm eval] wrote MMBench summary to {score_path}")
+
+ print(
+ f"[umm eval] completed MMBench for backbone={backbone}, outputs={out_dir}, "
+ f"samples_written={total_written_all}, world_size={dist_info.world_size}"
+ )
+ return 0
+ finally:
+ cleanup_distributed(dist_info)
diff --git a/src/umm/cli/mme_eval.py b/src/umm/cli/mme_eval.py
index 57d9b62..fc5619e 100644
--- a/src/umm/cli/mme_eval.py
+++ b/src/umm/cli/mme_eval.py
@@ -1,14 +1,24 @@
from __future__ import annotations
import json
-import os
import re
import subprocess
import sys
from pathlib import Path
-from typing import Any
+from typing import Any, Iterator
from umm.core.config import load_config
+from umm.eval.distributed import (
+ barrier,
+ cleanup_distributed,
+ cleanup_shards,
+ load_shard_items,
+ maybe_init_distributed,
+ merge_shards,
+ rank_shard_path,
+ sum_across_ranks,
+)
+from umm.eval.runner import run_sharded_inference
from umm.inference import InferencePipeline
@@ -60,8 +70,7 @@ def _extract_text(output: Any) -> str:
def _post_process(response: str) -> str:
response = response.replace("\n", "").replace("不是", "No").replace("是", "Yes").replace("否", "No")
response = response.lower().replace("true", "yes").replace("false", "no")
- response = re.sub(re.compile(r"[\u4e00-\u9fa5]"), "", response)
- # Extract first yes/no, discard repetitive garbage
+ response = re.sub(re.compile(r"[一-龥]"), "", response)
response = response.strip()
match = re.match(r"^[^a-z]*(yes|no)\b", response)
if match:
@@ -127,105 +136,169 @@ def run_mme_eval_command(args: Any) -> int:
if not image_root.exists():
raise FileNotFoundError(f"MME image root not found: {image_root}")
- out_dir.mkdir(parents=True, exist_ok=True)
- pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
-
- txt_files = sorted([p for p in dataset_root.iterdir() if p.suffix == ".txt"])
- if not txt_files:
- raise FileNotFoundError(f"No .txt files found in MME root: {dataset_root}")
-
- total_written = 0
- missing_images = 0
- skipped_rows = 0
- for task_txt in txt_files:
- task_name = task_txt.stem
- out_file = out_dir / task_txt.name
-
- # Resume: count already-completed lines and skip them
- existing_lines = 0
- if out_file.exists():
- existing_lines = sum(1 for ln in out_file.open("r", encoding="utf-8") if ln.strip())
- if existing_lines > 0:
- print(f"[mme] {task_name}: resuming after {existing_lines} existing lines", flush=True)
-
- with task_txt.open("r", encoding="utf-8") as fin, out_file.open("a", encoding="utf-8") as fout:
- done = 0
- for idx, line in enumerate(fin, start=1):
- row = line.strip().split("\t")
- if len(row) != 3:
- skipped_rows += 1
- continue
- img, question, gt = row
-
- img_path = image_root / task_name / img
- if not img_path.exists():
- img_path = image_root / task_name / "images" / img
- if not img_path.exists():
- missing_images += 1
- continue
-
- # Skip rows already written in a previous run
- done += 1
- if done <= existing_lines:
- continue
-
+ dist_info = maybe_init_distributed()
+ try:
+ out_dir.mkdir(parents=True, exist_ok=True)
+ pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
+
+ txt_files = sorted([p for p in dataset_root.iterdir() if p.suffix == ".txt"])
+ if not txt_files:
+ raise FileNotFoundError(f"No .txt files found in MME root: {dataset_root}")
+
+ if dist_info.world_size > 1:
+ print(
+ f"[mme] distributed inference enabled: rank={dist_info.rank}, "
+ f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}",
+ flush=True,
+ )
+
+ total_written = 0
+ # Counters incremented by rank 0 only so the totals are not multiplied
+ # by world_size — every rank reads the full TSV during pre-filtering.
+ missing_images = 0
+ skipped_rows = 0
+
+ for task_txt in txt_files:
+ task_name = task_txt.stem
+ out_file = out_dir / task_txt.name
+ checkpoint_jsonl = out_dir / f"{task_name}_checkpoint.jsonl"
+ shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size)
+
+ done_idx = {int(it["_sample_idx"]) for it in load_shard_items(shard_path)}
+ if done_idx:
+ print(
+ f"[mme] {task_name}: rank {dist_info.rank} resuming after "
+ f"{len(done_idx)} existing shard items",
+ flush=True,
+ )
+
+ def iter_valid_rows() -> Iterator[tuple[int, str, str, str, str]]:
+ nonlocal missing_images, skipped_rows
+ counter = 0
+ with task_txt.open("r", encoding="utf-8") as fin:
+ for line in fin:
+ row = line.strip().split("\t")
+ if len(row) != 3:
+ if dist_info.rank == 0:
+ skipped_rows += 1
+ continue
+ img, question, gt = row
+ img_path = image_root / task_name / img
+ if not img_path.exists():
+ img_path = image_root / task_name / "images" / img
+ if not img_path.exists():
+ if dist_info.rank == 0:
+ missing_images += 1
+ continue
+ counter += 1
+ yield counter, img, question, gt, str(img_path)
+
+ def payload_fn(sample: tuple[int, str, str, str, str]) -> dict[str, Any]:
+ _, img, question, _gt, img_path = sample
prompt = f"{question} {prompt_suffix}".strip()
- payload = {
+ print(f"[mme] rank {dist_info.rank} | {task_name}: {img} ... inferring", flush=True)
+ return {
"backbone": backbone,
"task": "understanding",
"prompt": prompt,
- "images": [str(img_path)],
+ "images": [img_path],
"params": request_params,
}
- print(f"[mme] {task_name} | {idx}: {img} ... inferring", flush=True)
- output = pipeline.run(payload)
- response = _post_process(_extract_text(output))
+
+ def record_fn(sample: tuple[int, str, str, str, str], raw: Any, _idx: int) -> dict[str, Any]:
+ _, img, question, gt, _img_path = sample
+ prompt = f"{question} {prompt_suffix}".strip()
+ response = _post_process(_extract_text(raw))
if not response.strip():
print(f"[mme] WARNING empty response: {task_name}/{img}", flush=True)
- print(f"[mme] {task_name} | {idx}: {img} -> {response[:80]!r}", flush=True)
- print(img, prompt, gt, response, sep="\t", file=fout)
- fout.flush()
- os.fsync(fout.fileno())
- total_written += 1
- if max_samples > 0 and idx >= max_samples:
- break
-
- summary: dict[str, Any] = {
- "benchmark": "mme",
- "backbone": backbone,
- "out_dir": str(out_dir),
- "samples_written": total_written,
- }
+ print(f"[mme] rank {dist_info.rank} | {task_name}: {img} -> {response[:80]!r}", flush=True)
+ return {
+ "img": img,
+ "prompt": prompt,
+ "gt": gt,
+ "response": response,
+ }
- # Check if out_dir has any result files at all (including from previous runs)
- has_results = any(p.suffix == ".txt" and p.stat().st_size > 0 for p in out_dir.iterdir()) if out_dir.exists() else False
- if total_written == 0 and not has_results:
+ n_written = run_sharded_inference(
+ infer_fn=pipeline.run,
+ dist_info=dist_info,
+ shard_path=shard_path,
+ samples=iter_valid_rows(),
+ payload_fn=payload_fn,
+ record_fn=record_fn,
+ sample_id_fn=lambda sample: sample[0],
+ done_ids=done_idx,
+ max_samples=max_samples,
+ log_prefix=f"mme/{task_name}/rank{dist_info.rank}",
+ )
+ total_written += n_written
+
+ barrier(dist_info)
+
+ if dist_info.rank == 0:
+ merged = merge_shards(checkpoint_jsonl)
+ with out_file.open("w", encoding="utf-8") as fout:
+ for item in merged:
+ print(item["img"], item["prompt"], item["gt"], item["response"], sep="\t", file=fout)
+ cleanup_shards(checkpoint_jsonl)
+ if dist_info.world_size <= 1 and checkpoint_jsonl.exists():
+ checkpoint_jsonl.unlink()
+
+ barrier(dist_info)
+
+ total_written_all = sum_across_ranks(total_written, dist_info)
+
+ if dist_info.rank != 0:
+ print(
+ f"[umm eval] rank {dist_info.rank} finished shard: "
+ f"samples_written={total_written}",
+ flush=True,
+ )
+ return 0
+
+ summary: dict[str, Any] = {
+ "benchmark": "mme",
+ "backbone": backbone,
+ "out_dir": str(out_dir),
+ "samples_written": total_written_all,
+ "world_size": dist_info.world_size,
+ }
+
+ has_results = (
+ any(p.suffix == ".txt" and p.stat().st_size > 0 for p in out_dir.iterdir())
+ if out_dir.exists()
+ else False
+ )
+ if total_written_all == 0 and not has_results:
+ print(
+ "[umm eval] warning: no MME samples were written. "
+ "Skipping calculation. Check `mme.root` and `mme.image_root`."
+ )
+ run_calculation = False
+
+ if run_calculation:
+ cmd = [sys.executable, str(calculation_script), "--results_dir", str(out_dir)]
+ proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True)
+ print(proc.stdout)
+ if proc.returncode != 0:
+ if proc.stderr:
+ print(proc.stderr, file=sys.stderr)
+ raise RuntimeError(f"MME calculation failed with return code {proc.returncode}")
+ summary["calculation_stdout"] = proc.stdout
+
+ if isinstance(score_output_path, str) and score_output_path:
+ score_path = _resolve_path(score_output_path, repo_root)
+ score_path.parent.mkdir(parents=True, exist_ok=True)
+ score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
+ print(f"[umm eval] wrote MME summary to {score_path}")
+
+ summary["missing_images"] = missing_images
+ summary["skipped_rows"] = skipped_rows
print(
- "[umm eval] warning: no MME samples were written. "
- "Skipping calculation. Check `mme.root` and `mme.image_root`."
+ f"[umm eval] completed MME for backbone={backbone}, outputs={out_dir}, "
+ f"samples_written={total_written_all}, missing_images={missing_images}, "
+ f"skipped_rows={skipped_rows}, world_size={dist_info.world_size}"
)
- run_calculation = False
-
- if run_calculation:
- cmd = [sys.executable, str(calculation_script), "--results_dir", str(out_dir)]
- proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True)
- print(proc.stdout)
- if proc.returncode != 0:
- if proc.stderr:
- print(proc.stderr, file=sys.stderr)
- raise RuntimeError(f"MME calculation failed with return code {proc.returncode}")
- summary["calculation_stdout"] = proc.stdout
-
- if isinstance(score_output_path, str) and score_output_path:
- score_path = _resolve_path(score_output_path, repo_root)
- score_path.parent.mkdir(parents=True, exist_ok=True)
- score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
- print(f"[umm eval] wrote MME summary to {score_path}")
-
- summary["missing_images"] = missing_images
- summary["skipped_rows"] = skipped_rows
- print(
- f"[umm eval] completed MME for backbone={backbone}, outputs={out_dir}, "
- f"samples_written={total_written}, missing_images={missing_images}, skipped_rows={skipped_rows}"
- )
- return 0
+ return 0
+ finally:
+ cleanup_distributed(dist_info)
diff --git a/src/umm/cli/mmmu_eval.py b/src/umm/cli/mmmu_eval.py
index 6ac84ea..a0d3e07 100644
--- a/src/umm/cli/mmmu_eval.py
+++ b/src/umm/cli/mmmu_eval.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import json
+import os
import random
import subprocess
import sys
@@ -10,11 +11,21 @@
from datasets import concatenate_datasets, load_dataset
from PIL import Image
-from tqdm import tqdm
from umm.core.config import load_config
-from umm.inference import InferencePipeline
+from umm.eval.distributed import (
+ barrier,
+ cleanup_distributed,
+ cleanup_shards,
+ load_shard_items,
+ maybe_init_distributed,
+ merge_shards,
+ rank_shard_path,
+ sum_across_ranks,
+)
from umm.eval.internvl_chat.eval.mmmu import data_utils, eval_utils
+from umm.eval.runner import run_sharded_inference
+from umm.inference import InferencePipeline
def _resolve_path(path_str: str, repo_root: Path) -> Path:
@@ -53,7 +64,6 @@ def _extract_text(output: Any) -> str:
text = _extract_text(item)
if text:
return text
- # Handle adapters that return {"understandings": [{"response": "..."}]}
for list_key in ("understandings",):
container = output.get(list_key)
if isinstance(container, list):
@@ -79,8 +89,27 @@ def _load_eval_cfg(config_path: str) -> tuple[dict[str, Any], dict[str, Any], di
return eval_cfg, mmmu_cfg, inference_cfg
+def _load_mmmu_subject_from_local_snapshot(root_path: Path, subject: str, split: str) -> Any:
+ subject_dir = root_path / subject
+ parquet_files = sorted(subject_dir.glob(f"{split}-*.parquet"))
+ if not parquet_files:
+ raise FileNotFoundError(
+ f"Missing MMMU parquet for subject={subject}, split={split}, expected files under {subject_dir}"
+ )
+ return load_dataset(
+ "parquet",
+ data_files={split: [str(path) for path in parquet_files]},
+ split=split,
+ )
+
+
def _load_mmmu_dataset(root: str, split: str, cache_dir: str | None) -> Any:
+ root_path = Path(root).expanduser()
datasets_list = []
+ if root_path.exists() and root_path.is_dir():
+ for subject in data_utils.CAT_SHORT2LONG.values():
+ datasets_list.append(_load_mmmu_subject_from_local_snapshot(root_path, subject, split))
+ return concatenate_datasets(datasets_list)
for subject in data_utils.CAT_SHORT2LONG.values():
datasets_list.append(load_dataset(root, subject, split=split, cache_dir=cache_dir))
return concatenate_datasets(datasets_list)
@@ -200,75 +229,89 @@ def run_mmmu_eval_command(args: Any) -> int:
max_images = int(mmmu_cfg.get("max_images", 1) or 1)
seed = int(mmmu_cfg.get("seed", 0) or 0)
- out_dir.mkdir(parents=True, exist_ok=True)
- pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
-
- summary: dict[str, Any] = {
- "benchmark": "mmmu",
- "backbone": backbone,
- "out_dir": str(out_dir),
- "datasets": datasets,
- }
-
- random.seed(seed)
- for ds_name in datasets:
- split = "validation"
- if ds_name.endswith("test"):
- split = "test"
- elif ds_name.endswith("dev"):
- split = "dev"
- elif ds_name.endswith("validation"):
+ dist_info = maybe_init_distributed()
+ try:
+ out_dir.mkdir(parents=True, exist_ok=True)
+ pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
+
+ summary: dict[str, Any] = {
+ "benchmark": "mmmu",
+ "backbone": backbone,
+ "out_dir": str(out_dir),
+ "datasets": datasets,
+ "world_size": dist_info.world_size,
+ }
+
+ if dist_info.world_size > 1:
+ print(
+ f"[mmmu] distributed inference enabled: rank={dist_info.rank}, "
+ f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}",
+ flush=True,
+ )
+
+ random.seed(seed)
+ local_total_written = 0
+ for ds_name in datasets:
split = "validation"
+ if ds_name.endswith("test"):
+ split = "test"
+ elif ds_name.endswith("dev"):
+ split = "dev"
+ elif ds_name.endswith("validation"):
+ split = "validation"
+
+ dataset = _load_mmmu_dataset(
+ root=dataset_root,
+ split=split,
+ cache_dir=str(cache_dir_path),
+ )
+
+ checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl"
+ shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size)
+
+ done_ids: set[Any] = {
+ str(it.get("data_id", "")) for it in load_shard_items(shard_path)
+ }
+ if done_ids:
+ print(
+ f"[mmmu] {ds_name}: rank {dist_info.rank} resuming after "
+ f"{len(done_ids)} shard items",
+ flush=True,
+ )
- dataset = _load_mmmu_dataset(
- root=dataset_root,
- split=split,
- cache_dir=str(cache_dir_path),
- )
-
- # Resume: load checkpoint JSONL if exists
- checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl"
- done_ids: set[str] = set()
- outputs: list[dict[str, Any]] = []
- if checkpoint_jsonl.exists():
- with checkpoint_jsonl.open("r", encoding="utf-8") as reader:
- for line in reader:
- line = line.strip()
- if not line:
- continue
- item = json.loads(line)
- outputs.append(item)
- done_ids.add(str(item.get("data_id", "")))
- print(f"[mmmu] resume: {len(done_ids)} done, skipping completed items", flush=True)
-
- total = len(dataset)
- remaining = total - len(done_ids)
- print(f"[mmmu] {ds_name}: {total} total, {len(done_ids)} done, {remaining} remaining", flush=True)
-
- with checkpoint_jsonl.open("a", encoding="utf-8") as ckpt_writer:
- for idx, sample in enumerate(tqdm(dataset, desc=f"mmmu/{ds_name}", file=sys.stdout), start=1):
+ total = len(dataset)
+ assigned_total = (
+ (total + dist_info.world_size - 1 - dist_info.rank) // dist_info.world_size
+ if dist_info.world_size > 1
+ else total
+ )
+ print(
+ f"[mmmu] {ds_name}: total={total}, rank={dist_info.rank}, "
+ f"assigned={assigned_total}, done={len(done_ids)}, "
+ f"remaining={max(0, assigned_total - len(done_ids))}",
+ flush=True,
+ )
+
+ def payload_fn(sample: Any) -> dict[str, Any]:
data = data_utils.process_single_sample(sample)
data_id = str(data["id"])
- if data_id in done_ids:
- continue
-
question = str(data["question"]).strip()
question_type = str(data["question_type"])
- options = eval(data["options"]) if isinstance(data.get("options"), str) else data.get("options", [])
+ options = (
+ eval(data["options"])
+ if isinstance(data.get("options"), str)
+ else data.get("options", [])
+ )
if not isinstance(options, list):
options = []
-
prompt = _build_prompt(question, question_type, options, prompt_cfg)
- index2ans, all_choices = data_utils.get_multi_choice_info(options) if options else ({}, [])
-
image_paths = _coerce_image_paths(
data.get("image", []),
image_dir=image_dir,
data_id=data_id,
max_images=max_images,
)
-
- payload = {
+ return {
"backbone": backbone,
"task": "understanding",
"prompt": prompt,
@@ -276,14 +319,29 @@ def run_mmmu_eval_command(args: Any) -> int:
"params": request_params,
"metadata": {"question_type": question_type, "data_id": data_id},
}
- output = pipeline.run(payload)
- response = _extract_text(output)
+
+ def record_fn(sample: Any, raw: Any, _idx: int) -> dict[str, Any]:
+ data = data_utils.process_single_sample(sample)
+ data_id = str(data["id"])
+ question = str(data["question"]).strip()
+ question_type = str(data["question_type"])
+ options = (
+ eval(data["options"])
+ if isinstance(data.get("options"), str)
+ else data.get("options", [])
+ )
+ if not isinstance(options, list):
+ options = []
+ prompt = _build_prompt(question, question_type, options, prompt_cfg)
+ index2ans, all_choices = (
+ data_utils.get_multi_choice_info(options) if options else ({}, [])
+ )
+ response = _extract_text(raw)
if question_type == "multiple-choice" and all_choices and index2ans:
pred = eval_utils.parse_multi_choice_response(response, all_choices, index2ans)
else:
pred = response
-
- item = {
+ return {
"question": question,
"answer": pred,
"gt_answers": data.get("answer"),
@@ -292,52 +350,86 @@ def run_mmmu_eval_command(args: Any) -> int:
"prompt": prompt,
"raw_response": response,
}
- outputs.append(item)
- ckpt_writer.write(json.dumps(item) + "\n")
- ckpt_writer.flush()
-
- if max_samples > 0 and len(outputs) >= max_samples:
- break
-
- time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime())
- output_json = out_dir / f"{ds_name}_{time_prefix}.json"
- output_jsonl = out_dir / f"{ds_name}_{time_prefix}.jsonl"
-
- output_dict = {item["data_id"]: item["answer"] for item in outputs}
- output_json.write_text(json.dumps(output_dict, indent=4), encoding="utf-8")
- with output_jsonl.open("w", encoding="utf-8") as writer:
- for item in outputs:
- writer.write(json.dumps(item) + "\n")
-
- # Clean up checkpoint after successful completion
- if checkpoint_jsonl.exists():
- checkpoint_jsonl.unlink()
-
- summary[f"{ds_name}_output_path"] = str(output_json)
- summary[f"{ds_name}_output_jsonl"] = str(output_jsonl)
-
- if run_calculation and split == "validation":
- cmd = [
- sys.executable,
- str(calculation_script),
- "--output_path",
- str(output_json),
- "--answer_path",
- str(answer_path),
- ]
- proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True)
- print(proc.stdout)
- if proc.returncode != 0:
- if proc.stderr:
- print(proc.stderr, file=sys.stderr)
- raise RuntimeError(f"MMMU calculation failed with return code {proc.returncode}")
- summary[f"{ds_name}_calculation_stdout"] = proc.stdout
-
- if isinstance(score_output_path, str) and score_output_path:
- score_path = _resolve_path(score_output_path, repo_root)
- score_path.parent.mkdir(parents=True, exist_ok=True)
- score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
- print(f"[umm eval] wrote MMMU summary to {score_path}")
-
- print(f"[umm eval] completed MMMU for backbone={backbone}, outputs={out_dir}")
- return 0
+
+ def sample_id_fn(sample: Any) -> str:
+ data = data_utils.process_single_sample(sample)
+ return str(data["id"])
+
+ n_written = run_sharded_inference(
+ infer_fn=pipeline.run,
+ dist_info=dist_info,
+ shard_path=shard_path,
+ samples=dataset,
+ total=total,
+ payload_fn=payload_fn,
+ record_fn=record_fn,
+ sample_id_fn=sample_id_fn,
+ done_ids=done_ids,
+ max_samples=max_samples,
+ log_prefix=f"mmmu/{ds_name}/rank{dist_info.rank}",
+ )
+ local_total_written += n_written
+
+ barrier(dist_info)
+
+ time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime())
+ output_json = out_dir / f"{ds_name}_{time_prefix}.json"
+ output_jsonl = out_dir / f"{ds_name}_{time_prefix}.jsonl"
+
+ if dist_info.rank == 0:
+ merged_outputs = merge_shards(checkpoint_jsonl)
+ output_dict = {item["data_id"]: item["answer"] for item in merged_outputs}
+ output_json.write_text(json.dumps(output_dict, indent=4), encoding="utf-8")
+ with output_jsonl.open("w", encoding="utf-8") as writer:
+ for item in merged_outputs:
+ writer.write(json.dumps(item) + "\n")
+
+ cleanup_shards(checkpoint_jsonl)
+ if dist_info.world_size <= 1 and checkpoint_jsonl.exists():
+ checkpoint_jsonl.unlink()
+
+ summary[f"{ds_name}_output_path"] = str(output_json)
+ summary[f"{ds_name}_output_jsonl"] = str(output_jsonl)
+
+ if run_calculation and split == "validation":
+ cmd = [
+ sys.executable,
+ str(calculation_script),
+ "--output_path",
+ str(output_json),
+ "--answer_path",
+ str(answer_path),
+ ]
+ proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True)
+ print(proc.stdout)
+ if proc.returncode != 0:
+ if proc.stderr:
+ print(proc.stderr, file=sys.stderr)
+ raise RuntimeError(f"MMMU calculation failed with return code {proc.returncode}")
+ summary[f"{ds_name}_calculation_stdout"] = proc.stdout
+
+ barrier(dist_info)
+
+ total_written_all = sum_across_ranks(local_total_written, dist_info)
+ if dist_info.rank != 0:
+ print(
+ f"[umm eval] rank {dist_info.rank} finished MMMU shard: "
+ f"samples_written={local_total_written}",
+ flush=True,
+ )
+ return 0
+
+ summary["samples_written"] = total_written_all
+ if isinstance(score_output_path, str) and score_output_path:
+ score_path = _resolve_path(score_output_path, repo_root)
+ score_path.parent.mkdir(parents=True, exist_ok=True)
+ score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
+ print(f"[umm eval] wrote MMMU summary to {score_path}")
+
+ print(
+ f"[umm eval] completed MMMU for backbone={backbone}, outputs={out_dir}, "
+ f"samples_written={total_written_all}, world_size={dist_info.world_size}"
+ )
+ return 0
+ finally:
+ cleanup_distributed(dist_info)
diff --git a/src/umm/cli/mmvet_eval.py b/src/umm/cli/mmvet_eval.py
index b277b5d..0af6c54 100644
--- a/src/umm/cli/mmvet_eval.py
+++ b/src/umm/cli/mmvet_eval.py
@@ -1,15 +1,24 @@
from __future__ import annotations
import json
-import sys
import time
from pathlib import Path
-from typing import Any
+from typing import Any, Iterator
from PIL import Image
-from tqdm import tqdm
from umm.core.config import load_config
+from umm.eval.distributed import (
+ barrier,
+ cleanup_distributed,
+ cleanup_shards,
+ load_shard_items,
+ maybe_init_distributed,
+ merge_shards,
+ rank_shard_path,
+ sum_across_ranks,
+)
+from umm.eval.runner import run_sharded_inference
from umm.inference import InferencePipeline
@@ -56,7 +65,6 @@ def _extract_text(output: Any) -> str:
text = _extract_text(item)
if text:
return text
- # Handle adapters that return {"understandings": [{"response": "..."}]}
for list_key in ("understandings",):
container = output.get(list_key)
if isinstance(container, list):
@@ -82,6 +90,11 @@ def _load_eval_cfg(config_path: str) -> tuple[dict[str, Any], dict[str, Any], di
return eval_cfg, mmvet_cfg, inference_cfg
+def _normalize_output_key(question_id: Any) -> str:
+ qid = str(question_id)
+ return qid if qid.startswith("v1_") else f"v1_{qid}"
+
+
def run_mmvet_eval_command(args: Any) -> int:
config_path = str(args.config)
eval_cfg, mmvet_cfg, inference_cfg = _load_eval_cfg(config_path)
@@ -125,97 +138,154 @@ def run_mmvet_eval_command(args: Any) -> int:
if not isinstance(dataset_paths, dict):
dataset_paths = {}
- out_dir.mkdir(parents=True, exist_ok=True)
- pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
+ dist_info = maybe_init_distributed()
+ try:
+ out_dir.mkdir(parents=True, exist_ok=True)
+ pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg)
+
+ summary: dict[str, Any] = {
+ "benchmark": "mmvet",
+ "backbone": backbone,
+ "out_dir": str(out_dir),
+ "datasets": datasets,
+ "world_size": dist_info.world_size,
+ }
+
+ if dist_info.world_size > 1:
+ print(
+ f"[mmvet] distributed inference enabled: rank={dist_info.rank}, "
+ f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}",
+ flush=True,
+ )
- summary: dict[str, Any] = {
- "benchmark": "mmvet",
- "backbone": backbone,
- "out_dir": str(out_dir),
- "datasets": datasets,
- }
+ local_total_written = 0
+ for ds_name in datasets:
+ entry = DS_COLLECTIONS.get(ds_name)
+ if not entry and ds_name not in dataset_paths:
+ raise ValueError(f"Unknown MM-Vet dataset: {ds_name}")
+ image_root_value = dataset_paths.get("image_root")
+ question_value = dataset_paths.get("question")
+ if not image_root_value or not question_value:
+ raise ValueError(
+ "MM-Vet requires `mmvet.dataset_paths.image_root` and "
+ "`mmvet.dataset_paths.question` to be set in the YAML config."
+ )
+ image_root = _resolve_path(str(image_root_value), repo_root)
+ question_path = _resolve_path(str(question_value), repo_root)
+ if not image_root.exists():
+ raise FileNotFoundError(f"MM-Vet image root not found: {image_root}")
+ if not question_path.exists():
+ raise FileNotFoundError(f"MM-Vet question file not found: {question_path}")
+
+ checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl"
+ shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size)
+
+ done_keys = {
+ str(it.get("output_key", "")) for it in load_shard_items(shard_path)
+ }
+ if done_keys:
+ print(
+ f"[mmvet] {ds_name}: rank {dist_info.rank} resuming after "
+ f"{len(done_keys)} shard items",
+ flush=True,
+ )
+
+ lines = [l.strip() for l in question_path.read_text("utf-8").splitlines() if l.strip()]
+ print(
+ f"[mmvet] {ds_name}: total={len(lines)}, rank={dist_info.rank}, "
+ f"done={len(done_keys)}",
+ flush=True,
+ )
- for ds_name in datasets:
- entry = DS_COLLECTIONS.get(ds_name)
- if not entry and ds_name not in dataset_paths:
- raise ValueError(f"Unknown MM-Vet dataset: {ds_name}")
- image_root_value = dataset_paths.get("image_root")
- question_value = dataset_paths.get("question")
- if not image_root_value or not question_value:
- raise ValueError(
- "MM-Vet requires `mmvet.dataset_paths.image_root` and "
- "`mmvet.dataset_paths.question` to be set in the YAML config."
+ def iter_rows() -> Iterator[dict[str, Any]]:
+ for line in lines:
+ yield json.loads(line)
+
+ def payload_fn(row: dict[str, Any]) -> dict[str, Any]:
+ image_name = row["image"]
+ question = row["text"]
+ question_id = row["question_id"]
+ image_path = image_root / image_name
+ if not image_path.exists():
+ raise FileNotFoundError(f"MM-Vet image not found: {image_path}")
+ try:
+ with Image.open(image_path) as img:
+ img.verify()
+ except Exception as exc:
+ raise RuntimeError(f"Failed to open image {image_path}: {exc}") from exc
+ return {
+ "backbone": backbone,
+ "task": "understanding",
+ "prompt": question,
+ "images": [str(image_path)],
+ "params": request_params,
+ "metadata": {"question_id": question_id, "dataset": ds_name},
+ }
+
+ def record_fn(row: dict[str, Any], raw: Any, _idx: int) -> dict[str, Any]:
+ response = _extract_text(raw)
+ output_key = _normalize_output_key(row["question_id"])
+ return {
+ "output_key": output_key,
+ "response": response,
+ }
+
+ def sample_id_fn(row: dict[str, Any]) -> str:
+ return _normalize_output_key(row["question_id"])
+
+ n_written = run_sharded_inference(
+ infer_fn=pipeline.run,
+ dist_info=dist_info,
+ shard_path=shard_path,
+ samples=iter_rows(),
+ total=len(lines),
+ payload_fn=payload_fn,
+ record_fn=record_fn,
+ sample_id_fn=sample_id_fn,
+ done_ids=done_keys,
+ max_samples=max_samples,
+ log_prefix=f"mmvet/{ds_name}/rank{dist_info.rank}",
)
- image_root = _resolve_path(str(image_root_value), repo_root)
- question_path = _resolve_path(str(question_value), repo_root)
- if not image_root.exists():
- raise FileNotFoundError(f"MM-Vet image root not found: {image_root}")
- if not question_path.exists():
- raise FileNotFoundError(f"MM-Vet question file not found: {question_path}")
-
- # Resume: load checkpoint if exists
- checkpoint_json = out_dir / f"{ds_name}_checkpoint.json"
- outputs: dict[str, str] = {}
- if checkpoint_json.exists():
- outputs = json.loads(checkpoint_json.read_text("utf-8"))
- print(f"[mmvet] resume: {len(outputs)} done, skipping completed items", flush=True)
-
- lines = [l.strip() for l in question_path.read_text("utf-8").splitlines() if l.strip()]
- print(f"[mmvet] {ds_name}: {len(lines)} total, {len(outputs)} done", flush=True)
-
- for idx, line in enumerate(tqdm(lines, desc=f"mmvet/{ds_name}", file=sys.stdout), start=1):
- row = json.loads(line)
- image_name = row["image"]
- question = row["text"]
- question_id = row["question_id"]
- # question_id already has "v1_" prefix in the dataset
- output_key = str(question_id) if str(question_id).startswith("v1_") else f"v1_{question_id}"
-
- if output_key in outputs:
- continue
-
- image_path = image_root / image_name
- if not image_path.exists():
- raise FileNotFoundError(f"MM-Vet image not found: {image_path}")
-
- try:
- with Image.open(image_path) as img:
- img.verify()
- except Exception as exc:
- raise RuntimeError(f"Failed to open image {image_path}: {exc}") from exc
-
- payload = {
- "backbone": backbone,
- "task": "understanding",
- "prompt": question,
- "images": [str(image_path)],
- "params": request_params,
- "metadata": {"question_id": question_id, "dataset": ds_name},
- }
- response = _extract_text(pipeline.run(payload))
- outputs[output_key] = response
+ local_total_written += n_written
- # Write checkpoint after each item
- checkpoint_json.write_text(json.dumps(outputs, indent=2), encoding="utf-8")
+ barrier(dist_info)
- if max_samples > 0 and idx >= max_samples:
- break
+ time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime())
+ results_file = out_dir / f"{ds_name}_{time_prefix}.json"
- time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime())
- results_file = out_dir / f"{ds_name}_{time_prefix}.json"
- results_file.write_text(json.dumps(outputs, indent=2), encoding="utf-8")
+ if dist_info.rank == 0:
+ merged = merge_shards(checkpoint_jsonl)
+ outputs: dict[str, str] = {item["output_key"]: item["response"] for item in merged}
+ results_file.write_text(json.dumps(outputs, indent=2), encoding="utf-8")
- # Clean up checkpoint after successful completion
- if checkpoint_json.exists():
- checkpoint_json.unlink()
+ cleanup_shards(checkpoint_jsonl)
+ if dist_info.world_size <= 1 and checkpoint_jsonl.exists():
+ checkpoint_jsonl.unlink()
- summary[f"{ds_name}_output_path"] = str(results_file)
+ summary[f"{ds_name}_output_path"] = str(results_file)
- if isinstance(score_output_path, str) and score_output_path:
- score_path = _resolve_path(score_output_path, repo_root)
- score_path.parent.mkdir(parents=True, exist_ok=True)
- score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
- print(f"[umm eval] wrote MM-Vet summary to {score_path}")
+ barrier(dist_info)
- print(f"[umm eval] completed MM-Vet for backbone={backbone}, outputs={out_dir}")
- return 0
+ total_written_all = sum_across_ranks(local_total_written, dist_info)
+ if dist_info.rank != 0:
+ print(
+ f"[umm eval] rank {dist_info.rank} finished MM-Vet shard: "
+ f"samples_written={local_total_written}",
+ flush=True,
+ )
+ return 0
+
+ summary["samples_written"] = total_written_all
+ if isinstance(score_output_path, str) and score_output_path:
+ score_path = _resolve_path(score_output_path, repo_root)
+ score_path.parent.mkdir(parents=True, exist_ok=True)
+ score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
+ print(f"[umm eval] wrote MM-Vet summary to {score_path}")
+
+ print(
+ f"[umm eval] completed MM-Vet for backbone={backbone}, outputs={out_dir}, "
+ f"samples_written={total_written_all}, world_size={dist_info.world_size}"
+ )
+ return 0
+ finally:
+ cleanup_distributed(dist_info)
diff --git a/src/umm/eval/distributed.py b/src/umm/eval/distributed.py
new file mode 100644
index 0000000..89c9a06
--- /dev/null
+++ b/src/umm/eval/distributed.py
@@ -0,0 +1,164 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Distributed-evaluation plumbing shared by understanding-class eval CLIs.
+
+These helpers are intentionally narrow: they cover process-group init/teardown,
+rank-shard path naming, and JSONL shard merge/cleanup. Per-benchmark output
+formatting and scoring stay in each CLI.
+
+Single-card mode (``WORLD_SIZE <= 1``) is a noop everywhere — no process group
+is initialized and ``rank_shard_path`` returns the base path unchanged, so the
+caller's on-disk filenames are unaffected.
+
+`torch` is imported lazily inside functions so that callers that never enter
+distributed mode don't pay the import cost.
+"""
+from __future__ import annotations
+
+import json
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+
+@dataclass(frozen=True)
+class DistInfo:
+ rank: int
+ world_size: int
+ local_rank: int
+ enabled: bool # True iff this process actually init'd a process group
+
+
+def get_dist_info() -> DistInfo:
+ """Read RANK / WORLD_SIZE / LOCAL_RANK from the environment. Does not init."""
+ return DistInfo(
+ rank=int(os.environ.get("RANK", "0")),
+ world_size=int(os.environ.get("WORLD_SIZE", "1")),
+ local_rank=int(os.environ.get("LOCAL_RANK", "0")),
+ enabled=False,
+ )
+
+
+def maybe_init_distributed() -> DistInfo:
+ """Initialize ``torch.distributed`` when ``WORLD_SIZE > 1``; otherwise noop."""
+ info = get_dist_info()
+ if info.world_size <= 1:
+ return info
+
+ import torch
+ import torch.distributed as dist
+
+ if torch.cuda.is_available():
+ torch.cuda.set_device(info.local_rank)
+ if not dist.is_initialized():
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
+ dist.init_process_group(backend=backend, init_method="env://")
+ return DistInfo(
+ rank=info.rank,
+ world_size=info.world_size,
+ local_rank=info.local_rank,
+ enabled=True,
+ )
+
+
+def cleanup_distributed(info: DistInfo) -> None:
+ if not info.enabled:
+ return
+ import torch.distributed as dist
+
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+
+def barrier(info: DistInfo) -> None:
+ if not info.enabled:
+ return
+ import torch.distributed as dist
+
+ if dist.is_initialized():
+ dist.barrier()
+
+
+def sum_across_ranks(value: int, info: DistInfo) -> int:
+ if not info.enabled:
+ return value
+ import torch
+ import torch.distributed as dist
+
+ if not dist.is_initialized():
+ return value
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ tensor = torch.tensor([value], dtype=torch.long, device=device)
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+ return int(tensor.item())
+
+
+def rank_shard_path(base: Path, rank: int, world_size: int) -> Path:
+ """Return the rank-local shard path for ``base``.
+
+ Single-card mode (``world_size <= 1``) returns ``base`` unchanged so the
+ caller's existing on-disk filenames are preserved.
+ """
+ if world_size <= 1:
+ return base
+ return base.parent / f"{base.stem}.rank{rank}{base.suffix}"
+
+
+def _shard_glob_pattern(base: Path) -> str:
+ return f"{base.stem}.rank*{base.suffix}"
+
+
+def load_shard_items(shard: Path) -> list[dict[str, Any]]:
+ """Read a JSONL shard, return [] if missing."""
+ items: list[dict[str, Any]] = []
+ if not shard.exists():
+ return items
+ with shard.open("r", encoding="utf-8") as reader:
+ for line in reader:
+ line = line.strip()
+ if not line:
+ continue
+ items.append(json.loads(line))
+ return items
+
+
+def merge_shards(base: Path) -> list[dict[str, Any]]:
+ """Merge all rank shards for ``base`` into a single sorted list.
+
+ Globs ``{stem}.rank*{suffix}`` plus ``base`` itself if present, so a re-run
+ that uses a smaller world_size doesn't silently drop rows from earlier
+ larger runs. Items are sorted by ``_sample_idx`` (injected by the runner).
+ """
+ candidates: list[Path] = []
+ if base.exists():
+ candidates.append(base)
+ candidates.extend(sorted(base.parent.glob(_shard_glob_pattern(base))))
+
+ merged: list[tuple[int, dict[str, Any]]] = []
+ seen_keys: set[tuple[int, int]] = set()
+ for path in candidates:
+ for item in load_shard_items(path):
+ key = int(item.get("_sample_idx", 0))
+ # Defensive: if both base and a shard have the same _sample_idx
+ # (shouldn't happen in practice), keep the first occurrence.
+ dedup = (key, id(item))
+ if dedup in seen_keys:
+ continue
+ seen_keys.add(dedup)
+ merged.append((key, item))
+ merged.sort(key=lambda pair: pair[0])
+ return [item for _, item in merged]
+
+
+def cleanup_shards(base: Path) -> None:
+ """Delete all rank shards matching ``{stem}.rank*{suffix}``.
+
+ The base path itself is NOT touched — the caller is responsible for the
+ final merged file (which often lives at a different timestamped name).
+ """
+ for shard in base.parent.glob(_shard_glob_pattern(base)):
+ try:
+ shard.unlink()
+ except FileNotFoundError:
+ pass
diff --git a/src/umm/eval/runner.py b/src/umm/eval/runner.py
new file mode 100644
index 0000000..4a3888b
--- /dev/null
+++ b/src/umm/eval/runner.py
@@ -0,0 +1,99 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Sharded inference loop shared by understanding-class eval CLIs.
+
+The runner is intentionally minimal: it iterates a caller-supplied ``samples``
+iterable, sends one payload per assigned sample through ``infer_fn``, and
+appends a JSONL line (with an injected ``_sample_idx``) to a per-rank shard.
+
+Per-benchmark output formatting (Excel, scoring, calculation scripts) stays in
+the calling CLI behind ``if rank == 0:`` — the runner does not own those.
+
+Single-card mode is supported by passing a ``DistInfo`` with ``world_size <= 1``;
+all rank/skip checks become noops.
+"""
+from __future__ import annotations
+
+import json
+import os
+import sys
+from pathlib import Path
+from typing import Any, Callable, Hashable, Iterable, Optional, TypeVar
+
+from tqdm import tqdm
+
+from umm.eval.distributed import DistInfo
+
+
+T = TypeVar("T")
+
+
+def run_sharded_inference(
+ *,
+ infer_fn: Callable[[dict[str, Any]], Any],
+ dist_info: DistInfo,
+ shard_path: Path,
+ samples: Iterable[T],
+ payload_fn: Callable[[T], dict[str, Any]],
+ record_fn: Callable[[T, Any, int], dict[str, Any]],
+ total: Optional[int] = None,
+ sample_id_fn: Optional[Callable[[T], Hashable]] = None,
+ done_ids: Optional[set[Hashable]] = None,
+ max_samples: int = 0,
+ log_prefix: str = "eval",
+) -> int:
+ """Iterate ``samples``; run ``infer_fn`` for samples assigned to this rank;
+ append each result as a JSONL line to ``shard_path``.
+
+ Loop semantics (sample_idx is 1-based, in iteration order):
+ - ``max_samples > 0``: stop when ``sample_idx > max_samples`` (global cap).
+ - World size > 1: skip when ``(sample_idx - 1) % world_size != rank``.
+ - Resume: skip when ``sample_id_fn(sample) in done_ids``.
+ - Otherwise: ``payload = payload_fn(sample); raw = infer_fn(payload)``;
+ ``item = record_fn(sample, raw, sample_idx)``; ``item["_sample_idx"] =
+ sample_idx``; write JSONL line + flush + fsync.
+
+ Pre-filter samples that should be skipped *before* this runner sees them
+ (e.g. missing-image rows in MME) by wrapping ``samples`` in a generator —
+ the runner's contract is "every sample yielded gets processed or sharded".
+
+ Returns the number of samples written by THIS rank's shard.
+ """
+ rank = dist_info.rank
+ world_size = dist_info.world_size
+ done = done_ids if done_ids is not None else set()
+
+ n_written = 0
+ shard_path.parent.mkdir(parents=True, exist_ok=True)
+ with shard_path.open("a", encoding="utf-8") as shard_writer:
+ iterator = tqdm(
+ samples,
+ total=total,
+ desc=log_prefix,
+ file=sys.stdout,
+ disable=(rank != 0),
+ )
+ for sample_idx, sample in enumerate(iterator, start=1):
+ if max_samples > 0 and sample_idx > max_samples:
+ break
+ if world_size > 1 and (sample_idx - 1) % world_size != rank:
+ continue
+ if sample_id_fn is not None:
+ sample_key = sample_id_fn(sample)
+ if sample_key in done:
+ continue
+ else:
+ sample_key = None
+
+ payload = payload_fn(sample)
+ raw = infer_fn(payload)
+ item = dict(record_fn(sample, raw, sample_idx))
+ item["_sample_idx"] = sample_idx
+ shard_writer.write(json.dumps(item, ensure_ascii=False) + "\n")
+ shard_writer.flush()
+ os.fsync(shard_writer.fileno())
+ n_written += 1
+ if sample_key is not None:
+ done.add(sample_key)
+
+ return n_written