diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 83a58b291..ff41d5101 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1274,7 +1274,6 @@ def __post_init__(self): "SAPO is not compatible with `use_decoupled_loss=True`. " "Please set `actor.use_decoupled_loss=false` in your configuration." ) - super().__post_init__() diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 09d86622f..818a98d60 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -3,6 +3,7 @@ import abc from collections.abc import Callable from concurrent.futures import Future +from contextlib import nullcontext from typing import TYPE_CHECKING, Any import torch @@ -263,6 +264,10 @@ def prepare_batch( """ raise NotImplementedError() + def prepare_batch_context(self): + """Return a context manager for rollout batch preparation.""" + return nullcontext() + @abc.abstractmethod def set_version(self, version: int): """Set the current weight version in the training engine. @@ -680,6 +685,10 @@ def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: """ raise NotImplementedError() + def sync_weights_from_disk(self, meta: WeightUpdateMeta) -> None: + """Update weights from disk in a blocking manner.""" + self.update_weights_from_disk(meta).result() + def set_version(self, version: int) -> None: """Set the current weight version in the inference engine. diff --git a/areal/api/io_struct.py b/areal/api/io_struct.py index 63b0f2d47..b36324844 100644 --- a/areal/api/io_struct.py +++ b/areal/api/io_struct.py @@ -257,7 +257,6 @@ def from_fsdp_xccl( base_model_name=base_model_name, ) - @dataclass class HttpRequest: """Represents an HTTP request to be sent to a remote inference server.""" diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index ba72247c9..291e067f6 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -718,6 +718,13 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None: def export_stats(self) -> dict[str, float]: return stats_tracker.export_all(reduce_group=self.data_parallel_group) + def prepare_batch_context(self): + return ( + torch_memory_saver.disable() + if self.is_offload and not torch.version.hip + else nullcontext() + ) + def offload(self) -> None: """Offload model memory to CPU using torch_memory_saver. diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ce29f518e..805670369 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -767,6 +767,13 @@ def export_stats(self) -> dict[str, float]: data.update(data_list[0]) return data + def prepare_batch_context(self): + return ( + torch_memory_saver.disable() + if self.is_offload and not torch.version.hip + else nullcontext() + ) + def offload(self) -> None: """Offload model memory to CPU using torch_memory_saver. diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index 9447cb6f7..aa28764f1 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -718,6 +718,13 @@ def export_stats(self) -> dict[str, float]: data.update(data_list[0]) return data + def prepare_batch_context(self): + return ( + torch_memory_saver.disable() + if self.is_offload and not torch.version.hip + else nullcontext() + ) + def get_device_stats(self) -> DeviceRuntimeInfo: return DeviceRuntimeInfo.get_current() diff --git a/areal/infra/__init__.py b/areal/infra/__init__.py index 76e3139ca..078353fae 100644 --- a/areal/infra/__init__.py +++ b/areal/infra/__init__.py @@ -1,6 +1,7 @@ """Core components for AREAL.""" from . import workflow_context +from .colocated import ColocatedOrchestrator from .controller import RolloutController, TrainController from .launcher import ( LocalLauncher, @@ -22,6 +23,7 @@ ) __all__ = [ + "ColocatedOrchestrator", "RemoteInfBackendProtocol", "RemoteInfEngine", "StalenessManager", diff --git a/areal/infra/colocated.py b/areal/infra/colocated.py new file mode 100644 index 000000000..d7ee78079 --- /dev/null +++ b/areal/infra/colocated.py @@ -0,0 +1,114 @@ +"""Colocated (GPU time-sharing) orchestration for on-policy training. + +In colocated mode, the training engine and inference engine share the same +GPUs and alternate between offloaded/onloaded states. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch.distributed as dist + +from areal.api.io_struct import WeightUpdateMeta +from areal.utils import logging + +if TYPE_CHECKING: + from areal.api import InferenceEngine, TrainEngine + +logger = logging.getLogger("Colocated") + + +class ColocatedOrchestrator: + """Orchestrate GPU ownership between colocated training and inference.""" + + def __init__( + self, + train_engine: TrainEngine, + inf_engine: InferenceEngine, + ) -> None: + self._train_engine: TrainEngine = train_engine + self._inf_engine: InferenceEngine = inf_engine + self._train_on_gpu: bool = True + self._inf_on_gpu: bool = True + + def _is_rollout_coordinator(self) -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + def _barrier(self) -> None: + if not dist.is_initialized(): + return + cpu_group = self._train_engine.cpu_group + if cpu_group is None: + dist.barrier() + return + dist.barrier(group=cpu_group) + + def initial_offload_training(self) -> None: + """Offload training once so inference owns the GPU before first rollout.""" + if not self._train_on_gpu: + logger.warning( + "initial_offload_training called but training engine is already off GPU." + ) + return + + if self._is_rollout_coordinator(): + logger.info("Initial offload: moving training engine off GPU") + self._train_engine.offload() + self._train_on_gpu = False + + def prepare_for_training(self) -> None: + """Switch GPU ownership from inference to training.""" + if self._train_on_gpu: + logger.debug("Training engine already on GPU, skipping switch") + return + + if self._is_rollout_coordinator(): + logger.info("Switching to training mode") + + # Pause local submission on every rank first so no new requests are queued. + self._inf_engine.pause() + + # Only one coordinator should touch the shared rollout servers. + if self._is_rollout_coordinator(): + self._inf_engine.pause_generation() + if self._inf_on_gpu: + self._inf_engine.offload() + + self._barrier() + self._inf_on_gpu = False + + # All training ranks must participate in the training-engine collective. + self._train_engine.onload() + self._train_on_gpu = True + + def prepare_for_inference(self, meta: WeightUpdateMeta) -> None: + """Switch GPU ownership from training to inference and sync weights.""" + if self._inf_on_gpu: + logger.debug("Inference engine already on GPU, skipping switch") + return + + if self._is_rollout_coordinator(): + logger.info("Switching to inference mode") + + if self._train_on_gpu: + self._train_engine.offload() + self._train_on_gpu = False + + self._barrier() + + if self._is_rollout_coordinator(): + self._inf_engine.onload() + + if meta.version is None: + raise ValueError("Colocated disk weight sync requires meta.version.") + + # Colocated flow publishes the ready signal before trainer later calls + # rollout.set_version(new_version), so align the rollout-side wait key here. + self._inf_engine.set_version(meta.version) + self._inf_engine.sync_weights_from_disk(meta) + self._inf_engine.continue_generation() + + self._barrier() + self._inf_engine.resume() + self._inf_on_gpu = True diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 9314b3150..4363a48e7 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import copy import shutil import threading import traceback @@ -32,6 +33,7 @@ InferenceEngineConfig, PerfTracerConfig, SchedulingSpec, + SchedulingStrategyType, ) from areal.infra.rpc.serialization import deserialize_value from areal.infra.utils.concurrent import run_async_task @@ -46,6 +48,11 @@ logger = logging.getLogger("RolloutController") +_ROLLOUT_FORK_UNSET_ENV_KEYS = [ + "LD_PRELOAD", + "TMS_INIT_ENABLE", + "TMS_INIT_ENABLE_CPU_BACKUP", +] # NOTE: remote task input has a slightly different # type annotation, which disallows workflow object or types @@ -150,6 +157,25 @@ def _engine_name(self, rank: int) -> str: """ return f"{self._worker_role}/{rank}" + @staticmethod + def _build_create_worker_kwargs(job: Job) -> dict[str, Any]: + """Build scheduler kwargs for rollout worker creation. + + For colocated rollout->actor with fork=True, sanitize inherited actor-side + TMS env vars at the fork boundary so the rollout child and its SGLang server + do not inherit training-side memory-saver hooks. + """ + strategy = job.scheduling_strategy + if ( + SchedulingStrategyType(strategy.type) == SchedulingStrategyType.colocation + and strategy.target == "actor" + and strategy.fork + ): + return { + "fork_unset_env_keys": list(_ROLLOUT_FORK_UNSET_ENV_KEYS), + } + return {} + def initialize( self, role: str, @@ -231,8 +257,16 @@ async def _async_initialize( **kwargs, ): # Create workers via scheduler - logger.info("Creating workers via scheduler...") - worker_ids = self.scheduler.create_workers(job=job) + create_worker_kwargs = self._build_create_worker_kwargs(job) + if create_worker_kwargs: + logger.info( + "Creating workers via scheduler with fork env cleanup: " + f"{create_worker_kwargs}" + ) + else: + logger.info("Creating workers via scheduler...") + + worker_ids = self.scheduler.create_workers(job=job, **create_worker_kwargs) logger.info(f"Workers created: {worker_ids}") # Wait for workers to be ready @@ -565,12 +599,12 @@ def update_weights_disk(): @app.route("/callback/pause_generation", methods=["POST"]) def pause_generation(): - self._callback_loop.run_until_complete(self.pause_generation()) + self.pause_generation() return jsonify({"status": "ok"}) @app.route("/callback/continue_generation", methods=["POST"]) def continue_generation(): - self._callback_loop.run_until_complete(self.continue_generation()) + self.continue_generation() return jsonify({"status": "ok"}) @app.route("/callback/rollout_complete", methods=["POST"]) @@ -1018,16 +1052,27 @@ async def update_weights_from_distributed( ) async def update_weights_from_disk(self, meta: WeightUpdateMeta): - meta.clear_checkpoint_after_load = False - await self._collective_rpc_async("update_weights_from_disk", meta=meta) - shutil.rmtree(meta.path, ignore_errors=True) + update_meta = copy.copy(meta) + update_meta.clear_checkpoint_after_load = False + await self._collective_rpc_async("update_weights_from_disk", meta=update_meta) + if meta.clear_checkpoint_after_load and meta.path is not None: + shutil.rmtree(meta.path, ignore_errors=True) + + def sync_weights_from_disk(self, meta: WeightUpdateMeta) -> None: + run_async_task(self.update_weights_from_disk, meta) - async def pause_generation(self): + async def _pause_generation_async(self): await self._collective_rpc_async("pause_generation") - async def continue_generation(self): + def pause_generation(self): + run_async_task(self._pause_generation_async) + + async def _continue_generation_async(self): await self._collective_rpc_async("continue_generation") + def continue_generation(self): + run_async_task(self._continue_generation_async) + def set_version(self, version: int) -> None: with self._version_lock: self._version = version @@ -1041,6 +1086,12 @@ def get_version(self) -> int: with self._version_lock: return self._version + def offload(self) -> None: + self._collective_rpc("offload", http_timeout=60.0) + + def onload(self, tags: list[str] | None = None) -> None: + self._collective_rpc("onload", tags=tags, http_timeout=60.0) + def pause(self): self.dispatcher.pause() self._collective_rpc("pause", http_timeout=60.0) diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index d5b8473ca..7d9ccaa98 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import nullcontext from typing import Any import torch @@ -139,6 +140,7 @@ def __init__( self._own_process_group = False self.rollout: RolloutController = None + self.is_offload = False def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): """Placeholder method for process group creation. @@ -584,6 +586,17 @@ def update_weights(self, meta: WeightUpdateMeta): self._check_rollout_engine_connected() self._custom_function_call("update_weights", meta=meta) + def offload(self) -> None: + self._custom_function_call("offload") + self.is_offload = True + + def onload(self) -> None: + self._custom_function_call("onload") + self.is_offload = False + + def prepare_batch_context(self): + return nullcontext() + def get_device_stats(self): return self._custom_function_call("get_device_stats") @@ -603,15 +616,28 @@ async def _call(): return await asyncio.gather(*tasks) run_async_task(_call) - + def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None: - self._custom_function_call("save_perf_tracer", step=step, force=force) + async def _call(): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="save_perf_tracer", + engine_name=self._engine_name(rank), + step=step, + force=force, + ) + for rank, worker in enumerate(self.workers) + ] + return await asyncio.gather(*tasks) + + run_async_task(_call) def prepare_batch( self, dataloader: StatefulDataLoader, workflow: WorkflowLike, - workflow_kwargs: dict[str, Any], + workflow_kwargs: dict[str, Any] | None = None, should_accept_fn: str | None = None, group_size: int = 1, dynamic_bs: bool = False, @@ -629,7 +655,7 @@ def rollout_batch( self, data: list[dict[str, Any]], workflow: WorkflowLike, - workflow_kwargs: dict[str, Any], + workflow_kwargs: dict[str, Any] | None = None, should_accept_fn: str | None = None, group_size: int = 1, ) -> list[dict[str, Any]]: diff --git a/areal/infra/launcher/local.py b/areal/infra/launcher/local.py index 5ae27be90..e89a7b1f4 100644 --- a/areal/infra/launcher/local.py +++ b/areal/infra/launcher/local.py @@ -365,6 +365,7 @@ def local_main(config, run_id: int = 0): tms_env_vars = get_tms_env_vars() else: tms_env_vars = {} + # Launch trainer entrypoint if alloc_mode.type_ != AllocationType.LLM_SERVER_ONLY: gpu = nprocs = alloc_mode.train.world_size diff --git a/areal/infra/rpc/rpc_server.py b/areal/infra/rpc/rpc_server.py index 5e2a261c4..e0462d6ac 100644 --- a/areal/infra/rpc/rpc_server.py +++ b/areal/infra/rpc/rpc_server.py @@ -203,13 +203,15 @@ def fork_worker(): This endpoint spawns a new RPC server process as a child of this worker. The child inherits the same environment (including CUDA_VISIBLE_DEVICES) - but runs as an independent process with its own engine registry. + unless env overrides/unsets are explicitly requested. Expected JSON payload: { - "role": "ref", # Role name for the forked worker - "worker_index": 0, # Worker index - "command": "areal.infra.rpc.rpc_server" # Optional: custom module to run + "role": "ref", # Role name for the forked worker + "worker_index": 0, # Worker index + "command": "areal.infra.rpc.rpc_server", # Optional custom module to run + "env_overrides": {"FOO": "bar"}, # Optional env overrides for child + "unset_env_keys": ["LD_PRELOAD"] # Optional env keys removed from child } Returns: @@ -229,20 +231,27 @@ def fork_worker(): role = data.get("role") worker_index = data.get("worker_index") - command = data.get("command") # Optional custom module path + command = data.get("command") + env_overrides = data.get("env_overrides") or {} + unset_env_keys = data.get("unset_env_keys") or [] if role is None: return jsonify({"error": "Missing 'role' field in request"}), 400 if worker_index is None: return jsonify({"error": "Missing 'worker_index' field in request"}), 400 + if not isinstance(env_overrides, dict): + return jsonify({"error": "'env_overrides' must be a dictionary"}), 400 + if not isinstance(unset_env_keys, list) or not all( + isinstance(key, str) for key in unset_env_keys + ): + return jsonify({"error": "'unset_env_keys' must be a list[str]"}), 400 + if not all(isinstance(key, str) for key in env_overrides): + return jsonify({"error": "All env_overrides keys must be strings"}), 400 - # Allocate a free port for the child process ports = find_free_ports(1, exclude_ports=_allocated_ports) child_port = ports[0] _allocated_ports.add(child_port) - # Build command for child process - # Use custom module if specified, otherwise default to rpc_server module = command if command else "areal.infra.rpc.rpc_server" cmd = [ sys.executable, @@ -275,8 +284,6 @@ def fork_worker(): f"on port {child_port}" ) - # Build shell command with tee/sed for streaming logs to terminal and files - # This matches LocalScheduler's logging pattern log_dir = ( Path(_fileroot) / "logs" @@ -290,23 +297,32 @@ def fork_worker(): logger.info(f"Forked worker logs will be written to: {log_file}") - # Use streaming log utility for terminal, role log, and merged log output + child_env = os.environ.copy() + for key in unset_env_keys: + child_env.pop(key, None) + for key, value in env_overrides.items(): + child_env[key] = str(value) + + if unset_env_keys or env_overrides: + logger.info( + f"Fork child env patch for role '{role}' index {worker_index}: " + f"unset={unset_env_keys}, overrides={list(env_overrides.keys())}" + ) + child_process = run_with_streaming_logs( cmd, log_file, merged_log, role, - env=os.environ.copy(), + env=child_env, ) with _forked_children_lock: _forked_children.append(child_process) _forked_children_map[(role, worker_index)] = child_process - # Wait for child to be ready child_host = _server_host if not _wait_for_worker_ready(child_host, child_port): - # Cleanup on failure try: kill_process_tree(child_process.pid, timeout=3, graceful=True) except Exception: @@ -668,7 +684,16 @@ def call_engine_method(): def execute_in_engine_thread(): try: # Broadcast args when engine is a TrainEngine and has been initialized - if isinstance(engine, TrainEngine) and engine.initialized: + NON_COLLECTIVE_TRAIN_ENGINE_METHODS = { + "config_perf_tracer", + # 如有必要,可后续加入其它纯控制类方法 + } + should_broadcast = ( + isinstance(engine, TrainEngine) + and engine.initialized + and method_name not in NON_COLLECTIVE_TRAIN_ENGINE_METHODS + ) + if should_broadcast: logger.debug( f"Broadcasting data for TrainEngine method: {method_name}" ) diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 0dd1f5474..e053c5061 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -60,6 +60,17 @@ class WorkerInfo: log_file: str env_vars: dict[str, str] = field(default_factory=dict) +def _apply_env_patch( + base_env: dict[str, str], + env_overrides: dict[str, Any] | None = None, + unset_env_keys: list[str] | None = None, +) -> dict[str, str]: + env = dict(base_env) + for key in unset_env_keys or []: + env.pop(key, None) + for key, value in (env_overrides or {}).items(): + env[key] = str(value) + return env def _get_device_count_safely() -> int | None: """ @@ -258,6 +269,8 @@ async def _fork_single_worker( target_wi: WorkerInfo, target_role: str, command: str | None = None, + env_overrides: dict[str, Any] | None = None, + unset_env_keys: list[str] | None = None, ) -> WorkerInfo: """Fork a single worker asynchronously. @@ -266,6 +279,10 @@ async def _fork_single_worker( command : str, optional Custom module path to run instead of the default rpc_server. If specified, the forked process runs this module. + env_overrides : dict[str, Any], optional + Environment variables to override/add in the forked child. + unset_env_keys : list[str], optional + Environment variable names to remove from the forked child. """ worker_id = f"{role}/{idx}" target_url = ( @@ -273,9 +290,14 @@ async def _fork_single_worker( ) try: - payload = {"role": role, "worker_index": idx} + payload: dict[str, Any] = {"role": role, "worker_index": idx} if command is not None: payload["command"] = command + if env_overrides: + payload["env_overrides"] = env_overrides + if unset_env_keys: + payload["unset_env_keys"] = unset_env_keys + async with session.post( target_url, json=payload, @@ -330,7 +352,11 @@ async def _fork_single_worker( gpu_devices=target_wi.gpu_devices, # Inherited from target created_at=time.time(), log_file=str(self.log_dir / f"{role}.log"), - env_vars=target_wi.env_vars.copy(), # Inherited from target + env_vars=_apply_env_patch( + target_wi.env_vars, + env_overrides=env_overrides, + unset_env_keys=unset_env_keys, + ), ) async def _kill_forked_worker( @@ -399,6 +425,8 @@ async def _create_forked_workers_async( target_role: str, target_workers: list[WorkerInfo], command: str | None = None, + env_overrides: dict[str, Any] | None = None, + unset_env_keys: list[str] | None = None, ) -> list[str]: """Create forked workers concurrently using async requests. @@ -407,22 +435,31 @@ async def _create_forked_workers_async( command : str, optional Custom module path to run instead of the default rpc_server. If specified, the forked processes run this module. + env_overrides : dict[str, Any], optional + Environment variables to override/add in forked children. + unset_env_keys : list[str], optional + Environment variable names to remove from forked children. """ timeout = aiohttp.ClientTimeout(total=120.0) async with aiohttp.ClientSession( timeout=timeout, connector=get_default_connector(), ) as session: - # Launch all fork requests concurrently with exception handling tasks = [ self._fork_single_worker( - session, role, idx, target_wi, target_role, command + session, + role, + idx, + target_wi, + target_role, + command=command, + env_overrides=env_overrides, + unset_env_keys=unset_env_keys, ) for idx, target_wi in enumerate(target_workers) ] results = await asyncio.gather(*tasks, return_exceptions=True) - # Separate successful workers from failures workers = [] failed_indices = [] for idx, result in enumerate(results): @@ -434,13 +471,11 @@ async def _create_forked_workers_async( else: workers.append(result) - # If any fork failed, cleanup successful workers and raise if failed_indices: if workers: logger.warning( f"Cleaning up {len(workers)} successfully forked workers due to partial failure" ) - # Kill the forked processes via parent RPC servers try: await self._cleanup_forked_workers_async(role, target_role, workers) except Exception as cleanup_error: @@ -461,7 +496,6 @@ async def _create_forked_workers_async( f"created {len(workers)} new worker processes" ) - # Configure forked workers if exp_config is available if self.exp_config is not None: for worker_rank, worker_info in enumerate(workers): self._configure_worker(worker_info, worker_rank) @@ -473,6 +507,8 @@ def fork_workers( role: str, target_role: str, command: str | None = None, + env_overrides: dict[str, Any] | None = None, + unset_env_keys: list[str] | None = None, ) -> list[str]: """Fork new worker processes from existing workers. @@ -488,6 +524,10 @@ def fork_workers( command : str, optional Custom module path to run instead of the default rpc_server. If specified, the forked process runs this module. + env_overrides : dict[str, Any], optional + Environment variables to override/add in the forked child. + unset_env_keys : list[str], optional + Environment variable names to remove from the forked child. Returns ------- @@ -504,10 +544,11 @@ def fork_workers( role, target_role, target_workers, - command, + command=command, + env_overrides=env_overrides, + unset_env_keys=unset_env_keys, ) except Exception: - # Cleanup on failure if role in self._workers: del self._workers[role] if role in self._colocated_roles: @@ -555,6 +596,8 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: ) schedulings = self._prepare_worker_specs(role, num_workers, job.tasks) + fork_env_overrides = kwargs.get("fork_env_overrides") + fork_unset_env_keys = kwargs.get("fork_unset_env_keys") strategy = job.scheduling_strategy strategy_type = SchedulingStrategyType(strategy.type) @@ -589,7 +632,12 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: # Check if fork mode is enabled if strategy.fork: # Fork mode: spawn new processes on same GPUs via /fork endpoint - worker_ids = self.fork_workers(role, colocate_role) + worker_ids = self.fork_workers( + role, + colocate_role, + env_overrides=fork_env_overrides, + unset_env_keys=fork_unset_env_keys, + ) else: # Reuse existing workers - no new processes spawned worker_ids = [w.worker.id for w in target_workers] diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 6014885c6..cda992592 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -3,6 +3,7 @@ import functools import os from collections.abc import Callable +from contextlib import contextmanager from copy import deepcopy from typing import TYPE_CHECKING, Any @@ -41,11 +42,12 @@ SlurmScheduler, current_platform, ) -from areal.utils import logging, perf_tracer, seeding, stats_tracker +from areal.utils import logging, name_resolve, names, perf_tracer, seeding, stats_tracker from areal.utils.dataloader import create_dataloader from areal.utils.environ import is_single_controller from areal.utils.evaluator import Evaluator from areal.utils.hf_utils import load_hf_processor_and_tokenizer +from areal.utils.offload import get_tms_env_vars from areal.utils.perf_tracer import Category from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver @@ -105,6 +107,7 @@ def __init__( logging.setup_file_logging(StatsLogger.get_log_path(config.stats_logger)) self.config = config + self._colocated: bool = self._is_colocated_rollout(self.config.rollout) self.processor, self.tokenizer = load_hf_processor_and_tokenizer( config.tokenizer_path ) @@ -119,7 +122,7 @@ def __init__( self.allocation_mode = AllocationMode.from_str(config.allocation_mode) # Validate config before proceeding with weight initialization - self._validate_cfg() + self._validate_cfg(train_dataset) self._amend_xccl_weight_update_envvar() @@ -186,44 +189,70 @@ def __init__( "ft_spec": ft_spec, "alloc_mode": self.allocation_mode, } - self.actor.initialize(**engine_init_kwargs, role="actor") - if self.critic is not None: - self.critic.initialize(**engine_init_kwargs, role="critic") - if self.ref is not None: - self.ref.initialize(**engine_init_kwargs, role="ref") - self.teacher = None - if config.teacher is not None: - self.teacher = self._create_teacher(config.teacher) - teacher_allocation_mode = AllocationMode.from_str( - config.teacher.allocation_mode + if self._colocated: + # SPMD: actor first, then connect to existing SGLang/vLLM server + self.actor.initialize(**engine_init_kwargs, role="actor") + self.rollout = self._init_rollout( + config.rollout, is_eval=False, lora_path=None ) - teacher_init_kwargs = { - "addr": None, - "ft_spec": ft_spec, - "alloc_mode": teacher_allocation_mode, - } - self.teacher.initialize(**teacher_init_kwargs, role="teacher") - # Save initial LoRA weights if enabled (for inference server pre-loading) - initial_lora_path = self._save_initial_lora_weights() + # Save initial LoRA weights if needed. + # In colocated mode the rollout was started without LoRA, so the + # initial adapter must be synced to the inference engine later + # (after ColocatedOrchestrator is set up). + self._initial_lora_path = self._save_initial_lora_weights() - # Initialize inference with LoRA path - self.rollout = self._init_rollout( - config.rollout, is_eval=False, lora_path=initial_lora_path - ) - # Online mode detection: skip eval rollout for efficiency. - openai_cfg = config.rollout.openai - self._online_mode = train_dataset is None or ( - openai_cfg is not None and openai_cfg.mode == "online" - ) - - self.eval_rollout = None - if not self._online_mode: + # No critic / ref / teacher in colocated mode + self.teacher = None self.eval_rollout = self._init_rollout( - config.rollout, is_eval=True, lora_path=initial_lora_path + config.rollout, + is_eval=True, + lora_path=self._initial_lora_path, + ) + + else: + # Standard mode: original initialization order preserved exactly + self.actor.initialize(**engine_init_kwargs, role="actor") + if self.critic is not None: + self.critic.initialize(**engine_init_kwargs, role="critic") + if self.ref is not None: + self.ref.initialize(**engine_init_kwargs, role="ref") + + self.teacher = None + if config.teacher is not None: + self.teacher = self._create_teacher(config.teacher) + teacher_allocation_mode = AllocationMode.from_str( + config.teacher.allocation_mode + ) + teacher_init_kwargs = { + "addr": None, + "ft_spec": ft_spec, + "alloc_mode": teacher_allocation_mode, + } + self.teacher.initialize(**teacher_init_kwargs, role="teacher") + + # Save initial LoRA weights if enabled (for inference server pre-loading) + initial_lora_path = self._save_initial_lora_weights() + + # Initialize inference with LoRA path + self.rollout = self._init_rollout( + rollout_config=config.rollout, + is_eval=False, + lora_path=initial_lora_path, + ) + # Online mode detection: skip eval rollout for efficiency. + openai_cfg = config.rollout.openai + self._online_mode = train_dataset is None or ( + openai_cfg is not None and openai_cfg.mode == "online" ) + self.eval_rollout = None + if not self._online_mode: + self.eval_rollout = self._init_rollout( + config.rollout, is_eval=True, lora_path=initial_lora_path + ) + # Proxy worker initialization (lazy, for AgentWorkflow support) self._proxy_started = False @@ -268,7 +297,29 @@ def __init__( ) self.actor.connect_engine(self.rollout, self.weight_update_meta) - # Set up evaluation (skip in online mode) + # Initialize colocated orchestrator if enabled + self.colocated_orch = None + if self._colocated: + from areal.infra.colocated import ColocatedOrchestrator + + self.colocated_orch = ColocatedOrchestrator( + train_engine=self.actor, + inf_engine=self.rollout, + ) + + _initial_lora_meta = None + if self._initial_lora_path is not None: + _initial_lora_meta = self.weight_update_meta.with_version(0) + self._save_actor_weights_for_rollout(_initial_lora_meta) + + self.colocated_orch.initial_offload_training() + + if _initial_lora_meta is not None: + self.rollout.sync_weights_from_disk(_initial_lora_meta) + + logger.info("Colocated mode enabled via rollout.scheduling_strategy.") + + # Set up evaluation helpers. self.evaluator = Evaluator(config.evaluator, ft_spec) # Set up save as HF model @@ -278,16 +329,52 @@ def __init__( # Set up statistics logging (wandb, tensoboard, etc.) self.stats_logger = StatsLogger(config, ft_spec) + # In colocated mode, train stats must be exported before switching GPU + # ownership to inference, but the final commit should still happen + # after evaluation so eval metrics can be logged together. + self._pending_train_stats_for_commit: dict[str, float] | None = None + # Set up checkpointing for recover - self.recover_info = self.recover_handler.load( - self.actor, - self.saver, - self.evaluator, - self.stats_logger, - self.train_dataloader, - inference_engine=self.rollout, - weight_update_meta=self.weight_update_meta, - ) + if self._colocated: + # In colocated mode, the actor is already offloaded. The standard + # recover flow (update_engine.update_weights) cannot work because + # _update_weights_from_disk needs GPU parameters. Instead, skip + # inference_engine sync in recover_handler.load and handle it + # manually via the ColocatedOrchestrator. + self.recover_info = self.recover_handler.load( + self.actor, + self.saver, + self.evaluator, + self.stats_logger, + self.train_dataloader, + inference_engine=None, + weight_update_meta=None, + ) + if self.recover_info is not None: + # Recovered from checkpoint — sync weights to inference engine. + # The actor is offloaded; onload it, save weights, offload + # again, and update inference. + assert self.colocated_orch is not None + global_step = self.recover_info.last_step_info.global_step + recovery_version = global_step + 1 + versioned_meta = self.weight_update_meta.with_version(recovery_version) + # save() must be called while actor is on GPU; onload first. + self.colocated_orch.prepare_for_training() + self._save_actor_weights_for_rollout(versioned_meta) + self.actor.set_version(recovery_version) + self.rollout.set_version(recovery_version) + # offload training, onload inference, and load new weights + self.colocated_orch.prepare_for_inference(versioned_meta) + else: + self.recover_info = self.recover_handler.load( + self.actor, + self.saver, + self.evaluator, + self.stats_logger, + self.train_dataloader, + inference_engine=self.rollout, + weight_update_meta=self.weight_update_meta, + ) self._config_perf_tracer() @@ -347,6 +434,8 @@ def train( "epoch_step": step, }, ), + self.actor.prepare_batch_context(), + self._colocated_prepare_batch_context(global_step), ): rollout_batch = self.actor.prepare_batch( self.train_dataloader, @@ -457,7 +546,8 @@ def train( self.critic.get_device_stats().log("ppo critic update") # pause inference for updating weights, save, and evaluation - self.rollout.pause() + if not self._colocated: + self.rollout.pause() with ( stats_tracker.record_timing("update_weights"), @@ -467,18 +557,29 @@ def train( args={"global_step": global_step}, ), ): - # Use versioned path for weight updates new_version = global_step + 1 versioned_meta = self.weight_update_meta.with_version(new_version) - self.actor.update_weights(versioned_meta) + + if self._colocated: + # Colocated mode: save weights to the versioned disk path + # while training engine is still on GPU, then publish the + # rendezvous signal expected by remote rollout workers. + self._save_actor_weights_for_rollout(versioned_meta) + else: + # Standard mode: use FSDP's update_weights (xccl or disk) + self.actor.update_weights(versioned_meta) self.actor.set_version(new_version) if self.critic is not None: self.critic.set_version(new_version) - self.rollout.set_version(new_version) - if self.eval_rollout is not None: - self.eval_rollout.set_version(new_version) + if not self._colocated: + self.rollout.set_version(new_version) + if self.eval_rollout is not None: + self.eval_rollout.set_version(new_version) + + # In colocated mode, save HF and recover checkpoint BEFORE switching + # to inference, since these operations need the train engine on GPU. with ( stats_tracker.record_timing("save"), perf_tracer.trace_scope( @@ -501,6 +602,38 @@ def train( epoch=epoch, epoch_step=step, global_step=global_step ) + with ( + stats_tracker.record_timing("clear_batches"), + perf_tracer.trace_scope( + "train.clear_batches", + category=Category.INSTR, + args={"global_step": global_step}, + ), + ): + # Since all RTensor objects are affiliated IPs, + # calling `clear_batches` once should be sufficient. + self.actor.clear_batches(rollout_batch, adv_batch) + + if self._colocated: + self._capture_train_stats_snapshot() + + # === Colocated mode: switch from training to inference === + if self._colocated: + assert self.colocated_orch is not None + with ( + stats_tracker.record_timing("colocated_switch_to_inference"), + perf_tracer.trace_scope( + "train.colocated_switch_to_inference", + category=Category.COMM, + args={"global_step": global_step}, + ), + ): + self.colocated_orch.prepare_for_inference(versioned_meta) + + self.rollout.set_version(new_version) + if self.eval_rollout is not None: + self.eval_rollout.set_version(new_version) + with ( stats_tracker.record_timing("eval"), perf_tracer.trace_scope( @@ -517,18 +650,6 @@ def train( global_step=global_step, ) - with ( - stats_tracker.record_timing("clear_batches"), - perf_tracer.trace_scope( - "train.clear_batches", - category=Category.INSTR, - args={"global_step": global_step}, - ), - ): - # Since all RTensor objects are affiliated IPs, - # calling `clear_batches` once should be sufficient. - self.actor.clear_batches(rollout_batch, adv_batch) - with perf_tracer.trace_scope( "train.log_stats", category=Category.INSTR, @@ -553,6 +674,10 @@ def close(self): self.ref.destroy() if self.critic is not None: self.critic.destroy() + + if self._colocated and getattr(self.actor, "is_offload", False): + self.actor.onload() + self.actor.destroy() perf_tracer.save(force=True) @@ -587,6 +712,86 @@ def _save_perf_tracer(self, step: int): self.rollout.save_perf_tracer(step=step) perf_tracer.save(step=step) + @staticmethod + def _is_colocated_rollout(rollout_config: InferenceEngineConfig) -> bool: + strategy = getattr(rollout_config, "scheduling_strategy", None) + return ( + strategy is not None + and getattr(strategy, "type", None) == SchedulingStrategyType.colocation + and getattr(strategy, "target", None) == "actor" + ) + + @contextmanager + def _colocated_prepare_batch_context(self, global_step: int): + if self.colocated_orch is None: + yield + return + + try: + yield + except Exception: + raise + else: + with ( + stats_tracker.record_timing("colocated_switch_to_train"), + perf_tracer.trace_scope( + "train.colocated_switch_to_train", + category=Category.COMM, + args={"global_step": global_step}, + ), + ): + self.colocated_orch.prepare_for_training() + + def _save_actor_weights_for_rollout(self, meta: WeightUpdateMeta) -> None: + if meta.type != "disk": + raise ValueError( + "Colocated rollout sync only supports disk-based weight updates. " + f"Got '{meta.type}'." + ) + + self.actor.save( + SaveLoadMeta( + path=meta.path, + weight_format="hf", + with_optim=False, + tokenizer=self.tokenizer, + processor=self.processor, + ) + ) + self._publish_disk_weight_update_ready(meta) + + def _publish_disk_weight_update_ready(self, meta: WeightUpdateMeta) -> None: + import time + + if meta.version is None: + raise ValueError("Colocated disk weight sync requires meta.version.") + + if not dist.is_initialized(): + update_name = names.update_weights_from_disk( + self.config.experiment_name, + self.config.trial_name, + meta.version, + ) + name_resolve.add( + update_name, + str(time.time()), + keepalive_ttl=120, + ) + return + + if dist.get_rank() == 0: + update_name = names.update_weights_from_disk( + self.config.experiment_name, + self.config.trial_name, + meta.version, + ) + name_resolve.add( + update_name, + str(time.time()), + keepalive_ttl=120, + ) + dist.barrier(group=self.actor.cpu_group) + def _init_scheduler(self) -> Scheduler: cfg = self.config.scheduler if cfg.type == "local": @@ -615,6 +820,13 @@ def _amend_xccl_weight_update_envvar(self): if not is_single_controller(): # These environs are set by the launcher in the SPMD mode. return + + tms_env_vars = None + if self.config.enable_offload or self._colocated: + tms_env_vars = get_tms_env_vars() + for spec in self.config.actor.scheduling_spec: + spec.env_vars.update(tms_env_vars) + if self.allocation_mode.gen_backend != "sglang": return @@ -709,11 +921,10 @@ def _init_rollout( is_eval: bool = False, lora_path: str | None = None, ) -> InferenceEngine | RolloutController: - if lora_path is not None and not is_single_controller(): + if lora_path is not None and not is_single_controller() and not self._colocated: raise ValueError( - "LoRA is only supported in single-controller mode. " - "Use `python3 train.py scheduler.type=local` instead of " - "`python3 -m areal.infra.launcher.local`." + "LoRA is only supported in single-controller mode or when rollout is " + "colocated with actor via rollout.scheduling_strategy and actor.weight_update_mode=disk." ) # Create a working copy of config config = deepcopy(rollout_config) @@ -914,19 +1125,41 @@ def _evaluate( dist.barrier(group=self.actor.cpu_group) current_platform.synchronize() - def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int): - # Upload statistics to the logger (e.g., wandb) + def _capture_train_stats_snapshot(self) -> None: + """Capture train-side stats before colocated eval takes over the GPU.""" + if self._pending_train_stats_for_commit is not None: + logger.warning( + "Overwriting pending train stats snapshot before previous commit." + ) stats = self.actor.export_stats() stats.update(self.rollout.export_stats()) + self._pending_train_stats_for_commit = stats + + def _export_eval_stats_snapshot(self) -> dict[str, float]: + stats: dict[str, float] = {} if self.eval_rollout is not None: stats.update(self.eval_rollout.export_stats()) + return stats + + def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int): + # Upload statistics to the logger (e.g., wandb) + pending_train_stats = self._pending_train_stats_for_commit + self._pending_train_stats_for_commit = None + + if pending_train_stats is not None: + stats = dict(pending_train_stats) + else: + stats = self.actor.export_stats() + stats.update(self.rollout.export_stats()) + + stats.update(self._export_eval_stats_snapshot()) self.stats_logger.commit(epoch, epoch_step, global_step, stats) dist.barrier(group=self.actor.cpu_group) current_platform.synchronize() - def _validate_cfg(self): - """validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models.""" + def _validate_cfg(self, train_dataset: Dataset | None): + """Validate config before weight initialization to fail fast on unsupported setups.""" if ( self.allocation_mode.gen_backend == "vllm" and self.config.rollout.return_routed_experts @@ -936,6 +1169,47 @@ def _validate_cfg(self): "Please disable return_routed_experts or switch to SGLang backend." ) + if not self._colocated: + return + + if self.config.cluster.n_nodes != 1: + raise ValueError( + "Colocated mode only supports single-node runs. " + f"Got cluster.n_nodes={self.config.cluster.n_nodes}." + ) + + if self.config.actor.weight_update_mode != "disk": + raise ValueError( + "Colocated mode requires actor.weight_update_mode='disk'. " + f"Got '{self.config.actor.weight_update_mode}'." + ) + + if train_dataset is None: + raise ValueError( + "Colocated mode does not support online training and requires a train_dataset." + ) + + openai_cfg = self.config.rollout.openai + if openai_cfg is not None and openai_cfg.mode == "online": + raise ValueError( + "Colocated mode does not support rollout.openai.mode='online'." + ) + + if self.config.critic is not None: + raise ValueError( + "Colocated mode only supports actor-only training: critic is not supported." + ) + + if self.config.ref is not None or self.config.actor.kl_ctl > 0: + raise ValueError( + "Colocated mode only supports actor-only training: ref/kl_ctl is not supported." + ) + + if self.config.teacher is not None: + raise ValueError( + "Colocated mode only supports actor-only training: teacher is not supported." + ) + def _requires_proxy_workflow(self, workflow: WorkflowLike | None) -> bool: """Check if workflow requires proxy workers (i.e., not a RolloutWorkflow). diff --git a/areal/utils/stats_tracker.py b/areal/utils/stats_tracker.py index 1bcf625fb..167051aab 100644 --- a/areal/utils/stats_tracker.py +++ b/areal/utils/stats_tracker.py @@ -198,7 +198,9 @@ def _aggregate(self, key, reduce_group): elif reduce_type == ReduceType.MAX: result[key] = self._max_of(key, reduce_group) elif reduce_type == ReduceType.SCALAR: - if current_platform.is_initialized(): + if reduce_group is None: + device = "cpu" + elif current_platform.is_initialized(): device = current_platform.device_type else: device = "cpu" diff --git a/examples/math/gsm8k_grpo_colocated.yaml b/examples/math/gsm8k_grpo_colocated.yaml new file mode 100644 index 000000000..f94ead3d8 --- /dev/null +++ b/examples/math/gsm8k_grpo_colocated.yaml @@ -0,0 +1,161 @@ +# GSM8K GRPO with colocated (GPU time-sharing) training + +experiment_name: gsm8k-grpo-colocated +trial_name: trial0 + +seed: 1 +enable_offload: true # Enables torch_memory_saver for train-side offload/onload +total_train_epochs: 1 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +# Training and inference share the same 8 GPUs +allocation_mode: sglang[rollout]:d8p1t1|fsdp[actor]:d8p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 0 # On-policy: no off-policyness allowed + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + enable_rollout_tracing: false + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /data/data/models--Qwen--Qwen2.5-1.5B-Instruct/snapshots/989aa7980e4cf806f80c7fef2b1adb7bc71aa306 + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + weight_update_mode: disk # Colocated mode uses ordinary disk sync under cluster.fileroot + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +# No reference model in colocated mode to save GPU memory +ref: null + +# SGLang configuration +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + enable_memory_saver: true # Enable memory saver for offload/onload support + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + enable_sleep_mode: true # Enable sleep mode for offload/onload support + +# Datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: /data/data/datasets--openai--gsm8k/snapshots/cc7b047b6e5bb11b4f1af84efc572db110a51b3c + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: /data/data/datasets--openai--gsm8k/snapshots/cc7b047b6e5bb11b4f1af84efc572db110a51b3c + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + swanlab: + mode: cloud + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/tests/test_colocated_engine.py b/tests/test_colocated_engine.py new file mode 100644 index 000000000..fc0fbbce1 --- /dev/null +++ b/tests/test_colocated_engine.py @@ -0,0 +1,582 @@ +"""Unit tests for colocated orchestration and scheduler-driven trainer behavior.""" + +from __future__ import annotations + +import asyncio +import shutil +import tempfile +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from areal.api.cli_args import SchedulingStrategy, SchedulingStrategyType +from areal.api.io_struct import WeightUpdateMeta +from areal.infra.colocated import ColocatedOrchestrator +from areal.infra.controller.rollout_controller import RolloutController +from areal.infra.controller.train_controller import TrainController +from areal.infra.remote_inf_engine import RemoteInfEngine +from areal.infra.scheduler.local import LocalScheduler, _apply_env_patch +from areal.trainer.rl_trainer import PPOTrainer +from areal.utils import names + + +@pytest.fixture +def mock_train_engine(): + engine = MagicMock() + engine.offload = MagicMock() + engine.onload = MagicMock() + return engine + + +@pytest.fixture +def mock_inf_engine(): + engine = MagicMock() + engine.pause = MagicMock() + engine.resume = MagicMock() + engine.pause_generation = MagicMock() + engine.continue_generation = MagicMock() + engine.offload = MagicMock() + engine.onload = MagicMock() + engine.sync_weights_from_disk = MagicMock() + return engine + + +@pytest.fixture +def orchestrator(mock_train_engine, mock_inf_engine): + return ColocatedOrchestrator( + train_engine=mock_train_engine, + inf_engine=mock_inf_engine, + ) + + +class TestColocatedOrchestrator: + def test_initial_state(self, orchestrator): + assert orchestrator._train_on_gpu is True + assert orchestrator._inf_on_gpu is True + + def test_initial_offload_training(self, orchestrator, mock_train_engine): + orchestrator.initial_offload_training() + + mock_train_engine.offload.assert_called_once() + assert orchestrator._train_on_gpu is False + assert orchestrator._inf_on_gpu is True + + def test_prepare_for_training_switches_gpu_owner( + self, orchestrator, mock_train_engine, mock_inf_engine + ): + orchestrator.initial_offload_training() + mock_train_engine.offload.reset_mock() + + with patch("areal.infra.colocated.dist.is_initialized", return_value=False): + orchestrator.prepare_for_training() + + mock_inf_engine.pause.assert_called_once() + mock_inf_engine.pause_generation.assert_called_once() + mock_inf_engine.offload.assert_called_once() + mock_train_engine.onload.assert_called_once() + assert orchestrator._train_on_gpu is True + assert orchestrator._inf_on_gpu is False + + def test_prepare_for_training_orders_local_and_remote_rollout_shutdown( + self, orchestrator, mock_train_engine, mock_inf_engine + ): + events: list[str] = [] + mock_inf_engine.pause.side_effect = lambda: events.append("pause") + mock_inf_engine.pause_generation.side_effect = lambda: events.append( + "pause_generation" + ) + mock_inf_engine.offload.side_effect = lambda: events.append("inf_offload") + mock_train_engine.onload.side_effect = lambda: events.append("train_onload") + + orchestrator.initial_offload_training() + events.clear() + + with patch("areal.infra.colocated.dist.is_initialized", return_value=False): + orchestrator.prepare_for_training() + + assert events == [ + "pause", + "pause_generation", + "inf_offload", + "train_onload", + ] + + def test_prepare_for_training_only_coordinator_controls_shared_rollout_server( + self, orchestrator, mock_train_engine, mock_inf_engine + ): + orchestrator.initial_offload_training() + mock_train_engine.offload.reset_mock() + + with ( + patch("areal.infra.colocated.dist.is_initialized", return_value=True), + patch("areal.infra.colocated.dist.get_rank", return_value=3), + patch("areal.infra.colocated.dist.barrier") as mock_barrier, + ): + orchestrator.prepare_for_training() + + mock_inf_engine.pause.assert_called_once() + mock_inf_engine.pause_generation.assert_not_called() + mock_inf_engine.offload.assert_not_called() + mock_train_engine.onload.assert_called_once() + mock_barrier.assert_called() + assert orchestrator._train_on_gpu is True + assert orchestrator._inf_on_gpu is False + + def test_prepare_for_inference_switches_gpu_owner_and_syncs_weights( + self, orchestrator, mock_train_engine, mock_inf_engine + ): + orchestrator.initial_offload_training() + with patch("areal.infra.colocated.dist.is_initialized", return_value=False): + orchestrator.prepare_for_training() + mock_inf_engine.pause.reset_mock() + mock_inf_engine.pause_generation.reset_mock() + mock_inf_engine.offload.reset_mock() + mock_inf_engine.onload.reset_mock() + mock_inf_engine.sync_weights_from_disk.reset_mock() + mock_inf_engine.continue_generation.reset_mock() + mock_inf_engine.resume.reset_mock() + mock_train_engine.onload.reset_mock() + mock_train_engine.offload.reset_mock() + + meta = WeightUpdateMeta(type="disk", path="/tmp/weight_update_v1") + with patch("areal.infra.colocated.dist.is_initialized", return_value=False): + orchestrator.prepare_for_inference(meta) + + mock_train_engine.offload.assert_called_once() + mock_inf_engine.onload.assert_called_once() + mock_inf_engine.sync_weights_from_disk.assert_called_once_with(meta) + mock_inf_engine.continue_generation.assert_called_once() + mock_inf_engine.resume.assert_called_once() + assert orchestrator._train_on_gpu is False + assert orchestrator._inf_on_gpu is True + + def test_prepare_for_inference_only_coordinator_controls_shared_rollout_server( + self, orchestrator, mock_train_engine, mock_inf_engine + ): + orchestrator.initial_offload_training() + with patch("areal.infra.colocated.dist.is_initialized", return_value=False): + orchestrator.prepare_for_training() + mock_train_engine.offload.reset_mock() + mock_inf_engine.onload.reset_mock() + mock_inf_engine.sync_weights_from_disk.reset_mock() + mock_inf_engine.continue_generation.reset_mock() + mock_inf_engine.resume.reset_mock() + + meta = WeightUpdateMeta(type="disk", path="/tmp/weight_update_v3") + with patch( + "areal.infra.colocated.dist.is_initialized", return_value=True + ): + with patch("areal.infra.colocated.dist.get_rank", return_value=5): + with patch("areal.infra.colocated.dist.barrier") as mock_barrier: + orchestrator.prepare_for_inference(meta) + + mock_train_engine.offload.assert_called_once() + mock_inf_engine.onload.assert_not_called() + mock_inf_engine.sync_weights_from_disk.assert_not_called() + mock_inf_engine.continue_generation.assert_not_called() + mock_inf_engine.resume.assert_called_once() + mock_barrier.assert_called() + assert orchestrator._train_on_gpu is False + assert orchestrator._inf_on_gpu is True + + def test_prepare_calls_are_idempotent( + self, orchestrator, mock_train_engine, mock_inf_engine + ): + orchestrator.initial_offload_training() + with patch("areal.infra.colocated.dist.is_initialized", return_value=False): + orchestrator.prepare_for_training() + orchestrator.prepare_for_training() + + mock_inf_engine.pause.assert_called_once() + mock_inf_engine.pause_generation.assert_called_once() + mock_inf_engine.offload.assert_called_once() + mock_train_engine.onload.assert_called_once() + + +class TestTrainControllerColocatedInterfaces: + def test_offload_updates_state_and_dispatches(self): + controller = TrainController.__new__(TrainController) + controller._custom_function_call = MagicMock() + controller.is_offload = False + + controller.offload() + + controller._custom_function_call.assert_called_once_with("offload") + assert controller.is_offload is True + + def test_onload_updates_state_and_dispatches(self): + controller = TrainController.__new__(TrainController) + controller._custom_function_call = MagicMock() + controller.is_offload = True + + controller.onload() + + controller._custom_function_call.assert_called_once_with("onload") + assert controller.is_offload is False + + def test_prepare_batch_context_is_noop(self): + controller = TrainController.__new__(TrainController) + + with controller.prepare_batch_context(): + pass + + +class TestRolloutControllerColocatedInterfaces: + def test_sync_weights_from_disk_uses_run_async_task(self): + controller = RolloutController.__new__(RolloutController) + meta = WeightUpdateMeta(type="disk", path="/tmp/weight_update_v2") + + with patch( + "areal.infra.controller.rollout_controller.run_async_task" + ) as mock_run_async_task: + controller.sync_weights_from_disk(meta) + + mock_run_async_task.assert_called_once_with( + controller.update_weights_from_disk, meta + ) + + def test_pause_generation_and_continue_generation_use_run_async_task(self): + controller = RolloutController.__new__(RolloutController) + + with patch( + "areal.infra.controller.rollout_controller.run_async_task" + ) as mock_run_async_task: + controller.pause_generation() + controller.continue_generation() + + assert mock_run_async_task.call_count == 2 + assert mock_run_async_task.call_args_list[0].args[0].__name__ == "_pause_generation_async" + assert ( + mock_run_async_task.call_args_list[1].args[0].__name__ + == "_continue_generation_async" + ) + + def test_offload_and_onload_delegate_to_collective_rpc(self): + controller = RolloutController.__new__(RolloutController) + controller._collective_rpc = MagicMock() + + controller.offload() + controller.onload(tags=["lora"]) + + assert controller._collective_rpc.call_args_list == [ + (("offload",), {"http_timeout": 60.0}), + (("onload",), {"tags": ["lora"], "http_timeout": 60.0}), + ] + + def test_update_weights_from_disk_does_not_mutate_original_meta(self): + controller = RolloutController.__new__(RolloutController) + controller._collective_rpc_async = AsyncMock() + + temp_dir = Path(tempfile.mkdtemp(prefix="areal-colocated-test-")) + try: + meta = WeightUpdateMeta( + type="disk", + path=str(temp_dir), + clear_checkpoint_after_load=True, + ) + + asyncio.run(controller.update_weights_from_disk(meta)) + + assert meta.clear_checkpoint_after_load is True + assert not temp_dir.exists() + await_args = controller._collective_rpc_async.await_args + assert await_args is not None + sent_meta = await_args.kwargs["meta"] + assert sent_meta.clear_checkpoint_after_load is False + assert sent_meta.path == meta.path + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_build_create_worker_kwargs_clears_tms_env_for_actor_colocated_fork(self): + controller = RolloutController.__new__(RolloutController) + job = SimpleNamespace( + scheduling_strategy=SchedulingStrategy( + type=SchedulingStrategyType.colocation, + target="actor", + fork=True, + ) + ) + + assert controller._build_create_worker_kwargs(job) == { + "fork_unset_env_keys": [ + "LD_PRELOAD", + "TMS_INIT_ENABLE", + "TMS_INIT_ENABLE_CPU_BACKUP", + ] + } + + @pytest.mark.parametrize( + "strategy", + [ + SchedulingStrategy( + type=SchedulingStrategyType.colocation, + target="actor", + fork=False, + ), + SchedulingStrategy( + type=SchedulingStrategyType.colocation, + target="critic", + fork=True, + ), + SchedulingStrategy(type=SchedulingStrategyType.separation), + ], + ) + def test_build_create_worker_kwargs_skips_non_actor_fork_colocation(self, strategy): + controller = RolloutController.__new__(RolloutController) + job = SimpleNamespace(scheduling_strategy=strategy) + + assert controller._build_create_worker_kwargs(job) == {} + + +class TestLocalSchedulerForkEnvPatch: + def test_apply_env_patch_unsets_and_overrides(self): + patched = _apply_env_patch( + { + "LD_PRELOAD": "/tmp/libtms.so", + "TMS_INIT_ENABLE": "1", + "KEEP": "yes", + }, + env_overrides={"NEW_KEY": "value"}, + unset_env_keys=["LD_PRELOAD", "TMS_INIT_ENABLE", "MISSING_KEY"], + ) + + assert patched == { + "KEEP": "yes", + "NEW_KEY": "value", + } + + def test_create_workers_forwards_fork_env_cleanup(self): + scheduler = LocalScheduler.__new__(LocalScheduler) + scheduler._workers = cast(Any, {"actor": [object()]}) + scheduler._colocated_roles = {} + scheduler._prepare_worker_specs = MagicMock(return_value=[SimpleNamespace()]) + scheduler.fork_workers = MagicMock(return_value=["rollout/0"]) + + job = SimpleNamespace( + role="rollout", + replicas=1, + tasks=[SimpleNamespace()], + scheduling_strategy=SchedulingStrategy( + type=SchedulingStrategyType.colocation, + target="actor", + fork=True, + ), + ) + + worker_ids = scheduler.create_workers( + job=job, + fork_unset_env_keys=[ + "LD_PRELOAD", + "TMS_INIT_ENABLE", + "TMS_INIT_ENABLE_CPU_BACKUP", + ], + ) + + assert worker_ids == ["rollout/0"] + scheduler.fork_workers.assert_called_once_with( + "rollout", + "actor", + env_overrides=None, + unset_env_keys=[ + "LD_PRELOAD", + "TMS_INIT_ENABLE", + "TMS_INIT_ENABLE_CPU_BACKUP", + ], + ) + assert scheduler._colocated_roles["rollout"] == "actor" + + +def _make_validation_trainer( + *, + colocated: bool = True, + weight_update_mode: str = "disk", +) -> Any: + trainer = cast(Any, PPOTrainer.__new__(PPOTrainer)) + trainer.allocation_mode = SimpleNamespace(gen_backend="sglang") + scheduling_strategy = SchedulingStrategy( + type=( + SchedulingStrategyType.colocation + if colocated + else SchedulingStrategyType.separation + ), + target="actor" if colocated else None, + ) + trainer._colocated = colocated + trainer.config = SimpleNamespace( + enable_offload=False, + actor=SimpleNamespace( + kl_ctl=0, + weight_update_mode=weight_update_mode, + scheduling_spec=[SimpleNamespace(env_vars={})], + ), + rollout=SimpleNamespace( + return_routed_experts=False, + openai=None, + scheduling_strategy=scheduling_strategy, + scheduling_spec=[SimpleNamespace(env_vars={})], + ), + critic=None, + ref=None, + teacher=None, + cluster=SimpleNamespace(n_nodes=1), + experiment_name="gsm8k-grpo-colocated", + trial_name="trial0", + ) + return trainer + + +class TestPPOTrainerColocatedScheduling: + def test_is_colocated_rollout_detects_actor_colocation(self): + rollout_cfg = SimpleNamespace( + scheduling_strategy=SchedulingStrategy( + type=SchedulingStrategyType.colocation, + target="actor", + ) + ) + + assert cast(Any, PPOTrainer)._is_colocated_rollout(rollout_cfg) is True + + def test_is_colocated_rollout_rejects_other_topologies(self): + rollout_cfg = SimpleNamespace( + scheduling_strategy=SchedulingStrategy( + type=SchedulingStrategyType.colocation, + target="critic", + ) + ) + + assert cast(Any, PPOTrainer)._is_colocated_rollout(rollout_cfg) is False + + def test_validate_cfg_allows_single_controller(self): + trainer = _make_validation_trainer() + + with patch("areal.trainer.rl_trainer.is_single_controller", return_value=True): + trainer._validate_cfg(train_dataset=object()) + + def test_validate_cfg_rejects_multi_node(self): + trainer = _make_validation_trainer() + trainer.config.cluster.n_nodes = 2 + + with pytest.raises(ValueError, match="single-node runs"): + trainer._validate_cfg(train_dataset=object()) + + def test_validate_cfg_rejects_non_disk_weight_update(self): + trainer = _make_validation_trainer(weight_update_mode="xccl") + + with pytest.raises(ValueError, match="weight_update_mode='disk'"): + trainer._validate_cfg(train_dataset=object()) + + def test_validate_cfg_rejects_missing_train_dataset(self): + trainer = _make_validation_trainer() + + with pytest.raises(ValueError, match="requires a train_dataset"): + trainer._validate_cfg(train_dataset=None) + + def test_validate_cfg_rejects_online_mode(self): + trainer = _make_validation_trainer() + trainer.config.rollout.openai = SimpleNamespace(mode="online") + + with pytest.raises(ValueError, match="rollout.openai.mode='online'"): + trainer._validate_cfg(train_dataset=object()) + + @pytest.mark.parametrize( + ("mutate", "expected_error"), + [ + ( + lambda trainer: setattr(trainer.config, "critic", object()), + "critic is not supported", + ), + ( + lambda trainer: setattr(trainer.config, "ref", object()), + "ref/kl_ctl is not supported", + ), + ( + lambda trainer: setattr(trainer.config.actor, "kl_ctl", 0.1), + "ref/kl_ctl is not supported", + ), + ( + lambda trainer: setattr(trainer.config, "teacher", object()), + "teacher is not supported", + ), + ], + ) + def test_validate_cfg_rejects_non_actor_only_components( + self, mutate, expected_error + ): + trainer = _make_validation_trainer() + mutate(trainer) + + with pytest.raises(ValueError, match=expected_error): + trainer._validate_cfg(train_dataset=object()) + + def test_validate_cfg_skips_colocated_restrictions_for_standard_mode(self): + trainer = _make_validation_trainer(colocated=False) + trainer.config.cluster.n_nodes = 4 + trainer.config.rollout.openai = SimpleNamespace(mode="online") + trainer.config.critic = object() + trainer.config.ref = object() + trainer.config.teacher = object() + + trainer._validate_cfg(train_dataset=None) + + def test_amend_xccl_weight_update_envvar_injects_tms_for_colocated_controller(self): + trainer = _make_validation_trainer() + trainer.allocation_mode = SimpleNamespace(gen_backend="vllm") + + with ( + patch("areal.trainer.rl_trainer.is_single_controller", return_value=True), + patch( + "areal.trainer.rl_trainer.get_tms_env_vars", + return_value={"LD_PRELOAD": "/tmp/libtms.so", "TMS_INIT_ENABLE": "1"}, + ), + ): + trainer._amend_xccl_weight_update_envvar() + + assert trainer.config.actor.scheduling_spec[0].env_vars["LD_PRELOAD"] == "/tmp/libtms.so" + assert "LD_PRELOAD" not in trainer.config.rollout.scheduling_spec[0].env_vars + + def test_publish_disk_weight_update_ready_uses_rollout_version(self): + trainer = _make_validation_trainer() + trainer.rollout = SimpleNamespace(get_version=MagicMock(return_value=0)) + meta = WeightUpdateMeta(type="disk", path="/tmp/weight_update_v7", version=7) + + with patch("areal.trainer.rl_trainer.name_resolve.add") as mock_add: + trainer._publish_disk_weight_update_ready(meta) + + mock_add.assert_called_once_with( + names.update_weights_from_disk( + "gsm8k-grpo-colocated", + "trial0", + 7, + ), + mock_add.call_args.args[1], + keepalive_ttl=120, + ) + + +class TestRemoteInfEngineDiskWeightSync: + def test_update_weights_from_disk_uses_engine_version_for_rendezvous(self): + engine = cast(Any, RemoteInfEngine.__new__(RemoteInfEngine)) + engine.backend = MagicMock() + engine.config = SimpleNamespace( + experiment_name="exp", + trial_name="trial", + request_retries=2, + request_timeout=30.0, + ) + engine.addresses = ["127.0.0.1:8000"] + engine.get_version = MagicMock(return_value=0) + engine.logger = MagicMock() + + meta = WeightUpdateMeta(type="disk", path="/tmp/weight_update_v5", version=5) + fake_future = MagicMock() + + with patch("areal.infra.remote_inf_engine.get_executor") as mock_get_executor: + mock_get_executor.return_value.submit.return_value = fake_future + engine.update_weights_from_disk(meta) + + submit_args = mock_get_executor.return_value.submit.call_args.args + assert submit_args[4] == 0 + fake_future.add_done_callback.assert_called_once()