Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,6 @@ def __post_init__(self):
"SAPO is not compatible with `use_decoupled_loss=True`. "
"Please set `actor.use_decoupled_loss=false` in your configuration."
)

super().__post_init__()


Expand Down
9 changes: 9 additions & 0 deletions areal/api/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
from collections.abc import Callable
from concurrent.futures import Future
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any

import torch
Expand Down Expand Up @@ -263,6 +264,10 @@ def prepare_batch(
"""
raise NotImplementedError()

def prepare_batch_context(self):
"""Return a context manager for rollout batch preparation."""
return nullcontext()

@abc.abstractmethod
def set_version(self, version: int):
"""Set the current weight version in the training engine.
Expand Down Expand Up @@ -680,6 +685,10 @@ def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]:
"""
raise NotImplementedError()

def sync_weights_from_disk(self, meta: WeightUpdateMeta) -> None:
"""Update weights from disk in a blocking manner."""
self.update_weights_from_disk(meta).result()

def set_version(self, version: int) -> None:
"""Set the current weight version in the inference engine.

Expand Down
1 change: 0 additions & 1 deletion areal/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def from_fsdp_xccl(
base_model_name=base_model_name,
)


@dataclass
class HttpRequest:
"""Represents an HTTP request to be sent to a remote inference server."""
Expand Down
7 changes: 7 additions & 0 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,13 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None:
def export_stats(self) -> dict[str, float]:
return stats_tracker.export_all(reduce_group=self.data_parallel_group)

def prepare_batch_context(self):
return (
torch_memory_saver.disable()
if self.is_offload and not torch.version.hip
else nullcontext()
)

def offload(self) -> None:
"""Offload model memory to CPU using torch_memory_saver.

Expand Down
7 changes: 7 additions & 0 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,13 @@ def export_stats(self) -> dict[str, float]:
data.update(data_list[0])
return data

def prepare_batch_context(self):
return (
torch_memory_saver.disable()
if self.is_offload and not torch.version.hip
else nullcontext()
)

def offload(self) -> None:
"""Offload model memory to CPU using torch_memory_saver.

Expand Down
7 changes: 7 additions & 0 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,13 @@ def export_stats(self) -> dict[str, float]:
data.update(data_list[0])
return data

def prepare_batch_context(self):
return (
torch_memory_saver.disable()
if self.is_offload and not torch.version.hip
else nullcontext()
)

def get_device_stats(self) -> DeviceRuntimeInfo:
return DeviceRuntimeInfo.get_current()

Expand Down
2 changes: 2 additions & 0 deletions areal/infra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Core components for AREAL."""

from . import workflow_context
from .colocated import ColocatedOrchestrator
from .controller import RolloutController, TrainController
from .launcher import (
LocalLauncher,
Expand All @@ -22,6 +23,7 @@
)

__all__ = [
"ColocatedOrchestrator",
"RemoteInfBackendProtocol",
"RemoteInfEngine",
"StalenessManager",
Expand Down
114 changes: 114 additions & 0 deletions areal/infra/colocated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Colocated (GPU time-sharing) orchestration for on-policy training.

In colocated mode, the training engine and inference engine share the same
GPUs and alternate between offloaded/onloaded states.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch.distributed as dist

from areal.api.io_struct import WeightUpdateMeta
from areal.utils import logging

if TYPE_CHECKING:
from areal.api import InferenceEngine, TrainEngine

logger = logging.getLogger("Colocated")


class ColocatedOrchestrator:
"""Orchestrate GPU ownership between colocated training and inference."""

def __init__(
self,
train_engine: TrainEngine,
inf_engine: InferenceEngine,
) -> None:
self._train_engine: TrainEngine = train_engine
self._inf_engine: InferenceEngine = inf_engine
self._train_on_gpu: bool = True
self._inf_on_gpu: bool = True

def _is_rollout_coordinator(self) -> bool:
return not dist.is_initialized() or dist.get_rank() == 0

def _barrier(self) -> None:
if not dist.is_initialized():
return
cpu_group = self._train_engine.cpu_group
if cpu_group is None:
dist.barrier()
return
dist.barrier(group=cpu_group)

def initial_offload_training(self) -> None:
"""Offload training once so inference owns the GPU before first rollout."""
if not self._train_on_gpu:
logger.warning(
"initial_offload_training called but training engine is already off GPU."
)
return

if self._is_rollout_coordinator():
logger.info("Initial offload: moving training engine off GPU")
self._train_engine.offload()
self._train_on_gpu = False

def prepare_for_training(self) -> None:
"""Switch GPU ownership from inference to training."""
if self._train_on_gpu:
logger.debug("Training engine already on GPU, skipping switch")
return

if self._is_rollout_coordinator():
logger.info("Switching to training mode")

# Pause local submission on every rank first so no new requests are queued.
self._inf_engine.pause()

# Only one coordinator should touch the shared rollout servers.
if self._is_rollout_coordinator():
self._inf_engine.pause_generation()
if self._inf_on_gpu:
self._inf_engine.offload()

self._barrier()
self._inf_on_gpu = False

# All training ranks must participate in the training-engine collective.
self._train_engine.onload()
self._train_on_gpu = True

def prepare_for_inference(self, meta: WeightUpdateMeta) -> None:
"""Switch GPU ownership from training to inference and sync weights."""
if self._inf_on_gpu:
logger.debug("Inference engine already on GPU, skipping switch")
return

if self._is_rollout_coordinator():
logger.info("Switching to inference mode")

if self._train_on_gpu:
self._train_engine.offload()
self._train_on_gpu = False

self._barrier()

if self._is_rollout_coordinator():
self._inf_engine.onload()

if meta.version is None:
raise ValueError("Colocated disk weight sync requires meta.version.")

# Colocated flow publishes the ready signal before trainer later calls
# rollout.set_version(new_version), so align the rollout-side wait key here.
self._inf_engine.set_version(meta.version)
self._inf_engine.sync_weights_from_disk(meta)
self._inf_engine.continue_generation()

self._barrier()
self._inf_engine.resume()
self._inf_on_gpu = True
69 changes: 60 additions & 9 deletions areal/infra/controller/rollout_controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import copy
import shutil
import threading
import traceback
Expand Down Expand Up @@ -32,6 +33,7 @@
InferenceEngineConfig,
PerfTracerConfig,
SchedulingSpec,
SchedulingStrategyType,
)
from areal.infra.rpc.serialization import deserialize_value
from areal.infra.utils.concurrent import run_async_task
Expand All @@ -46,6 +48,11 @@

logger = logging.getLogger("RolloutController")

_ROLLOUT_FORK_UNSET_ENV_KEYS = [
"LD_PRELOAD",
"TMS_INIT_ENABLE",
"TMS_INIT_ENABLE_CPU_BACKUP",
]

# NOTE: remote task input has a slightly different
# type annotation, which disallows workflow object or types
Expand Down Expand Up @@ -150,6 +157,25 @@ def _engine_name(self, rank: int) -> str:
"""
return f"{self._worker_role}/{rank}"

