diff --git a/src/umm/cli/mathvista_eval.py b/src/umm/cli/mathvista_eval.py index e72f35a..d1ef1f3 100644 --- a/src/umm/cli/mathvista_eval.py +++ b/src/umm/cli/mathvista_eval.py @@ -12,6 +12,17 @@ from tqdm import tqdm from umm.core.config import load_config +from umm.eval.distributed import ( + barrier, + cleanup_distributed, + cleanup_shards, + load_shard_items, + maybe_init_distributed, + merge_shards, + rank_shard_path, + sum_across_ranks, +) +from umm.eval.runner import run_sharded_inference DS_COLLECTIONS = { @@ -44,7 +55,6 @@ def _quick_extract(response: str, problem: dict) -> str | None: - """Try rule-based extraction before falling back to LLM.""" if not response: return "" question_type = problem.get("question_type", "") @@ -64,18 +74,16 @@ def _quick_extract(response: str, problem: dict) -> str | None: except (ValueError, TypeError): pass - # Try regex for "Final answer: ..." or "Answer: ..." match = re.search(r"(?:Final answer:|Answer:)\s*(.*)", response, re.IGNORECASE) if match: ans = match.group(1).strip() if ans: return ans - return None # need LLM extraction + return None def _build_extract_prompt(query: str, response: str) -> str: - """Build the full prompt for LLM-based answer extraction.""" test_prompt = f"{query}\n\n{response}" return f"{_EXTRACT_DEMO_PROMPT.strip()}\n\n{test_prompt}\n\nExtracted answer: " @@ -87,13 +95,11 @@ def _run_llm_extraction( use_quick_extract: bool = True, ) -> dict[str, Any]: """Extract answers from model responses using a local LLM (e.g. Qwen3-32B).""" - # First pass: check which items already have extraction or can be rule-extracted already_done = 0 need_llm = [] for pid, problem in results.items(): if "choices" not in problem: continue - # Skip if extraction was already done in a previous run if "extraction" in problem and problem["extraction"]: already_done += 1 continue @@ -117,7 +123,6 @@ def _run_llm_extraction( flush=True, ) - # Only load the heavy LLM if there are items that actually need it if not need_llm: print("[mathvista] all items already extracted, skipping LLM load", flush=True) return results @@ -156,11 +161,9 @@ def _run_llm_extraction( ) generated = outputs[0][inputs["input_ids"].shape[1]:] extraction = tokenizer.decode(generated, skip_special_tokens=True).strip() - # Handle Qwen3 thinking format: strip ... block think_match = re.search(r"\s*(.*)", extraction, re.DOTALL) if think_match: extraction = think_match.group(1).strip() - # Take only the first line as the extracted answer extraction = extraction.split("\n")[0].strip() problem["extraction"] = extraction @@ -209,7 +212,6 @@ def _extract_text(output: Any) -> str: text = _extract_text(item) if text: return text - # Handle adapters that return {"understandings": [{"response": "..."}]} for list_key in ("understandings",): container = output.get(list_key) if isinstance(container, list): @@ -236,7 +238,6 @@ def _load_eval_cfg(config_path: str) -> tuple[dict[str, Any], dict[str, Any], di def _find_latest_results(out_dir: Path, ds_name: str) -> Path | None: - """Find the most recent results JSON for a dataset in out_dir.""" candidates = sorted(out_dir.glob(f"{ds_name}_*.json")) candidates = [c for c in candidates if "_checkpoint" not in c.name and "_score" not in c.name] if candidates: @@ -288,7 +289,6 @@ def run_mathvista_eval_command(args: Any) -> int: gt_file = mathvista_cfg.get("gt_file") resume = bool(mathvista_cfg.get("resume", False)) - # Mode support (like wise): generate / score / full mode = str(mathvista_cfg.get("mode", "full")).strip().lower() if mode not in ("full", "generate", "score"): print(f"[mathvista] unknown mode '{mode}', defaulting to 'full'", flush=True) @@ -297,14 +297,12 @@ def run_mathvista_eval_command(args: Any) -> int: run_gen = mode in ("full", "generate") run_score = mode in ("full", "score") - # LLM extraction config (replaces OpenAI) llm_extract_cfg = mathvista_cfg.get("llm_extract", {}) if not isinstance(llm_extract_cfg, dict): llm_extract_cfg = {} llm_model_path = str(llm_extract_cfg.get("model_path", "")).strip() llm_max_new_tokens = int(llm_extract_cfg.get("max_new_tokens", 2048)) - # Legacy OpenAI config (fallback when llm_extract is not configured) run_extract_legacy = bool(mathvista_cfg.get("run_extract", False)) run_calculation_legacy = bool(mathvista_cfg.get("run_calculation", False)) openai_api_key = mathvista_cfg.get("openai_api_key") @@ -320,195 +318,249 @@ def run_mathvista_eval_command(args: Any) -> int: "mode": mode, } - # ── Phase 1: Generation ── - if run_gen: - from datasets import load_dataset - from PIL import Image + dist_info = maybe_init_distributed() + try: + summary["world_size"] = dist_info.world_size + if dist_info.world_size > 1 and run_gen: + print( + f"[mathvista] distributed inference enabled: rank={dist_info.rank}, " + f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}", + flush=True, + ) + + local_total_written = 0 - from umm.inference import InferencePipeline + # ── Phase 1: Generation ── + if run_gen: + from datasets import load_dataset + from PIL import Image - pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) + from umm.inference import InferencePipeline - for ds_name in datasets: - entry = DS_COLLECTIONS.get(ds_name) - if not entry: - raise ValueError(f"Unknown MathVista dataset: {ds_name}") + pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) + + for ds_name in datasets: + entry = DS_COLLECTIONS.get(ds_name) + if not entry: + raise ValueError(f"Unknown MathVista dataset: {ds_name}") + + dataset_root = str(mathvista_cfg.get("root", entry["root"])) + split = str(mathvista_cfg.get("split", entry["split"])) + dataset = load_dataset( + dataset_root, + cache_dir=str(_resolve_path(cache_dir, repo_root)) if cache_dir else None, + ) + data = dataset[split] + + checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl" + shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size) + + done_pids: set[str] = set() + results_file: Path | None = None + + if resume: + shard_items = load_shard_items(shard_path) + if shard_items: + done_pids = {str(it.get("pid", "")) for it in shard_items} + print( + f"[mathvista] {ds_name}: rank {dist_info.rank} resume " + f"from shard: {len(done_pids)} done", + flush=True, + ) + elif dist_info.world_size <= 1: + # Single-card fallback: a previous completed run can be loaded + # whole. Multi-card runs rely on the rank shard only. + prior = _find_latest_results(out_dir, ds_name) + if prior: + print( + f"[mathvista] resume: using completed file {prior}", + flush=True, + ) + results_file = prior + + if results_file is None: + print( + f"[mathvista] {ds_name}: total={len(data)}, rank={dist_info.rank}, " + f"done={len(done_pids)}", + flush=True, + ) + + def payload_fn(data_item: Any) -> dict[str, Any]: + pid = data_item.get("pid") + if pid is None: + raise ValueError("MathVista sample missing `pid`.") + image = data_item.get("decoded_image") + if image is None: + raise ValueError("MathVista sample missing `decoded_image`.") + if not isinstance(image, Image.Image): + raise ValueError("Expected `decoded_image` to be a PIL image.") + image_dir.mkdir(parents=True, exist_ok=True) + image_path = image_dir / f"{pid}.png" + image.save(image_path, format="PNG") + question = data_item.get("query") + if question is None: + raise ValueError("MathVista sample missing `query`.") + prompt = COT_INSTRUCTION.format(question=question) if use_cot else question + return { + "backbone": backbone, + "task": "understanding", + "prompt": prompt, + "images": [str(image_path)], + "params": request_params, + "metadata": {"pid": pid, "dataset": ds_name}, + } + + def record_fn(data_item: Any, raw: Any, _idx: int) -> dict[str, Any]: + response = _extract_text(raw) + item = dict(data_item) + item.pop("decoded_image", None) + item["response"] = response + # Ensure pid is JSON-friendly so resume works. + item["pid"] = str(item.get("pid", "")) + return item + + n_written = run_sharded_inference( + infer_fn=pipeline.run, + dist_info=dist_info, + shard_path=shard_path, + samples=data, + total=len(data), + payload_fn=payload_fn, + record_fn=record_fn, + sample_id_fn=lambda d: str(d.get("pid", "")), + done_ids=done_pids, + max_samples=max_samples, + log_prefix=f"mathvista/{ds_name}/rank{dist_info.rank}", + ) + local_total_written += n_written + + barrier(dist_info) + + time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime()) + results_file = out_dir / f"{ds_name}_{time_prefix}.json" + + if dist_info.rank == 0: + merged = merge_shards(checkpoint_jsonl) + results: dict[str, Any] = {item["pid"]: item for item in merged} + results_file.write_text(json.dumps(results, indent=2), encoding="utf-8") + cleanup_shards(checkpoint_jsonl) + if dist_info.world_size <= 1 and checkpoint_jsonl.exists(): + checkpoint_jsonl.unlink() + + barrier(dist_info) + + summary[f"{ds_name}_output_path"] = str(results_file) + + if mode == "generate": + print(f"[mathvista] generation phase done, outputs={out_dir}", flush=True) + + # Free GPU memory from generation pipeline before scoring (rank 0 only + # touches scoring; other ranks return after the barriers above). + del pipeline + import gc + gc.collect() + import torch as _torch + if _torch.cuda.is_available(): + _torch.cuda.empty_cache() + print( + f"[mathvista] rank {dist_info.rank}: released generation pipeline GPU memory", + flush=True, + ) - dataset_root = str(mathvista_cfg.get("root", entry["root"])) - split = str(mathvista_cfg.get("split", entry["split"])) - dataset = load_dataset( - dataset_root, - cache_dir=str(_resolve_path(cache_dir, repo_root)) if cache_dir else None, + # Multi-card: ranks > 0 finish here; only rank 0 proceeds with scoring + + # summary writing (scoring loads its own LLM with device_map="auto"). + local_written_all = sum_across_ranks(local_total_written, dist_info) + if dist_info.rank != 0: + print( + f"[umm eval] rank {dist_info.rank} finished MathVista shard: " + f"samples_written={local_total_written}", + flush=True, ) - data = dataset[split] + return 0 - checkpoint_json = out_dir / f"{ds_name}_checkpoint.json" - results: dict[str, Any] = {} - results_file: Path | None = None + summary["samples_written"] = local_written_all - if resume: - if checkpoint_json.exists(): - results = json.loads(checkpoint_json.read_text("utf-8")) - print(f"[mathvista] resume from checkpoint: {len(results)} done", flush=True) + # ── Phase 2: Scoring (extract + calculate) — rank 0 only ── + if run_score: + for ds_name in datasets: + results_file = None + if f"{ds_name}_output_path" in summary: + results_file = Path(summary[f"{ds_name}_output_path"]) else: results_file = _find_latest_results(out_dir, ds_name) - if results_file: - print(f"[mathvista] resume: using completed file {results_file}", flush=True) - - if results_file is None: - print( - f"[mathvista] {ds_name}: {len(data)} total, {len(results)} done, " - f"{len(data) - len(results)} remaining", - flush=True, - ) - for idx, data_item in enumerate(tqdm(data, desc=f"mathvista/{ds_name}", file=sys.stdout), start=1): - pid = data_item.get("pid") - if pid is None: - raise ValueError("MathVista sample missing `pid`.") - if str(pid) in results: - continue - - image = data_item.get("decoded_image") - if image is None: - raise ValueError("MathVista sample missing `decoded_image`.") - if not isinstance(image, Image.Image): - raise ValueError("Expected `decoded_image` to be a PIL image.") - - image_dir.mkdir(parents=True, exist_ok=True) - image_path = image_dir / f"{pid}.png" - image.save(image_path, format="PNG") - - question = data_item.get("query") - if question is None: - raise ValueError("MathVista sample missing `query`.") + if results_file is None or not results_file.exists(): + raise FileNotFoundError( + f"No results file found for {ds_name} in {out_dir}. " + f"Run generation phase first (mode: generate)." + ) + + print(f"[mathvista] scoring {ds_name} from {results_file}", flush=True) + results = json.loads(results_file.read_text("utf-8")) + + if llm_model_path: + results = _run_llm_extraction( + results, + model_path=llm_model_path, + max_new_tokens=llm_max_new_tokens, + use_quick_extract=use_cot, + ) + results_file.write_text(json.dumps(results, indent=2), encoding="utf-8") + print(f"[mathvista] saved extractions to {results_file}", flush=True) + elif run_extract_legacy: + cmd = [ + sys.executable, + "src/umm/eval/internvl_chat/eval/mathvista/extract_answer.py", + "--output_file", + results_file.name, + "--output_dir", + str(out_dir), + ] if use_cot: - prompt = COT_INSTRUCTION.format(question=question) - else: - prompt = question - - payload = { - "backbone": backbone, - "task": "understanding", - "prompt": prompt, - "images": [str(image_path)], - "params": request_params, - "metadata": {"pid": pid, "dataset": ds_name}, - } - response = _extract_text(pipeline.run(payload)) - - item = dict(data_item) - item.pop("decoded_image", None) - item["response"] = response - results[str(pid)] = item - - checkpoint_json.write_text(json.dumps(results, indent=2), encoding="utf-8") - - if max_samples > 0 and len(results) >= max_samples: - break - - time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime()) - results_file = out_dir / f"{ds_name}_{time_prefix}.json" - results_file.write_text(json.dumps(results, indent=2), encoding="utf-8") - - if checkpoint_json.exists(): - checkpoint_json.unlink() - - summary[f"{ds_name}_output_path"] = str(results_file) - - if mode == "generate": - print(f"[mathvista] generation phase done, outputs={out_dir}", flush=True) - - # Free GPU memory from generation pipeline before scoring - del pipeline - import gc - gc.collect() - import torch as _torch - if _torch.cuda.is_available(): - _torch.cuda.empty_cache() - print("[mathvista] released generation pipeline GPU memory", flush=True) - - # ── Phase 2: Scoring (extract + calculate) ── - if run_score: - for ds_name in datasets: - # Find results file from generation phase (or previous run) - results_file = None - if f"{ds_name}_output_path" in summary: - results_file = Path(summary[f"{ds_name}_output_path"]) - else: - results_file = _find_latest_results(out_dir, ds_name) - if results_file is None or not results_file.exists(): - raise FileNotFoundError( - f"No results file found for {ds_name} in {out_dir}. " - f"Run generation phase first (mode: generate)." - ) - - print(f"[mathvista] scoring {ds_name} from {results_file}", flush=True) - results = json.loads(results_file.read_text("utf-8")) - - # ── Extract answers ── - if llm_model_path: - # Use local Qwen model for extraction - results = _run_llm_extraction( - results, - model_path=llm_model_path, - max_new_tokens=llm_max_new_tokens, - use_quick_extract=use_cot, - ) - # Save results with extraction field - results_file.write_text(json.dumps(results, indent=2), encoding="utf-8") - print(f"[mathvista] saved extractions to {results_file}", flush=True) - elif run_extract_legacy: - # Fallback: use legacy OpenAI-based extract_answer.py + cmd.append("--quick_extract") + env = None + if isinstance(openai_api_key, str) and openai_api_key.strip(): + env = dict(os.environ) + env["OPENAI_API_KEY"] = openai_api_key.strip() + proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True, env=env) + print(proc.stdout) + if proc.returncode != 0: + if proc.stderr: + print(proc.stderr, file=sys.stderr) + raise RuntimeError(f"MathVista extract_answer failed with return code {proc.returncode}") + summary[f"{ds_name}_extract_stdout"] = proc.stdout + + score_file = results_file.with_name(f"{results_file.stem}_score.json") cmd = [ sys.executable, - "src/umm/eval/internvl_chat/eval/mathvista/extract_answer.py", + "src/umm/eval/internvl_chat/eval/mathvista/calculate_score.py", "--output_file", results_file.name, "--output_dir", str(out_dir), + "--score_file", + score_file.name, ] - if use_cot: - cmd.append("--quick_extract") - env = None - if isinstance(openai_api_key, str) and openai_api_key.strip(): - env = dict(os.environ) - env["OPENAI_API_KEY"] = openai_api_key.strip() - proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True, env=env) + if isinstance(gt_file, str) and gt_file.strip(): + cmd.extend(["--gt_file", gt_file.strip()]) + proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True) print(proc.stdout) if proc.returncode != 0: if proc.stderr: print(proc.stderr, file=sys.stderr) - raise RuntimeError(f"MathVista extract_answer failed with return code {proc.returncode}") - summary[f"{ds_name}_extract_stdout"] = proc.stdout - - # ── Calculate scores ── - score_file = results_file.with_name(f"{results_file.stem}_score.json") - cmd = [ - sys.executable, - "src/umm/eval/internvl_chat/eval/mathvista/calculate_score.py", - "--output_file", - results_file.name, - "--output_dir", - str(out_dir), - "--score_file", - score_file.name, - ] - if isinstance(gt_file, str) and gt_file.strip(): - cmd.extend(["--gt_file", gt_file.strip()]) - proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True) - print(proc.stdout) - if proc.returncode != 0: - if proc.stderr: - print(proc.stderr, file=sys.stderr) - raise RuntimeError(f"MathVista calculate_score failed with return code {proc.returncode}") - summary[f"{ds_name}_score_file"] = str(score_file) - summary[f"{ds_name}_score_stdout"] = proc.stdout - - if isinstance(score_output_path, str) and score_output_path: - score_path = _resolve_path(score_output_path, repo_root) - score_path.parent.mkdir(parents=True, exist_ok=True) - score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") - print(f"[umm eval] wrote MathVista summary to {score_path}") - - print(f"[umm eval] completed MathVista (mode={mode}) for backbone={backbone}, outputs={out_dir}") - return 0 + raise RuntimeError(f"MathVista calculate_score failed with return code {proc.returncode}") + summary[f"{ds_name}_score_file"] = str(score_file) + summary[f"{ds_name}_score_stdout"] = proc.stdout + + if isinstance(score_output_path, str) and score_output_path: + score_path = _resolve_path(score_output_path, repo_root) + score_path.parent.mkdir(parents=True, exist_ok=True) + score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + print(f"[umm eval] wrote MathVista summary to {score_path}") + + print( + f"[umm eval] completed MathVista (mode={mode}) for backbone={backbone}, " + f"outputs={out_dir}, world_size={dist_info.world_size}" + ) + return 0 + finally: + cleanup_distributed(dist_info) diff --git a/src/umm/cli/mmbench_eval.py b/src/umm/cli/mmbench_eval.py index 7b14fa1..97d5f64 100644 --- a/src/umm/cli/mmbench_eval.py +++ b/src/umm/cli/mmbench_eval.py @@ -1,12 +1,8 @@ from __future__ import annotations import base64 -import copy import json import os -import random -import string -import sys import time from io import BytesIO from pathlib import Path @@ -14,9 +10,20 @@ import pandas as pd from PIL import Image -from tqdm import tqdm from umm.core.config import load_config +from umm.eval.distributed import ( + barrier, + cleanup_distributed, + cleanup_shards, + load_shard_items, + maybe_init_distributed, + merge_shards, + rank_shard_path, + sum_across_ranks, +) +from umm.eval.runner import run_sharded_inference +from umm.inference import InferencePipeline DS_COLLECTIONS = { @@ -50,387 +57,9 @@ "type": "dev", "language": "cn", }, - # V11 test splits — released with GT labels after the online eval service closed - # on 2026-03-31 (see open-compass/MMBench#61). Same TSV schema as dev (four - # rotations per question, index = base + k * 1_000_000 for k = 0..3). - "MMBench_TEST_EN_V11": { - "root": "/datasets/mmbench/MMBench_TEST_EN_V11.tsv", - "type": "test", - "language": "en", - }, - "MMBench_TEST_CN_V11": { - "root": "/datasets/mmbench/MMBench_TEST_CN_V11.tsv", - "type": "test", - "language": "cn", - }, } -# Port of VLMEvalKit MMB_abbrs (vlmeval/dataset/utils/multiple_choice.py:17-24) -MMB_ABBRS = { - "coarse_perception": "CP", - "finegrained_perception (instance-level)": "FP-S", - "finegrained_perception (cross-instance)": "FP-C", - "logic_reasoning": "LR", - "relation_reasoning": "RR", - "attribute_reasoning": "AR", -} - - -# VLMEvalKit uses the same English scaffold for both EN and CN splits — only the -# question / option / hint fields themselves are localised inside the TSV. -# (vlmeval/dataset/image_mcq.py:210-245) -_ANSWER_INSTRUCTION = "Please select the correct answer from the options above. \n" - - -# --------------------------------------------------------------------------- -# Generation-time prompt (ports vlmeval/dataset/image_mcq.py:210-245) -# --------------------------------------------------------------------------- - - -def _build_prompt(question: str, options: "dict[str, str]", hint: str | None) -> str: - options_prompt = "Options:\n" - for key, item in options.items(): - options_prompt += f"{key}. {item}\n" - prompt = "" - if hint: - prompt += f"Hint: {hint}\n" - prompt += f"Question: {question}\n" - if options: - prompt += options_prompt - prompt += _ANSWER_INSTRUCTION - return prompt - - -# --------------------------------------------------------------------------- -# Exact-matching extraction — ports vlmeval/utils/matching_util.py:12-116 -# --------------------------------------------------------------------------- - - -_REJECT_TO_ANSWER = [ - "Sorry, I can't help with images of people yet.", - "I can't process this file.", - "I'm sorry, but without the image provided", - "Cannot determine the answer", -] -_EXACT_CHARS_TO_SPACE = ".()[],:;!*#{}" - - -def _can_infer_option(answer: str, choices: "dict[str, str]") -> "str | bool": - if "Failed to obtain answer via API" in answer: - return False - for err in _REJECT_TO_ANSWER: - if err in answer: - return "Z" - - answer_mod = copy.copy(answer) - for c in _EXACT_CHARS_TO_SPACE: - answer_mod = answer_mod.replace(c, " ") - splits = [x.strip() for x in answer_mod.split()] - - def count_choice(tokens, cands, prefix: str = "", suffix: str = "") -> int: - return sum(1 for c in cands if (prefix + c + suffix) in tokens) - - count = count_choice(splits, choices) - if count == 1: - for ch in choices: - if ch in splits and splits.index(ch) > len(splits) - 5: - return ch - elif count == 0 and count_choice(splits, {"Z", ""}) == 1: - return "Z" - return False - - -def _can_infer_text(answer: str, choices: "dict[str, str]") -> "str | bool": - answer_low = answer.lower() - total_len = sum(len(str(v)) for v in choices.values()) - if len(answer_low) > 2 * total_len: - return False - lowered = {k: str(v).lower() for k, v in choices.items()} - cands = [k for k, v in lowered.items() if v in answer_low] - if len(cands) == 1: - return cands[0] - return False - - -def _can_infer(answer: Any, choices: "dict[str, str]") -> "str | bool": - answer = str(answer) - copt = _can_infer_option(answer, choices) - return copt if copt else _can_infer_text(answer, choices) - - -# --------------------------------------------------------------------------- -# LLM judge — ports vlmeval/dataset/utils/multiple_choice.py build_prompt(_cn) -# and build_option_str / cn_string. -# --------------------------------------------------------------------------- - - -_JUDGE_PROMPT_EN = ( - "You are an AI assistant who will help me to match " - "an answer with several options of a single-choice question. " - "You are provided with a question, several options, and an answer, " - "and you need to find which option is most similar to the answer. " - "If the meaning of all options are significantly different from the answer, output Z. " - "Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n" - "Example 1: \n" - "Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n" - "Answer: a cute teddy bear\nYour output: A\n" - "Example 2: \n" - "Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n" - "Answer: Spider\nYour output: Z\n" - "Example 3: \n" - "Question: {}?\nOptions: {}\nAnswer: {}\nYour output: " -) - - -_JUDGE_PROMPT_CN = ( - "你是一个帮助我匹配答案与单选题中多个选项的 AI 助手。" - "你会被提供:一个问题,多个选项,一个答案。你的任务是找到与答案意义最相近的选项。" - "如果所有选项的意义都与答案显著不同,则输出 Z。" - "你应该输出一个单个的大写字母,例如 A, B, C, D(如果它们是有效选项),或 Z。" - "例 1:" - "问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 一只可爱的泰迪熊\n输出: A\n" - "例 2: \n" - "问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 蜘蛛\n输出: Z\n" - "例 3: \n" - "问题: {}?\n选项: {}\n答案: {}\n输出: " -) - - -def _build_option_str(options: "dict[str, str]") -> str: - s = "There are several options: \n" - for c, content in options.items(): - if content is None: - continue - if isinstance(content, float) and pd.isna(content): - continue - s += f"{c}. {content}\n" - return s - - -def _cn_string(text: str) -> bool: - return any("\u4e00" <= ch <= "\u9fff" for ch in str(text)) - - -def _build_judge_prompt(question: str, options: "dict[str, str]", prediction: str, language: str) -> str: - option_str = _build_option_str(options) - if language == "cn" or _cn_string(question): - tmpl = _JUDGE_PROMPT_CN - else: - tmpl = _JUDGE_PROMPT_EN - return tmpl.format(question, option_str, prediction) - - -# --------------------------------------------------------------------------- -# Judge model bundle — lazy loaded, shared across all extractions in a run. -# --------------------------------------------------------------------------- - - -class _JudgeBundle: - def __init__(self, model_path: str, max_new_tokens: int = 32) -> None: - import torch - from transformers import AutoModelForCausalLM, AutoTokenizer - - print(f"[mmbench] loading judge LLM: {model_path}", flush=True) - self._torch = torch - self.max_new_tokens = max_new_tokens - self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - self.model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype="auto", - device_map="auto", - trust_remote_code=True, - ) - self.model.eval() - - def generate(self, prompt: str) -> str: - messages = [{"role": "user", "content": prompt}] - text = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, - ) - inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device) - with self._torch.no_grad(): - outputs = self.model.generate( - **inputs, - max_new_tokens=self.max_new_tokens, - do_sample=False, - ) - generated = outputs[0][inputs["input_ids"].shape[1]:] - raw = self.tokenizer.decode(generated, skip_special_tokens=True).strip() - # Strip Qwen3-style ... if present so can_infer sees the letter - if "" in raw: - raw = raw.split("", 1)[1].strip() - return raw - - def close(self) -> None: - del self.model - del self.tokenizer - if self._torch.cuda.is_available(): - self._torch.cuda.empty_cache() - - -# --------------------------------------------------------------------------- -# extract_answer_from_item — ports multiple_choice.py:359-406 -# --------------------------------------------------------------------------- - - -def _extract_answer_from_item(item: "dict[str, Any]", judge: "_JudgeBundle | None") -> "dict[str, str]": - """Return {'opt': letter, 'log': ...}, matching VLMEvalKit's extract_answer_from_item.""" - choices = item["choices"] - prediction = str(item.get("prediction", "")) - - ret = _can_infer(prediction, choices) - if ret: - return {"opt": str(ret), "log": prediction} - if judge is None: - return { - "opt": "Z", - "log": "Failed in Prefetch, no LLM-based answer matching under `exact_matching` policy.", - } - - prompt = _build_judge_prompt( - question=str(item.get("question", "")), - options=choices, - prediction=prediction, - language=item.get("language", "en"), - ) - for _ in range(3): - ans = judge.generate(prompt) - if "Failed to obtain answer via API" in ans: - continue - ret = _can_infer(ans, choices) - if ret: - return {"opt": str(ret), "log": ans} - # Random fallback (matches VLMEvalKit — the random choice ensures we still - # produce a concrete letter rather than silently skipping the item). - options_pool = list(choices) + (["Z"] if "Z" not in choices else []) - return {"opt": random.choice(options_pool), "log": "Failed to predict, thus randomly generate one."} - - -# --------------------------------------------------------------------------- -# Circular evaluation — ports multiple_choice.py:409-471 -# --------------------------------------------------------------------------- - - -def _prefetch_answer(item: "dict[str, Any]") -> "str | bool": - return _can_infer(str(item.get("prediction", "")), item["choices"]) - - -def _prefetch_circular_group( - sub_items: "list[dict[str, Any]]", -) -> "tuple[dict | None, list, list]": - """Returns (result_dict_or_None, gts, preds). - - If a definitive result can be decided via can_infer alone (all matched, or - any rotation's prefetched answer conflicts with its GT), returns a dict with - hit/log. Otherwise returns None (LLM fallback is needed). - """ - gts: list = [] - preds: list = [] - for i, item in enumerate(sub_items): - gts.append(item["gt_answer"]) - preds.append(_prefetch_answer(item)) - if preds[-1] and gts[-1] != preds[-1]: - return ( - { - "hit": 0, - "log": ( - f"Failed in Prefetching Rolling {i}: Answer is {gts[-1]}, " - f"Prediction is {item.get('prediction', '')}, " - f"Pre-fetched is {preds[-1]}." - ), - }, - gts, - preds, - ) - if all(g == p for g, p in zip(gts, preds)): - return {"hit": 1, "log": "Succeed During Pre-fetching"}, gts, preds - return None, gts, preds - - -def _eval_circular_group( - sub_items: "list[dict[str, Any]]", - judge: "_JudgeBundle | None", -) -> "dict[str, Any]": - """Evaluate one circular group (1-4 rotations of the same question). - - Returns {'hit': 0 or 1, 'log': ...}. hit=1 requires every rotation's - extracted letter to match its own rotated GT. - """ - result, gts, preds = _prefetch_circular_group(sub_items) - if result is not None: - return result - - log = "" - for i, item in enumerate(sub_items): - if preds[i]: - log += f"Rolling {i} Matched.\n" - continue - res = _extract_answer_from_item(item, judge) - opt, match_log = res["opt"], res["log"] - preds[i] = opt - if preds[i] != gts[i]: - log += ( - f"Failed in Rolling {i}: Answer is {gts[i]}; " - f"Prediction is {item.get('prediction', '')}; " - f"Pre-fetched is {preds[i]}; Match Log is {match_log}.\n" - ) - return {"hit": 0, "log": log} - log += ( - f"Rolling {i}: Answer is {gts[i]}, " - f"Prediction is {item.get('prediction', '')}, Pre-fetched is {preds[i]}.\n" - ) - return {"hit": 1, "log": log} - - -# --------------------------------------------------------------------------- -# report_acc — ports multiple_choice.py:77-100 -# --------------------------------------------------------------------------- - - -def _report_acc(df: pd.DataFrame) -> "dict[str, float]": - """Return a flat dict keyed by 'split=|' → accuracy in [0,1]. - - Empty (split, category) cells return NaN to match VLMEvalKit's - np.mean([]) behaviour rather than 0.0 (which would distort - average-of-averages aggregations downstream). - """ - res: "dict[str, float]" = {} - if "split" in df.columns: - splits = sorted(df["split"].dropna().unique().tolist()) - else: - df = df.copy() - df["split"] = "none" - splits = ["none"] - - for group in (None, "l2-category", "category"): - if group is None: - for sp in splits: - sub = df[df["split"] == sp]["hit"] - val = float(sub.mean()) if len(sub) else float("nan") - res[f"split={sp}|Overall"] = val - elif group not in df.columns: - continue - else: - abilities = sorted(df[group].dropna().unique().tolist()) - for ab in abilities: - ab_name = MMB_ABBRS.get(ab, ab) - sub_df = df[df[group] == ab] - for sp in splits: - cell = sub_df[sub_df["split"] == sp]["hit"] - val = float(cell.mean()) if len(cell) else float("nan") - res[f"split={sp}|{ab_name}"] = val - return res - - -# --------------------------------------------------------------------------- -# Helpers (unchanged) -# --------------------------------------------------------------------------- - - def _resolve_path(path_str: str, repo_root: Path) -> Path: path = Path(path_str).expanduser() if not path.is_absolute(): @@ -467,7 +96,6 @@ def _extract_text(output: Any) -> str: text = _extract_text(item) if text: return text - # Handle adapters that return {"understandings": [{"response": "..."}]} for list_key in ("understandings",): container = output.get(list_key) if isinstance(container, list): @@ -483,7 +111,21 @@ def _extract_text(output: Any) -> str: return "" -def _load_eval_cfg(config_path: str) -> "tuple[dict, dict, dict]": +def _post_process(pred: str, option: dict[str, str]) -> str: + pred = pred.strip() + option_candidate = list(option.keys()) + if len(pred) == 1: + return pred + if len(pred) > 1 and pred[0] in option_candidate: + return pred[0] + if len(pred) > 1 and pred[0] not in option_candidate: + for k, v in option.items(): + if v in pred: + return k + return pred + + +def _load_eval_cfg(config_path: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: raw_cfg = load_config(config_path) eval_cfg = raw_cfg.get("eval", {}) if isinstance(raw_cfg.get("eval"), dict) else {} mmbench_cfg = raw_cfg.get("mmbench", {}) if isinstance(raw_cfg.get("mmbench"), dict) else {} @@ -493,6 +135,18 @@ def _load_eval_cfg(config_path: str) -> "tuple[dict, dict, dict]": return eval_cfg, mmbench_cfg, inference_cfg +def _build_prompt(question: str, options: dict[str, str], hint: str | None, language: str) -> str: + if hint: + question = f"{hint}\n{question}" + for key, item in options.items(): + question += f"\n{key}. {item}" + if language == "cn": + suffix = "请直接回答选项字母。" + else: + suffix = "Answer with the option's letter from the given choices directly." + return f"{question}\n{suffix}".strip() + + def _decode_image(image_b64: str, image_dir: Path, row_index: int) -> str: image_dir.mkdir(parents=True, exist_ok=True) image = Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB") @@ -502,10 +156,10 @@ def _decode_image(image_b64: str, image_dir: Path, row_index: int) -> str: def _get_dataset_paths( - datasets: "list[str]", + datasets: list[str], repo_root: Path, - override_paths: "dict[str, Any]", -) -> "dict[str, Path]": + override_paths: dict[str, Any], +) -> dict[str, Path]: resolved: dict[str, Path] = {} for name in datasets: if name in override_paths: @@ -518,46 +172,14 @@ def _get_dataset_paths( return resolved -def _find_latest_jsonl(out_dir: Path, ds_name: str) -> "Path | None": - candidates = sorted(out_dir.glob(f"{ds_name}_*.jsonl")) - candidates = [c for c in candidates if "_checkpoint" not in c.name] - if candidates: - return max(candidates, key=lambda p: p.stat().st_mtime) - return None - - -def _parse_options(row: pd.Series) -> "dict[str, str]": - """Collect A/B/C/D/E/... option letters whose cell is non-null, preserving TSV order.""" +def _options_for(row: pd.Series) -> dict[str, str]: options: dict[str, str] = {} - for cand in string.ascii_uppercase: + for cand in ["A", "B", "C", "D", "E"]: if cand in row and not pd.isna(row[cand]): options[cand] = row[cand] return options -def _resolve_image_blob(image_map: "dict[int, str]", idx: int) -> str: - """Resolve MMBench's short-circuit image storage (a ≤64-char string that - points to another row's index). Ports image_base.py:154-164.""" - value = image_map.get(idx, "") - if not isinstance(value, str): - return "" - if len(value) <= 64: - # It's a redirect to another row's full base64 blob - try: - target = int(value) - except (TypeError, ValueError): - return value - target_val = image_map.get(target, "") - if isinstance(target_val, str) and len(target_val) > 64: - return target_val - return value - - -# --------------------------------------------------------------------------- -# Main entry point -# --------------------------------------------------------------------------- - - def run_mmbench_eval_command(args: Any) -> int: config_path = str(args.config) eval_cfg, mmbench_cfg, inference_cfg = _load_eval_cfg(config_path) @@ -577,7 +199,7 @@ def run_mmbench_eval_command(args: Any) -> int: raise ValueError("`inference.backbone_cfg` must be a dict when provided.") request_cfg = inference_cfg.get("request", {}) - request_params: "dict[str, Any]" = {} + request_params: dict[str, Any] = {} if isinstance(request_cfg, dict): params = request_cfg.get("params", {}) if isinstance(params, dict): @@ -600,41 +222,33 @@ def run_mmbench_eval_command(args: Any) -> int: resume = bool(mmbench_cfg.get("resume", False)) resume_jsonl = mmbench_cfg.get("resume_jsonl") - mode = str(mmbench_cfg.get("mode", "generate")).strip().lower() - if mode not in ("full", "generate", "score"): - print(f"[mmbench] unknown mode '{mode}', defaulting to 'generate'", flush=True) - mode = "generate" - run_gen = mode in ("full", "generate") - run_score = mode in ("full", "score") - - llm_extract_cfg = mmbench_cfg.get("llm_extract", {}) - if not isinstance(llm_extract_cfg, dict): - llm_extract_cfg = {} - llm_model_path = str(llm_extract_cfg.get("model_path", "")).strip() - llm_max_new_tokens = int(llm_extract_cfg.get("max_new_tokens", 32)) - dataset_paths = _get_dataset_paths( datasets=datasets, repo_root=repo_root, override_paths=mmbench_cfg.get("dataset_paths", {}) if isinstance(mmbench_cfg.get("dataset_paths"), dict) else {}, ) - out_dir.mkdir(parents=True, exist_ok=True) - - summary: "dict[str, Any]" = { - "benchmark": "mmbench", - "backbone": backbone, - "out_dir": str(out_dir), - "datasets": datasets, - "mode": mode, - } + dist_info = maybe_init_distributed() + try: + out_dir.mkdir(parents=True, exist_ok=True) + pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) - # ── Phase 1: Generation ── - if run_gen: - from umm.inference import InferencePipeline + summary: dict[str, Any] = { + "benchmark": "mmbench", + "backbone": backbone, + "out_dir": str(out_dir), + "datasets": datasets, + "world_size": dist_info.world_size, + } - pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) + if dist_info.world_size > 1: + print( + f"[mmbench] distributed inference enabled: rank={dist_info.rank}, " + f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}", + flush=True, + ) + local_total_written = 0 for ds_name in datasets: dataset_path = dataset_paths[ds_name] if not dataset_path.exists(): @@ -643,316 +257,157 @@ def run_mmbench_eval_command(args: Any) -> int: entry = DS_COLLECTIONS.get(ds_name, {}) language = str(entry.get("language", "en")) df = pd.read_csv(dataset_path, sep="\t") - df["index"] = df["index"].astype(int) - - # Resolve MMBench's image short-circuits so every row has its own blob. - image_map: "dict[int, str]" = {} - if "image" in df.columns: - for _, row in df.iterrows(): - image_map[int(row["index"])] = str(row["image"]) if not pd.isna(row["image"]) else "" checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl" - outputs: "list[dict[str, Any]]" = [] - done_indices: "set[int]" = set() + shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size) + done_indices: set[Any] = set() if resume: - if checkpoint_jsonl.exists(): - with checkpoint_jsonl.open("r", encoding="utf-8") as reader: - for line in reader: - line = line.strip() - if not line: - continue - item = json.loads(line) - outputs.append(item) - done_indices.add(int(item["index"])) - print(f"[mmbench] resume from checkpoint: {len(outputs)} done", flush=True) - else: - jsonl_path: "Path | None" = None + shard_items = load_shard_items(shard_path) + if shard_items: + done_indices = {int(item["index"]) for item in shard_items} + print( + f"[mmbench] {ds_name}: rank {dist_info.rank} resume from " + f"checkpoint: {len(done_indices)} done", + flush=True, + ) + elif dist_info.world_size <= 1: + # Single-card fallback: recover from a previously-completed + # output JSONL. Multi-card runs rely on the rank shard only. + jsonl_path: Path | None = None if isinstance(resume_jsonl, str) and resume_jsonl: jsonl_path = _resolve_path(resume_jsonl, repo_root) else: - jsonl_path = _find_latest_jsonl(out_dir, ds_name) + candidates = sorted(out_dir.glob(f"{ds_name}_*.jsonl")) + candidates = [c for c in candidates if "_checkpoint" not in c.name] + if candidates: + jsonl_path = max(candidates, key=lambda p: p.stat().st_mtime) if jsonl_path and jsonl_path.exists(): - with jsonl_path.open("r", encoding="utf-8") as reader: - for line in reader: - line = line.strip() - if not line: - continue - item = json.loads(line) - outputs.append(item) - done_indices.add(int(item["index"])) - print(f"[mmbench] resume from {jsonl_path}: {len(outputs)} done", flush=True) + prior = load_shard_items(jsonl_path) + done_indices = {int(item["index"]) for item in prior} + print( + f"[mmbench] resume from {jsonl_path}: " + f"{len(done_indices)} done", + flush=True, + ) + assigned_total = ( + (len(df) + dist_info.world_size - 1 - dist_info.rank) // dist_info.world_size + if dist_info.world_size > 1 + else len(df) + ) print( - f"[mmbench] {ds_name}: {len(df)} total, {len(done_indices)} done, " - f"{len(df) - len(done_indices)} remaining", + f"[mmbench] {ds_name}: total={len(df)}, rank={dist_info.rank}, " + f"assigned={assigned_total}, done={len(done_indices)}, " + f"remaining={max(0, assigned_total - len(done_indices))}", flush=True, ) - with checkpoint_jsonl.open("a", encoding="utf-8") as ckpt_writer: - for _, row in tqdm(df.iterrows(), total=len(df), desc=f"mmbench/{ds_name}", file=sys.stdout): - row_index = int(row["index"]) - if row_index in done_indices: - continue - - image_b64 = _resolve_image_blob(image_map, row_index) - if not image_b64: - # Skip rows without a usable image (rare). - continue - image_path = _decode_image(image_b64, image_dir, row_index=row_index) - options = _parse_options(row) - hint = None - if "hint" in row and not pd.isna(row["hint"]): - hint = str(row["hint"]) - - question_text = str(row["question"]) - prompt = _build_prompt(question_text, options, hint) - payload = { - "backbone": backbone, - "task": "understanding", - "prompt": prompt, - "images": [image_path], - "params": request_params, - "metadata": {"index": row_index, "dataset": ds_name}, - } - prediction = _extract_text(pipeline.run(payload)) - gt = None - if "answer" in row and not pd.isna(row["answer"]): - gt = str(row["answer"]).strip().upper() - item = { - "index": row_index, - "question": question_text, - "options": options, - "prediction": prediction, - "gt_answer": gt, - "language": language, - } - outputs.append(item) - done_indices.add(row_index) - ckpt_writer.write(json.dumps(item) + "\n") - ckpt_writer.flush() - os.fsync(ckpt_writer.fileno()) - - if max_samples > 0 and len(outputs) >= max_samples: - break + def payload_fn(pair: tuple[Any, pd.Series]) -> dict[str, Any]: + _, row = pair + row_index = int(row["index"]) + image_path = _decode_image(str(row["image"]), image_dir, row_index=row_index) + options = _options_for(row) + hint = None + if "hint" in row and not pd.isna(row["hint"]): + hint = row["hint"] + question = _build_prompt(str(row["question"]), options, hint, language=language) + return { + "backbone": backbone, + "task": "understanding", + "prompt": question, + "images": [image_path], + "params": request_params, + "metadata": {"index": row_index, "dataset": ds_name}, + } + + def record_fn(pair: tuple[Any, pd.Series], raw: Any, _idx: int) -> dict[str, Any]: + _, row = pair + row_index = int(row["index"]) + options = _options_for(row) + hint = None + if "hint" in row and not pd.isna(row["hint"]): + hint = row["hint"] + question = _build_prompt(str(row["question"]), options, hint, language=language) + response = _extract_text(raw) + pred = _post_process(response, options) + return { + "question": question, + "answer": pred, + "gt_answers": row["answer"] if "answer" in row else None, + "index": row_index, + } + + n_written = run_sharded_inference( + infer_fn=pipeline.run, + dist_info=dist_info, + shard_path=shard_path, + samples=df.iterrows(), + total=len(df), + payload_fn=payload_fn, + record_fn=record_fn, + sample_id_fn=lambda pair: int(pair[1]["index"]), + done_ids=done_indices, + max_samples=max_samples, + log_prefix=f"mmbench/{ds_name}/rank{dist_info.rank}", + ) + local_total_written += n_written + + barrier(dist_info) time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime()) + results_file = f"{ds_name}_{time_prefix}.xlsx" + output_path = out_dir / results_file jsonl_path_out = out_dir / f"{ds_name}_{time_prefix}.jsonl" - with jsonl_path_out.open("w", encoding="utf-8") as writer: - for item in outputs: - writer.write(json.dumps(item) + "\n") - - if checkpoint_jsonl.exists(): - checkpoint_jsonl.unlink() - - summary[f"{ds_name}_output_jsonl"] = str(jsonl_path_out) - - if mode == "generate": - print(f"[mmbench] generation phase done, outputs={out_dir}", flush=True) - - # Release generation GPU memory before loading the judge LLM. - del pipeline - import gc - gc.collect() - try: - import torch as _torch - if _torch.cuda.is_available(): - _torch.cuda.empty_cache() - except ImportError: - pass - - # ── Phase 2: Circular scoring (VLMEvalKit-compatible) ── - if run_score: - judge: "_JudgeBundle | None" = None - try: - for ds_name in datasets: - jsonl_path = None - if f"{ds_name}_output_jsonl" in summary: - jsonl_path = Path(summary[f"{ds_name}_output_jsonl"]) - else: - jsonl_path = _find_latest_jsonl(out_dir, ds_name) - if jsonl_path is None or not jsonl_path.exists(): - raise FileNotFoundError( - f"No generation output found for {ds_name} in {out_dir}. " - f"Run generation phase first (mode: generate)." - ) - - print(f"[mmbench] scoring {ds_name} from {jsonl_path}", flush=True) - items: "list[dict[str, Any]]" = [] - with jsonl_path.open("r", encoding="utf-8") as reader: - for line in reader: - line = line.strip() - if not line: - continue - items.append(json.loads(line)) - - # Re-hydrate missing keys from the source TSV when needed - # (back-compat with JSONLs produced by the pre-rewrite code). - dataset_path = dataset_paths[ds_name] - df = pd.read_csv(dataset_path, sep="\t") if dataset_path.exists() else pd.DataFrame() - if not df.empty: - df["index"] = df["index"].astype(int) - df_by_idx = df.set_index("index", drop=False) - language_default = str(DS_COLLECTIONS.get(ds_name, {}).get("language", "en")) - for item in items: - idx = int(item["index"]) - row = df_by_idx.loc[idx] if idx in df_by_idx.index else None - if "options" not in item and row is not None: - item["options"] = _parse_options(row) - if "prediction" not in item: - item["prediction"] = item.get("response", item.get("answer", "")) - if "gt_answer" not in item and row is not None: - if "answer" in row and not pd.isna(row["answer"]): - item["gt_answer"] = str(row["answer"]).strip().upper() - if "language" not in item: - item["language"] = language_default - if "question" not in item and row is not None and "question" in row: - item["question"] = str(row["question"]) - - # Prepare per-item `choices` (copy of options, used by can_infer). - for item in items: - item["choices"] = dict(item.get("options", {})) - - # Group by g_index = index % 1e6 → circular rotation group - groups: "dict[int, list[dict[str, Any]]]" = {} - for item in items: - g = int(item["index"]) % 1_000_000 - groups.setdefault(g, []).append(item) - # Sort rotations inside each group by their original index so - # rotation k=0 (smallest) comes first, matching VLMEvalKit's order. - for g in groups: - groups[g].sort(key=lambda it: int(it["index"])) - - # Fast-path: resolve as many groups as possible with can_infer only. - group_results: "dict[int, dict[str, Any]]" = {} - pending_groups: "list[tuple[int, list[dict[str, Any]]]]" = [] - for g, rows in groups.items(): - pre_res, _gts, _preds = _prefetch_circular_group(rows) - if pre_res is not None: - group_results[g] = pre_res - else: - pending_groups.append((g, rows)) - # Load the judge once if any groups still need LLM extraction. - if pending_groups and llm_model_path and judge is None: - judge = _JudgeBundle(llm_model_path, max_new_tokens=llm_max_new_tokens) + if dist_info.rank == 0: + merged_outputs = merge_shards(checkpoint_jsonl) - if pending_groups: - print( - f"[mmbench] {ds_name}: {len(group_results)} groups decided by can_infer, " - f"{len(pending_groups)} need LLM judging", - flush=True, - ) - for g, rows in tqdm(pending_groups, desc=f"mmbench/{ds_name}/judge", file=sys.stdout): - group_results[g] = _eval_circular_group(rows, judge=judge) + cur_df = df.copy() + if "mmbench" in ds_name: + cur_df = cur_df.drop(columns=["hint", "category", "source", "image", "comment", "l2-category"]) + cur_df.insert(6, "prediction", None) else: - print( - f"[mmbench] {ds_name}: all {len(group_results)} groups decided by can_infer", - flush=True, - ) + cur_df = cur_df.drop(columns=["category", "image"]) + cur_df.insert(8, "prediction", None) - df_scored: pd.DataFrame - if not df.empty: - df_work = df.copy() - df_work["g_index"] = df_work["index"].astype(int) % 1_000_000 - df_scored = df_work[df_work["index"] == df_work["g_index"]].copy() - expected_g = set(int(g) for g in df_scored["g_index"].tolist()) - missing = expected_g - set(group_results.keys()) - if missing: - raise RuntimeError( - f"[mmbench] {ds_name}: {len(missing)} groups in TSV missing from " - f"generation outputs (e.g. {sorted(missing)[:5]}). " - f"Generation phase skipped some rows; rerun with mode: full or " - f"mode: generate to refill the JSONL before scoring." - ) - df_scored["hit"] = df_scored["g_index"].map(lambda g: group_results[int(g)]["hit"]) - df_scored["log"] = df_scored["g_index"].map(lambda g: group_results[int(g)]["log"]) - else: - # No TSV available (unlikely): fall back to a minimal frame. - rows = [] - for g, rows_list in groups.items(): - head = rows_list[0] - rows.append( - { - "index": head["index"], - "g_index": g, - "hit": group_results[g]["hit"], - } - ) - df_scored = pd.DataFrame(rows) - - metrics = _report_acc(df_scored) - overall_vals = [v for k, v in metrics.items() if k.endswith("|Overall")] - overall_acc_pct = round(100.0 * (overall_vals[0] if overall_vals else 0.0), 2) - question_count = int(len(df_scored)) - hit_count = int(df_scored["hit"].sum()) if "hit" in df_scored.columns else 0 - # Write the annotated items (now with extraction+hit) back to JSONL. - # Each item gets its group's hit attached for easy inspection. - for item in items: - g = int(item["index"]) % 1_000_000 - res = group_results.get(g, {"hit": 0, "log": ""}) - item["group_hit"] = int(res.get("hit", 0)) - with jsonl_path.open("w", encoding="utf-8") as writer: - for item in items: + for item in merged_outputs: + cur_df.loc[df["index"] == item["index"], "prediction"] = item["answer"] + + cur_df.to_excel(output_path, index=False, engine="openpyxl") # pip install openpyxl + with jsonl_path_out.open("w", encoding="utf-8") as writer: + for item in merged_outputs: writer.write(json.dumps(item) + "\n") - score_path = jsonl_path.with_name(f"{jsonl_path.stem}_score.json") - score_payload = { - "overall": { - "accuracy": overall_acc_pct, - "hit_count": hit_count, - "question_count": question_count, - "mode": "circular", - }, - "metrics": {k: round(100.0 * v, 2) for k, v in metrics.items()}, - } - score_path.write_text(json.dumps(score_payload, indent=2), encoding="utf-8") - print( - f"[mmbench] {ds_name} Overall Acc = {overall_acc_pct}% " - f"({hit_count}/{question_count} groups)", - flush=True, - ) - - # Persist the flat acc CSV too for parity with VLMEvalKit. - if not df_scored.empty: - csv_path = jsonl_path.with_name(f"{jsonl_path.stem}_acc.csv") - acc_rows = [{"metric": k, "accuracy": round(100.0 * v, 2)} for k, v in metrics.items()] - pd.DataFrame(acc_rows).to_csv(csv_path, index=False) - summary[f"{ds_name}_acc_csv"] = str(csv_path) - - # Optional xlsx with predictions filled in (requires openpyxl). - # Wise image ships without openpyxl; if missing, silently skip. - is_mmbench_schema = ds_name.startswith("mmbench_") or ds_name.startswith("MMBench_") - if not df.empty and is_mmbench_schema: - try: - import openpyxl # noqa: F401 - xlsx_path = jsonl_path.with_suffix(".xlsx") - cur_df = df.copy() - drop_cols = [c for c in ("hint", "category", "source", "image", "comment", "l2-category") if c in cur_df.columns] - if drop_cols: - cur_df = cur_df.drop(columns=drop_cols) - # For circular V11 the xlsx still lists all 4 rotations; each row - # gets the extracted letter from the corresponding item. - pred_by_idx: "dict[int, str]" = {} - for item in items: - pred_by_idx[int(item["index"])] = str(item.get("extraction", item.get("prediction", ""))) - cur_df["prediction"] = cur_df["index"].map(lambda k: pred_by_idx.get(int(k), "")) - cur_df.to_excel(xlsx_path, index=False, engine="openpyxl") - summary[f"{ds_name}_xlsx"] = str(xlsx_path) - except ImportError: - print("[mmbench] openpyxl unavailable, skipping xlsx export", flush=True) - - summary[f"{ds_name}_score_file"] = str(score_path) - summary[f"{ds_name}_accuracy"] = overall_acc_pct - finally: - if judge is not None: - judge.close() - - if isinstance(score_output_path, str) and score_output_path: - score_path = _resolve_path(score_output_path, repo_root) - score_path.parent.mkdir(parents=True, exist_ok=True) - score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") - print(f"[umm eval] wrote MMBench summary to {score_path}") - - print(f"[umm eval] completed MMBench (mode={mode}) for backbone={backbone}, outputs={out_dir}") - return 0 + cleanup_shards(checkpoint_jsonl) + if dist_info.world_size <= 1 and checkpoint_jsonl.exists(): + checkpoint_jsonl.unlink() + + summary[f"{ds_name}_output_path"] = str(output_path) + summary[f"{ds_name}_output_jsonl"] = str(jsonl_path_out) + + barrier(dist_info) + + total_written_all = sum_across_ranks(local_total_written, dist_info) + if dist_info.rank != 0: + print( + f"[umm eval] rank {dist_info.rank} finished MMBench shard: " + f"samples_written={local_total_written}", + flush=True, + ) + return 0 + + summary["samples_written"] = total_written_all + if isinstance(score_output_path, str) and score_output_path: + score_path = _resolve_path(score_output_path, repo_root) + score_path.parent.mkdir(parents=True, exist_ok=True) + score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + print(f"[umm eval] wrote MMBench summary to {score_path}") + + print( + f"[umm eval] completed MMBench for backbone={backbone}, outputs={out_dir}, " + f"samples_written={total_written_all}, world_size={dist_info.world_size}" + ) + return 0 + finally: + cleanup_distributed(dist_info) diff --git a/src/umm/cli/mme_eval.py b/src/umm/cli/mme_eval.py index 57d9b62..fc5619e 100644 --- a/src/umm/cli/mme_eval.py +++ b/src/umm/cli/mme_eval.py @@ -1,14 +1,24 @@ from __future__ import annotations import json -import os import re import subprocess import sys from pathlib import Path -from typing import Any +from typing import Any, Iterator from umm.core.config import load_config +from umm.eval.distributed import ( + barrier, + cleanup_distributed, + cleanup_shards, + load_shard_items, + maybe_init_distributed, + merge_shards, + rank_shard_path, + sum_across_ranks, +) +from umm.eval.runner import run_sharded_inference from umm.inference import InferencePipeline @@ -60,8 +70,7 @@ def _extract_text(output: Any) -> str: def _post_process(response: str) -> str: response = response.replace("\n", "").replace("不是", "No").replace("是", "Yes").replace("否", "No") response = response.lower().replace("true", "yes").replace("false", "no") - response = re.sub(re.compile(r"[\u4e00-\u9fa5]"), "", response) - # Extract first yes/no, discard repetitive garbage + response = re.sub(re.compile(r"[一-龥]"), "", response) response = response.strip() match = re.match(r"^[^a-z]*(yes|no)\b", response) if match: @@ -127,105 +136,169 @@ def run_mme_eval_command(args: Any) -> int: if not image_root.exists(): raise FileNotFoundError(f"MME image root not found: {image_root}") - out_dir.mkdir(parents=True, exist_ok=True) - pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) - - txt_files = sorted([p for p in dataset_root.iterdir() if p.suffix == ".txt"]) - if not txt_files: - raise FileNotFoundError(f"No .txt files found in MME root: {dataset_root}") - - total_written = 0 - missing_images = 0 - skipped_rows = 0 - for task_txt in txt_files: - task_name = task_txt.stem - out_file = out_dir / task_txt.name - - # Resume: count already-completed lines and skip them - existing_lines = 0 - if out_file.exists(): - existing_lines = sum(1 for ln in out_file.open("r", encoding="utf-8") if ln.strip()) - if existing_lines > 0: - print(f"[mme] {task_name}: resuming after {existing_lines} existing lines", flush=True) - - with task_txt.open("r", encoding="utf-8") as fin, out_file.open("a", encoding="utf-8") as fout: - done = 0 - for idx, line in enumerate(fin, start=1): - row = line.strip().split("\t") - if len(row) != 3: - skipped_rows += 1 - continue - img, question, gt = row - - img_path = image_root / task_name / img - if not img_path.exists(): - img_path = image_root / task_name / "images" / img - if not img_path.exists(): - missing_images += 1 - continue - - # Skip rows already written in a previous run - done += 1 - if done <= existing_lines: - continue - + dist_info = maybe_init_distributed() + try: + out_dir.mkdir(parents=True, exist_ok=True) + pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) + + txt_files = sorted([p for p in dataset_root.iterdir() if p.suffix == ".txt"]) + if not txt_files: + raise FileNotFoundError(f"No .txt files found in MME root: {dataset_root}") + + if dist_info.world_size > 1: + print( + f"[mme] distributed inference enabled: rank={dist_info.rank}, " + f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}", + flush=True, + ) + + total_written = 0 + # Counters incremented by rank 0 only so the totals are not multiplied + # by world_size — every rank reads the full TSV during pre-filtering. + missing_images = 0 + skipped_rows = 0 + + for task_txt in txt_files: + task_name = task_txt.stem + out_file = out_dir / task_txt.name + checkpoint_jsonl = out_dir / f"{task_name}_checkpoint.jsonl" + shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size) + + done_idx = {int(it["_sample_idx"]) for it in load_shard_items(shard_path)} + if done_idx: + print( + f"[mme] {task_name}: rank {dist_info.rank} resuming after " + f"{len(done_idx)} existing shard items", + flush=True, + ) + + def iter_valid_rows() -> Iterator[tuple[int, str, str, str, str]]: + nonlocal missing_images, skipped_rows + counter = 0 + with task_txt.open("r", encoding="utf-8") as fin: + for line in fin: + row = line.strip().split("\t") + if len(row) != 3: + if dist_info.rank == 0: + skipped_rows += 1 + continue + img, question, gt = row + img_path = image_root / task_name / img + if not img_path.exists(): + img_path = image_root / task_name / "images" / img + if not img_path.exists(): + if dist_info.rank == 0: + missing_images += 1 + continue + counter += 1 + yield counter, img, question, gt, str(img_path) + + def payload_fn(sample: tuple[int, str, str, str, str]) -> dict[str, Any]: + _, img, question, _gt, img_path = sample prompt = f"{question} {prompt_suffix}".strip() - payload = { + print(f"[mme] rank {dist_info.rank} | {task_name}: {img} ... inferring", flush=True) + return { "backbone": backbone, "task": "understanding", "prompt": prompt, - "images": [str(img_path)], + "images": [img_path], "params": request_params, } - print(f"[mme] {task_name} | {idx}: {img} ... inferring", flush=True) - output = pipeline.run(payload) - response = _post_process(_extract_text(output)) + + def record_fn(sample: tuple[int, str, str, str, str], raw: Any, _idx: int) -> dict[str, Any]: + _, img, question, gt, _img_path = sample + prompt = f"{question} {prompt_suffix}".strip() + response = _post_process(_extract_text(raw)) if not response.strip(): print(f"[mme] WARNING empty response: {task_name}/{img}", flush=True) - print(f"[mme] {task_name} | {idx}: {img} -> {response[:80]!r}", flush=True) - print(img, prompt, gt, response, sep="\t", file=fout) - fout.flush() - os.fsync(fout.fileno()) - total_written += 1 - if max_samples > 0 and idx >= max_samples: - break - - summary: dict[str, Any] = { - "benchmark": "mme", - "backbone": backbone, - "out_dir": str(out_dir), - "samples_written": total_written, - } + print(f"[mme] rank {dist_info.rank} | {task_name}: {img} -> {response[:80]!r}", flush=True) + return { + "img": img, + "prompt": prompt, + "gt": gt, + "response": response, + } - # Check if out_dir has any result files at all (including from previous runs) - has_results = any(p.suffix == ".txt" and p.stat().st_size > 0 for p in out_dir.iterdir()) if out_dir.exists() else False - if total_written == 0 and not has_results: + n_written = run_sharded_inference( + infer_fn=pipeline.run, + dist_info=dist_info, + shard_path=shard_path, + samples=iter_valid_rows(), + payload_fn=payload_fn, + record_fn=record_fn, + sample_id_fn=lambda sample: sample[0], + done_ids=done_idx, + max_samples=max_samples, + log_prefix=f"mme/{task_name}/rank{dist_info.rank}", + ) + total_written += n_written + + barrier(dist_info) + + if dist_info.rank == 0: + merged = merge_shards(checkpoint_jsonl) + with out_file.open("w", encoding="utf-8") as fout: + for item in merged: + print(item["img"], item["prompt"], item["gt"], item["response"], sep="\t", file=fout) + cleanup_shards(checkpoint_jsonl) + if dist_info.world_size <= 1 and checkpoint_jsonl.exists(): + checkpoint_jsonl.unlink() + + barrier(dist_info) + + total_written_all = sum_across_ranks(total_written, dist_info) + + if dist_info.rank != 0: + print( + f"[umm eval] rank {dist_info.rank} finished shard: " + f"samples_written={total_written}", + flush=True, + ) + return 0 + + summary: dict[str, Any] = { + "benchmark": "mme", + "backbone": backbone, + "out_dir": str(out_dir), + "samples_written": total_written_all, + "world_size": dist_info.world_size, + } + + has_results = ( + any(p.suffix == ".txt" and p.stat().st_size > 0 for p in out_dir.iterdir()) + if out_dir.exists() + else False + ) + if total_written_all == 0 and not has_results: + print( + "[umm eval] warning: no MME samples were written. " + "Skipping calculation. Check `mme.root` and `mme.image_root`." + ) + run_calculation = False + + if run_calculation: + cmd = [sys.executable, str(calculation_script), "--results_dir", str(out_dir)] + proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True) + print(proc.stdout) + if proc.returncode != 0: + if proc.stderr: + print(proc.stderr, file=sys.stderr) + raise RuntimeError(f"MME calculation failed with return code {proc.returncode}") + summary["calculation_stdout"] = proc.stdout + + if isinstance(score_output_path, str) and score_output_path: + score_path = _resolve_path(score_output_path, repo_root) + score_path.parent.mkdir(parents=True, exist_ok=True) + score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + print(f"[umm eval] wrote MME summary to {score_path}") + + summary["missing_images"] = missing_images + summary["skipped_rows"] = skipped_rows print( - "[umm eval] warning: no MME samples were written. " - "Skipping calculation. Check `mme.root` and `mme.image_root`." + f"[umm eval] completed MME for backbone={backbone}, outputs={out_dir}, " + f"samples_written={total_written_all}, missing_images={missing_images}, " + f"skipped_rows={skipped_rows}, world_size={dist_info.world_size}" ) - run_calculation = False - - if run_calculation: - cmd = [sys.executable, str(calculation_script), "--results_dir", str(out_dir)] - proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True) - print(proc.stdout) - if proc.returncode != 0: - if proc.stderr: - print(proc.stderr, file=sys.stderr) - raise RuntimeError(f"MME calculation failed with return code {proc.returncode}") - summary["calculation_stdout"] = proc.stdout - - if isinstance(score_output_path, str) and score_output_path: - score_path = _resolve_path(score_output_path, repo_root) - score_path.parent.mkdir(parents=True, exist_ok=True) - score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") - print(f"[umm eval] wrote MME summary to {score_path}") - - summary["missing_images"] = missing_images - summary["skipped_rows"] = skipped_rows - print( - f"[umm eval] completed MME for backbone={backbone}, outputs={out_dir}, " - f"samples_written={total_written}, missing_images={missing_images}, skipped_rows={skipped_rows}" - ) - return 0 + return 0 + finally: + cleanup_distributed(dist_info) diff --git a/src/umm/cli/mmmu_eval.py b/src/umm/cli/mmmu_eval.py index 6ac84ea..a0d3e07 100644 --- a/src/umm/cli/mmmu_eval.py +++ b/src/umm/cli/mmmu_eval.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import os import random import subprocess import sys @@ -10,11 +11,21 @@ from datasets import concatenate_datasets, load_dataset from PIL import Image -from tqdm import tqdm from umm.core.config import load_config -from umm.inference import InferencePipeline +from umm.eval.distributed import ( + barrier, + cleanup_distributed, + cleanup_shards, + load_shard_items, + maybe_init_distributed, + merge_shards, + rank_shard_path, + sum_across_ranks, +) from umm.eval.internvl_chat.eval.mmmu import data_utils, eval_utils +from umm.eval.runner import run_sharded_inference +from umm.inference import InferencePipeline def _resolve_path(path_str: str, repo_root: Path) -> Path: @@ -53,7 +64,6 @@ def _extract_text(output: Any) -> str: text = _extract_text(item) if text: return text - # Handle adapters that return {"understandings": [{"response": "..."}]} for list_key in ("understandings",): container = output.get(list_key) if isinstance(container, list): @@ -79,8 +89,27 @@ def _load_eval_cfg(config_path: str) -> tuple[dict[str, Any], dict[str, Any], di return eval_cfg, mmmu_cfg, inference_cfg +def _load_mmmu_subject_from_local_snapshot(root_path: Path, subject: str, split: str) -> Any: + subject_dir = root_path / subject + parquet_files = sorted(subject_dir.glob(f"{split}-*.parquet")) + if not parquet_files: + raise FileNotFoundError( + f"Missing MMMU parquet for subject={subject}, split={split}, expected files under {subject_dir}" + ) + return load_dataset( + "parquet", + data_files={split: [str(path) for path in parquet_files]}, + split=split, + ) + + def _load_mmmu_dataset(root: str, split: str, cache_dir: str | None) -> Any: + root_path = Path(root).expanduser() datasets_list = [] + if root_path.exists() and root_path.is_dir(): + for subject in data_utils.CAT_SHORT2LONG.values(): + datasets_list.append(_load_mmmu_subject_from_local_snapshot(root_path, subject, split)) + return concatenate_datasets(datasets_list) for subject in data_utils.CAT_SHORT2LONG.values(): datasets_list.append(load_dataset(root, subject, split=split, cache_dir=cache_dir)) return concatenate_datasets(datasets_list) @@ -200,75 +229,89 @@ def run_mmmu_eval_command(args: Any) -> int: max_images = int(mmmu_cfg.get("max_images", 1) or 1) seed = int(mmmu_cfg.get("seed", 0) or 0) - out_dir.mkdir(parents=True, exist_ok=True) - pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) - - summary: dict[str, Any] = { - "benchmark": "mmmu", - "backbone": backbone, - "out_dir": str(out_dir), - "datasets": datasets, - } - - random.seed(seed) - for ds_name in datasets: - split = "validation" - if ds_name.endswith("test"): - split = "test" - elif ds_name.endswith("dev"): - split = "dev" - elif ds_name.endswith("validation"): + dist_info = maybe_init_distributed() + try: + out_dir.mkdir(parents=True, exist_ok=True) + pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) + + summary: dict[str, Any] = { + "benchmark": "mmmu", + "backbone": backbone, + "out_dir": str(out_dir), + "datasets": datasets, + "world_size": dist_info.world_size, + } + + if dist_info.world_size > 1: + print( + f"[mmmu] distributed inference enabled: rank={dist_info.rank}, " + f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}", + flush=True, + ) + + random.seed(seed) + local_total_written = 0 + for ds_name in datasets: split = "validation" + if ds_name.endswith("test"): + split = "test" + elif ds_name.endswith("dev"): + split = "dev" + elif ds_name.endswith("validation"): + split = "validation" + + dataset = _load_mmmu_dataset( + root=dataset_root, + split=split, + cache_dir=str(cache_dir_path), + ) + + checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl" + shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size) + + done_ids: set[Any] = { + str(it.get("data_id", "")) for it in load_shard_items(shard_path) + } + if done_ids: + print( + f"[mmmu] {ds_name}: rank {dist_info.rank} resuming after " + f"{len(done_ids)} shard items", + flush=True, + ) - dataset = _load_mmmu_dataset( - root=dataset_root, - split=split, - cache_dir=str(cache_dir_path), - ) - - # Resume: load checkpoint JSONL if exists - checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl" - done_ids: set[str] = set() - outputs: list[dict[str, Any]] = [] - if checkpoint_jsonl.exists(): - with checkpoint_jsonl.open("r", encoding="utf-8") as reader: - for line in reader: - line = line.strip() - if not line: - continue - item = json.loads(line) - outputs.append(item) - done_ids.add(str(item.get("data_id", ""))) - print(f"[mmmu] resume: {len(done_ids)} done, skipping completed items", flush=True) - - total = len(dataset) - remaining = total - len(done_ids) - print(f"[mmmu] {ds_name}: {total} total, {len(done_ids)} done, {remaining} remaining", flush=True) - - with checkpoint_jsonl.open("a", encoding="utf-8") as ckpt_writer: - for idx, sample in enumerate(tqdm(dataset, desc=f"mmmu/{ds_name}", file=sys.stdout), start=1): + total = len(dataset) + assigned_total = ( + (total + dist_info.world_size - 1 - dist_info.rank) // dist_info.world_size + if dist_info.world_size > 1 + else total + ) + print( + f"[mmmu] {ds_name}: total={total}, rank={dist_info.rank}, " + f"assigned={assigned_total}, done={len(done_ids)}, " + f"remaining={max(0, assigned_total - len(done_ids))}", + flush=True, + ) + + def payload_fn(sample: Any) -> dict[str, Any]: data = data_utils.process_single_sample(sample) data_id = str(data["id"]) - if data_id in done_ids: - continue - question = str(data["question"]).strip() question_type = str(data["question_type"]) - options = eval(data["options"]) if isinstance(data.get("options"), str) else data.get("options", []) + options = ( + eval(data["options"]) + if isinstance(data.get("options"), str) + else data.get("options", []) + ) if not isinstance(options, list): options = [] - prompt = _build_prompt(question, question_type, options, prompt_cfg) - index2ans, all_choices = data_utils.get_multi_choice_info(options) if options else ({}, []) - image_paths = _coerce_image_paths( data.get("image", []), image_dir=image_dir, data_id=data_id, max_images=max_images, ) - - payload = { + return { "backbone": backbone, "task": "understanding", "prompt": prompt, @@ -276,14 +319,29 @@ def run_mmmu_eval_command(args: Any) -> int: "params": request_params, "metadata": {"question_type": question_type, "data_id": data_id}, } - output = pipeline.run(payload) - response = _extract_text(output) + + def record_fn(sample: Any, raw: Any, _idx: int) -> dict[str, Any]: + data = data_utils.process_single_sample(sample) + data_id = str(data["id"]) + question = str(data["question"]).strip() + question_type = str(data["question_type"]) + options = ( + eval(data["options"]) + if isinstance(data.get("options"), str) + else data.get("options", []) + ) + if not isinstance(options, list): + options = [] + prompt = _build_prompt(question, question_type, options, prompt_cfg) + index2ans, all_choices = ( + data_utils.get_multi_choice_info(options) if options else ({}, []) + ) + response = _extract_text(raw) if question_type == "multiple-choice" and all_choices and index2ans: pred = eval_utils.parse_multi_choice_response(response, all_choices, index2ans) else: pred = response - - item = { + return { "question": question, "answer": pred, "gt_answers": data.get("answer"), @@ -292,52 +350,86 @@ def run_mmmu_eval_command(args: Any) -> int: "prompt": prompt, "raw_response": response, } - outputs.append(item) - ckpt_writer.write(json.dumps(item) + "\n") - ckpt_writer.flush() - - if max_samples > 0 and len(outputs) >= max_samples: - break - - time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime()) - output_json = out_dir / f"{ds_name}_{time_prefix}.json" - output_jsonl = out_dir / f"{ds_name}_{time_prefix}.jsonl" - - output_dict = {item["data_id"]: item["answer"] for item in outputs} - output_json.write_text(json.dumps(output_dict, indent=4), encoding="utf-8") - with output_jsonl.open("w", encoding="utf-8") as writer: - for item in outputs: - writer.write(json.dumps(item) + "\n") - - # Clean up checkpoint after successful completion - if checkpoint_jsonl.exists(): - checkpoint_jsonl.unlink() - - summary[f"{ds_name}_output_path"] = str(output_json) - summary[f"{ds_name}_output_jsonl"] = str(output_jsonl) - - if run_calculation and split == "validation": - cmd = [ - sys.executable, - str(calculation_script), - "--output_path", - str(output_json), - "--answer_path", - str(answer_path), - ] - proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True) - print(proc.stdout) - if proc.returncode != 0: - if proc.stderr: - print(proc.stderr, file=sys.stderr) - raise RuntimeError(f"MMMU calculation failed with return code {proc.returncode}") - summary[f"{ds_name}_calculation_stdout"] = proc.stdout - - if isinstance(score_output_path, str) and score_output_path: - score_path = _resolve_path(score_output_path, repo_root) - score_path.parent.mkdir(parents=True, exist_ok=True) - score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") - print(f"[umm eval] wrote MMMU summary to {score_path}") - - print(f"[umm eval] completed MMMU for backbone={backbone}, outputs={out_dir}") - return 0 + + def sample_id_fn(sample: Any) -> str: + data = data_utils.process_single_sample(sample) + return str(data["id"]) + + n_written = run_sharded_inference( + infer_fn=pipeline.run, + dist_info=dist_info, + shard_path=shard_path, + samples=dataset, + total=total, + payload_fn=payload_fn, + record_fn=record_fn, + sample_id_fn=sample_id_fn, + done_ids=done_ids, + max_samples=max_samples, + log_prefix=f"mmmu/{ds_name}/rank{dist_info.rank}", + ) + local_total_written += n_written + + barrier(dist_info) + + time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime()) + output_json = out_dir / f"{ds_name}_{time_prefix}.json" + output_jsonl = out_dir / f"{ds_name}_{time_prefix}.jsonl" + + if dist_info.rank == 0: + merged_outputs = merge_shards(checkpoint_jsonl) + output_dict = {item["data_id"]: item["answer"] for item in merged_outputs} + output_json.write_text(json.dumps(output_dict, indent=4), encoding="utf-8") + with output_jsonl.open("w", encoding="utf-8") as writer: + for item in merged_outputs: + writer.write(json.dumps(item) + "\n") + + cleanup_shards(checkpoint_jsonl) + if dist_info.world_size <= 1 and checkpoint_jsonl.exists(): + checkpoint_jsonl.unlink() + + summary[f"{ds_name}_output_path"] = str(output_json) + summary[f"{ds_name}_output_jsonl"] = str(output_jsonl) + + if run_calculation and split == "validation": + cmd = [ + sys.executable, + str(calculation_script), + "--output_path", + str(output_json), + "--answer_path", + str(answer_path), + ] + proc = subprocess.run(cmd, cwd=str(repo_root), capture_output=True, text=True) + print(proc.stdout) + if proc.returncode != 0: + if proc.stderr: + print(proc.stderr, file=sys.stderr) + raise RuntimeError(f"MMMU calculation failed with return code {proc.returncode}") + summary[f"{ds_name}_calculation_stdout"] = proc.stdout + + barrier(dist_info) + + total_written_all = sum_across_ranks(local_total_written, dist_info) + if dist_info.rank != 0: + print( + f"[umm eval] rank {dist_info.rank} finished MMMU shard: " + f"samples_written={local_total_written}", + flush=True, + ) + return 0 + + summary["samples_written"] = total_written_all + if isinstance(score_output_path, str) and score_output_path: + score_path = _resolve_path(score_output_path, repo_root) + score_path.parent.mkdir(parents=True, exist_ok=True) + score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + print(f"[umm eval] wrote MMMU summary to {score_path}") + + print( + f"[umm eval] completed MMMU for backbone={backbone}, outputs={out_dir}, " + f"samples_written={total_written_all}, world_size={dist_info.world_size}" + ) + return 0 + finally: + cleanup_distributed(dist_info) diff --git a/src/umm/cli/mmvet_eval.py b/src/umm/cli/mmvet_eval.py index b277b5d..0af6c54 100644 --- a/src/umm/cli/mmvet_eval.py +++ b/src/umm/cli/mmvet_eval.py @@ -1,15 +1,24 @@ from __future__ import annotations import json -import sys import time from pathlib import Path -from typing import Any +from typing import Any, Iterator from PIL import Image -from tqdm import tqdm from umm.core.config import load_config +from umm.eval.distributed import ( + barrier, + cleanup_distributed, + cleanup_shards, + load_shard_items, + maybe_init_distributed, + merge_shards, + rank_shard_path, + sum_across_ranks, +) +from umm.eval.runner import run_sharded_inference from umm.inference import InferencePipeline @@ -56,7 +65,6 @@ def _extract_text(output: Any) -> str: text = _extract_text(item) if text: return text - # Handle adapters that return {"understandings": [{"response": "..."}]} for list_key in ("understandings",): container = output.get(list_key) if isinstance(container, list): @@ -82,6 +90,11 @@ def _load_eval_cfg(config_path: str) -> tuple[dict[str, Any], dict[str, Any], di return eval_cfg, mmvet_cfg, inference_cfg +def _normalize_output_key(question_id: Any) -> str: + qid = str(question_id) + return qid if qid.startswith("v1_") else f"v1_{qid}" + + def run_mmvet_eval_command(args: Any) -> int: config_path = str(args.config) eval_cfg, mmvet_cfg, inference_cfg = _load_eval_cfg(config_path) @@ -125,97 +138,154 @@ def run_mmvet_eval_command(args: Any) -> int: if not isinstance(dataset_paths, dict): dataset_paths = {} - out_dir.mkdir(parents=True, exist_ok=True) - pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) + dist_info = maybe_init_distributed() + try: + out_dir.mkdir(parents=True, exist_ok=True) + pipeline = InferencePipeline(backbone_name=backbone, backbone_cfg=backbone_cfg) + + summary: dict[str, Any] = { + "benchmark": "mmvet", + "backbone": backbone, + "out_dir": str(out_dir), + "datasets": datasets, + "world_size": dist_info.world_size, + } + + if dist_info.world_size > 1: + print( + f"[mmvet] distributed inference enabled: rank={dist_info.rank}, " + f"local_rank={dist_info.local_rank}, world_size={dist_info.world_size}", + flush=True, + ) - summary: dict[str, Any] = { - "benchmark": "mmvet", - "backbone": backbone, - "out_dir": str(out_dir), - "datasets": datasets, - } + local_total_written = 0 + for ds_name in datasets: + entry = DS_COLLECTIONS.get(ds_name) + if not entry and ds_name not in dataset_paths: + raise ValueError(f"Unknown MM-Vet dataset: {ds_name}") + image_root_value = dataset_paths.get("image_root") + question_value = dataset_paths.get("question") + if not image_root_value or not question_value: + raise ValueError( + "MM-Vet requires `mmvet.dataset_paths.image_root` and " + "`mmvet.dataset_paths.question` to be set in the YAML config." + ) + image_root = _resolve_path(str(image_root_value), repo_root) + question_path = _resolve_path(str(question_value), repo_root) + if not image_root.exists(): + raise FileNotFoundError(f"MM-Vet image root not found: {image_root}") + if not question_path.exists(): + raise FileNotFoundError(f"MM-Vet question file not found: {question_path}") + + checkpoint_jsonl = out_dir / f"{ds_name}_checkpoint.jsonl" + shard_path = rank_shard_path(checkpoint_jsonl, dist_info.rank, dist_info.world_size) + + done_keys = { + str(it.get("output_key", "")) for it in load_shard_items(shard_path) + } + if done_keys: + print( + f"[mmvet] {ds_name}: rank {dist_info.rank} resuming after " + f"{len(done_keys)} shard items", + flush=True, + ) + + lines = [l.strip() for l in question_path.read_text("utf-8").splitlines() if l.strip()] + print( + f"[mmvet] {ds_name}: total={len(lines)}, rank={dist_info.rank}, " + f"done={len(done_keys)}", + flush=True, + ) - for ds_name in datasets: - entry = DS_COLLECTIONS.get(ds_name) - if not entry and ds_name not in dataset_paths: - raise ValueError(f"Unknown MM-Vet dataset: {ds_name}") - image_root_value = dataset_paths.get("image_root") - question_value = dataset_paths.get("question") - if not image_root_value or not question_value: - raise ValueError( - "MM-Vet requires `mmvet.dataset_paths.image_root` and " - "`mmvet.dataset_paths.question` to be set in the YAML config." + def iter_rows() -> Iterator[dict[str, Any]]: + for line in lines: + yield json.loads(line) + + def payload_fn(row: dict[str, Any]) -> dict[str, Any]: + image_name = row["image"] + question = row["text"] + question_id = row["question_id"] + image_path = image_root / image_name + if not image_path.exists(): + raise FileNotFoundError(f"MM-Vet image not found: {image_path}") + try: + with Image.open(image_path) as img: + img.verify() + except Exception as exc: + raise RuntimeError(f"Failed to open image {image_path}: {exc}") from exc + return { + "backbone": backbone, + "task": "understanding", + "prompt": question, + "images": [str(image_path)], + "params": request_params, + "metadata": {"question_id": question_id, "dataset": ds_name}, + } + + def record_fn(row: dict[str, Any], raw: Any, _idx: int) -> dict[str, Any]: + response = _extract_text(raw) + output_key = _normalize_output_key(row["question_id"]) + return { + "output_key": output_key, + "response": response, + } + + def sample_id_fn(row: dict[str, Any]) -> str: + return _normalize_output_key(row["question_id"]) + + n_written = run_sharded_inference( + infer_fn=pipeline.run, + dist_info=dist_info, + shard_path=shard_path, + samples=iter_rows(), + total=len(lines), + payload_fn=payload_fn, + record_fn=record_fn, + sample_id_fn=sample_id_fn, + done_ids=done_keys, + max_samples=max_samples, + log_prefix=f"mmvet/{ds_name}/rank{dist_info.rank}", ) - image_root = _resolve_path(str(image_root_value), repo_root) - question_path = _resolve_path(str(question_value), repo_root) - if not image_root.exists(): - raise FileNotFoundError(f"MM-Vet image root not found: {image_root}") - if not question_path.exists(): - raise FileNotFoundError(f"MM-Vet question file not found: {question_path}") - - # Resume: load checkpoint if exists - checkpoint_json = out_dir / f"{ds_name}_checkpoint.json" - outputs: dict[str, str] = {} - if checkpoint_json.exists(): - outputs = json.loads(checkpoint_json.read_text("utf-8")) - print(f"[mmvet] resume: {len(outputs)} done, skipping completed items", flush=True) - - lines = [l.strip() for l in question_path.read_text("utf-8").splitlines() if l.strip()] - print(f"[mmvet] {ds_name}: {len(lines)} total, {len(outputs)} done", flush=True) - - for idx, line in enumerate(tqdm(lines, desc=f"mmvet/{ds_name}", file=sys.stdout), start=1): - row = json.loads(line) - image_name = row["image"] - question = row["text"] - question_id = row["question_id"] - # question_id already has "v1_" prefix in the dataset - output_key = str(question_id) if str(question_id).startswith("v1_") else f"v1_{question_id}" - - if output_key in outputs: - continue - - image_path = image_root / image_name - if not image_path.exists(): - raise FileNotFoundError(f"MM-Vet image not found: {image_path}") - - try: - with Image.open(image_path) as img: - img.verify() - except Exception as exc: - raise RuntimeError(f"Failed to open image {image_path}: {exc}") from exc - - payload = { - "backbone": backbone, - "task": "understanding", - "prompt": question, - "images": [str(image_path)], - "params": request_params, - "metadata": {"question_id": question_id, "dataset": ds_name}, - } - response = _extract_text(pipeline.run(payload)) - outputs[output_key] = response + local_total_written += n_written - # Write checkpoint after each item - checkpoint_json.write_text(json.dumps(outputs, indent=2), encoding="utf-8") + barrier(dist_info) - if max_samples > 0 and idx >= max_samples: - break + time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime()) + results_file = out_dir / f"{ds_name}_{time_prefix}.json" - time_prefix = time.strftime("%y%m%d%H%M%S", time.localtime()) - results_file = out_dir / f"{ds_name}_{time_prefix}.json" - results_file.write_text(json.dumps(outputs, indent=2), encoding="utf-8") + if dist_info.rank == 0: + merged = merge_shards(checkpoint_jsonl) + outputs: dict[str, str] = {item["output_key"]: item["response"] for item in merged} + results_file.write_text(json.dumps(outputs, indent=2), encoding="utf-8") - # Clean up checkpoint after successful completion - if checkpoint_json.exists(): - checkpoint_json.unlink() + cleanup_shards(checkpoint_jsonl) + if dist_info.world_size <= 1 and checkpoint_jsonl.exists(): + checkpoint_jsonl.unlink() - summary[f"{ds_name}_output_path"] = str(results_file) + summary[f"{ds_name}_output_path"] = str(results_file) - if isinstance(score_output_path, str) and score_output_path: - score_path = _resolve_path(score_output_path, repo_root) - score_path.parent.mkdir(parents=True, exist_ok=True) - score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") - print(f"[umm eval] wrote MM-Vet summary to {score_path}") + barrier(dist_info) - print(f"[umm eval] completed MM-Vet for backbone={backbone}, outputs={out_dir}") - return 0 + total_written_all = sum_across_ranks(local_total_written, dist_info) + if dist_info.rank != 0: + print( + f"[umm eval] rank {dist_info.rank} finished MM-Vet shard: " + f"samples_written={local_total_written}", + flush=True, + ) + return 0 + + summary["samples_written"] = total_written_all + if isinstance(score_output_path, str) and score_output_path: + score_path = _resolve_path(score_output_path, repo_root) + score_path.parent.mkdir(parents=True, exist_ok=True) + score_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + print(f"[umm eval] wrote MM-Vet summary to {score_path}") + + print( + f"[umm eval] completed MM-Vet for backbone={backbone}, outputs={out_dir}, " + f"samples_written={total_written_all}, world_size={dist_info.world_size}" + ) + return 0 + finally: + cleanup_distributed(dist_info) diff --git a/src/umm/eval/distributed.py b/src/umm/eval/distributed.py new file mode 100644 index 0000000..89c9a06 --- /dev/null +++ b/src/umm/eval/distributed.py @@ -0,0 +1,164 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Distributed-evaluation plumbing shared by understanding-class eval CLIs. + +These helpers are intentionally narrow: they cover process-group init/teardown, +rank-shard path naming, and JSONL shard merge/cleanup. Per-benchmark output +formatting and scoring stay in each CLI. + +Single-card mode (``WORLD_SIZE <= 1``) is a noop everywhere — no process group +is initialized and ``rank_shard_path`` returns the base path unchanged, so the +caller's on-disk filenames are unaffected. + +`torch` is imported lazily inside functions so that callers that never enter +distributed mode don't pay the import cost. +""" +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +@dataclass(frozen=True) +class DistInfo: + rank: int + world_size: int + local_rank: int + enabled: bool # True iff this process actually init'd a process group + + +def get_dist_info() -> DistInfo: + """Read RANK / WORLD_SIZE / LOCAL_RANK from the environment. Does not init.""" + return DistInfo( + rank=int(os.environ.get("RANK", "0")), + world_size=int(os.environ.get("WORLD_SIZE", "1")), + local_rank=int(os.environ.get("LOCAL_RANK", "0")), + enabled=False, + ) + + +def maybe_init_distributed() -> DistInfo: + """Initialize ``torch.distributed`` when ``WORLD_SIZE > 1``; otherwise noop.""" + info = get_dist_info() + if info.world_size <= 1: + return info + + import torch + import torch.distributed as dist + + if torch.cuda.is_available(): + torch.cuda.set_device(info.local_rank) + if not dist.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method="env://") + return DistInfo( + rank=info.rank, + world_size=info.world_size, + local_rank=info.local_rank, + enabled=True, + ) + + +def cleanup_distributed(info: DistInfo) -> None: + if not info.enabled: + return + import torch.distributed as dist + + if dist.is_initialized(): + dist.destroy_process_group() + + +def barrier(info: DistInfo) -> None: + if not info.enabled: + return + import torch.distributed as dist + + if dist.is_initialized(): + dist.barrier() + + +def sum_across_ranks(value: int, info: DistInfo) -> int: + if not info.enabled: + return value + import torch + import torch.distributed as dist + + if not dist.is_initialized(): + return value + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor([value], dtype=torch.long, device=device) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return int(tensor.item()) + + +def rank_shard_path(base: Path, rank: int, world_size: int) -> Path: + """Return the rank-local shard path for ``base``. + + Single-card mode (``world_size <= 1``) returns ``base`` unchanged so the + caller's existing on-disk filenames are preserved. + """ + if world_size <= 1: + return base + return base.parent / f"{base.stem}.rank{rank}{base.suffix}" + + +def _shard_glob_pattern(base: Path) -> str: + return f"{base.stem}.rank*{base.suffix}" + + +def load_shard_items(shard: Path) -> list[dict[str, Any]]: + """Read a JSONL shard, return [] if missing.""" + items: list[dict[str, Any]] = [] + if not shard.exists(): + return items + with shard.open("r", encoding="utf-8") as reader: + for line in reader: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + return items + + +def merge_shards(base: Path) -> list[dict[str, Any]]: + """Merge all rank shards for ``base`` into a single sorted list. + + Globs ``{stem}.rank*{suffix}`` plus ``base`` itself if present, so a re-run + that uses a smaller world_size doesn't silently drop rows from earlier + larger runs. Items are sorted by ``_sample_idx`` (injected by the runner). + """ + candidates: list[Path] = [] + if base.exists(): + candidates.append(base) + candidates.extend(sorted(base.parent.glob(_shard_glob_pattern(base)))) + + merged: list[tuple[int, dict[str, Any]]] = [] + seen_keys: set[tuple[int, int]] = set() + for path in candidates: + for item in load_shard_items(path): + key = int(item.get("_sample_idx", 0)) + # Defensive: if both base and a shard have the same _sample_idx + # (shouldn't happen in practice), keep the first occurrence. + dedup = (key, id(item)) + if dedup in seen_keys: + continue + seen_keys.add(dedup) + merged.append((key, item)) + merged.sort(key=lambda pair: pair[0]) + return [item for _, item in merged] + + +def cleanup_shards(base: Path) -> None: + """Delete all rank shards matching ``{stem}.rank*{suffix}``. + + The base path itself is NOT touched — the caller is responsible for the + final merged file (which often lives at a different timestamped name). + """ + for shard in base.parent.glob(_shard_glob_pattern(base)): + try: + shard.unlink() + except FileNotFoundError: + pass diff --git a/src/umm/eval/runner.py b/src/umm/eval/runner.py new file mode 100644 index 0000000..4a3888b --- /dev/null +++ b/src/umm/eval/runner.py @@ -0,0 +1,99 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Sharded inference loop shared by understanding-class eval CLIs. + +The runner is intentionally minimal: it iterates a caller-supplied ``samples`` +iterable, sends one payload per assigned sample through ``infer_fn``, and +appends a JSONL line (with an injected ``_sample_idx``) to a per-rank shard. + +Per-benchmark output formatting (Excel, scoring, calculation scripts) stays in +the calling CLI behind ``if rank == 0:`` — the runner does not own those. + +Single-card mode is supported by passing a ``DistInfo`` with ``world_size <= 1``; +all rank/skip checks become noops. +""" +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path +from typing import Any, Callable, Hashable, Iterable, Optional, TypeVar + +from tqdm import tqdm + +from umm.eval.distributed import DistInfo + + +T = TypeVar("T") + + +def run_sharded_inference( + *, + infer_fn: Callable[[dict[str, Any]], Any], + dist_info: DistInfo, + shard_path: Path, + samples: Iterable[T], + payload_fn: Callable[[T], dict[str, Any]], + record_fn: Callable[[T, Any, int], dict[str, Any]], + total: Optional[int] = None, + sample_id_fn: Optional[Callable[[T], Hashable]] = None, + done_ids: Optional[set[Hashable]] = None, + max_samples: int = 0, + log_prefix: str = "eval", +) -> int: + """Iterate ``samples``; run ``infer_fn`` for samples assigned to this rank; + append each result as a JSONL line to ``shard_path``. + + Loop semantics (sample_idx is 1-based, in iteration order): + - ``max_samples > 0``: stop when ``sample_idx > max_samples`` (global cap). + - World size > 1: skip when ``(sample_idx - 1) % world_size != rank``. + - Resume: skip when ``sample_id_fn(sample) in done_ids``. + - Otherwise: ``payload = payload_fn(sample); raw = infer_fn(payload)``; + ``item = record_fn(sample, raw, sample_idx)``; ``item["_sample_idx"] = + sample_idx``; write JSONL line + flush + fsync. + + Pre-filter samples that should be skipped *before* this runner sees them + (e.g. missing-image rows in MME) by wrapping ``samples`` in a generator — + the runner's contract is "every sample yielded gets processed or sharded". + + Returns the number of samples written by THIS rank's shard. + """ + rank = dist_info.rank + world_size = dist_info.world_size + done = done_ids if done_ids is not None else set() + + n_written = 0 + shard_path.parent.mkdir(parents=True, exist_ok=True) + with shard_path.open("a", encoding="utf-8") as shard_writer: + iterator = tqdm( + samples, + total=total, + desc=log_prefix, + file=sys.stdout, + disable=(rank != 0), + ) + for sample_idx, sample in enumerate(iterator, start=1): + if max_samples > 0 and sample_idx > max_samples: + break + if world_size > 1 and (sample_idx - 1) % world_size != rank: + continue + if sample_id_fn is not None: + sample_key = sample_id_fn(sample) + if sample_key in done: + continue + else: + sample_key = None + + payload = payload_fn(sample) + raw = infer_fn(payload) + item = dict(record_fn(sample, raw, sample_idx)) + item["_sample_idx"] = sample_idx + shard_writer.write(json.dumps(item, ensure_ascii=False) + "\n") + shard_writer.flush() + os.fsync(shard_writer.fileno()) + n_written += 1 + if sample_key is not None: + done.add(sample_key) + + return n_written