@staticmethod
def _build_create_worker_kwargs(job: Job) -> dict[str, Any]:
"""Build scheduler kwargs for rollout worker creation.

For colocated rollout->actor with fork=True, sanitize inherited actor-side
TMS env vars at the fork boundary so the rollout child and its SGLang server
do not inherit training-side memory-saver hooks.
"""
strategy = job.scheduling_strategy
if (
SchedulingStrategyType(strategy.type) == SchedulingStrategyType.colocation
and strategy.target == "actor"
and strategy.fork
):
return {
"fork_unset_env_keys": list(_ROLLOUT_FORK_UNSET_ENV_KEYS),
}
return {}

def initialize(
self,
role: str,
Expand Down Expand Up @@ -231,8 +257,16 @@ async def _async_initialize(
**kwargs,
):
# Create workers via scheduler
logger.info("Creating workers via scheduler...")
worker_ids = self.scheduler.create_workers(job=job)
create_worker_kwargs = self._build_create_worker_kwargs(job)
if create_worker_kwargs:
logger.info(
"Creating workers via scheduler with fork env cleanup: "
f"{create_worker_kwargs}"
)
else:
logger.info("Creating workers via scheduler...")

worker_ids = self.scheduler.create_workers(job=job, **create_worker_kwargs)
logger.info(f"Workers created: {worker_ids}")

# Wait for workers to be ready
Expand Down Expand Up @@ -565,12 +599,12 @@ def update_weights_disk():

@app.route("/callback/pause_generation", methods=["POST"])
def pause_generation():
self._callback_loop.run_until_complete(self.pause_generation())
self.pause_generation()
return jsonify({"status": "ok"})

@app.route("/callback/continue_generation", methods=["POST"])
def continue_generation():
self._callback_loop.run_until_complete(self.continue_generation())
self.continue_generation()
return jsonify({"status": "ok"})

@app.route("/callback/rollout_complete", methods=["POST"])
Expand Down Expand Up @@ -1018,16 +1052,27 @@ async def update_weights_from_distributed(
)

async def update_weights_from_disk(self, meta: WeightUpdateMeta):
meta.clear_checkpoint_after_load = False
await self._collective_rpc_async("update_weights_from_disk", meta=meta)
shutil.rmtree(meta.path, ignore_errors=True)
update_meta = copy.copy(meta)
update_meta.clear_checkpoint_after_load = False
await self._collective_rpc_async("update_weights_from_disk", meta=update_meta)
if meta.clear_checkpoint_after_load and meta.path is not None:
shutil.rmtree(meta.path, ignore_errors=True)

def sync_weights_from_disk(self, meta: WeightUpdateMeta) -> None:
run_async_task(self.update_weights_from_disk, meta)

async def pause_generation(self):
async def _pause_generation_async(self):
await self._collective_rpc_async("pause_generation")

async def continue_generation(self):
def pause_generation(self):
run_async_task(self._pause_generation_async)

async def _continue_generation_async(self):
await self._collective_rpc_async("continue_generation")

def continue_generation(self):
run_async_task(self._continue_generation_async)

def set_version(self, version: int) -> None:
with self._version_lock:
self._version = version
Expand All @@ -1041,6 +1086,12 @@ def get_version(self) -> int:
with self._version_lock:
return self._version

def offload(self) -> None:
self._collective_rpc("offload", http_timeout=60.0)

def onload(self, tags: list[str] | None = None) -> None:
self._collective_rpc("onload", tags=tags, http_timeout=60.0)

def pause(self):
self.dispatcher.pause()
self._collective_rpc("pause", http_timeout=60.0)
Expand Down
Loading
Loading