diff --git a/CHANGELOG.md b/CHANGELOG.md index ef731086..54d5079e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## v0.0.9 (2025-??-??) +### Changed + +* Separate the generic replaybuffer into a separate crate (`border-generic-replay-buffer`). + ## v0.0.8 (2025-05-17) ### Added diff --git a/Cargo.toml b/Cargo.toml index cc3696e1..c78dd7d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "border-core", + "border-generic-replay-buffer", "border-tensorboard", "border-mlflow-tracking", "border-py-gym-env", @@ -23,6 +24,7 @@ repository = "https://github.com/laboroai/border" keywords = ["reinforcement", "learning", "rl"] categories = ["science"] license = "MIT OR Apache-2.0" +readme = "README.md" [workspace.dependencies] clap = { version = "4.5.8", features = ["derive"] } diff --git a/README.md b/README.md index 1425e022..de03769e 100644 --- a/README.md +++ b/README.md @@ -11,18 +11,19 @@ Border consists of the following crates: * Core and utility * [border-core](https://crates.io/crates/border-core) ([doc](https://docs.rs/border-core/latest/border_core/)) provides basic traits and functions for environments and reinforcement learning (RL) agents. - * [border-tensorboard](https://crates.io/crates/border-tensorboard) ([doc](https://docs.rs/border-core/latest/border_tensorboard/)) implements the `TensorboardRecorder` struct for writing records that can be visualized in Tensorboard, based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). - * [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) ([doc](https://docs.rs/border-core/latest/border_mlflow_tracking/)) provides MLflow tracking support for logging metrics during training via REST API. - * [border-async-trainer](https://crates.io/crates/border-async-trainer) ([doc](https://docs.rs/border-core/latest/border_async_trainer/)) defines traits and functions for asynchronous training of RL agents using multiple actors. Each actor runs a sampling process in parallel, where an agent interacts with an environment to collect samples for a shared replay buffer. + * [border-generic-replay-buffer](https://crates.io/crates/border-generic-replay-buffer) ([doc](https://docs.rs/border-generic-replay-buffer/latest/border_generic_replay_buffer/)) provides a generic implementation of replay buffer. + * [border-tensorboard](https://crates.io/crates/border-tensorboard) ([doc](https://docs.rs/border-tensorboard/latest/border_tensorboard/)) implements the `TensorboardRecorder` struct for writing records that can be visualized in Tensorboard, based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). + * [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) ([doc](https://docs.rs/border-mlflow-tracking/latest/border_mlflow_tracking/)) provides MLflow tracking support for logging metrics during training via REST API. + * [border-async-trainer](https://crates.io/crates/border-async-trainer) ([doc](https://docs.rs/border-async-trainer/latest/border_async_trainer/)) defines traits and functions for asynchronous training of RL agents using multiple actors. Each actor runs a sampling process in parallel, where an agent interacts with an environment to collect samples for a shared replay buffer. * [border](https://crates.io/crates/border) serves as a collection of examples. * Environment - * [border-py-gym-env](https://crates.io/crates/border-py-gym-env) ([doc](https://docs.rs/border-core/latest/border_py_gym_env/)) provides a wrapper for [Gymnasium](https://gymnasium.farama.org) environments written in Python. - * [border-atari-env](https://crates.io/crates/border-atari-env) ([doc](https://docs.rs/border-core/latest/border_atari_env/)) implements a wrapper for [atari-env](https://crates.io/crates/atari-env), which is part of [gym-rs](https://crates.io/crates/gym-rs). - * [border-minari](https://crates.io/crates/border-minari) ([doc](https://docs.rs/border-core/latest/border_minari/)) provides a wrapper for [Minari](https://minari.farama.org). + * [border-py-gym-env](https://crates.io/crates/border-py-gym-env) ([doc](https://docs.rs/border-py-gym-env/latest/border_py_gym_env/)) provides a wrapper for [Gymnasium](https://gymnasium.farama.org) environments written in Python. + * [border-atari-env](https://crates.io/crates/border-atari-env) ([doc](https://docs.rs/border-atari-env/latest/border_atari_env/)) implements a wrapper for [atari-env](https://crates.io/crates/atari-env), which is part of [gym-rs](https://crates.io/crates/gym-rs). + * [border-minari](https://crates.io/crates/border-minari) ([doc](https://docs.rs/border-minari/latest/border_minari/)) provides a wrapper for [Minari](https://minari.farama.org). * Agent - * [border-tch-agent](https://crates.io/crates/border-tch-agent) ([doc](https://docs.rs/border-core/latest/border_tch_agent/)) implements RL agents based on [tch](https://crates.io/crates/tch), including Deep Q Network (DQN), Implicit Quantile Network (IQN), and Soft Actor-Critic (SAC). - * [border-candle-agent](https://crates.io/crates/border-candle-agent) ([doc](https://docs.rs/border-core/latest/border_candle_agent/)) implements RL agents based on [candle](https://crates.io/crates/candle-core). - * [border-policy-no-backend](https://crates.io/crates/border-policy-no-backend) ([doc](https://docs.rs/border-core/latest/border_policy_no_backend/)) implements policies that are independent of any deep learning backend, such as Torch. + * [border-tch-agent](https://crates.io/crates/border-tch-agent) ([doc](https://docs.rs/border-tch-agent/latest/border_tch_agent/)) implements RL agents based on [tch](https://crates.io/crates/tch), including Deep Q Network (DQN), Implicit Quantile Network (IQN), and Soft Actor-Critic (SAC). + * [border-candle-agent](https://crates.io/crates/border-candle-agent) ([doc](https://docs.rs/border-candle-agent/latest/border_candle_agent/)) implements RL agents based on [candle](https://crates.io/crates/candle-core). + * [border-policy-no-backend](https://crates.io/crates/border-policy-no-backend) ([doc](https://docs.rs/border-policy-no-backend/latest/border_policy_no_backend/)) implements policies that are independent of any deep learning backend, such as Torch. ## Status @@ -38,16 +39,17 @@ Docker configuration files for development and testing are available in the [dev ## License -Crates | License ---------------------------|------------------ -`border-core` | MIT OR Apache-2.0 -`border-tensorboard` | MIT OR Apache-2.0 -`border-mlflow-tracking` | MIT OR Apache-2.0 -`border-async-trainer` | MIT OR Apache-2.0 -`border-py-gym-env` | MIT OR Apache-2.0 -`border-atari-env` | GPL-2.0-or-later -`border-minari` | MIT OR Apache-2.0 -`border-tch-agent` | MIT OR Apache-2.0 -`border-candle-agent` | MIT OR Apache-2.0 -`border-policy-no-backend`| MIT OR Apache-2.0 -`border` | GPL-2.0-or-later +Crates | License +------------------------------|------------------ +`border-core` | MIT OR Apache-2.0 +`border-generic-replay-buffer`| MIT OR Apache-2.0 +`border-tensorboard` | MIT OR Apache-2.0 +`border-mlflow-tracking` | MIT OR Apache-2.0 +`border-async-trainer` | MIT OR Apache-2.0 +`border-py-gym-env` | MIT OR Apache-2.0 +`border-atari-env` | GPL-2.0-or-later +`border-minari` | MIT OR Apache-2.0 +`border-tch-agent` | MIT OR Apache-2.0 +`border-candle-agent` | MIT OR Apache-2.0 +`border-policy-no-backend` | MIT OR Apache-2.0 +`border` | GPL-2.0-or-later diff --git a/border-async-trainer/Cargo.toml b/border-async-trainer/Cargo.toml index edbcbbdc..8eb6386a 100644 --- a/border-async-trainer/Cargo.toml +++ b/border-async-trainer/Cargo.toml @@ -24,4 +24,5 @@ thiserror = { workspace = true } [dev-dependencies] env_logger = { workspace = true } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } test-log = "0.2.8" \ No newline at end of file diff --git a/border-async-trainer/README.md b/border-async-trainer/README.md index e69de29b..923d2d71 100644 --- a/border-async-trainer/README.md +++ b/border-async-trainer/README.md @@ -0,0 +1,160 @@ +Asynchronous trainer with parallel sampling processes. + +The code might look like below. + +``` +# use serde::{Deserialize, Serialize}; +# use border_generic_replay_buffer::test::{ +# TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch, +# TestAct, TestActBatch +# }; +# use border_core::Env as _; +# use border_async_trainer::{ +# //test::{TestAgent, TestAgentConfig, TestEnv}, +# ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, +# }; +# use border_generic_replay_buffer::{ +# GenericReplayBuffer, GenericReplayBufferConfig, +# SimpleStepProcessorConfig, SimpleStepProcessor +# }; +# use border_core::{ +# record::{Recorder, NullRecorder}, DefaultEvaluator, +# }; +# +# use std::path::{Path, PathBuf}; +# +# fn agent_config() -> TestAgentConfig { +# TestAgentConfig +# } +# +# fn env_config() -> usize { +# 0 +# } + +type Env = TestEnv; +type ObsBatch = TestObsBatch; +type ActBatch = TestActBatch; +type ReplayBuffer = GenericReplayBuffer; +type StepProcessor = SimpleStepProcessor; + +// Create a new agent by wrapping the existing agent in order to implement SyncModel. +struct TestAgent2(TestAgent); + +impl border_core::Configurable for TestAgent2 { + type Config = TestAgentConfig; + + fn build(config: Self::Config) -> Self { + Self(TestAgent::build(config)) + } +} + +impl border_core::Agent for TestAgent2 { + // Boilerplate code to delegate the method calls to the inner agent. + fn train(&mut self) { + self.0.train(); + } + + // For other methods ... +# fn is_train(&self) -> bool { +# self.0.is_train() +# } +# +# fn eval(&mut self) { +# self.0.eval(); +# } +# +# fn opt_with_record(&mut self, buffer: &mut ReplayBuffer) -> border_core::record::Record { +# self.0.opt_with_record(buffer) +# } +# +# fn save_params(&self, path: &Path) -> anyhow::Result> { +# self.0.save_params(path) +# } +# +# fn load_params(&mut self, path: &Path) -> anyhow::Result<()> { +# self.0.load_params(path) +# } +# +# fn opt(&mut self, buffer: &mut ReplayBuffer) { +# self.0.opt_with_record(buffer); +# } +# +# fn as_any_ref(&self) -> &dyn std::any::Any { +# self +# } +# +# fn as_any_mut(&mut self) -> &mut dyn std::any::Any { +# self +# } +} + +impl border_core::Policy for TestAgent2 { + // Boilerplate code to delegate the method calls to the inner agent. + // ... +# fn sample(&mut self, obs: &TestObs) -> TestAct { +# self.0.sample(obs) +# } +} + +impl border_async_trainer::SyncModel for TestAgent2{ + // Self::ModelInfo shold include the model parameters. + type ModelInfo = usize; + + + fn model_info(&self) -> (usize, Self::ModelInfo) { + // Extracts the model parameters and returns them as Self::ModelInfo. + // The first element of the tuple is the number of optimization steps. + (0, 0) + } + + fn sync_model(&mut self, _model_info: &Self::ModelInfo) { + // implements synchronization of the model based on the _model_info + } +} + +let agent_configs: Vec<_> = vec![agent_config()]; +let env_config_train = env_config(); +let env_config_eval = env_config(); +let replay_buffer_config = GenericReplayBufferConfig::default(); +let step_proc_config = SimpleStepProcessorConfig::default(); +let actor_man_config = ActorManagerConfig::default(); +let async_trainer_config = AsyncTrainerConfig::default(); +let mut recorder: Box> = Box::new(NullRecorder::new()); +let mut evaluator = DefaultEvaluator::::new(&env_config_eval, 0, 1).unwrap(); + +border_async_trainer::util::train_async::( + &agent_config(), + &agent_configs, + &env_config_train, + &env_config_eval, + &step_proc_config, + &replay_buffer_config, + &actor_man_config, + &async_trainer_config, + &mut recorder, + &mut evaluator, +); +``` + +Training process consists of the following two components: + +* [`ActorManager`] manages [`Actor`]s, each of which runs a thread for interacting + [`Agent`] and [`Env`] and taking samples. Those samples will be sent to + the replay buffer in [`AsyncTrainer`]. +* [`AsyncTrainer`] is responsible for training of an agent. It also runs a thread + for pushing samples from [`ActorManager`] into a replay buffer. + +The `Agent` must implement [`SyncModel`] trait in order to synchronize the model of +the agent in [`Actor`] with the trained agent in [`AsyncTrainer`]. The trait has +the ability to import and export the information of the model as +[`SyncModel`]`::ModelInfo`. + +The `Agent` in [`AsyncTrainer`] is responsible for training, typically with a GPU, +while the `Agent`s in [`Actor`]s in [`ActorManager`] is responsible for sampling +using CPU. + +Both [`AsyncTrainer`] and [`ActorManager`] are running in the same machine and +communicate by channels. + +[`Agent`]: border_core::Agent +[`Env`]: border_core::Env \ No newline at end of file diff --git a/border-async-trainer/src/actor/base.rs b/border-async-trainer/src/actor/base.rs index cae4412a..8c2fc432 100644 --- a/border-async-trainer/src/actor/base.rs +++ b/border-async-trainer/src/actor/base.rs @@ -1,6 +1,6 @@ use crate::{ActorStat, PushedItemMessage, ReplayBufferProxy, ReplayBufferProxyConfig, SyncModel}; use border_core::{ - Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, Sampler, StepProcessor, + Agent, Configurable, Env, ExperienceBuffer, ReplayBuffer, Sampler, StepProcessor, }; use crossbeam_channel::Sender; use log::{debug, info}; @@ -21,7 +21,7 @@ use std::{ /// B-->|Env::Obs|A /// B-->|Step<E: Env>|C[StepProcessor] /// end -/// C-->|ReplayBufferBase::PushedItem|F[ReplayBufferProxy] +/// C-->|ReplayBuffer::PushedItem|F[ReplayBufferProxy] /// ``` /// /// In [`Actor`], an [`Agent`] runs on an [`Env`] and generates [`Step`] objects. @@ -41,7 +41,7 @@ where A: Agent + Configurable + SyncModel + 'static, E: Env, P: StepProcessor, - R: ExperienceBufferBase + ReplayBufferBase, + R: ExperienceBuffer + ReplayBuffer, { /// Stops sampling process if this field is set to `true`. id: usize, @@ -60,7 +60,7 @@ where A: Agent + Configurable + SyncModel + 'static, E: Env, P: StepProcessor, - R: ExperienceBufferBase + ReplayBufferBase, + R: ExperienceBuffer + ReplayBuffer, { pub fn build( id: usize, diff --git a/border-async-trainer/src/actor_manager/base.rs b/border-async-trainer/src/actor_manager/base.rs index 7e6ef59c..f1d6ad20 100644 --- a/border-async-trainer/src/actor_manager/base.rs +++ b/border-async-trainer/src/actor_manager/base.rs @@ -2,7 +2,7 @@ use crate::{ Actor, ActorManagerConfig, ActorStat, PushedItemMessage, ReplayBufferProxyConfig, SyncModel, }; use border_core::{ - Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor, + Agent, Configurable, Env, ExperienceBuffer, ReplayBuffer, StepProcessor, }; use crossbeam_channel::{bounded, /*unbounded,*/ Receiver, Sender}; use log::info; @@ -25,7 +25,7 @@ where A: Agent + Configurable + SyncModel, E: Env, P: StepProcessor, - R: ExperienceBufferBase + ReplayBufferBase, + R: ExperienceBuffer + ReplayBuffer, { /// Configurations of [`Agent`]s. agent_configs: Vec, @@ -72,7 +72,7 @@ where A: Agent + Configurable + SyncModel + 'static, E: Env, P: StepProcessor, - R: ExperienceBufferBase + Send + 'static + ReplayBufferBase, + R: ExperienceBuffer + Send + 'static + ReplayBuffer, A::Config: Send + 'static, E::Config: Send + 'static, P::Config: Send + 'static, diff --git a/border-async-trainer/src/async_trainer/base.rs b/border-async-trainer/src/async_trainer/base.rs index 49f5ef89..b0e15deb 100644 --- a/border-async-trainer/src/async_trainer/base.rs +++ b/border-async-trainer/src/async_trainer/base.rs @@ -2,7 +2,7 @@ use crate::{AsyncTrainStat, AsyncTrainerConfig, PushedItemMessage, SyncModel}; use anyhow::Result; use border_core::{ record::{Record, RecordValue::Scalar, Recorder}, - Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, + Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, }; use crossbeam_channel::{Receiver, Sender}; use log::{debug, info}; @@ -21,7 +21,7 @@ use std::{ /// ```mermaid /// flowchart LR /// subgraph ActorManager -/// E[Actor]-->|ReplayBufferBase::PushedItem|H[ReplayBufferProxy] +/// E[Actor]-->|ReplayBuffer::PushedItem|H[ReplayBufferProxy] /// F[Actor]-->H /// G[Actor]-->H /// end @@ -31,36 +31,36 @@ use std::{ /// /// subgraph I[AsyncTrainer] /// H-->|PushedItemMessage|J[ReplayBuffer] -/// J-->|ReplayBufferBase::Batch|K[Agent] +/// J-->|ReplayBuffer::Batch|K[Agent] /// end /// ``` /// /// * The [`Agent`] in [`AsyncTrainer`] (left) is trained with batches -/// of type [`ReplayBufferBase::Batch`], which are taken from the replay buffer. +/// of type [`ReplayBuffer::Batch`], which are taken from the replay buffer. /// * The model parameters of the [`Agent`] in [`AsyncTrainer`] are wrapped in /// [`SyncModel::ModelInfo`] and periodically sent to the [`Agent`]s in [`Actor`]s. /// [`Agent`] must implement [`SyncModel`] to synchronize the model parameters. /// * In [`ActorManager`] (right), [`Actor`]s sample transitions, which have type -/// [`ReplayBufferBase::Item`], and push the transitions into +/// [`ReplayBuffer::Item`], and push the transitions into /// [`ReplayBufferProxy`]. -/// * [`ReplayBufferProxy`] has a type parameter of [`ReplayBufferBase`] and the proxy accepts -/// [`ReplayBufferBase::Item`]. +/// * [`ReplayBufferProxy`] has a type parameter of [`ReplayBuffer`] and the proxy accepts +/// [`ReplayBuffer::Item`]. /// * The proxy sends the transitions into the replay buffer in the [`AsyncTrainer`]. /// /// [`ActorManager`]: crate::ActorManager /// [`Actor`]: crate::Actor -/// [`ReplayBufferBase::Item`]: border_core::ReplayBufferBase::PushedItem -/// [`ReplayBufferBase::Batch`]: border_core::ReplayBufferBase::PushedBatch +/// [`ReplayBuffer::Item`]: border_core::ReplayBuffer::PushedItem +/// [`ReplayBuffer::Batch`]: border_core::ReplayBuffer::PushedBatch /// [`ReplayBufferProxy`]: crate::ReplayBufferProxy -/// [`ReplayBufferBase`]: border_core::ReplayBufferBase +/// [`ReplayBuffer`]: border_core::ReplayBuffer /// [`SyncModel::ModelInfo`]: crate::SyncModel::ModelInfo /// [`Agent`]: border_core::Agent pub struct AsyncTrainer where A: Agent + Configurable + SyncModel, E: Env, - // R: ReplayBufferBase + Sync + Send + 'static, - R: ExperienceBufferBase + ReplayBufferBase, + // R: ReplayBuffer + Sync + Send + 'static, + R: ExperienceBuffer + ReplayBuffer, R::Item: Send + 'static, { /// Configuration of [`Env`]. Note that it is used only for evaluation, not for training. @@ -130,8 +130,8 @@ impl AsyncTrainer where A: Agent + Configurable + SyncModel + 'static, E: Env, - // R: ReplayBufferBase + Sync + Send + 'static, - R: ExperienceBufferBase + ReplayBufferBase, + // R: ReplayBuffer + Sync + Send + 'static, + R: ExperienceBuffer + ReplayBuffer, R::Item: Send + 'static, { /// Creates [`AsyncTrainer`]. @@ -231,7 +231,7 @@ where ) -> Result<()> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, D: Evaluator, { // Evaluation @@ -288,14 +288,14 @@ where /// In the training loop, the following values will be pushed into the given recorder: /// /// * `samples_total` - Total number of samples pushed into the replay buffer. - /// Here, a "sample" is an item in [`ExperienceBufferBase::Item`]. + /// Here, a "sample" is an item in [`ExperienceBuffer::Item`]. /// * `opt_steps_per_sec` - The number of optimization steps per second. /// * `samples_per_sec` - The number of samples per second. /// * `samples_per_opt_steps` - The number of samples per optimization step. /// /// These values will typically be monitored with tensorboard. /// - /// [`ExperienceBufferBase::Item`]: border_core::ExperienceBufferBase::Item + /// [`ExperienceBuffer::Item`]: border_core::ExperienceBuffer::Item pub fn train( &mut self, recorder: &mut Box>, diff --git a/border-async-trainer/src/lib.rs b/border-async-trainer/src/lib.rs index 52a8a7eb..e7ebdeaf 100644 --- a/border-async-trainer/src/lib.rs +++ b/border-async-trainer/src/lib.rs @@ -1,163 +1,4 @@ -//! Asynchronous trainer with parallel sampling processes. -//! -//! The code might look like below. -//! -//! ``` -//! # use serde::{Deserialize, Serialize}; -//! # use border_core::test::{ -//! # TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch, -//! # TestAct, TestActBatch -//! # }; -//! # use border_core::Env as _; -//! # use border_async_trainer::{ -//! # //test::{TestAgent, TestAgentConfig, TestEnv}, -//! # ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, -//! # }; -//! # use border_core::{ -//! # generic_replay_buffer::{ -//! # SimpleReplayBuffer, SimpleReplayBufferConfig, -//! # SimpleStepProcessorConfig, SimpleStepProcessor -//! # }, -//! # record::{Recorder, NullRecorder}, DefaultEvaluator, -//! # }; -//! # -//! # use std::path::{Path, PathBuf}; -//! # -//! # fn agent_config() -> TestAgentConfig { -//! # TestAgentConfig -//! # } -//! # -//! # fn env_config() -> usize { -//! # 0 -//! # } -//! -//! type Env = TestEnv; -//! type ObsBatch = TestObsBatch; -//! type ActBatch = TestActBatch; -//! type ReplayBuffer = SimpleReplayBuffer; -//! type StepProcessor = SimpleStepProcessor; -//! -//! // Create a new agent by wrapping the existing agent in order to implement SyncModel. -//! struct TestAgent2(TestAgent); -//! -//! impl border_core::Configurable for TestAgent2 { -//! type Config = TestAgentConfig; -//! -//! fn build(config: Self::Config) -> Self { -//! Self(TestAgent::build(config)) -//! } -//! } -//! -//! impl border_core::Agent for TestAgent2 { -//! // Boilerplate code to delegate the method calls to the inner agent. -//! fn train(&mut self) { -//! self.0.train(); -//! } -//! -//! // For other methods ... -//! # fn is_train(&self) -> bool { -//! # self.0.is_train() -//! # } -//! # -//! # fn eval(&mut self) { -//! # self.0.eval(); -//! # } -//! # -//! # fn opt_with_record(&mut self, buffer: &mut ReplayBuffer) -> border_core::record::Record { -//! # self.0.opt_with_record(buffer) -//! # } -//! # -//! # fn save_params(&self, path: &Path) -> anyhow::Result> { -//! # self.0.save_params(path) -//! # } -//! # -//! # fn load_params(&mut self, path: &Path) -> anyhow::Result<()> { -//! # self.0.load_params(path) -//! # } -//! # -//! # fn opt(&mut self, buffer: &mut ReplayBuffer) { -//! # self.0.opt_with_record(buffer); -//! # } -//! # -//! # fn as_any_ref(&self) -> &dyn std::any::Any { -//! # self -//! # } -//! # -//! # fn as_any_mut(&mut self) -> &mut dyn std::any::Any { -//! # self -//! # } -//! } -//! -//! impl border_core::Policy for TestAgent2 { -//! // Boilerplate code to delegate the method calls to the inner agent. -//! // ... -//! # fn sample(&mut self, obs: &TestObs) -> TestAct { -//! # self.0.sample(obs) -//! # } -//! } -//! -//! impl border_async_trainer::SyncModel for TestAgent2{ -//! // Self::ModelInfo shold include the model parameters. -//! type ModelInfo = usize; -//! -//! -//! fn model_info(&self) -> (usize, Self::ModelInfo) { -//! // Extracts the model parameters and returns them as Self::ModelInfo. -//! // The first element of the tuple is the number of optimization steps. -//! (0, 0) -//! } -//! -//! fn sync_model(&mut self, _model_info: &Self::ModelInfo) { -//! // implements synchronization of the model based on the _model_info -//! } -//! } -//! -//! let agent_configs: Vec<_> = vec![agent_config()]; -//! let env_config_train = env_config(); -//! let env_config_eval = env_config(); -//! let replay_buffer_config = SimpleReplayBufferConfig::default(); -//! let step_proc_config = SimpleStepProcessorConfig::default(); -//! let actor_man_config = ActorManagerConfig::default(); -//! let async_trainer_config = AsyncTrainerConfig::default(); -//! let mut recorder: Box> = Box::new(NullRecorder::new()); -//! let mut evaluator = DefaultEvaluator::::new(&env_config_eval, 0, 1).unwrap(); -//! -//! border_async_trainer::util::train_async::( -//! &agent_config(), -//! &agent_configs, -//! &env_config_train, -//! &env_config_eval, -//! &step_proc_config, -//! &replay_buffer_config, -//! &actor_man_config, -//! &async_trainer_config, -//! &mut recorder, -//! &mut evaluator, -//! ); -//! ``` -//! -//! Training process consists of the following two components: -//! -//! * [`ActorManager`] manages [`Actor`]s, each of which runs a thread for interacting -//! [`Agent`] and [`Env`] and taking samples. Those samples will be sent to -//! the replay buffer in [`AsyncTrainer`]. -//! * [`AsyncTrainer`] is responsible for training of an agent. It also runs a thread -//! for pushing samples from [`ActorManager`] into a replay buffer. -//! -//! The `Agent` must implement [`SyncModel`] trait in order to synchronize the model of -//! the agent in [`Actor`] with the trained agent in [`AsyncTrainer`]. The trait has -//! the ability to import and export the information of the model as -//! [`SyncModel`]`::ModelInfo`. -//! -//! The `Agent` in [`AsyncTrainer`] is responsible for training, typically with a GPU, -//! while the `Agent`s in [`Actor`]s in [`ActorManager`] is responsible for sampling -//! using CPU. -//! -//! Both [`AsyncTrainer`] and [`ActorManager`] are running in the same machine and -//! communicate by channels. -//! -//! [`Agent`]: border_core::Agent -//! [`Env`]: border_core::Env +#![doc = include_str!("../README.md")] mod actor; mod actor_manager; mod async_trainer; @@ -198,7 +39,7 @@ pub mod test { obs: Vec, } - impl border_core::generic_replay_buffer::BatchBase for TestObsBatch { + impl border_generic_replay_buffer::BatchBase for TestObsBatch { fn new(capacity: usize) -> Self { Self { obs: vec![0; capacity], @@ -240,7 +81,7 @@ pub mod test { } } - impl border_core::generic_replay_buffer::BatchBase for TestActBatch { + impl border_generic_replay_buffer::BatchBase for TestActBatch { fn new(capacity: usize) -> Self { Self { act: vec![0; capacity], @@ -337,7 +178,7 @@ pub mod test { } type ReplayBuffer = - border_core::generic_replay_buffer::SimpleReplayBuffer; + border_generic_replay_buffer::GenericReplayBuffer; /// Agent for testing. pub struct TestAgent {} diff --git a/border-async-trainer/src/messages.rs b/border-async-trainer/src/messages.rs index 32070a91..dd26a6ed 100644 --- a/border-async-trainer/src/messages.rs +++ b/border-async-trainer/src/messages.rs @@ -1,4 +1,4 @@ -/// Message containing a [`ReplayBufferBase`](border_core::ReplayBufferBase)`::Item`. +/// Message containing a [`ReplayBuffer`](border_core::ReplayBuffer)`::Item`. /// /// It will be sent from [`Actor`](crate::Actor) to [`ActorManager`](crate::ActorManager). pub struct PushedItemMessage { diff --git a/border-async-trainer/src/replay_buffer_proxy.rs b/border-async-trainer/src/replay_buffer_proxy.rs index 263c5beb..1a34758e 100644 --- a/border-async-trainer/src/replay_buffer_proxy.rs +++ b/border-async-trainer/src/replay_buffer_proxy.rs @@ -1,6 +1,6 @@ use crate::PushedItemMessage; use anyhow::Result; -use border_core::{ExperienceBufferBase, ReplayBufferBase}; +use border_core::{ExperienceBuffer, ReplayBuffer}; use crossbeam_channel::Sender; use std::marker::PhantomData; @@ -14,7 +14,7 @@ pub struct ReplayBufferProxyConfig { } /// A wrapper of replay buffer for asynchronous trainer. -pub struct ReplayBufferProxy { +pub struct ReplayBufferProxy { id: usize, /// Sender of [PushedItemMessage]. @@ -29,7 +29,7 @@ pub struct ReplayBufferProxy { phantom: PhantomData, } -impl ReplayBufferProxy { +impl ReplayBufferProxy { pub fn build_with_sender( id: usize, config: &ReplayBufferProxyConfig, @@ -46,7 +46,7 @@ impl ReplayBufferProxy { } } -impl ExperienceBufferBase for ReplayBufferProxy { +impl ExperienceBuffer for ReplayBufferProxy { type Item = R::Item; fn push(&mut self, tr: Self::Item) -> Result<()> { @@ -76,7 +76,7 @@ impl ExperienceBufferBase for ReplayBufferProxy { } } -impl ReplayBufferBase for ReplayBufferProxy { +impl ReplayBuffer for ReplayBufferProxy { type Config = ReplayBufferProxyConfig; type Batch = R::Batch; diff --git a/border-async-trainer/src/util.rs b/border-async-trainer/src/util.rs index eff40041..5c4a6d6e 100644 --- a/border-async-trainer/src/util.rs +++ b/border-async-trainer/src/util.rs @@ -3,7 +3,7 @@ use crate::{ actor_stats_fmt, ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel, }; use border_core::{ - record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, + record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, StepProcessor, }; use crossbeam_channel::unbounded; @@ -42,7 +42,7 @@ pub fn train_async( ) where A: Agent + Configurable + SyncModel + 'static, E: Env, - R: ExperienceBufferBase + Send + 'static + ReplayBufferBase, + R: ExperienceBuffer + Send + 'static + ReplayBuffer, S: StepProcessor, A::Config: Send + 'static, E::Config: Send + 'static, diff --git a/border-atari-env/Cargo.toml b/border-atari-env/Cargo.toml index 6b1a1273..26ea4e66 100644 --- a/border-atari-env/Cargo.toml +++ b/border-atari-env/Cargo.toml @@ -18,6 +18,7 @@ border-core = { version = "0.0.9", path = "../border-core" } image = { workspace = true } tch = { workspace = true, optional = true } border-tch-agent = { version = "0.0.9", path = "../border-tch-agent", optional = true } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } candle-core = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } itertools = "0.10.1" diff --git a/border-atari-env/README.md b/border-atari-env/README.md index e69de29b..70e864a4 100644 --- a/border-atari-env/README.md +++ b/border-atari-env/README.md @@ -0,0 +1,86 @@ +A thin wrapper of [`atari-env`](https://crates.io/crates/atari-env) for [`Border`](https://crates.io/crates/border). + +The code under [atari_env] is adapted from the +[`atari-env`](https://crates.io/crates/atari-env) crate +(rev = `0ef0422f953d79e96b32ad14284c9600bd34f335`), +because the crate registered in crates.io does not implement +[`atari_env::AtariEnv::lives()`] method, which is required for episodic life environments. + +This environment applies some preprocessing to observation as in +[`atari_wrapper.py`](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py). + +You need to place Atari Rom directories under the directory specified by environment variable +`ATARI_ROM_DIR`. An easy way to do this is to use [AutoROM](https://pypi.org/project/AutoROM/) +Python package. + +```bash +pip install autorom +mkdir $HOME/atari_rom +AutoROM --install-dir $HOME/atari_rom +export ATARI_ROM_DIR=$HOME/atari_rom +``` + +Here is an example of running Pong environment with a random policy. + +```no_run +use anyhow::Result; +use border_atari_env::{ + BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, + BorderAtariObs, BorderAtariObsRawFilter, +}; +use border_core::{Env as _, Policy, DefaultEvaluator, Evaluator as _, NullReplayBuffer, Agent}; + +# type Obs = BorderAtariObs; +# type Act = BorderAtariAct; +# type ObsFilter = BorderAtariObsRawFilter; +# type ActFilter = BorderAtariActRawFilter; +# type EnvConfig = BorderAtariEnvConfig; +# type Env = BorderAtariEnv; +# +# #[derive(Clone)] +# struct RandomPolicyConfig { +# pub n_acts: usize, +# } +# +# struct RandomPolicy { +# n_acts: usize, +# } +# +# impl RandomPolicy { +# pub fn build(n_acts: usize) -> Self { +# Self { n_acts } +# } +# } +# +# impl Policy for RandomPolicy { +# fn sample(&mut self, _: &Obs) -> Act { +# fastrand::u8(..self.n_acts as u8).into() +# } +# } +# +# impl Agent for RandomPolicy {} +# +# fn env_config(name: String) -> EnvConfig { +# EnvConfig::default().name(name) +# } +# +fn main() -> Result<()> { +# env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); +# fastrand::seed(42); +# + // Creates Pong environment + let env_config = env_config("pong".to_string()); + + // Creates a random policy + let n_acts = 4; + let mut policy = Box::new(RandomPolicy::build(n_acts)) as _; + + // Runs evaluation + let env_config = env_config.render(true); + let _ = DefaultEvaluator::new(&env_config, 42, 5)? + .evaluate(&mut policy); + + Ok(()) +} +``` +[`atari_env::AtariEnv::lives()`]: atari_env::AtariEnv::lives diff --git a/border-atari-env/src/lib.rs b/border-atari-env/src/lib.rs index 7a71c57d..6fbbaa92 100644 --- a/border-atari-env/src/lib.rs +++ b/border-atari-env/src/lib.rs @@ -1,89 +1,4 @@ -//! A thin wrapper of [`atari-env`](https://crates.io/crates/atari-env) for [`Border`](https://crates.io/crates/border). -//! -//! The code under [atari_env] is adapted from the -//! [`atari-env`](https://crates.io/crates/atari-env) crate -//! (rev = `0ef0422f953d79e96b32ad14284c9600bd34f335`), -//! because the crate registered in crates.io does not implement -//! [`atari_env::AtariEnv::lives()`] method, which is required for episodic life environments. -//! -//! This environment applies some preprocessing to observation as in -//! [`atari_wrapper.py`](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py). -//! -//! You need to place Atari Rom directories under the directory specified by environment variable -//! `ATARI_ROM_DIR`. An easy way to do this is to use [AutoROM](https://pypi.org/project/AutoROM/) -//! Python package. -//! -//! ```bash -//! pip install autorom -//! mkdir $HOME/atari_rom -//! AutoROM --install-dir $HOME/atari_rom -//! export ATARI_ROM_DIR=$HOME/atari_rom -//! ``` -//! -//! Here is an example of running Pong environment with a random policy. -//! -//! ```no_run -//! use anyhow::Result; -//! use border_atari_env::{ -//! BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, -//! BorderAtariObs, BorderAtariObsRawFilter, -//! }; -//! use border_core::{Env as _, Policy, DefaultEvaluator, Evaluator as _, NullReplayBuffer, Agent}; -//! -//! # type Obs = BorderAtariObs; -//! # type Act = BorderAtariAct; -//! # type ObsFilter = BorderAtariObsRawFilter; -//! # type ActFilter = BorderAtariActRawFilter; -//! # type EnvConfig = BorderAtariEnvConfig; -//! # type Env = BorderAtariEnv; -//! # -//! # #[derive(Clone)] -//! # struct RandomPolicyConfig { -//! # pub n_acts: usize, -//! # } -//! # -//! # struct RandomPolicy { -//! # n_acts: usize, -//! # } -//! # -//! # impl RandomPolicy { -//! # pub fn build(n_acts: usize) -> Self { -//! # Self { n_acts } -//! # } -//! # } -//! # -//! # impl Policy for RandomPolicy { -//! # fn sample(&mut self, _: &Obs) -> Act { -//! # fastrand::u8(..self.n_acts as u8).into() -//! # } -//! # } -//! # -//! # impl Agent for RandomPolicy {} -//! # -//! # fn env_config(name: String) -> EnvConfig { -//! # EnvConfig::default().name(name) -//! # } -//! # -//! fn main() -> Result<()> { -//! # env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); -//! # fastrand::seed(42); -//! # -//! // Creates Pong environment -//! let env_config = env_config("pong".to_string()); -//! -//! // Creates a random policy -//! let n_acts = 4; -//! let mut policy = Box::new(RandomPolicy::build(n_acts)) as _; -//! -//! // Runs evaluation -//! let env_config = env_config.render(true); -//! let _ = DefaultEvaluator::new(&env_config, 42, 5)? -//! .evaluate(&mut policy); -//! -//! Ok(()) -//! } -//! ``` -//! [`atari_env::AtariEnv::lives()`]: atari_env::AtariEnv::lives +#![doc = include_str!("../README.md")] mod act; pub mod atari_env; mod env; diff --git a/border-atari-env/src/util/test.rs b/border-atari-env/src/util/test.rs index 350d9af0..e8b7d5c7 100644 --- a/border-atari-env/src/util/test.rs +++ b/border-atari-env/src/util/test.rs @@ -5,10 +5,10 @@ use crate::{ }; use anyhow::Result; use border_core::{ - generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, record::Record, - Agent as Agent_, Configurable, Policy, ReplayBufferBase, + Agent as Agent_, Configurable, Policy, ReplayBuffer as ReplayBuffer_, }; +use border_generic_replay_buffer::{BatchBase, GenericReplayBuffer}; use serde::Deserialize; use std::ptr::copy; @@ -17,13 +17,13 @@ pub type Act = BorderAtariAct; pub type ObsFilter = BorderAtariObsRawFilter; pub type ActFilter = BorderAtariActRawFilter; pub type EnvConfig = BorderAtariEnvConfig; -pub type ReplayBuffer = SimpleReplayBuffer; +pub type ReplayBuffer = GenericReplayBuffer; pub type Env = BorderAtariEnv; pub type Agent = RandomAgent; const FRAME_IN_BYTES: usize = 84 * 84; -/// Consists the observation part of a batch in [SimpleReplayBuffer]. +/// Consists the observation part of a batch in [`GenericReplayBuffer`]. pub struct ObsBatch { /// The number of samples in the batch. pub n: usize, @@ -78,7 +78,7 @@ impl From for ObsBatch { } } -/// Consists the action part of a batch in [SimpleReplayBuffer]. +/// Consists the action part of a batch in [`GenericReplayBuffer`]. pub struct ActBatch { /// The number of samples in the batch. pub n: usize, @@ -164,7 +164,7 @@ impl Configurable for RandomAgent { } } -impl Agent_ for RandomAgent { +impl Agent_ for RandomAgent { fn train(&mut self) { self.train = true; } diff --git a/border-candle-agent/Cargo.toml b/border-candle-agent/Cargo.toml index f1a2306a..4a09755f 100644 --- a/border-candle-agent/Cargo.toml +++ b/border-candle-agent/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } border-async-trainer = { version = "0.0.9", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } diff --git a/border-candle-agent/src/awac/base.rs b/border-candle-agent/src/awac/base.rs index 23dc7241..564938ca 100644 --- a/border-candle-agent/src/awac/base.rs +++ b/border-candle-agent/src/awac/base.rs @@ -9,7 +9,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{Device, Tensor, D}; use candle_nn::{loss::mse, ops::softmax}; @@ -53,7 +53,7 @@ where E: Env, Q: SubModel2, P: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into + Into, E::Act: Into + Into, Q::Input2: From + Into, @@ -282,7 +282,7 @@ where E: Env + 'static, Q: SubModel2 + 'static, P: SubModel1 + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into + Into, E::Act: Into + Into + From, Q::Input2: From + Into, diff --git a/border-candle-agent/src/bc/base.rs b/border-candle-agent/src/bc/base.rs index 93f60b6d..3060c8ce 100644 --- a/border-candle-agent/src/bc/base.rs +++ b/border-candle-agent/src/bc/base.rs @@ -4,7 +4,7 @@ use crate::{model::SubModel1, util::OutDim}; use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{shape::D, DType, Device, Tensor}; use candle_nn::loss::mse; @@ -92,7 +92,7 @@ impl Agent for Bc where E: Env, P: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into, E::Act: From, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, @@ -157,7 +157,7 @@ impl Bc where E: Env, P: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, R::Batch: TransitionBatch, ::ObsBatch: Into, diff --git a/border-candle-agent/src/dqn/base.rs b/border-candle-agent/src/dqn/base.rs index 894a061e..9186c0c2 100644 --- a/border-candle-agent/src/dqn/base.rs +++ b/border-candle-agent/src/dqn/base.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{shape::D, DType, Device, Tensor}; use candle_nn::loss::mse; @@ -50,7 +50,7 @@ impl Dqn where E: Env, Q: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, R::Batch: TransitionBatch, ::ObsBatch: Into, @@ -280,7 +280,7 @@ impl Agent for Dqn where E: Env + 'static, Q: SubModel1 + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, @@ -367,7 +367,7 @@ impl SyncModel for Dqn where E: Env, Q: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, diff --git a/border-candle-agent/src/iql/base.rs b/border-candle-agent/src/iql/base.rs index 0738a3ab..85b13eff 100644 --- a/border-candle-agent/src/iql/base.rs +++ b/border-candle-agent/src/iql/base.rs @@ -9,7 +9,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{Device, Tensor, D}; use candle_nn::{loss::mse, ops::softmax}; @@ -59,7 +59,7 @@ where Q: SubModel2, P: SubModel1, V: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into, E::Act: Into, A: Clone, @@ -259,7 +259,7 @@ where Q: SubModel2 + 'static, P: SubModel1 + 'static, V: SubModel1 + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into + Into, E::Act: Into + From, O: 'static, diff --git a/border-candle-agent/src/sac/base.rs b/border-candle-agent/src/sac/base.rs index 214f7190..51b2be61 100644 --- a/border-candle-agent/src/sac/base.rs +++ b/border-candle-agent/src/sac/base.rs @@ -9,7 +9,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{Device, Tensor, D}; use candle_nn::loss::mse; @@ -50,7 +50,7 @@ where E: Env, Q: SubModel2, P: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into + Into, E::Act: Into, Q::Input2: From, @@ -215,7 +215,7 @@ where E: Env + 'static, Q: SubModel2 + 'static, P: SubModel1 + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into + Into, E::Act: Into + From, Q::Input2: From, diff --git a/border-candle-agent/src/tensor_batch.rs b/border-candle-agent/src/tensor_batch.rs index ef17e183..c5038038 100644 --- a/border-candle-agent/src/tensor_batch.rs +++ b/border-candle-agent/src/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{error::Result, DType, Device, IndexOp, Tensor}; /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-core/README.md b/border-core/README.md index e69de29b..c010bb21 100644 --- a/border-core/README.md +++ b/border-core/README.md @@ -0,0 +1,71 @@ +Core components for reinforcement learning. + +# Observation and Action + +The [`Obs`] and [`Act`] traits provide abstractions for observations and actions in environments. + +# Environment + +The [`Env`] trait serves as the fundamental abstraction for environments. It defines four associated types: +`Config`, `Obs`, `Act`, and `Info`. The `Obs` and `Act` types represent concrete implementations of +environment observations and actions, respectively. These types must implement the [`Obs`] and [`Act`] traits. +Environments implementing [`Env`] generate [`Step`] objects at each interaction step through the +[`Env::step()`] method. The [`Info`] type stores additional information from each agent-environment interaction, +which may be empty (implemented as a zero-sized struct). The `Config` type represents environment configurations +and is used during environment construction. + +# Policy + +The [`Policy`] trait represents a decision-making policy. The [`Policy::sample()`] method takes an +`E::Obs` and generates an `E::Act`. Policies can be either probabilistic or deterministic, depending on the +implementation. + +# Agent + +In this crate, an [`Agent`] is defined as a trainable [`Policy`]. +Agents operate in either training or evaluation mode. During training, the agent's policy may be probabilistic +to facilitate exploration, while in evaluation mode, it typically becomes deterministic. + +The [`Agent::opt()`] method executes a single optimization step. The specific implementation of an optimization +step varies between agents and may include multiple stochastic gradient descent steps. Training samples are +obtained from the [`ReplayBuffer`]. + +This trait also provides methods for saving and loading trained policy parameters to and from a directory. + +# Batch + +The [`TransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`. +This trait is used for training [`Agent`]s using reinforcement learning algorithms. + +# Replay Buffer and Experience Buffer + +The [`ReplayBuffer`] trait provides an abstraction for replay buffers. Its associated type +[`ReplayBuffer::Batch`] represents samples used for training [`Agent`]s. Agents must implement the +[`Agent::opt()`] method, where [`ReplayBuffer::Batch`] must have appropriate type or trait bounds +for training the agent. + +While [`ReplayBuffer`] focuses on generating training batches, the [`ExperienceBuffer`] trait +handles sample storage. The [`ExperienceBuffer::push()`] method stores samples of type +[`ExperienceBuffer::Item`], typically obtained through environment interactions. + +# Step Processor + +The [`StepProcessor`] trait plays a crucial role in the training pipeline by transforming environment +interactions into training samples. It processes [`Step`] objects, which contain the current +observation, action, reward, and next observation, into a format suitable for the replay buffer. + +# Trainer + +The [`Trainer`] manages the training loop and related objects. A [`Trainer`] instance is configured with +training parameters such as the maximum number of optimization steps and the directory for saving agent +parameters during training. The [`Trainer::train`] method executes online training of an agent in an environment. +During the training loop, the agent interacts with the environment to collect samples and perform optimization +steps, while simultaneously recording various metrics. + +# Evaluator + +The [`Evaluator`] trait is used to evaluate a policy's (`P`) performance in an environment (`E`). +An instance of this type is provided to the [`Trainer`] for policy evaluation during training. +[`DefaultEvaluator`] serves as the default implementation of [`Evaluator`]. This evaluator +runs the policy in the environment for a specified number of episodes. At the start of each episode, +the environment is reset using [`Env::reset_with_index()`] to control specific evaluation conditions. diff --git a/border-core/src/base.rs b/border-core/src/base.rs index c73ed89f..f4585ca6 100644 --- a/border-core/src/base.rs +++ b/border-core/src/base.rs @@ -15,7 +15,7 @@ pub use agent::Agent; pub use batch::TransitionBatch; pub use env::Env; pub use policy::{Configurable, Policy}; -pub use replay_buffer::{ExperienceBufferBase, NullReplayBuffer, ReplayBufferBase}; +pub use replay_buffer::{ExperienceBuffer, NullReplayBuffer, ReplayBuffer}; use std::fmt::Debug; pub use step::{Info, Step, StepProcessor}; diff --git a/border-core/src/base/agent.rs b/border-core/src/base/agent.rs index ad2a8ecc..c1e54bb8 100644 --- a/border-core/src/base/agent.rs +++ b/border-core/src/base/agent.rs @@ -3,7 +3,7 @@ //! The [`Agent`] trait extends [`Policy`] with training capabilities, allowing the policy to //! learn from interactions with the environment. It provides methods for training, evaluation, //! parameter optimization, and model persistence. -use super::{Env, Policy, ReplayBufferBase}; +use super::{Env, Policy, ReplayBuffer}; use crate::record::Record; use anyhow::Result; use std::path::{Path, PathBuf}; @@ -21,7 +21,7 @@ use std::path::{Path, PathBuf}; /// /// During training, the agent uses a replay buffer to store and sample experiences, /// which are then used to update the policy's parameters through optimization steps. -pub trait Agent: Policy { +pub trait Agent: Policy { /// Switches the agent to training mode. /// /// In training mode, the policy may become stochastic to facilitate exploration. diff --git a/border-core/src/base/replay_buffer.rs b/border-core/src/base/replay_buffer.rs index 680be7a6..72aa9005 100644 --- a/border-core/src/base/replay_buffer.rs +++ b/border-core/src/base/replay_buffer.rs @@ -22,7 +22,7 @@ use anyhow::Result; /// items: Vec, /// } /// -/// impl ExperienceBufferBase for SimpleBuffer { +/// impl ExperienceBuffer for SimpleBuffer { /// type Item = T; /// /// fn push(&mut self, tr: T) -> Result<()> { @@ -35,7 +35,7 @@ use anyhow::Result; /// } /// } /// ``` -pub trait ExperienceBufferBase { +pub trait ExperienceBuffer { /// The type of items stored in the buffer. /// /// This can be any type that represents an experience or transition @@ -64,14 +64,14 @@ pub trait ExperienceBufferBase { /// Interface for replay buffers that generate batches for training. /// /// This trait provides functionality for sampling batches of experiences -/// for training agents. It is independent of [`ExperienceBufferBase`] and +/// for training agents. It is independent of [`ExperienceBuffer`] and /// focuses solely on the batch generation process. /// /// # Associated Types /// /// * `Config` - Configuration parameters for the buffer /// * `Batch` - The type of batch generated for training -pub trait ReplayBufferBase { +pub trait ReplayBuffer { /// Configuration parameters for the replay buffer. /// /// This type must implement `Clone` to support building multiple instances @@ -131,7 +131,7 @@ pub trait ReplayBufferBase { /// This struct is used as a placeholder when a replay buffer is not needed. pub struct NullReplayBuffer; -impl ReplayBufferBase for NullReplayBuffer { +impl ReplayBuffer for NullReplayBuffer { type Batch = (); type Config = (); diff --git a/border-core/src/dummy.rs b/border-core/src/dummy.rs index 9de88b23..67b55b17 100644 --- a/border-core/src/dummy.rs +++ b/border-core/src/dummy.rs @@ -90,7 +90,7 @@ impl crate::TransitionBatch for DummyBatch { /// Dummy replay buffer. pub struct DummyReplayBuffer; -impl crate::ReplayBufferBase for DummyReplayBuffer { +impl crate::ReplayBuffer for DummyReplayBuffer { type Batch = DummyBatch; type Config = usize; diff --git a/border-core/src/evaluator.rs b/border-core/src/evaluator.rs index f241830c..95f4ab79 100644 --- a/border-core/src/evaluator.rs +++ b/border-core/src/evaluator.rs @@ -8,7 +8,7 @@ //! - Monitor training progress //! - Validate the generalization of learned policies -use crate::{record::Record, Agent, Env, ReplayBufferBase}; +use crate::{record::Record, Agent, Env, ReplayBuffer}; use anyhow::Result; mod default_evaluator; pub use default_evaluator::DefaultEvaluator; @@ -36,7 +36,7 @@ pub use default_evaluator::DefaultEvaluator; /// impl Evaluator for CustomEvaluator { /// fn evaluate(&mut self, agent: &mut Box>) -> Result /// where -/// R: ReplayBufferBase, +/// R: ReplayBuffer, /// { /// // Custom evaluation logic /// // ... @@ -79,5 +79,5 @@ pub trait Evaluator { /// [`Trainer`]: crate::Trainer fn evaluate(&mut self, agent: &mut Box>) -> Result<(f32, Record)> where - R: ReplayBufferBase; + R: ReplayBuffer; } diff --git a/border-core/src/evaluator/default_evaluator.rs b/border-core/src/evaluator/default_evaluator.rs index 10b41719..ac1fd0e9 100644 --- a/border-core/src/evaluator/default_evaluator.rs +++ b/border-core/src/evaluator/default_evaluator.rs @@ -4,7 +4,7 @@ //! and calculates the average return across all episodes. use super::Evaluator; -use crate::{record::Record, Agent, Env, ReplayBufferBase}; +use crate::{record::Record, Agent, Env, ReplayBuffer}; use anyhow::Result; /// A default implementation of the [`Evaluator`] trait. @@ -63,7 +63,7 @@ impl Evaluator for DefaultEvaluator { /// - The environment fails to step fn evaluate(&mut self, policy: &mut Box>) -> Result<(f32, Record)> where - R: ReplayBufferBase, + R: ReplayBuffer, { let mut r_total = 0f32; diff --git a/border-core/src/generic_replay_buffer/base.rs b/border-core/src/generic_replay_buffer/base.rs index 14b3a623..d9dda3dc 100644 --- a/border-core/src/generic_replay_buffer/base.rs +++ b/border-core/src/generic_replay_buffer/base.rs @@ -9,7 +9,7 @@ mod iw_scheduler; mod sum_tree; use super::{config::PerConfig, BatchBase, GenericTransitionBatch, SimpleReplayBufferConfig}; -use crate::{ExperienceBufferBase, ReplayBufferBase, TransitionBatch}; +use crate::{ExperienceBuffer, ReplayBuffer, TransitionBatch}; use anyhow::Result; pub use iw_scheduler::IwScheduler; use rand::{rngs::StdRng, RngCore, SeedableRng}; @@ -267,7 +267,7 @@ where } } -impl ExperienceBufferBase for SimpleReplayBuffer +impl ExperienceBuffer for SimpleReplayBuffer where O: BatchBase, A: BatchBase, @@ -316,7 +316,7 @@ where } } -impl ReplayBufferBase for SimpleReplayBuffer +impl ReplayBuffer for SimpleReplayBuffer where O: BatchBase, A: BatchBase, diff --git a/border-core/src/lib.rs b/border-core/src/lib.rs index e374d79e..b924f2f3 100644 --- a/border-core/src/lib.rs +++ b/border-core/src/lib.rs @@ -1,329 +1,16 @@ +#![doc = include_str!("../README.md")] #![warn(missing_docs)] -//! Core components for reinforcement learning. -//! -//! # Observation and Action -//! -//! The [`Obs`] and [`Act`] traits provide abstractions for observations and actions in environments. -//! -//! # Environment -//! -//! The [`Env`] trait serves as the fundamental abstraction for environments. It defines four associated types: -//! `Config`, `Obs`, `Act`, and `Info`. The `Obs` and `Act` types represent concrete implementations of -//! environment observations and actions, respectively. These types must implement the [`Obs`] and [`Act`] traits. -//! Environments implementing [`Env`] generate [`Step`] objects at each interaction step through the -//! [`Env::step()`] method. The [`Info`] type stores additional information from each agent-environment interaction, -//! which may be empty (implemented as a zero-sized struct). The `Config` type represents environment configurations -//! and is used during environment construction. -//! -//! # Policy -//! -//! The [`Policy`] trait represents a decision-making policy. The [`Policy::sample()`] method takes an -//! `E::Obs` and generates an `E::Act`. Policies can be either probabilistic or deterministic, depending on the -//! implementation. -//! -//! # Agent -//! -//! In this crate, an [`Agent`] is defined as a trainable [`Policy`]. -//! Agents operate in either training or evaluation mode. During training, the agent's policy may be probabilistic -//! to facilitate exploration, while in evaluation mode, it typically becomes deterministic. -//! -//! The [`Agent::opt()`] method executes a single optimization step. The specific implementation of an optimization -//! step varies between agents and may include multiple stochastic gradient descent steps. Training samples are -//! obtained from the [`ReplayBufferBase`]. -//! -//! This trait also provides methods for saving and loading trained policy parameters to and from a directory. -//! -//! # Batch -//! -//! The [`TransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`. -//! This trait is used for training [`Agent`]s using reinforcement learning algorithms. -//! -//! # Replay Buffer and Experience Buffer -//! -//! The [`ReplayBufferBase`] trait provides an abstraction for replay buffers. Its associated type -//! [`ReplayBufferBase::Batch`] represents samples used for training [`Agent`]s. Agents must implement the -//! [`Agent::opt()`] method, where [`ReplayBufferBase::Batch`] must have appropriate type or trait bounds -//! for training the agent. -//! -//! While [`ReplayBufferBase`] focuses on generating training batches, the [`ExperienceBufferBase`] trait -//! handles sample storage. The [`ExperienceBufferBase::push()`] method stores samples of type -//! [`ExperienceBufferBase::Item`], typically obtained through environment interactions. -//! -//! ## Reference Implementation -//! -//! [`SimpleReplayBuffer`] implements both [`ReplayBufferBase`] and [`ExperienceBufferBase`]. -//! This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer. -//! Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec`. -//! The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of -//! `(o_t, r_t, a_t, o_t+1)` transitions. -//! -//! # Step Processor -//! -//! The [`StepProcessor`] trait plays a crucial role in the training pipeline by transforming environment -//! interactions into training samples. It processes [`Step`] objects, which contain the current -//! observation, action, reward, and next observation, into a format suitable for the replay buffer. -//! -//! The [`SimpleStepProcessor`] is a concrete implementation that: -//! 1. Maintains the previous observation to construct complete transitions -//! 2. Converts environment-specific observations and actions (`E::Obs` and `E::Act`) into batch-compatible -//! types (`O` and `A`) using the `From` trait -//! 3. Generates [`GenericTransitionBatch`] objects containing the complete transition -//! `(o_t, a_t, o_t+1, r_t, is_terminated, is_truncated)` -//! 4. Handles episode termination by properly resetting the previous observation -//! -//! This processor is essential for implementing temporal difference learning algorithms, as it ensures -//! that transitions are properly formatted and stored in the replay buffer for training. -//! -//! [`SimpleStepProcessor`] can be used with [`SimpleReplayBuffer`]. It converts `E::Obs` and -//! `E::Act` into their respective [`BatchBase`] types and generates [`GenericTransitionBatch`]. This conversion -//! relies on the trait bounds `O: From` and `A: From`. -//! -//! # Trainer -//! -//! The [`Trainer`] manages the training loop and related objects. A [`Trainer`] instance is configured with -//! training parameters such as the maximum number of optimization steps and the directory for saving agent -//! parameters during training. The [`Trainer::train`] method executes online training of an agent in an environment. -//! During the training loop, the agent interacts with the environment to collect samples and perform optimization -//! steps, while simultaneously recording various metrics. -//! -//! # Evaluator -//! -//! The [`Evaluator`] trait is used to evaluate a policy's (`P`) performance in an environment (`E`). -//! An instance of this type is provided to the [`Trainer`] for policy evaluation during training. -//! [`DefaultEvaluator`] serves as the default implementation of [`Evaluator`]. This evaluator -//! runs the policy in the environment for a specified number of episodes. At the start of each episode, -//! the environment is reset using [`Env::reset_with_index()`] to control specific evaluation conditions. -//! -//! [`SimpleReplayBuffer`]: generic_replay_buffer::SimpleReplayBuffer -//! [`SimpleReplayBuffer`]: generic_replay_buffer::SimpleReplayBuffer -//! [`BatchBase`]: generic_replay_buffer::BatchBase -//! [`GenericTransitionBatch`]: generic_replay_buffer::GenericTransitionBatch -//! [`SimpleStepProcessor`]: generic_replay_buffer::SimpleStepProcessor -//! [`SimpleStepProcessor`]: generic_replay_buffer::SimpleStepProcessor pub mod dummy; pub mod error; mod evaluator; -pub mod generic_replay_buffer; pub mod record; mod base; pub use base::{ - Act, Agent, Configurable, Env, ExperienceBufferBase, Info, NullReplayBuffer, Obs, Policy, - ReplayBufferBase, Step, StepProcessor, TransitionBatch, + Act, Agent, Configurable, Env, ExperienceBuffer, Info, NullReplayBuffer, Obs, Policy, + ReplayBuffer, Step, StepProcessor, TransitionBatch, }; mod trainer; pub use evaluator::{DefaultEvaluator, Evaluator}; pub use trainer::{Sampler, Trainer, TrainerConfig}; - -// TODO: Consider to compile this module only for tests. -/// Agent and Env for testing. -pub mod test { - use serde::{Deserialize, Serialize}; - - /// Obs for testing. - #[derive(Clone, Debug)] - pub struct TestObs { - obs: usize, - } - - impl crate::Obs for TestObs { - fn len(&self) -> usize { - 1 - } - } - - /// Batch of obs for testing. - pub struct TestObsBatch { - obs: Vec, - } - - impl crate::generic_replay_buffer::BatchBase for TestObsBatch { - fn new(capacity: usize) -> Self { - Self { - obs: vec![0; capacity], - } - } - - fn push(&mut self, i: usize, data: Self) { - self.obs[i] = data.obs[0]; - } - - fn sample(&self, ixs: &Vec) -> Self { - let obs = ixs.iter().map(|ix| self.obs[*ix]).collect(); - Self { obs } - } - } - - impl From for TestObsBatch { - fn from(obs: TestObs) -> Self { - Self { obs: vec![obs.obs] } - } - } - - /// Act for testing. - #[derive(Clone, Debug)] - pub struct TestAct { - act: usize, - } - - impl crate::Act for TestAct {} - - /// Batch of act for testing. - pub struct TestActBatch { - act: Vec, - } - - impl From for TestActBatch { - fn from(act: TestAct) -> Self { - Self { act: vec![act.act] } - } - } - - impl crate::generic_replay_buffer::BatchBase for TestActBatch { - fn new(capacity: usize) -> Self { - Self { - act: vec![0; capacity], - } - } - - fn push(&mut self, i: usize, data: Self) { - self.act[i] = data.act[0]; - } - - fn sample(&self, ixs: &Vec) -> Self { - let act = ixs.iter().map(|ix| self.act[*ix]).collect(); - Self { act } - } - } - - /// Info for testing. - pub struct TestInfo {} - - impl crate::Info for TestInfo {} - - /// Environment for testing. - pub struct TestEnv { - state_init: usize, - state: usize, - } - - impl crate::Env for TestEnv { - type Config = usize; - type Obs = TestObs; - type Act = TestAct; - type Info = TestInfo; - - fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { - self.state = self.state_init; - Ok(TestObs { obs: self.state }) - } - - fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { - self.state = self.state_init; - Ok(TestObs { obs: self.state }) - } - - fn step_with_reset(&mut self, a: &Self::Act) -> (crate::Step, crate::record::Record) - where - Self: Sized, - { - self.state = self.state + a.act; - let step = crate::Step { - obs: TestObs { obs: self.state }, - act: a.clone(), - reward: vec![0.0], - is_terminated: vec![0], - is_truncated: vec![0], - info: TestInfo {}, - init_obs: Some(TestObs { - obs: self.state_init, - }), - }; - return (step, crate::record::Record::empty()); - } - - fn step(&mut self, a: &Self::Act) -> (crate::Step, crate::record::Record) - where - Self: Sized, - { - self.state = self.state + a.act; - let step = crate::Step { - obs: TestObs { obs: self.state }, - act: a.clone(), - reward: vec![0.0], - is_terminated: vec![0], - is_truncated: vec![0], - info: TestInfo {}, - init_obs: Some(TestObs { - obs: self.state_init, - }), - }; - return (step, crate::record::Record::empty()); - } - - fn build(config: &Self::Config, _seed: i64) -> anyhow::Result - where - Self: Sized, - { - Ok(Self { - state_init: *config, - state: 0, - }) - } - } - - type ReplayBuffer = - crate::generic_replay_buffer::SimpleReplayBuffer; - - /// Agent for testing. - pub struct TestAgent {} - - #[derive(Clone, Deserialize, Serialize)] - /// Config of agent for testing. - pub struct TestAgentConfig; - - impl crate::Agent for TestAgent { - fn train(&mut self) {} - - fn is_train(&self) -> bool { - false - } - - fn eval(&mut self) {} - - fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record { - crate::record::Record::empty() - } - - fn save_params(&self, _path: &std::path::Path) -> anyhow::Result> { - Ok(vec![]) - } - - fn load_params(&mut self, _path: &std::path::Path) -> anyhow::Result<()> { - Ok(()) - } - - fn as_any_ref(&self) -> &dyn std::any::Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } - } - - impl crate::Policy for TestAgent { - fn sample(&mut self, _obs: &TestObs) -> TestAct { - TestAct { act: 1 } - } - } - - impl crate::Configurable for TestAgent { - type Config = TestAgentConfig; - - fn build(_config: Self::Config) -> Self { - Self {} - } - } -} diff --git a/border-core/src/record/base.rs b/border-core/src/record/base.rs index 5e97d0b9..849c6743 100644 --- a/border-core/src/record/base.rs +++ b/border-core/src/record/base.rs @@ -106,7 +106,7 @@ impl Record { /// # Returns /// /// An iterator over the record's keys - pub fn keys(&self) -> Keys { + pub fn keys(&self) -> Keys<'_, String, RecordValue> { self.0.keys() } diff --git a/border-core/src/record/buffered_recorder.rs b/border-core/src/record/buffered_recorder.rs index 0357b365..267c78cb 100644 --- a/border-core/src/record/buffered_recorder.rs +++ b/border-core/src/record/buffered_recorder.rs @@ -6,7 +6,7 @@ //! learning environments. use super::{Record, Recorder}; -use crate::{Env, ReplayBufferBase}; +use crate::{Env, ReplayBuffer}; use std::marker::PhantomData; /// A recorder that buffers sequences of observations and actions in memory. @@ -19,12 +19,12 @@ use std::marker::PhantomData; /// # Type Parameters /// /// * `E` - The environment type that implements the [`Env`] trait -/// * `R` - The replay buffer type that implements the [`ReplayBufferBase`] trait +/// * `R` - The replay buffer type that implements the [`ReplayBuffer`] trait #[derive(Default)] pub struct BufferedRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// The internal buffer storing the sequence of records buf: Vec, @@ -35,7 +35,7 @@ where impl BufferedRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Creates a new empty buffered recorder. /// @@ -57,7 +57,7 @@ where /// # Returns /// /// An iterator over references to the [`Record`]s in the buffer. - pub fn iter(&self) -> std::slice::Iter { + pub fn iter(&self) -> std::slice::Iter<'_, Record> { self.buf.iter() } } @@ -65,7 +65,7 @@ where impl Recorder for BufferedRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Writes a [`Record`] to the internal buffer. /// diff --git a/border-core/src/record/null_recorder.rs b/border-core/src/record/null_recorder.rs index b78a2929..53d3ab26 100644 --- a/border-core/src/record/null_recorder.rs +++ b/border-core/src/record/null_recorder.rs @@ -7,7 +7,7 @@ use std::marker::PhantomData; use super::{Record, Recorder}; -use crate::{Env, ReplayBufferBase}; +use crate::{Env, ReplayBuffer}; /// A recorder that discards all records without storing them. /// @@ -22,11 +22,11 @@ use crate::{Env, ReplayBufferBase}; /// # Type Parameters /// /// * `E` - The environment type that implements the [`Env`] trait -/// * `R` - The replay buffer type that implements the [`ReplayBufferBase`] trait +/// * `R` - The replay buffer type that implements the [`ReplayBuffer`] trait pub struct NullRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Phantom data to hold the type parameters phantom: PhantomData<(E, R)>, @@ -35,7 +35,7 @@ where impl NullRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Creates a new null recorder. /// @@ -52,7 +52,7 @@ where impl Recorder for NullRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Discards the given record without storing it. /// diff --git a/border-core/src/record/recorder.rs b/border-core/src/record/recorder.rs index 768a93e2..3aa49e35 100644 --- a/border-core/src/record/recorder.rs +++ b/border-core/src/record/recorder.rs @@ -6,7 +6,7 @@ //! recording strategies. use super::Record; -use crate::{Agent, Env, ReplayBufferBase}; +use crate::{Agent, Env, ReplayBuffer}; use anyhow::Result; use std::path::Path; @@ -22,11 +22,11 @@ use std::path::Path; /// # Type Parameters /// /// * `E` - The environment type that implements the [`Env`] trait -/// * `R` - The replay buffer type that implements the [`ReplayBufferBase`] trait +/// * `R` - The replay buffer type that implements the [`ReplayBuffer`] trait pub trait Recorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Writes a record to the recorder's output destination. /// diff --git a/border-core/src/trainer.rs b/border-core/src/trainer.rs index 87e403bf..3385ce42 100644 --- a/border-core/src/trainer.rs +++ b/border-core/src/trainer.rs @@ -10,7 +10,7 @@ use std::time::{Duration, SystemTime}; use crate::{ record::{Record, RecordValue::Scalar, Recorder}, - Agent, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, StepProcessor, + Agent, Env, Evaluator, ExperienceBuffer, ReplayBuffer, StepProcessor, }; use anyhow::Result; pub use config::TrainerConfig; @@ -201,7 +201,7 @@ impl Trainer { ) -> Result<(Record, bool)> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { if self.env_steps < self.warmup_period { Ok((Record::empty(), false)) @@ -237,7 +237,7 @@ impl Trainer { ) -> Result<()> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, D: Evaluator, { // Evaluation @@ -276,7 +276,7 @@ impl Trainer { where E: Env, P: StepProcessor, - R: ExperienceBufferBase + ReplayBufferBase, + R: ExperienceBuffer + ReplayBuffer, D: Evaluator, { let mut sampler = Sampler::new(env, step_proc); @@ -336,7 +336,7 @@ impl Trainer { ) -> Result<()> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, D: Evaluator, { // Return empty record diff --git a/border-core/src/trainer/sampler.rs b/border-core/src/trainer/sampler.rs index 4f96cbae..6c7a9203 100644 --- a/border-core/src/trainer/sampler.rs +++ b/border-core/src/trainer/sampler.rs @@ -21,7 +21,7 @@ //! 3. Performance Monitoring: //! * Monitor episode length //! * Record environment metrics -use crate::{record::Record, Agent, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor}; +use crate::{record::Record, Agent, Env, ExperienceBuffer, ReplayBuffer, StepProcessor}; use anyhow::Result; /// Manages the sampling of experiences from the environment. @@ -102,8 +102,8 @@ where buffer: &mut R_, ) -> Result where - R: ExperienceBufferBase + ReplayBufferBase, - R_: ExperienceBufferBase, + R: ExperienceBuffer + ReplayBuffer, + R_: ExperienceBuffer, { // Reset environment(s) if required if self.prev_obs.is_none() { diff --git a/border-generic-replay-buffer/Cargo.toml b/border-generic-replay-buffer/Cargo.toml new file mode 100644 index 00000000..fc828362 --- /dev/null +++ b/border-generic-replay-buffer/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "border-generic-replay-buffer" +description = "A generic implementation of replaybuffer for Border" +version.workspace = true +edition.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +border-core = { version = "0.0.9", path = "../border-core" } +serde = { workspace = true, features = ["derive"] } +serde_yaml = { workspace = true } +# log = { workspace = true } +# thiserror = { workspace = true } +anyhow = { workspace = true } +# chrono = { workspace = true } +# aquamarine = { workspace = true } +fastrand = { workspace = true } +segment-tree = { workspace = true } +# xxhash-rust = { workspace = true } +# Consider to replace with fastrand +rand = { workspace = true } + +[dev-dependencies] +tempdir = { workspace = true } diff --git a/border-generic-replay-buffer/README.md b/border-generic-replay-buffer/README.md new file mode 100644 index 00000000..a0a204e5 --- /dev/null +++ b/border-generic-replay-buffer/README.md @@ -0,0 +1,75 @@ +Generic implementation of replay buffers for reinforcement learning. + +This module provides a flexible implementation of replay buffers +that can handle arbitrary observation and action types. It supports both +standard experience replay and prioritized experience replay (PER). + +# Key Components + +- [`GenericReplayBuffer`]: A generic replay buffer implementation +- [`GenericTransitionBatch`]: A generic batch structure for transitions +- [`SimpleStepProcessor`]: A processor for converting environment steps to transitions +- [`PerConfig`]: Configuration for prioritized experience replay + +# Features + +- Generic type support for observations and actions +- Efficient batch processing +- Prioritized experience replay with importance sampling +- Configurable weight normalization +- Step processing for non-vectorized environments + +# [`BatchBase`] + +The [`BatchBase`] trait represents batches of observations and actions, serving dual purposes in the reinforcement learning pipeline: + +1. **Training Batch Component**: Forms the building blocks for training batches used by agents during optimization +2. **Storage Container**: Acts as the internal storage mechanism within [`GenericReplayBuffer`] for efficiently managing observation and action data + +# [`GenericTransitionBatch`] + +The [`GenericTransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`. +This is composed of structs that implement the [`BatchBase`] trait and +used for training [`Agent`]s using reinforcement learning algorithms. + +# [`SimpleStepProcessor`] + +The [`SimpleStepProcessor`] is an implementation of [`StepProcessor`] that: + +1. Maintains the previous observation to construct complete transitions +2. Converts environment-specific observations and actions (`E::Obs` and `E::Act`) into batch-compatible + types (`O` and `A`) using the `From` trait +3. Generates [`GenericTransitionBatch`] objects containing the complete transition + `(o_t, a_t, o_t+1, r_t, is_terminated, is_truncated)` +4. Handles episode termination by properly resetting the previous observation + +This processor is essential for implementing temporal difference learning algorithms, as it ensures +that transitions are properly formatted and stored in the replay buffer for training. + +[`SimpleStepProcessor`] can be used with [`GenericReplayBuffer`]. It converts `E::Obs` and +`E::Act` into their respective [`BatchBase`] types and generates [`GenericTransitionBatch`]. This conversion +relies on the trait bounds `O: From` and `A: From`. + +# [`GenericReplayBuffer`] + +[`GenericReplayBuffer`] implements both [`ReplayBuffer`] and [`ExperienceBuffer`]. +This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer. +Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec`. +The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of +`(o_t, r_t, a_t, o_t+1)` transitions. + +## Type Compatibility + +Typically, the associated types `SimpleStepProcessor::Output` and `GenericReplayBuffer::Item` need to be matched. This ensures that the step data from environment interactions matches the samples expected by the replay buffer, guaranteeing type compatibility throughout the reinforcement learning pipeline. + +[`GenericReplayBuffer`]: crate::GenericReplayBuffer +[`GenericReplayBuffer`]: crate::GenericReplayBuffer +[`BatchBase`]: crate::BatchBase +[`GenericTransitionBatch`]: crate::GenericTransitionBatch +[`SimpleStepProcessor`]: crate::SimpleStepProcessor +[`SimpleStepProcessor`]: crate::SimpleStepProcessor +[`BatchBase`]: crate::BatchBase +[`ReplayBuffer`]: border_core::ReplayBuffer +[`ExperienceBuffer`]: border_core::ExperienceBuffer +[`Agent`]: border_core::Agent +[`StepProcessor`]: border_core::StepProcessor diff --git a/border-generic-replay-buffer/src/batch.rs b/border-generic-replay-buffer/src/batch.rs new file mode 100644 index 00000000..a115fb76 --- /dev/null +++ b/border-generic-replay-buffer/src/batch.rs @@ -0,0 +1,206 @@ +//! Generic implementation of transition batches for reinforcement learning. +//! +//! This module provides a generic implementation of transition batches that can handle +//! arbitrary observation and action types. It supports the following features: +//! - Efficient batch processing +//! - Weighting for prioritized experience replay +//! - Transition sampling and management + +use border_core::TransitionBatch; + +/// A trait defining basic batch operations. +/// +/// This trait provides fundamental operations for efficiently managing batches of +/// observations and actions. +/// +/// # Type Parameters +/// +/// * `Self` - The batch type, representing batches of observations or actions. +/// +/// # Examples +/// +/// ```ignore +/// struct TensorBatch { +/// data: Vec, +/// shape: Vec, +/// } +/// +/// impl BatchBase for TensorBatch { +/// fn new(capacity: usize) -> Self { +/// Self { +/// data: Vec::with_capacity(capacity), +/// shape: vec![], +/// } +/// } +/// +/// fn push(&mut self, ix: usize, data: Self) { +/// // Data addition logic +/// } +/// +/// fn sample(&self, ixs: &Vec) -> Self { +/// // Sampling logic +/// } +/// } +/// ``` +pub trait BatchBase { + /// Creates a new batch with the specified capacity. + /// + /// # Arguments + /// + /// * `capacity` - Initial capacity of the batch + fn new(capacity: usize) -> Self; + + /// Adds data at the specified index. + /// + /// # Arguments + /// + /// * `ix` - Index where data should be added + /// * `data` - Data to be added + fn push(&mut self, ix: usize, data: Self); + + /// Retrieves samples from the specified indices. + /// + /// # Arguments + /// + /// * `ixs` - List of indices to sample from + /// + /// # Returns + /// + /// A new batch containing the sampled data + fn sample(&self, ixs: &Vec) -> Self; +} + +/// A generic structure representing transitions in reinforcement learning. +/// +/// This structure efficiently manages reinforcement learning transitions +/// (observations, actions, rewards, etc.). It also includes support for +/// prioritized experience replay (PER). +/// +/// # Type Parameters +/// +/// * `O` - Observation type, must implement `BatchBase` +/// * `A` - Action type, must implement `BatchBase` +/// +/// # Examples +/// +/// ```ignore +/// let batch = GenericTransitionBatch::::with_capacity(32); +/// ``` +pub struct GenericTransitionBatch +where + O: BatchBase, + A: BatchBase, +{ + /// Current observations + pub obs: O, + + /// Selected actions + pub act: A, + + /// Next state observations + pub next_obs: O, + + /// Transition rewards + pub reward: Vec, + + /// Episode termination flags + pub is_terminated: Vec, + + /// Episode truncation flags + pub is_truncated: Vec, + + /// Weights for prioritized experience replay + pub weight: Option>, + + /// Indices of sampled transitions + pub ix_sample: Option>, +} + +impl TransitionBatch for GenericTransitionBatch +where + O: BatchBase, + A: BatchBase, +{ + type ObsBatch = O; + type ActBatch = A; + + /// Decomposes the batch into its individual components. + /// + /// # Returns + /// + /// A tuple containing the following elements: + /// 1. Observations + /// 2. Actions + /// 3. Next observations + /// 4. Rewards + /// 5. Termination flags + /// 6. Truncation flags + /// 7. Sample indices + /// 8. Weights + fn unpack( + self, + ) -> ( + Self::ObsBatch, + Self::ActBatch, + Self::ObsBatch, + Vec, + Vec, + Vec, + Option>, + Option>, + ) { + ( + self.obs, + self.act, + self.next_obs, + self.reward, + self.is_terminated, + self.is_truncated, + self.ix_sample, + self.weight, + ) + } + + /// Returns the number of transitions in the batch. + fn len(&self) -> usize { + self.reward.len() + } + + /// Returns a reference to the batch of observations. + fn obs(&self) -> &Self::ObsBatch { + &self.obs + } + + /// Returns a reference to the batch of actions. + fn act(&self) -> &Self::ActBatch { + &self.act + } +} + +impl GenericTransitionBatch +where + O: BatchBase, + A: BatchBase, +{ + /// Creates a new batch with the specified capacity. + /// + /// # Arguments + /// + /// * `capacity` - Initial capacity of the batch + /// + /// # Returns + /// + /// A new `GenericTransitionBatch` instance + pub fn with_capacity(capacity: usize) -> Self { + Self { + obs: O::new(capacity), + act: A::new(capacity), + next_obs: O::new(capacity), + reward: Vec::with_capacity(capacity), + is_terminated: Vec::with_capacity(capacity), + is_truncated: Vec::with_capacity(capacity), + weight: None, + ix_sample: None, + } + } +} diff --git a/border-generic-replay-buffer/src/config.rs b/border-generic-replay-buffer/src/config.rs new file mode 100644 index 00000000..8e41614a --- /dev/null +++ b/border-generic-replay-buffer/src/config.rs @@ -0,0 +1,294 @@ +//! Configuration for the replay buffer implementation. +//! +//! This module provides configuration structures for the replay buffer, including: +//! - Basic buffer configuration (capacity, seed) +//! - Prioritized Experience Replay (PER) configuration +//! - Serialization and deserialization support + +use super::{WeightNormalizer, WeightNormalizer::All}; +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::{ + default::Default, + fs::File, + io::{BufReader, Write}, + path::Path, +}; + +/// Configuration for Prioritized Experience Replay (PER). +/// +/// This structure defines the parameters for prioritized sampling in the replay buffer. +/// It controls how transitions are sampled based on their importance and how +/// importance weights are calculated and normalized. +/// +/// # Fields +/// +/// * `alpha` - Controls the degree of prioritization (0 = uniform sampling) +/// * `beta_0` - Initial value for importance sampling weights +/// * `beta_final` - Final value for importance sampling weights +/// * `n_opts_final` - Number of optimization steps to reach `beta_final` +/// * `normalize` - Method for normalizing importance weights +/// +/// # Examples +/// +/// ```rust +/// use border_generic_replay_buffer::{PerConfig, WeightNormalizer}; +/// +/// let config = PerConfig::default() +/// .alpha(0.6) +/// .beta_0(0.4) +/// .beta_final(1.0) +/// .n_opts_final(500_000) +/// .normalize(WeightNormalizer::All); +/// ``` +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub struct PerConfig { + /// Exponent for prioritization. Higher values increase the bias towards + /// high-priority transitions. A value of 0 results in uniform sampling. + pub alpha: f32, + + /// Initial value of the importance sampling exponent. Lower values reduce + /// the impact of importance sampling weights. + pub beta_0: f32, + + /// Final value of the importance sampling exponent. Typically set to 1.0 + /// to fully compensate for the non-uniform sampling. + pub beta_final: f32, + + /// Number of optimization steps after which `beta` reaches its final value. + /// This allows for a gradual increase in the impact of importance sampling. + pub n_opts_final: usize, + + /// Method for normalizing importance sampling weights. Controls how the + /// weights are scaled to prevent numerical instability. + pub normalize: WeightNormalizer, +} + +impl Default for PerConfig { + /// Creates a default PER configuration with commonly used values: + /// - `alpha = 0.6` (moderate prioritization) + /// - `beta_0 = 0.4` (initial importance sampling) + /// - `beta_final = 1.0` (full compensation) + /// - `n_opts_final = 500_000` (gradual increase) + /// - `normalize = All` (normalize all weights) + fn default() -> Self { + Self { + alpha: 0.6, + beta_0: 0.4, + beta_final: 1.0, + n_opts_final: 500_000, + normalize: All, + } + } +} + +impl PerConfig { + /// Sets the prioritization exponent `alpha`. + /// + /// # Arguments + /// + /// * `alpha` - The new value for the prioritization exponent + /// + /// # Returns + /// + /// The modified configuration + pub fn alpha(mut self, alpha: f32) -> Self { + self.alpha = alpha; + self + } + + /// Sets the initial importance sampling exponent `beta_0`. + /// + /// # Arguments + /// + /// * `beta_0` - The new initial value for the importance sampling exponent + /// + /// # Returns + /// + /// The modified configuration + pub fn beta_0(mut self, beta_0: f32) -> Self { + self.beta_0 = beta_0; + self + } + + /// Sets the final importance sampling exponent `beta_final`. + /// + /// # Arguments + /// + /// * `beta_final` - The new final value for the importance sampling exponent + /// + /// # Returns + /// + /// The modified configuration + pub fn beta_final(mut self, beta_final: f32) -> Self { + self.beta_final = beta_final; + self + } + + /// Sets the number of optimization steps to reach the final beta value. + /// + /// # Arguments + /// + /// * `n_opts_final` - The new number of optimization steps + /// + /// # Returns + /// + /// The modified configuration + pub fn n_opts_final(mut self, n_opts_final: usize) -> Self { + self.n_opts_final = n_opts_final; + self + } + + /// Sets the method for normalizing importance weights. + /// + /// # Arguments + /// + /// * `normalize` - The new normalization method + /// + /// # Returns + /// + /// The modified configuration + pub fn normalize(mut self, normalize: WeightNormalizer) -> Self { + self.normalize = normalize; + self + } +} + +/// Configuration for the replay buffer. +/// +/// This structure defines the basic parameters for the replay buffer, +/// including its capacity, random seed, and optional PER configuration. +/// +/// # Fields +/// +/// * `capacity` - Maximum number of transitions to store +/// * `seed` - Random seed for sampling +/// * `per_config` - Optional configuration for prioritized experience replay +/// +/// # Examples +/// +/// ```rust +/// use border_generic_replay_buffer::{GenericReplayBufferConfig, PerConfig}; +/// +/// // Basic configuration +/// let config = GenericReplayBufferConfig::default() +/// .capacity(10000) +/// .seed(42); +/// +/// // Configuration with PER +/// let config_with_per = GenericReplayBufferConfig::default() +/// .capacity(10000) +/// .seed(42) +/// .per_config(Some(PerConfig::default())); +/// ``` +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub struct GenericReplayBufferConfig { + /// Maximum number of transitions that can be stored in the buffer. + /// When the buffer is full, new transitions replace the oldest ones. + pub capacity: usize, + + /// Random seed used for sampling transitions. This ensures reproducibility + /// of the sampling process when the same seed is used. + pub seed: u64, + + /// Optional configuration for prioritized experience replay. If `None`, + /// transitions are sampled uniformly at random. + pub per_config: Option, +} + +impl Default for GenericReplayBufferConfig { + /// Creates a default replay buffer configuration with commonly used values: + /// - `capacity = 10000` (moderate buffer size) + /// - `seed = 42` (fixed random seed) + /// - `per_config = None` (uniform sampling) + fn default() -> Self { + Self { + capacity: 10000, + seed: 42, + per_config: None, + } + } +} + +impl GenericReplayBufferConfig { + /// Sets the capacity of the replay buffer. + /// + /// # Arguments + /// + /// * `capacity` - The new capacity for the buffer + /// + /// # Returns + /// + /// The modified configuration + pub fn capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + + /// Sets the random seed for sampling. + /// + /// # Arguments + /// + /// * `seed` - The new random seed + /// + /// # Returns + /// + /// The modified configuration + pub fn seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } + + /// Sets the configuration for prioritized experience replay. + /// + /// # Arguments + /// + /// * `per_config` - The new PER configuration + /// + /// # Returns + /// + /// The modified configuration + pub fn per_config(mut self, per_config: Option) -> Self { + self.per_config = per_config; + self + } + + /// Loads the configuration from a YAML file. + /// + /// # Arguments + /// + /// * `path` - Path to the configuration file + /// + /// # Returns + /// + /// The loaded configuration + /// + /// # Errors + /// + /// Returns an error if the file cannot be read or parsed + pub fn load(path: impl AsRef) -> Result { + let file = File::open(path)?; + let rdr = BufReader::new(file); + let b = serde_yaml::from_reader(rdr)?; + Ok(b) + } + + /// Saves the configuration to a YAML file. + /// + /// # Arguments + /// + /// * `path` - Path where the configuration should be saved + /// + /// # Returns + /// + /// `Ok(())` if the configuration was saved successfully + /// + /// # Errors + /// + /// Returns an error if the file cannot be written + pub fn save(&self, path: impl AsRef) -> Result<()> { + let mut file = File::create(path)?; + file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; + Ok(()) + } +} diff --git a/border-generic-replay-buffer/src/iw_scheduler.rs b/border-generic-replay-buffer/src/iw_scheduler.rs new file mode 100644 index 00000000..13c02ce0 --- /dev/null +++ b/border-generic-replay-buffer/src/iw_scheduler.rs @@ -0,0 +1,46 @@ +//! Scheduling the exponent of importance weight for PER. +use serde::{Deserialize, Serialize}; + +/// Scheduler of the exponent of importance weight for PER. +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +pub struct IwScheduler { + /// Initial value of $\beta$. + pub beta_0: f32, + + /// Final value of $\beta$. + pub beta_final: f32, + + /// Optimization steps when beta reaches its final value. + pub n_opts_final: usize, + + /// Current optimizatioin steps. + pub n_opts: usize, +} + +impl IwScheduler { + /// Creates a scheduler. + pub fn new(beta_0: f32, beta_final: f32, n_opts_final: usize) -> Self { + Self { + beta_0, + beta_final, + n_opts_final, + n_opts: 0, + } + } + + /// Gets the exponents of importance sampling weight. + pub fn beta(&self) -> f32 { + let n_opts = self.n_opts; + if n_opts >= self.n_opts_final { + self.beta_final + } else { + let d = self.beta_final - self.beta_0; + self.beta_0 + d * (n_opts as f32 / self.n_opts_final as f32) + } + } + + /// Add optimization steps for scheduling beta through training. + pub fn add_n_opts(&mut self) { + self.n_opts += 1; + } +} diff --git a/border-generic-replay-buffer/src/lib.rs b/border-generic-replay-buffer/src/lib.rs new file mode 100644 index 00000000..9dba02f2 --- /dev/null +++ b/border-generic-replay-buffer/src/lib.rs @@ -0,0 +1,225 @@ +#![doc = include_str!("../README.md")] +mod batch; +mod config; +mod iw_scheduler; +mod replay_buffer; +mod step_proc; +mod sum_tree; +pub use batch::{BatchBase, GenericTransitionBatch}; +pub use config::{GenericReplayBufferConfig, PerConfig}; +pub use iw_scheduler::IwScheduler; +pub use replay_buffer::GenericReplayBuffer; +pub use step_proc::{SimpleStepProcessor, SimpleStepProcessorConfig}; +pub use sum_tree::WeightNormalizer; + +// TODO: Consider to compile this module only for tests. +/// Agent and Env for testing. +pub mod test { + use border_core::{record::Record, Act, Agent, Configurable, Env, Info, Obs, Policy, Step}; + use serde::{Deserialize, Serialize}; + + /// Obs for testing. + #[derive(Clone, Debug)] + pub struct TestObs { + obs: usize, + } + + impl Obs for TestObs { + fn len(&self) -> usize { + 1 + } + } + + /// Batch of obs for testing. + pub struct TestObsBatch { + obs: Vec, + } + + impl crate::BatchBase for TestObsBatch { + fn new(capacity: usize) -> Self { + Self { + obs: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.obs[i] = data.obs[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let obs = ixs.iter().map(|ix| self.obs[*ix]).collect(); + Self { obs } + } + } + + impl From for TestObsBatch { + fn from(obs: TestObs) -> Self { + Self { obs: vec![obs.obs] } + } + } + + /// Act for testing. + #[derive(Clone, Debug)] + pub struct TestAct { + act: usize, + } + + impl Act for TestAct {} + + /// Batch of act for testing. + pub struct TestActBatch { + act: Vec, + } + + impl From for TestActBatch { + fn from(act: TestAct) -> Self { + Self { act: vec![act.act] } + } + } + + impl crate::BatchBase for TestActBatch { + fn new(capacity: usize) -> Self { + Self { + act: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.act[i] = data.act[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let act = ixs.iter().map(|ix| self.act[*ix]).collect(); + Self { act } + } + } + + /// Info for testing. + pub struct TestInfo {} + + impl Info for TestInfo {} + + /// Environment for testing. + pub struct TestEnv { + state_init: usize, + state: usize, + } + + impl Env for TestEnv { + type Config = usize; + type Obs = TestObs; + type Act = TestAct; + type Info = TestInfo; + + fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn step_with_reset(&mut self, a: &Self::Act) -> (Step, Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: Some(TestObs { + obs: self.state_init, + }), + }; + return (step, Record::empty()); + } + + fn step(&mut self, a: &TestAct) -> (Step, Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: Some(TestObs { + obs: self.state_init, + }), + }; + return (step, Record::empty()); + } + + fn build(config: &Self::Config, _seed: i64) -> anyhow::Result + where + Self: Sized, + { + Ok(Self { + state_init: *config, + state: 0, + }) + } + } + + type ReplayBuffer = crate::GenericReplayBuffer; + + /// Agent for testing. + pub struct TestAgent {} + + #[derive(Clone, Deserialize, Serialize)] + /// Config of agent for testing. + pub struct TestAgentConfig; + + impl Agent for TestAgent { + fn train(&mut self) {} + + fn is_train(&self) -> bool { + false + } + + fn eval(&mut self) {} + + fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> Record { + Record::empty() + } + + fn save_params(&self, _path: &std::path::Path) -> anyhow::Result> { + Ok(vec![]) + } + + fn load_params(&mut self, _path: &std::path::Path) -> anyhow::Result<()> { + Ok(()) + } + + fn as_any_ref(&self) -> &dyn std::any::Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } + } + + impl Policy for TestAgent { + fn sample(&mut self, _obs: &TestObs) -> TestAct { + TestAct { act: 1 } + } + } + + impl Configurable for TestAgent { + type Config = TestAgentConfig; + + fn build(_config: Self::Config) -> Self { + Self {} + } + } +} diff --git a/border-generic-replay-buffer/src/replay_buffer.rs b/border-generic-replay-buffer/src/replay_buffer.rs new file mode 100644 index 00000000..6a434291 --- /dev/null +++ b/border-generic-replay-buffer/src/replay_buffer.rs @@ -0,0 +1,427 @@ +//! Generic implementation of replay buffers for reinforcement learning. +//! +//! This module provides a generic implementation of replay buffers that can store +//! and sample transitions of arbitrary observation and action types. It supports: +//! - Standard experience replay +//! - Prioritized experience replay (PER) +//! - Importance sampling weights for off-policy learning + +use super::{ + config::{GenericReplayBufferConfig, PerConfig}, + BatchBase, GenericTransitionBatch, +}; +use crate::iw_scheduler::IwScheduler; +use crate::sum_tree::SumTree; +use anyhow::Result; +use border_core::{ExperienceBuffer, ReplayBuffer, TransitionBatch}; +use rand::{rngs::StdRng, RngCore, SeedableRng}; + +/// State management for Prioritized Experience Replay (PER). +/// +/// This struct maintains the necessary state for PER, including: +/// - A sum tree for efficient priority sampling +/// - An importance weight scheduler for adjusting sample weights +struct PerState { + /// A sum tree data structure for efficient priority sampling. + sum_tree: SumTree, + + /// Scheduler for importance sampling weights. + iw_scheduler: IwScheduler, +} + +impl PerState { + /// Creates a new PER state with the given configuration. + /// + /// # Arguments + /// + /// * `capacity` - Maximum number of transitions to store + /// * `per_config` - Configuration for prioritized experience replay + fn new(capacity: usize, per_config: &PerConfig) -> Self { + Self { + sum_tree: SumTree::new(capacity, per_config.alpha, per_config.normalize), + iw_scheduler: IwScheduler::new( + per_config.beta_0, + per_config.beta_final, + per_config.n_opts_final, + ), + } + } +} + +/// A generic implementation of a replay buffer for reinforcement learning. +/// +/// This buffer can store transitions of arbitrary observation and action types, +/// making it suitable for a wide range of reinforcement learning tasks. It supports: +/// - Standard experience replay +/// - Prioritized experience replay (optional) +/// - Efficient sampling and storage +/// +/// # Type Parameters +/// +/// * `O` - The type of observations, must implement [`BatchBase`] +/// * `A` - The type of actions, must implement [`BatchBase`] +/// +/// # Examples +/// +/// ```ignore +/// let config = GenericReplayBufferConfig { +/// capacity: 10000, +/// per_config: Some(PerConfig { +/// alpha: 0.6, +/// beta_0: 0.4, +/// beta_final: 1.0, +/// n_opts_final: 100000, +/// normalize: true, +/// }), +/// }; +/// +/// let mut buffer = GenericReplayBuffer::::build(&config); +/// +/// // Add transitions +/// buffer.push(transition)?; +/// +/// // Sample a batch +/// let batch = buffer.batch(32)?; +/// ``` +pub struct GenericReplayBuffer +where + O: BatchBase, + A: BatchBase, +{ + /// Maximum number of transitions that can be stored. + capacity: usize, + + /// Current insertion index. + i: usize, + + /// Current number of stored transitions. + size: usize, + + /// Storage for observations. + obs: O, + + /// Storage for actions. + act: A, + + /// Storage for next observations. + next_obs: O, + + /// Storage for rewards. + reward: Vec, + + /// Storage for termination flags. + is_terminated: Vec, + + /// Storage for truncation flags. + is_truncated: Vec, + + /// Random number generator for sampling. + rng: StdRng, + + /// State for prioritized experience replay, if enabled. + per_state: Option, +} + +impl GenericReplayBuffer +where + O: BatchBase, + A: BatchBase, +{ + /// Pushes rewards into the buffer at the specified index. + /// + /// # Arguments + /// + /// * `i` - Starting index for insertion + /// * `b` - Vector of rewards to insert + #[inline] + fn push_reward(&mut self, i: usize, b: &Vec) { + let mut j = i; + for r in b.iter() { + self.reward[j] = *r; + j += 1; + if j == self.capacity { + j = 0; + } + } + } + + /// Pushes termination flags into the buffer at the specified index. + /// + /// # Arguments + /// + /// * `i` - Starting index for insertion + /// * `b` - Vector of termination flags to insert + #[inline] + fn push_is_terminated(&mut self, i: usize, b: &Vec) { + let mut j = i; + for d in b.iter() { + self.is_terminated[j] = *d; + j += 1; + if j == self.capacity { + j = 0; + } + } + } + + /// Pushes truncation flags into the buffer at the specified index. + /// + /// # Arguments + /// + /// * `i` - Starting index for insertion + /// * `b` - Vector of truncation flags to insert + fn push_is_truncated(&mut self, i: usize, b: &Vec) { + let mut j = i; + for d in b.iter() { + self.is_truncated[j] = *d; + j += 1; + if j == self.capacity { + j = 0; + } + } + } + + /// Samples rewards for the given indices. + /// + /// # Arguments + /// + /// * `ixs` - Indices to sample from + /// + /// # Returns + /// + /// Vector of sampled rewards + fn sample_reward(&self, ixs: &Vec) -> Vec { + ixs.iter().map(|ix| self.reward[*ix]).collect() + } + + /// Samples termination flags for the given indices. + /// + /// # Arguments + /// + /// * `ixs` - Indices to sample from + /// + /// # Returns + /// + /// Vector of sampled termination flags + fn sample_is_terminated(&self, ixs: &Vec) -> Vec { + ixs.iter().map(|ix| self.is_terminated[*ix]).collect() + } + + /// Samples truncation flags for the given indices. + /// + /// # Arguments + /// + /// * `ixs` - Indices to sample from + /// + /// # Returns + /// + /// Vector of sampled truncation flags + fn sample_is_truncated(&self, ixs: &Vec) -> Vec { + ixs.iter().map(|ix| self.is_truncated[*ix]).collect() + } + + /// Sets priorities for newly added samples in prioritized experience replay. + /// + /// # Arguments + /// + /// * `batch_size` - Number of new samples to prioritize + fn set_priority(&mut self, batch_size: usize) { + let sum_tree = &mut self.per_state.as_mut().unwrap().sum_tree; + let max_p = sum_tree.max(); + + for j in 0..batch_size { + let i = (self.i + j) % self.capacity; + sum_tree.add(i, max_p); + } + } + + /// Returns a batch containing all actions in the buffer. + /// + /// # Warning + /// + /// This method should be used with caution on large replay buffers + /// as it may consume significant memory. + pub fn whole_actions(&self) -> A { + let ixs = (0..self.size).collect::>(); + self.act.sample(&ixs) + } + + /// Returns the number of terminated episodes in the buffer. + pub fn num_terminated_flags(&self) -> usize { + self.is_terminated + .iter() + .map(|is_terminated| *is_terminated as usize) + .sum() + } + + /// Returns the number of truncated episodes in the buffer. + pub fn num_truncated_flags(&self) -> usize { + self.is_truncated + .iter() + .map(|is_truncated| *is_truncated as usize) + .sum() + } + + /// Returns the sum of all rewards in the buffer. + pub fn sum_rewards(&self) -> f32 { + self.reward.iter().sum() + } +} + +impl ExperienceBuffer for GenericReplayBuffer +where + O: BatchBase, + A: BatchBase, +{ + type Item = GenericTransitionBatch; + + /// Returns the current number of transitions in the buffer. + fn len(&self) -> usize { + self.size + } + + /// Adds a new transition to the buffer. + /// + /// # Arguments + /// + /// * `tr` - The transition to add + /// + /// # Returns + /// + /// `Ok(())` if the transition was added successfully + /// + /// # Errors + /// + /// Returns an error if the buffer is full and cannot accept more transitions + fn push(&mut self, tr: Self::Item) -> Result<()> { + let len = tr.len(); // batch size + let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = tr.unpack(); + self.obs.push(self.i, obs); + self.act.push(self.i, act); + self.next_obs.push(self.i, next_obs); + self.push_reward(self.i, &reward); + self.push_is_terminated(self.i, &is_terminated); + self.push_is_truncated(self.i, &is_truncated); + + if self.per_state.is_some() { + self.set_priority(len) + }; + + self.i = (self.i + len) % self.capacity; + self.size += len; + if self.size >= self.capacity { + self.size = self.capacity; + } + + Ok(()) + } +} + +impl ReplayBuffer for GenericReplayBuffer +where + O: BatchBase, + A: BatchBase, +{ + type Config = GenericReplayBufferConfig; + type Batch = GenericTransitionBatch; + + /// Creates a new replay buffer with the given configuration. + /// + /// # Arguments + /// + /// * `config` - Configuration for the replay buffer + /// + /// # Returns + /// + /// A new instance of the replay buffer + fn build(config: &Self::Config) -> Self { + let capacity = config.capacity; + let per_state = match &config.per_config { + Some(per_config) => Some(PerState::new(capacity, per_config)), + None => None, + }; + + Self { + capacity, + i: 0, + size: 0, + obs: O::new(capacity), + act: A::new(capacity), + next_obs: O::new(capacity), + reward: vec![0.; capacity], + is_terminated: vec![0; capacity], + is_truncated: vec![0; capacity], + rng: StdRng::seed_from_u64(config.seed as _), + per_state, + } + } + + /// Samples a batch of transitions from the buffer. + /// + /// If prioritized experience replay is enabled, samples are selected + /// according to their priorities. Otherwise, uniform random sampling is used. + /// + /// # Arguments + /// + /// * `size` - Number of transitions to sample + /// + /// # Returns + /// + /// A batch of sampled transitions + /// + /// # Errors + /// + /// Returns an error if: + /// - The buffer is empty + /// - The requested batch size is larger than the buffer size + fn batch(&mut self, size: usize) -> Result { + let (ixs, weight) = if let Some(per_state) = &self.per_state { + let sum_tree = &per_state.sum_tree; + let beta = per_state.iw_scheduler.beta(); + let (ixs, weight) = sum_tree.sample(size, beta); + let ixs = ixs.iter().map(|&ix| ix as usize).collect(); + (ixs, Some(weight)) + } else { + let ixs = (0..size) + // .map(|_| self.rng.usize(..self.size)) + .map(|_| (self.rng.next_u32() as usize) % self.size) + .collect::>(); + let weight = None; + (ixs, weight) + }; + + Ok(Self::Batch { + obs: self.obs.sample(&ixs), + act: self.act.sample(&ixs), + next_obs: self.next_obs.sample(&ixs), + reward: self.sample_reward(&ixs), + is_terminated: self.sample_is_terminated(&ixs), + is_truncated: self.sample_is_truncated(&ixs), + ix_sample: Some(ixs), + weight, + }) + } + + /// Updates the priorities of transitions in the buffer. + /// + /// This method is used in prioritized experience replay to adjust + /// the sampling probabilities based on TD errors. + /// + /// # Arguments + /// + /// * `ixs` - Optional indices of transitions to update + /// * `td_errs` - Optional TD errors for the transitions + fn update_priority(&mut self, ixs: &Option>, td_errs: &Option>) { + if let Some(per_state) = &mut self.per_state { + let ixs = ixs + .as_ref() + .expect("ixs should be Some(_) in update_priority()."); + let td_errs = td_errs + .as_ref() + .expect("td_errs should be Some(_) in update_priority()."); + for (&ix, &td_err) in ixs.iter().zip(td_errs.iter()) { + per_state.sum_tree.update(ix, td_err); + } + per_state.iw_scheduler.add_n_opts(); + } + } +} diff --git a/border-generic-replay-buffer/src/step_proc.rs b/border-generic-replay-buffer/src/step_proc.rs new file mode 100644 index 00000000..e6ff6eb8 --- /dev/null +++ b/border-generic-replay-buffer/src/step_proc.rs @@ -0,0 +1,138 @@ +//! Generic implementation of step processing for reinforcement learning. +//! +//! This module provides a generic implementation of the `StepProcessor` trait, +//! which handles the conversion of environment steps into transitions suitable +//! for training. It supports: +//! - 1-step TD backup for non-vectorized environments +//! - Generic observation and action types +//! - Efficient batch processing + +use super::{BatchBase, GenericTransitionBatch}; +use border_core::{Env, Obs, Step, StepProcessor}; +use std::{default::Default, marker::PhantomData}; + +/// Configuration for the simple step processor. +#[derive(Clone, Debug)] +pub struct SimpleStepProcessorConfig {} + +impl Default for SimpleStepProcessorConfig { + /// Creates a new default configuration. + fn default() -> Self { + Self {} + } +} + +/// A generic implementation of the `StepProcessor` trait. +/// +/// This processor converts environment steps into transitions suitable for +/// training reinforcement learning agents. It supports 1-step TD backup +/// for non-vectorized environments, meaning that each step contains exactly +/// one observation. +/// +/// # Type Parameters +/// +/// * `E` - The environment type, must implement `Env` +/// * `O` - The observation batch type, must implement `BatchBase` and `From` +/// * `A` - The action batch type, must implement `BatchBase` and `From` +pub struct SimpleStepProcessor { + /// The previous observation, used to construct transitions. + prev_obs: Option, + /// Phantom data to hold the generic type parameters. + phantom: PhantomData<(E, A)>, +} + +impl StepProcessor for SimpleStepProcessor +where + E: Env, + O: BatchBase + From, + A: BatchBase + From, +{ + type Config = SimpleStepProcessorConfig; + type Output = GenericTransitionBatch; + + /// Creates a new step processor with the given configuration. + /// + /// # Arguments + /// + /// * `_config` - The configuration for the processor + /// + /// # Returns + /// + /// A new instance of the step processor + fn build(_config: &Self::Config) -> Self { + Self { + prev_obs: None, + phantom: PhantomData, + } + } + + /// Resets the processor with an initial observation. + /// + /// This method must be called before processing any steps to initialize + /// the processor with the starting state of the environment. + /// + /// # Arguments + /// + /// * `init_obs` - The initial observation from the environment + fn reset(&mut self, init_obs: E::Obs) { + self.prev_obs = Some(init_obs.into()); + } + + /// Processes a step from the environment into a transition. + /// + /// This method converts an environment step into a transition suitable + /// for training. It handles: + /// - Converting observations and actions to the appropriate batch types + /// - Managing the previous observation for constructing transitions + /// - Handling episode termination and truncation + /// + /// # Arguments + /// + /// * `step` - The step to process + /// + /// # Returns + /// + /// A transition batch containing the processed step + /// + /// # Panics + /// + /// This method will panic if: + /// - The step contains more than one observation + /// - `reset()` has not been called before processing steps + /// - The step is terminal but does not contain an initial observation + fn process(&mut self, step: Step) -> Self::Output { + assert_eq!(step.obs.len(), 1); + + let batch = if self.prev_obs.is_none() { + panic!("prev_obs is not set. Forgot to call reset()?"); + } else { + let is_done = step.is_done(); + let next_obs = step.obs.clone().into(); + let obs = self.prev_obs.replace(step.obs.into()).unwrap(); + let act = step.act.into(); + let reward = step.reward; + let is_terminated = step.is_terminated; + let is_truncated = step.is_truncated; + let ix_sample = None; + let weight = None; + + if is_done { + self.prev_obs + .replace(step.init_obs.expect("Failed to unwrap init_obs").into()); + } + + GenericTransitionBatch { + obs, + act, + next_obs, + reward, + is_terminated, + is_truncated, + ix_sample, + weight, + } + }; + + batch + } +} diff --git a/border-generic-replay-buffer/src/sum_tree.rs b/border-generic-replay-buffer/src/sum_tree.rs new file mode 100644 index 00000000..39afaf4f --- /dev/null +++ b/border-generic-replay-buffer/src/sum_tree.rs @@ -0,0 +1,217 @@ +//! Sum tree for prioritized sampling. +//! +//! Code is adapted from and +/// +use segment_tree::{ + ops::{MaxIgnoreNaN, MinIgnoreNaN}, + SegmentPoint, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Copy, Debug, Clone, Deserialize, Serialize, PartialEq)] +/// Specifies how to normalize the importance weights in a prioritized batch. +pub enum WeightNormalizer { + /// Normalize weights by the maximum weight of all samples in the buffer. + All, + /// Normalize weights by the maximum weight of samples in the batch. + Batch, +} + +#[derive(Debug)] +pub struct SumTree { + eps: f32, + alpha: f32, + capacity: usize, + n_samples: usize, + tree: Vec, + min_tree: SegmentPoint, + max_tree: SegmentPoint, + normalize: WeightNormalizer, +} + +impl SumTree { + pub fn new(capacity: usize, alpha: f32, normalize: WeightNormalizer) -> Self { + Self { + eps: 1e-8, + alpha, + capacity, + n_samples: 0, + tree: vec![0f32; 2 * capacity - 1], + min_tree: SegmentPoint::build(vec![f32::MAX; capacity], MinIgnoreNaN), + max_tree: SegmentPoint::build(vec![1e-8f32; capacity], MaxIgnoreNaN), + normalize, + } + } + + fn propagate(&mut self, ix: usize, change: f32) { + let parent = (ix - 1) / 2; + self.tree[parent] += change; + if parent != 0 { + self.propagate(parent, change); + } + } + + fn retrieve(&self, ix: usize, s: f32) -> usize { + let left = 2 * ix + 1; + let right = left + 1; + + if left >= self.tree.len() { + return ix; + } + + if s <= self.tree[left] || self.tree[right] == 0f32 { + return self.retrieve(left, s); + } else { + return self.retrieve(right, s - self.tree[left]); + } + } + + pub fn total(&self) -> f32 { + return self.tree[0]; + } + + pub fn max(&self) -> f32 { + self.max_tree + .query(0, self.max_tree.len()) + .powf(1.0 / self.alpha) + } + + /// Add priority value at `ix`-th element in the sum tree. + /// + /// The alpha-th power of the priority value is taken when addition. + pub fn add(&mut self, ix: usize, p: f32) { + debug_assert!(ix <= self.n_samples); + + self.update(ix, p); + + if self.n_samples < self.capacity { + self.n_samples += 1; + } + } + + /// Update priority value at `ix`-th element in the sum tree. + pub fn update(&mut self, ix: usize, p: f32) { + debug_assert!(ix < self.capacity); + + let p = (p + self.eps).powf(self.alpha); + self.min_tree.modify(ix, p); + self.max_tree.modify(ix, p); + let ix = ix + self.capacity - 1; + let change = p - self.tree[ix]; + if change.is_nan() { + println!("{:?}, {:?}", p, self.tree[ix]); + panic!(); + } + self.tree[ix] = p; + self.propagate(ix, change); + } + + /// Get the maximal index of the sum tree where the sum of priority values is less than `s`. + pub fn get(&self, s: f32) -> usize { + let ix = self.retrieve(0, s); + debug_assert!(ix >= (self.capacity - 1)); + ix + 1 - self.capacity + } + + /// Samples indices for batch and returns normalized weights. + /// + /// The weight is $w_i=\left(N^{-1}P(i)^{-1}\right)^{\beta}$ + /// and it will be normalized by $max_i w_i$. + pub fn sample(&self, batch_size: usize, beta: f32) -> (Vec, Vec) { + let p_sum = &self.total(); + let ps = (0..batch_size) + .map(|_| p_sum * fastrand::f32()) + .collect::>(); + let indices = ps.iter().map(|&p| self.get(p)).collect::>(); + // let indices = (0..batch_size) + // .map(|_| self.get(p_sum * fastrand::f32())) + // .collect::>(); + + let n = self.n_samples as f32 / p_sum; + let ws = indices + .iter() + .map(|ix| self.tree[ix + self.capacity - 1]) + .map(|p| (n * p).powf(-beta)) + .collect::>(); + + // normalizer within all samples + let w_max_inv = match self.normalize { + WeightNormalizer::All => (n * self.min_tree.query(0, self.n_samples)).powf(beta), + WeightNormalizer::Batch => 1f32 / ws.iter().fold(0.0 / 0.0, |m, v| v.max(m)), + }; + let ws = ws.iter().map(|w| w * w_max_inv).collect::>(); + + if p_sum.is_nan() || w_max_inv.is_nan() || ws.iter().sum::().is_nan() { + println!("self.n_samples: {:?}", self.n_samples); + println!("p_sum: {:?}", p_sum); + println!("w_max_inv: {:?}", w_max_inv); + println!("ps: {:?}", ps); + println!("indices: {:?}", indices); + println!("{:?}", ws); + panic!(); + } + + let ixs = indices.iter().map(|&ix| ix as i64).collect(); + + (ixs, ws) + } + + #[allow(dead_code)] + pub fn print_tree(&self) { + let mut nl = 1; + + for i in 0..self.tree.len() { + print!("{} ", self.tree[i]); + if i == 2 * nl - 2 { + println!(); + nl *= 2; + } + } + println!("max = {}", self.max()); + // println!("min = {}", self.min()); + println!("total = {}", self.total()); + } +} + +#[cfg(test)] +mod tests { + use super::{SumTree, WeightNormalizer::Batch}; + + #[test] + fn test_sum_tree_odd() { + let data = vec![0.5f32, 0.2, 0.8, 0.3, 1.1, 2.5, 3.9]; + let mut sum_tree = SumTree::new(8, 1.0, Batch); + for ix in 0..data.len() { + sum_tree.add(ix, data[ix]); + } + sum_tree.print_tree(); + println!(); + + assert_eq!(sum_tree.get(0.0), 0); + assert_eq!(sum_tree.get(0.4), 0); + assert_eq!(sum_tree.get(0.5), 0); + assert_eq!(sum_tree.get(0.6), 1); + assert_eq!(sum_tree.get(1.2), 2); + assert_eq!(sum_tree.get(1.6), 3); + assert_eq!(sum_tree.get(2.0), 4); + assert_eq!(sum_tree.get(2.8), 4); + + sum_tree.update(7, 2.0); + sum_tree.print_tree(); + println!(); + + // let (ixs, ws) = sum_tree.sample(10, 1.0); + // println!("{:?}", ixs); + // println!("{:?}", ws); + // println!(); + + // let n_samples = 1000000; + // let (ixs, _) = sum_tree.sample(n_samples, 1.0); + // debug_assert!(ixs.iter().all(|&ix| ix < data.len() as i64)); + // (0..5).for_each(|ix| { + // let p = data[ix] / sum_tree.total() * (n_samples as f32); + // let n = ixs.iter().filter(|&&e| e == ix as i64).collect::>().len(); + // println!("ix={:?}: {:?} (p={:?})", ix, n, p); + // }) + } +} diff --git a/border-minari/Cargo.toml b/border-minari/Cargo.toml index 362e5965..35991c5d 100644 --- a/border-minari/Cargo.toml +++ b/border-minari/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } numpy = { workspace = true } pyo3 = { workspace = true, default-features = false, features = [ "auto-initialize", "macros" diff --git a/border-minari/README.md b/border-minari/README.md index e69de29b..181c7ab4 100644 --- a/border-minari/README.md +++ b/border-minari/README.md @@ -0,0 +1,76 @@ +A wrapper for [Minari](https://minari.farama.org) environments. + +This crate provides a Rust interface for Minari datasets, which are collections of offline reinforcement learning data. +It allows users to load and interact with Minari datasets in a way that is compatible with the Border framework. + +# Features + +- **Dataset Loading**: Load Minari datasets from disk or from the Minari registry. +- **Environment Interaction**: Interact with the loaded datasets using the Border environment interface. +- **Data Access**: Access observations, actions, rewards, and other data from the datasets. + +# Example + +The following example demonstrates how to: +1. Load a D4RL Kitchen dataset +2. Create a replay buffer from a specific episode +3. Recover the environment state +4. Replay the actions from the dataset + +This is particularly useful for: +- Analyzing expert demonstrations +- Testing environment behavior +- Validating dataset quality +- Reproducing recorded trajectories + +```no_run +# use anyhow::Result; +use border_core::Env; +use border_minari::{d4rl::kitchen::ndarray::KitchenConverter, MinariDataset}; +# use numpy::convert; +# use std::num; + +fn main() -> Result<()> { + // Load the D4RL Kitchen dataset + let dataset = MinariDataset::load_dataset("D4RL/kitchen/complete-v1", true)?; + + // Create a converter for handling observation and action types + let mut converter = KitchenConverter {}; + + // Create a replay buffer containing only the sixth episode + let replay_buffer = dataset.create_replay_buffer(&mut converter, Some(vec![5]))?; + + // Recover the environment state from the dataset + // The 'false' parameter indicates not to use the initial state + // 'human' indicates the agent type + let mut env = dataset.recover_environment(converter, false, "human")?; + + // Get the sequence of actions from the replay buffer + let actions = replay_buffer.whole_actions(); + + // Reset the environment and replay the actions + env.reset(None)?; + for ix in 0..actions.action.shape()[0] { + let act = actions.get(ix); + let _ = env.step(&act); + } + + Ok(()) +} +``` + +The example uses the following key components: +- [`KitchenConverter`]: Handles conversion between Python and Rust types for the Kitchen environment +- [`MinariDataset`]: Manages the dataset and provides methods for data access +- [`Env`]: The Border environment interface for interaction + +[`KitchenConverter`]: crate::d4rl::kitchen::ndarray::KitchenConverter +[`MinariDataset`]: crate::MinariDataset +[`Env`]: border_core::Env + +# Integration with Border + +This crate implements the [`Env`] trait from `border-core`, making it compatible with other Border components +such as agents, policies, and trainers. It can be used in both online and offline reinforcement learning scenarios. + +[`Env`]: border_core::Env diff --git a/border-minari/src/converter.rs b/border-minari/src/converter.rs index b27e78c6..f5141c0c 100644 --- a/border-minari/src/converter.rs +++ b/border-minari/src/converter.rs @@ -1,5 +1,6 @@ use anyhow::Result; -use border_core::{generic_replay_buffer::BatchBase, Act, Obs}; +use border_core::{Act, Obs}; +use border_generic_replay_buffer::BatchBase; use pyo3::{PyAny, PyObject, Python}; /// Conversion trait for observation and action. diff --git a/border-minari/src/d4rl/antmaze/candle.rs b/border-minari/src/d4rl/antmaze/candle.rs index faa15fdb..2847ae9e 100644 --- a/border-minari/src/d4rl/antmaze/candle.rs +++ b/border-minari/src/d4rl/antmaze/candle.rs @@ -4,7 +4,7 @@ use crate::{ MinariConverter, MinariDataset, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{DType, Device, Tensor}; use ndarray::{ArrayBase, ArrayD, Axis, Slice}; use pyo3::{types::PyIterator, PyAny, PyObject, Python}; diff --git a/border-minari/src/d4rl/antmaze/ndarray.rs b/border-minari/src/d4rl/antmaze/ndarray.rs index c8cd7e40..d549e2de 100644 --- a/border-minari/src/d4rl/antmaze/ndarray.rs +++ b/border-minari/src/d4rl/antmaze/ndarray.rs @@ -6,7 +6,7 @@ use crate::{ MinariConverter, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use ndarray::{s, ArrayD, Axis, IxDyn, Slice}; use pyo3::{PyAny, PyObject, Python}; diff --git a/border-minari/src/d4rl/kitchen/candle.rs b/border-minari/src/d4rl/kitchen/candle.rs index 925021ef..b1dd2b7e 100644 --- a/border-minari/src/d4rl/kitchen/candle.rs +++ b/border-minari/src/d4rl/kitchen/candle.rs @@ -9,7 +9,7 @@ use crate::{ MinariConverter, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{DType, Device, Tensor}; use ndarray::{ArrayBase, ArrayD, Axis, Slice}; use pyo3::{PyAny, PyObject, Python}; diff --git a/border-minari/src/d4rl/kitchen/ndarray.rs b/border-minari/src/d4rl/kitchen/ndarray.rs index 0a7a3a3e..e305acdc 100644 --- a/border-minari/src/d4rl/kitchen/ndarray.rs +++ b/border-minari/src/d4rl/kitchen/ndarray.rs @@ -6,7 +6,7 @@ use crate::{ MinariConverter, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use ndarray::{s, ArrayD, Axis, IxDyn, Slice}; use pyo3::{PyAny, PyObject, Python}; diff --git a/border-minari/src/d4rl/pointmaze/ndarray.rs b/border-minari/src/d4rl/pointmaze/ndarray.rs index 10d5a87e..336571a1 100644 --- a/border-minari/src/d4rl/pointmaze/ndarray.rs +++ b/border-minari/src/d4rl/pointmaze/ndarray.rs @@ -6,7 +6,7 @@ use crate::{ MinariConverter, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use ndarray::{s, ArrayD, Axis, IxDyn, Slice}; use pyo3::{PyAny, PyObject, Python}; diff --git a/border-minari/src/dataset.rs b/border-minari/src/dataset.rs index 11861ba6..a60833d6 100644 --- a/border-minari/src/dataset.rs +++ b/border-minari/src/dataset.rs @@ -1,8 +1,8 @@ use crate::{util, MinariConverter, MinariEnv}; use anyhow::Result; -use border_core::{ - generic_replay_buffer::{GenericTransitionBatch, SimpleReplayBuffer, SimpleReplayBufferConfig}, - ExperienceBufferBase, ReplayBufferBase, +use border_core::{ExperienceBuffer, ReplayBuffer}; +use border_generic_replay_buffer::{ + GenericTransitionBatch, GenericReplayBuffer, GenericReplayBufferConfig, }; use pyo3::{ types::{IntoPyDict, PyIterator}, @@ -65,7 +65,7 @@ impl MinariDataset { &self, converter: &mut T, episode_indices: Option>, - ) -> Result> + ) -> Result> where T::ObsBatch: std::fmt::Debug, T::ActBatch: std::fmt::Debug, @@ -75,7 +75,7 @@ impl MinariDataset { let num_transitions = self.get_num_transitions(episode_indices.clone())?; // Prepare replay buffer - let mut replay_buffer = SimpleReplayBuffer::build(&SimpleReplayBufferConfig { + let mut replay_buffer = GenericReplayBuffer::build(&GenericReplayBufferConfig { capacity: num_transitions, seed: 0, per_config: None, diff --git a/border-minari/src/evaluator.rs b/border-minari/src/evaluator.rs index 5ad28436..5021abd9 100644 --- a/border-minari/src/evaluator.rs +++ b/border-minari/src/evaluator.rs @@ -1,7 +1,7 @@ //! Evaluator for Minari environments. use crate::{MinariConverter, MinariEnv}; use anyhow::Result; -use border_core::{record::Record, Agent, Env, Evaluator, ReplayBufferBase}; +use border_core::{record::Record, Agent, Env, Evaluator, ReplayBuffer}; /// An evaluator for Minari environments. /// @@ -22,7 +22,7 @@ impl Evaluator> for MinariEvaluator { /// The average return over episodes is returned. /// If the environment has ref_min_score and ref_max_score, the normalized score is also returned /// in the record. - fn evaluate( + fn evaluate( &mut self, policy: &mut Box, R>>, ) -> Result<(f32, Record)> { diff --git a/border-minari/src/lib.rs b/border-minari/src/lib.rs index ffec0690..3d3cbbaa 100644 --- a/border-minari/src/lib.rs +++ b/border-minari/src/lib.rs @@ -1,80 +1,4 @@ -//! A wrapper for [Minari](https://minari.farama.org) environments. -//! -//! This crate provides a Rust interface for Minari datasets, which are collections of offline reinforcement learning data. -//! It allows users to load and interact with Minari datasets in a way that is compatible with the Border framework. -//! -//! # Features -//! -//! - **Dataset Loading**: Load Minari datasets from disk or from the Minari registry. -//! - **Environment Interaction**: Interact with the loaded datasets using the Border environment interface. -//! - **Data Access**: Access observations, actions, rewards, and other data from the datasets. -//! -//! # Example -//! -//! The following example demonstrates how to: -//! 1. Load a D4RL Kitchen dataset -//! 2. Create a replay buffer from a specific episode -//! 3. Recover the environment state -//! 4. Replay the actions from the dataset -//! -//! This is particularly useful for: -//! - Analyzing expert demonstrations -//! - Testing environment behavior -//! - Validating dataset quality -//! - Reproducing recorded trajectories -//! -//! ```no_run -//! # use anyhow::Result; -//! use border_core::Env; -//! use border_minari::{d4rl::kitchen::ndarray::KitchenConverter, MinariDataset}; -//! # use numpy::convert; -//! # use std::num; -//! -//! fn main() -> Result<()> { -//! // Load the D4RL Kitchen dataset -//! let dataset = MinariDataset::load_dataset("D4RL/kitchen/complete-v1", true)?; -//! -//! // Create a converter for handling observation and action types -//! let mut converter = KitchenConverter {}; -//! -//! // Create a replay buffer containing only the sixth episode -//! let replay_buffer = dataset.create_replay_buffer(&mut converter, Some(vec![5]))?; -//! -//! // Recover the environment state from the dataset -//! // The 'false' parameter indicates not to use the initial state -//! // 'human' indicates the agent type -//! let mut env = dataset.recover_environment(converter, false, "human")?; -//! -//! // Get the sequence of actions from the replay buffer -//! let actions = replay_buffer.whole_actions(); -//! -//! // Reset the environment and replay the actions -//! env.reset(None)?; -//! for ix in 0..actions.action.shape()[0] { -//! let act = actions.get(ix); -//! let _ = env.step(&act); -//! } -//! -//! Ok(()) -//! } -//! ``` -//! -//! The example uses the following key components: -//! - [`KitchenConverter`]: Handles conversion between Python and Rust types for the Kitchen environment -//! - [`MinariDataset`]: Manages the dataset and provides methods for data access -//! - [`Env`]: The Border environment interface for interaction -//! -//! [`KitchenConverter`]: crate::d4rl::kitchen::ndarray::KitchenConverter -//! [`MinariDataset`]: crate::MinariDataset -//! [`Env`]: border_core::Env -//! -//! # Integration with Border -//! -//! This crate implements the [`Env`] trait from `border-core`, making it compatible with other Border components -//! such as agents, policies, and trainers. It can be used in both online and offline reinforcement learning scenarios. -//! -//! [`Env`]: border_core::Env - +#![doc = include_str!("../README.md")] mod converter; pub mod d4rl; mod dataset; diff --git a/border-minari/src/util/candle/tensor_batch.rs b/border-minari/src/util/candle/tensor_batch.rs index 6051657c..0680a88d 100644 --- a/border-minari/src/util/candle/tensor_batch.rs +++ b/border-minari/src/util/candle/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{/*error::Result, DType,*/ Device, Tensor}; // /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-mlflow-tracking/README.md b/border-mlflow-tracking/README.md index dab7cc1d..62b737c3 100644 --- a/border-mlflow-tracking/README.md +++ b/border-mlflow-tracking/README.md @@ -1,16 +1,25 @@ -Support [MLflow](https://mlflow.org) tracking to manage experiments. +A logger for border-core crate. -Before running the program using this crate, run a tracking server with the following command: +This crate is based on [MLflow](https://mlflow.org) tracking. + +# Setup + +To use this crate, you need to start an MLflow tracking server first. You can do this by running: ```bash mlflow server --host 127.0.0.1 --port 8080 ``` -Then, training configurations and metrices can be logged to the tracking server. -The following code provides an example. Nested configuration parameters will be flattened, +Before running the program using this crate, you need to set the `MLFLOW_DEFAULT_ARTIFACT_ROOT` +environment variable to specify where model parameters and artifacts will be saved during training. +Typically, you should set this to the `mlruns` directory of your MLflow installation. + +# Example + +The following code is an example. Nested configuration parameters will be flattened, logged like `hyper_params.param1`, `hyper_params.param2`. -```rust +```no_run use anyhow::Result; use border_core::record::{Record, RecordValue, Recorder}; use border_mlflow_tracking::MlflowTrackingClient; @@ -66,7 +75,8 @@ fn main() -> Result<()> { }; // Set experiment for runs - let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; + let client = MlflowTrackingClient::new("http://localhost:8080") + .set_experiment("Default")?; // Create recorders for logging let mut recorder_run1 = client.create_recorder("")?; @@ -103,3 +113,11 @@ fn main() -> Result<()> { Ok(()) } ``` + +## Save model parameters during training + +[`MlflowTrackingClient`] relies on the `MLFLOW_DEFAULT_ARTIFACT_ROOT` environment variable +to locate where model parameters are saved during training. Note that this environment variable +should be set for the program using this crate, not for the tracking server program. +Currently, only saving to the local file system is supported. + diff --git a/border-mlflow-tracking/src/client.rs b/border-mlflow-tracking/src/client.rs index dc2c2ccb..f11c450f 100644 --- a/border-mlflow-tracking/src/client.rs +++ b/border-mlflow-tracking/src/client.rs @@ -1,6 +1,6 @@ use crate::{system_time_as_millis, Experiment, MlflowTrackingRecorder, Run}; use anyhow::Result; -use border_core::{Env, ReplayBufferBase}; +use border_core::{Env, ReplayBuffer}; use log::info; use reqwest::blocking::Client; use serde::{Deserialize, Serialize}; @@ -212,7 +212,7 @@ impl MlflowTrackingClient { ) -> Result> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { let run_name = run_name.as_ref(); let run = { diff --git a/border-mlflow-tracking/src/lib.rs b/border-mlflow-tracking/src/lib.rs index 698e181f..bf119554 100644 --- a/border-mlflow-tracking/src/lib.rs +++ b/border-mlflow-tracking/src/lib.rs @@ -1,126 +1,4 @@ -//! A logger for border-core crate. -//! -//! This crate is based on [MLflow](https://mlflow.org) tracking. -//! -//! # Setup -//! -//! To use this crate, you need to start an MLflow tracking server first. You can do this by running: -//! -//! ```bash -//! mlflow server --host 127.0.0.1 --port 8080 -//! ``` -//! -//! Before running the program using this crate, you need to set the `MLFLOW_DEFAULT_ARTIFACT_ROOT` -//! environment variable to specify where model parameters and artifacts will be saved during training. -//! Typically, you should set this to the `mlruns` directory of your MLflow installation. -//! -//! # Example -//! -//! The following code is an example. Nested configuration parameters will be flattened, -//! logged like `hyper_params.param1`, `hyper_params.param2`. -//! -//! ```no_run -//! use anyhow::Result; -//! use border_core::record::{Record, RecordValue, Recorder}; -//! use border_mlflow_tracking::MlflowTrackingClient; -//! use serde::Serialize; -//! -//! // Nested Configuration struct -//! #[derive(Debug, Serialize)] -//! struct Config { -//! env_params: String, -//! hyper_params: HyperParameters, -//! } -//! -//! #[derive(Debug, Serialize)] -//! struct HyperParameters { -//! param1: i64, -//! param2: Param2, -//! param3: Param3, -//! } -//! -//! #[derive(Debug, Serialize)] -//! enum Param2 { -//! Variant1, -//! Variant2(f32), -//! } -//! -//! #[derive(Debug, Serialize)] -//! struct Param3 { -//! dataset_name: String, -//! } -//! -//! fn main() -> Result<()> { -//! env_logger::init(); -//! -//! let config1 = Config { -//! env_params: "env1".to_string(), -//! hyper_params: HyperParameters { -//! param1: 0, -//! param2: Param2::Variant1, -//! param3: Param3 { -//! dataset_name: "a".to_string(), -//! }, -//! }, -//! }; -//! let config2 = Config { -//! env_params: "env2".to_string(), -//! hyper_params: HyperParameters { -//! param1: 0, -//! param2: Param2::Variant2(3.0), -//! param3: Param3 { -//! dataset_name: "a".to_string(), -//! }, -//! }, -//! }; -//! -//! // Set experiment for runs -//! let client = MlflowTrackingClient::new("http://localhost:8080") -//! .set_experiment("Default")?; -//! -//! // Create recorders for logging -//! let mut recorder_run1 = client.create_recorder("")?; -//! let mut recorder_run2 = client.create_recorder("")?; -//! recorder_run1.log_params(&config1)?; -//! recorder_run2.log_params(&config2)?; -//! -//! // Logging while training -//! for opt_steps in 0..100 { -//! let opt_steps = opt_steps as f32; -//! -//! // Create a record -//! let mut record = Record::empty(); -//! record.insert("opt_steps", RecordValue::Scalar(opt_steps)); -//! record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); -//! -//! // Log metrices in the record -//! recorder_run1.write(record); -//! } -//! -//! // Logging while training -//! for opt_steps in 0..100 { -//! let opt_steps = opt_steps as f32; -//! -//! // Create a record -//! let mut record = Record::empty(); -//! record.insert("opt_steps", RecordValue::Scalar(opt_steps)); -//! record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); -//! -//! // Log metrices in the record -//! recorder_run2.write(record); -//! } -//! -//! Ok(()) -//! } -//! ``` -//! -//! ## Save model parameters during training -//! -//! [`MlflowTrackingClient`] relies on the `MLFLOW_DEFAULT_ARTIFACT_ROOT` environment variable -//! to locate where model parameters are saved during training. Note that this environment variable -//! should be set for the program using this crate, not for the tracking server program. -//! Currently, only saving to the local file system is supported. -//! +#![doc = include_str!("../README.md")] mod client; mod experiment; mod recorder; diff --git a/border-mlflow-tracking/src/recorder.rs b/border-mlflow-tracking/src/recorder.rs index 61b4db1d..38d84424 100644 --- a/border-mlflow-tracking/src/recorder.rs +++ b/border-mlflow-tracking/src/recorder.rs @@ -2,7 +2,7 @@ use crate::{system_time_as_millis, Run}; use anyhow::Result; use border_core::{ record::{RecordStorage, RecordValue, Recorder}, - Agent, Env, ReplayBufferBase, + Agent, Env, ReplayBuffer, }; use chrono::{DateTime, Duration, Local, SecondsFormat}; use reqwest::blocking::Client; @@ -64,7 +64,7 @@ struct SetTagParams<'a> { pub struct MlflowTrackingRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { client: Client, base_url: String, @@ -81,7 +81,7 @@ where impl MlflowTrackingRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Create a new instance of `MlflowTrackingRecorder`. /// @@ -190,7 +190,7 @@ where impl Recorder for MlflowTrackingRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { fn write(&mut self, record: border_core::record::Record) { let url = format!("{}/api/2.0/mlflow/runs/log-metric", self.base_url); @@ -284,7 +284,7 @@ where impl Drop for MlflowTrackingRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Update run's status to "FINISHED" when dropped. /// diff --git a/border-policy-no-backend/src/lib.rs b/border-policy-no-backend/src/lib.rs index 93053528..a4a49260 100644 --- a/border-policy-no-backend/src/lib.rs +++ b/border-policy-no-backend/src/lib.rs @@ -1,4 +1,4 @@ -//! Policy with no backend. +#![doc = include_str!("../README.md")] mod mat; mod mlp; diff --git a/border-py-gym-env/Cargo.toml b/border-py-gym-env/Cargo.toml index 5516304c..bd63fe6b 100644 --- a/border-py-gym-env/Cargo.toml +++ b/border-py-gym-env/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer", optional = true } numpy = { workspace = true } pyo3 = { workspace = true, default-features = false, features = [ "auto-initialize", @@ -42,4 +43,5 @@ features = ["candle"] no-default-features = true [features] -candle = [ "candle-core" ] +candle = [ "candle-core", "border-generic-replay-buffer" ] +tch = ["dep:tch", "border-generic-replay-buffer"] \ No newline at end of file diff --git a/border-py-gym-env/README.md b/border-py-gym-env/README.md index e69de29b..085e6e7f 100644 --- a/border-py-gym-env/README.md +++ b/border-py-gym-env/README.md @@ -0,0 +1,44 @@ +A wrapper of [Gymnasium](https://gymnasium.farama.org) environments on Python. + +[`GymEnv`] is a wrapper of [Gymnasium](https://gymnasium.farama.org) based on [`PyO3`](https://github.com/PyO3/pyo3). +It has been tested on some of [classic control](https://gymnasium.farama.org/environments/classic_control/) and +[Gymnasium-Robotics](https://robotics.farama.org) environments. + +In order to bridge Python and Rust, we need to convert Python objects to Rust objects and vice versa. +This crate provides the [`GymEnvConverter`] trait to handle these conversions. + +# Type Conversion + +The [`GymEnvConverter`] trait provides a unified interface for converting between Python and Rust types: + +* `filt_obs`: Converts Python observations to Rust types +* `filt_act`: Converts Rust actions to Python types + +# Implementations + +This crate provides several implementations of [`GymEnvConverter`]: + +* [`ndarray::NdarrayConverter`]: Handles conversions for environments using ndarray types +* [`candle::CandleConverter`]: Handles conversions for environments using Candle tensor types + (requires `candle` feature flag) +* [`tch::TchConverter`]: Handles conversions for environments using Tch tensor types + (requires `tch` feature flag) + +To use Candle or Tch converters, enable the corresponding feature in your `Cargo.toml`: + +```toml +[dependencies] +border-py-gym-env = { version = "0.1.0", features = ["candle"] } # For Candle support +# or +border-py-gym-env = { version = "0.1.0", features = ["tch"] } # For Tch support +``` + +Each implementation supports different types of observations and actions: + +* Array observations (e.g., CartPole) +* Dictionary observations (e.g., FetchPickAndPlace) +* Discrete actions (e.g., CartPole) +* Continuous actions (e.g., Pendulum) + +[`Policy`]: border_core::Policy +[`ArrayD`]: https://docs.rs/ndarray/0.15.1/ndarray/type.ArrayD.html diff --git a/border-py-gym-env/src/candle/tensor_batch.rs b/border-py-gym-env/src/candle/tensor_batch.rs index 09bdc97e..5976922d 100644 --- a/border-py-gym-env/src/candle/tensor_batch.rs +++ b/border-py-gym-env/src/candle/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{error::Result, DType, Device, Tensor}; /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-py-gym-env/src/lib.rs b/border-py-gym-env/src/lib.rs index bd34abb0..9965bb4e 100644 --- a/border-py-gym-env/src/lib.rs +++ b/border-py-gym-env/src/lib.rs @@ -1,48 +1,6 @@ #![allow(rustdoc::broken_intra_doc_links)] -//! A wrapper of [Gymnasium](https://gymnasium.farama.org) environments on Python. -//! -//! [`GymEnv`] is a wrapper of [Gymnasium](https://gymnasium.farama.org) based on [`PyO3`](https://github.com/PyO3/pyo3). -//! It has been tested on some of [classic control](https://gymnasium.farama.org/environments/classic_control/) and -//! [Gymnasium-Robotics](https://robotics.farama.org) environments. -//! -//! In order to bridge Python and Rust, we need to convert Python objects to Rust objects and vice versa. -//! This crate provides the [`GymEnvConverter`] trait to handle these conversions. -//! -//! # Type Conversion -//! -//! The [`GymEnvConverter`] trait provides a unified interface for converting between Python and Rust types: -//! -//! * `filt_obs`: Converts Python observations to Rust types -//! * `filt_act`: Converts Rust actions to Python types -//! -//! # Implementations -//! -//! This crate provides several implementations of [`GymEnvConverter`]: -//! -//! * [`ndarray::NdarrayConverter`]: Handles conversions for environments using ndarray types -//! * [`candle::CandleConverter`]: Handles conversions for environments using Candle tensor types -//! (requires `candle` feature flag) -//! * [`tch::TchConverter`]: Handles conversions for environments using Tch tensor types -//! (requires `tch` feature flag) -//! -//! To use Candle or Tch converters, enable the corresponding feature in your `Cargo.toml`: -//! -//! ```toml -//! [dependencies] -//! border-py-gym-env = { version = "0.1.0", features = ["candle"] } # For Candle support -//! # or -//! border-py-gym-env = { version = "0.1.0", features = ["tch"] } # For Tch support -//! ``` -//! -//! Each implementation supports different types of observations and actions: -//! -//! * Array observations (e.g., CartPole) -//! * Dictionary observations (e.g., FetchPickAndPlace) -//! * Discrete actions (e.g., CartPole) -//! * Continuous actions (e.g., Pendulum) -//! -//! [`Policy`]: border_core::Policy -//! [`ArrayD`]: https://docs.rs/ndarray/0.15.1/ndarray/type.ArrayD.html +#![doc = include_str!("../README.md")] + mod base; #[cfg(feature = "candle")] pub mod candle; diff --git a/border-py-gym-env/src/tch/tensor_batch.rs b/border-py-gym-env/src/tch/tensor_batch.rs index 43f05740..91c88c83 100644 --- a/border-py-gym-env/src/tch/tensor_batch.rs +++ b/border-py-gym-env/src/tch/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use tch::{Device, Tensor}; /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-tch-agent/Cargo.toml b/border-tch-agent/Cargo.toml index e763004a..63a51c38 100644 --- a/border-tch-agent/Cargo.toml +++ b/border-tch-agent/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } border-async-trainer = { version = "0.0.9", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } diff --git a/border-tch-agent/src/dqn/base.rs b/border-tch-agent/src/dqn/base.rs index 2ba66a46..1d350ea3 100644 --- a/border-tch-agent/src/dqn/base.rs +++ b/border-tch-agent/src/dqn/base.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use serde::{de::DeserializeOwned, Serialize}; use std::{ @@ -51,7 +51,7 @@ impl Dqn where E: Env, Q: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, R::Batch: TransitionBatch, ::ObsBatch: Into, @@ -289,7 +289,7 @@ impl Agent for Dqn where E: Env + 'static, Q: SubModel + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, @@ -378,7 +378,7 @@ impl SyncModel for Dqn where E: Env, Q: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, diff --git a/border-tch-agent/src/iqn/base.rs b/border-tch-agent/src/iqn/base.rs index 51c1a1d5..75c028fa 100644 --- a/border-tch-agent/src/iqn/base.rs +++ b/border-tch-agent/src/iqn/base.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use log::trace; use serde::{de::DeserializeOwned, Serialize}; @@ -53,7 +53,7 @@ where E: Env, F: SubModel, M: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize + OutDim, R::Batch: TransitionBatch, @@ -275,7 +275,7 @@ where E: Env + 'static, F: SubModel + 'static, M: SubModel + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into, E::Act: From, F::Config: DeserializeOwned + Serialize + Clone, diff --git a/border-tch-agent/src/sac.rs b/border-tch-agent/src/sac.rs index 8f99c1cb..aaf987a8 100644 --- a/border-tch-agent/src/sac.rs +++ b/border-tch-agent/src/sac.rs @@ -5,15 +5,15 @@ //! ```no_run //! # use anyhow::Result; //! use border_core::{ -//! # Env as Env_, Obs as Obs_, Act as Act_, Step, test::{ -//! # TestAct as TestAct_, TestActBatch as TestActBatch_, -//! # TestEnv as TestEnv_, -//! # TestObs as TestObs_, TestObsBatch as TestObsBatch_, -//! # }, +//! # Env as Env_, Obs as Obs_, Act as Act_, Step, //! # record::Record, -//! # generic_replay_buffer::{SimpleReplayBuffer, BatchBase}, //! Configurable, //! }; +//! use border_generic_replay_buffer::{GenericReplayBuffer, BatchBase, test::{ +//! # TestAct as TestAct_, TestActBatch as TestActBatch_, +//! # TestEnv as TestEnv_, +//! # TestObs as TestObs_, TestObsBatch as TestObsBatch_, +//! # }}; //! use border_tch_agent::{ //! sac::{ActorConfig, CriticConfig, Sac, SacConfig}, //! mlp::{Mlp, Mlp2, MlpConfig}, @@ -136,7 +136,7 @@ //! # type Env = TestEnv; //! # type ObsBatch = TestObsBatch; //! # type ActBatch = TestActBatch; -//! # type ReplayBuffer = SimpleReplayBuffer; +//! # type ReplayBuffer = GenericReplayBuffer; //! # //! const DIM_OBS: i64 = 3; //! const DIM_ACT: i64 = 1; diff --git a/border-tch-agent/src/sac/base.rs b/border-tch-agent/src/sac/base.rs index d78f0293..046fa3e1 100644 --- a/border-tch-agent/src/sac/base.rs +++ b/border-tch-agent/src/sac/base.rs @@ -6,7 +6,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use serde::{de::DeserializeOwned, Serialize}; // use log::info; @@ -60,7 +60,7 @@ where E: Env, Q: SubModel2, P: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into + Into, E::Act: Into, Q::Input2: From, @@ -284,7 +284,7 @@ where E: Env + 'static, Q: SubModel2 + 'static, P: SubModel + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into + Into, E::Act: Into + From, Q::Input2: From, @@ -362,7 +362,7 @@ where E: Env, Q: SubModel2, P: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into + Into, E::Act: Into + From, Q::Input2: From, diff --git a/border-tch-agent/src/tensor_batch.rs b/border-tch-agent/src/tensor_batch.rs index 721e92be..719000ff 100644 --- a/border-tch-agent/src/tensor_batch.rs +++ b/border-tch-agent/src/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use tch::{Device, Tensor}; /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-tensorboard/src/lib.rs b/border-tensorboard/src/lib.rs index 59243794..cfbedde2 100644 --- a/border-tensorboard/src/lib.rs +++ b/border-tensorboard/src/lib.rs @@ -5,7 +5,7 @@ use anyhow::Result; use border_core::{ record::{Record, RecordValue, Recorder}, - Env, ReplayBufferBase, + Env, ReplayBuffer, }; use std::{ marker::PhantomData, @@ -17,7 +17,7 @@ use tensorboard_rs::summary_writer::SummaryWriter; pub struct TensorboardRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { model_dir: PathBuf, writer: SummaryWriter, @@ -30,7 +30,7 @@ where impl TensorboardRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Construct a [`TensorboardRecorder`]. /// @@ -56,7 +56,7 @@ where impl Recorder for TensorboardRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Writes a given [`Record`] into a TFRecord. /// diff --git a/examples/atari/dqn_atari/src/main.rs b/examples/atari/dqn_atari/src/main.rs index 79d6f610..2100ff05 100644 --- a/examples/atari/dqn_atari/src/main.rs +++ b/examples/atari/dqn_atari/src/main.rs @@ -5,7 +5,7 @@ use anyhow::Result; use args::Args; use border_core::{ generic_replay_buffer::SimpleStepProcessorConfig, record::Recorder, Agent, Configurable, - Env as _, Evaluator as _, ReplayBufferBase, StepProcessor, Trainer, + Env as _, Evaluator as _, ReplayBuffer, StepProcessor, Trainer, }; use border_mlflow_tracking::MlflowTrackingClient; use border_tensorboard::TensorboardRecorder; diff --git a/examples/atari/dqn_atari_tch/src/main.rs b/examples/atari/dqn_atari_tch/src/main.rs index bc98d47d..54342f08 100644 --- a/examples/atari/dqn_atari_tch/src/main.rs +++ b/examples/atari/dqn_atari_tch/src/main.rs @@ -5,7 +5,7 @@ use anyhow::Result; use args::Args; use border_core::{ generic_replay_buffer::SimpleStepProcessorConfig, record::Recorder, Agent, Configurable, - Env as _, Evaluator as _, ReplayBufferBase, StepProcessor, Trainer, + Env as _, Evaluator as _, ReplayBuffer, StepProcessor, Trainer, }; use border_mlflow_tracking::MlflowTrackingClient; use border_tensorboard::TensorboardRecorder; diff --git a/examples/d4rl/awac_pen/Cargo.toml b/examples/d4rl/awac_pen/Cargo.toml index 04ac03a8..7e9858c7 100644 --- a/examples/d4rl/awac_pen/Cargo.toml +++ b/examples/d4rl/awac_pen/Cargo.toml @@ -16,6 +16,7 @@ border-minari = { version = "0.0.9", path = "../../../border-minari", features = ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/d4rl/awac_pen/src/main.rs b/examples/d4rl/awac_pen/src/main.rs index f32e7799..e4f0c45f 100644 --- a/examples/d4rl/awac_pen/src/main.rs +++ b/examples/d4rl/awac_pen/src/main.rs @@ -10,11 +10,10 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, - record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, Trainer, + record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; +use border_generic_replay_buffer::{BatchBase, GenericReplayBuffer}; use border_minari::{ d4rl::pen::candle::{PenConverter, PenConverterConfig}, MinariConverter, MinariDataset, MinariEnv, MinariEvaluator, @@ -188,7 +187,7 @@ where E: Env + 'static, E::Obs: Into, E::Act: From + Into, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, R::Batch: TransitionBatch, ::ObsBatch: Into + Clone, ::ActBatch: Into + Clone, @@ -200,7 +199,7 @@ where fn create_replay_buffer( converter: &mut T, dataset: &MinariDataset, -) -> Result> +) -> Result> where T: MinariConverter, T::ObsBatch: BatchBase + Debug + Into, @@ -215,7 +214,7 @@ where fn create_recorder(config: &PenConfig) -> Result>> where E: Env + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, { log::info!("Create recorder"); if let Some(mlflow_run_name) = &config.args.mlflow_run_name { @@ -282,7 +281,7 @@ where T::ObsBatch: std::fmt::Debug + Into + 'static + Clone, T::ActBatch: std::fmt::Debug + Into + 'static + Clone, { - let mut agent: Box, SimpleReplayBuffer>> = + let mut agent: Box, GenericReplayBuffer>> = create_agent(&config); let recorder = create_recorder(&config)?; // used for loading a trained model let mut evaluator = create_evaluator(&config.args, converter, &dataset, true)?; diff --git a/examples/d4rl/bc_pen/Cargo.toml b/examples/d4rl/bc_pen/Cargo.toml index ea3b4030..67bbfb98 100644 --- a/examples/d4rl/bc_pen/Cargo.toml +++ b/examples/d4rl/bc_pen/Cargo.toml @@ -17,6 +17,7 @@ border-minari = { version = "0.0.9", path = "../../../border-minari", features = ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/d4rl/bc_pen/src/main.rs b/examples/d4rl/bc_pen/src/main.rs index 944e7e63..e8cd9dc3 100644 --- a/examples/d4rl/bc_pen/src/main.rs +++ b/examples/d4rl/bc_pen/src/main.rs @@ -5,11 +5,10 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, - record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, Trainer, + record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; +use border_generic_replay_buffer::{BatchBase, GenericReplayBuffer}; use border_minari::{ d4rl::pen::candle::{PenConverter, PenConverterConfig}, MinariConverter, MinariDataset, MinariEnv, MinariEvaluator, @@ -161,7 +160,7 @@ where E: Env + 'static, E::Obs: Into, E::Act: From + Into, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, R::Batch: TransitionBatch, ::ObsBatch: Into + Clone, ::ActBatch: Into + Clone, @@ -173,7 +172,7 @@ where fn create_replay_buffer( converter: &mut T, dataset: &MinariDataset, -) -> Result> +) -> Result> where T: MinariConverter, T::ObsBatch: BatchBase + Debug + Into, @@ -188,7 +187,7 @@ where fn create_recorder(config: &PenConfig) -> Result>> where E: Env + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, { log::info!("Create recorder"); if let Some(mlflow_run_name) = &config.args.mlflow_run_name { @@ -203,7 +202,7 @@ where let model_dir = format!("{}/{}", MODEL_DIR, config.args.env); Ok(Box::new(TensorboardRecorder::new( &model_dir, &model_dir, false, - ))) + ))) } } @@ -255,7 +254,7 @@ where T::ObsBatch: std::fmt::Debug + Into + 'static + Clone, T::ActBatch: std::fmt::Debug + Into + 'static + Clone, { - let mut agent: Box, SimpleReplayBuffer>> = + let mut agent: Box, GenericReplayBuffer>> = create_agent(&config); let recorder = create_recorder(&config)?; // used for loading a trained model let mut evaluator = create_evaluator(&config.args, converter, &dataset, true)?; diff --git a/examples/d4rl/iql_pen/Cargo.toml b/examples/d4rl/iql_pen/Cargo.toml index 04ac03a8..7e9858c7 100644 --- a/examples/d4rl/iql_pen/Cargo.toml +++ b/examples/d4rl/iql_pen/Cargo.toml @@ -16,6 +16,7 @@ border-minari = { version = "0.0.9", path = "../../../border-minari", features = ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/d4rl/iql_pen/src/main.rs b/examples/d4rl/iql_pen/src/main.rs index 0106ddb2..3b3ffb8e 100644 --- a/examples/d4rl/iql_pen/src/main.rs +++ b/examples/d4rl/iql_pen/src/main.rs @@ -10,11 +10,10 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, - record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, Trainer, + record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; +use border_generic_replay_buffer::{BatchBase, SimpleReplayBuffer}; use border_minari::{ d4rl::pen::candle::{PenConverter, PenConverterConfig}, MinariConverter, MinariDataset, MinariEnv, MinariEvaluator, @@ -197,7 +196,7 @@ where E: Env + 'static, E::Obs: Into, E::Act: From + Into, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, R::Batch: TransitionBatch, ::ObsBatch: Into + Clone, ::ActBatch: Into + Clone, @@ -209,7 +208,7 @@ where fn create_replay_buffer( converter: &mut T, dataset: &MinariDataset, -) -> Result> +) -> Result> where T: MinariConverter, T::ObsBatch: BatchBase + Debug + Into, @@ -224,7 +223,7 @@ where fn create_recorder(config: &PenConfig) -> Result>> where E: Env + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, { log::info!("Create recorder"); if let Some(mlflow_run_name) = &config.args.mlflow_run_name { @@ -239,7 +238,7 @@ where let model_dir = format!("{}/{}", MODEL_DIR, config.args.env); Ok(Box::new(TensorboardRecorder::new( &model_dir, &model_dir, false, - ))) + ))) } } @@ -291,7 +290,7 @@ where T::ObsBatch: std::fmt::Debug + Into + 'static + Clone, T::ActBatch: std::fmt::Debug + Into + 'static + Clone, { - let mut agent: Box, SimpleReplayBuffer>> = + let mut agent: Box, GenericReplayBuffer>> = create_agent(&config); let recorder = create_recorder(&config)?; // used for loading a trained model let mut evaluator = create_evaluator(&config.args, converter, &dataset, true)?; diff --git a/examples/gym/awac_pendulum/src/main.rs b/examples/gym/awac_pendulum/src/main.rs index 5cba63b5..96d04658 100644 --- a/examples/gym/awac_pendulum/src/main.rs +++ b/examples/gym/awac_pendulum/src/main.rs @@ -15,7 +15,7 @@ use border_core::{ SimpleStepProcessorConfig, }, record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer, StepProcessor, Trainer, TrainerConfig, }; use border_mlflow_tracking::MlflowTrackingClient; diff --git a/examples/gym/convert_policy/src/main.rs b/examples/gym/convert_policy/src/main.rs index 259b19c3..97379191 100644 --- a/examples/gym/convert_policy/src/main.rs +++ b/examples/gym/convert_policy/src/main.rs @@ -101,7 +101,7 @@ mod dummy { pub struct DummyReplayBuffer; - impl border_core::ReplayBufferBase for DummyReplayBuffer { + impl border_core::ReplayBuffer for DummyReplayBuffer { type Batch = DummyBatch; type Config = usize; diff --git a/examples/gym/dqn_cartpole/Cargo.toml b/examples/gym/dqn_cartpole/Cargo.toml index 1b8d0887..881d77eb 100644 --- a/examples/gym/dqn_cartpole/Cargo.toml +++ b/examples/gym/dqn_cartpole/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/dqn_cartpole/src/main.rs b/examples/gym/dqn_cartpole/src/main.rs index d032fb8f..17f8558f 100644 --- a/examples/gym/dqn_cartpole/src/main.rs +++ b/examples/gym/dqn_cartpole/src/main.rs @@ -7,13 +7,11 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, +}; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }; use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ @@ -31,7 +29,7 @@ use clap::Parser; use serde::Serialize; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -164,7 +162,7 @@ impl DqnCartpoleConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = DqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); diff --git a/examples/gym/dqn_cartpole_tch/Cargo.toml b/examples/gym/dqn_cartpole_tch/Cargo.toml index 7517abfe..0d19f8df 100644 --- a/examples/gym/dqn_cartpole_tch/Cargo.toml +++ b/examples/gym/dqn_cartpole_tch/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/dqn_cartpole_tch/src/main.rs b/examples/gym/dqn_cartpole_tch/src/main.rs index 078eca0f..8d273722 100644 --- a/examples/gym/dqn_cartpole_tch/src/main.rs +++ b/examples/gym/dqn_cartpole_tch/src/main.rs @@ -1,13 +1,12 @@ use anyhow::Result; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, +}; + use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ tch::{NdarrayConverter, NdarrayConverterConfig, TensorBatch}, @@ -19,12 +18,12 @@ use border_tch_agent::{ util::CriticLoss, }; use border_tensorboard::TensorboardRecorder; -use tch::Device; use clap::Parser; use serde::Serialize; +use tch::Device; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -157,7 +156,7 @@ impl DqnCartpoleConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = DqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); diff --git a/examples/gym/sac_fetch_reach/Cargo.toml b/examples/gym/sac_fetch_reach/Cargo.toml index 7b38c67e..6bb20fbc 100644 --- a/examples/gym/sac_fetch_reach/Cargo.toml +++ b/examples/gym/sac_fetch_reach/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/sac_fetch_reach/MUJOCO_LOG.TXT b/examples/gym/sac_fetch_reach/MUJOCO_LOG.TXT new file mode 100644 index 00000000..af0d24c6 --- /dev/null +++ b/examples/gym/sac_fetch_reach/MUJOCO_LOG.TXT @@ -0,0 +1,33 @@ +Sun Aug 10 11:01:19 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 1.6800. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + diff --git a/examples/gym/sac_fetch_reach/src/main.rs b/examples/gym/sac_fetch_reach/src/main.rs index b0b3a370..6ba9995c 100644 --- a/examples/gym/sac_fetch_reach/src/main.rs +++ b/examples/gym/sac_fetch_reach/src/main.rs @@ -7,14 +7,13 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, +}; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }; + use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ candle::{ @@ -31,7 +30,7 @@ use clap::Parser; use serde::Serialize; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -222,7 +221,7 @@ impl SacFetchReachConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = SacFetchReachConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); diff --git a/examples/gym/sac_pendulum/Cargo.toml b/examples/gym/sac_pendulum/Cargo.toml index a8dd6faf..8e83e6bc 100644 --- a/examples/gym/sac_pendulum/Cargo.toml +++ b/examples/gym/sac_pendulum/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/sac_pendulum/src/main.rs b/examples/gym/sac_pendulum/src/main.rs index 838443eb..7f6bd690 100644 --- a/examples/gym/sac_pendulum/src/main.rs +++ b/examples/gym/sac_pendulum/src/main.rs @@ -10,14 +10,13 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, +}; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }; + use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ candle::{ @@ -34,7 +33,7 @@ use clap::Parser; use serde::Serialize; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -190,7 +189,7 @@ impl SacPendulumConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = SacPendulumConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); diff --git a/examples/gym/sac_pendulum_tch/Cargo.toml b/examples/gym/sac_pendulum_tch/Cargo.toml index 7517abfe..0d19f8df 100644 --- a/examples/gym/sac_pendulum_tch/Cargo.toml +++ b/examples/gym/sac_pendulum_tch/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/sac_pendulum_tch/src/main.rs b/examples/gym/sac_pendulum_tch/src/main.rs index 5ec5b0b1..fc23b0dc 100644 --- a/examples/gym/sac_pendulum_tch/src/main.rs +++ b/examples/gym/sac_pendulum_tch/src/main.rs @@ -1,13 +1,12 @@ use anyhow::Result; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, +}; + use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ tch::{NdarrayConverter, NdarrayConverterConfig, TensorBatch}, @@ -24,7 +23,7 @@ use serde::Serialize; use tch::Device; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -153,7 +152,7 @@ impl SacPendulumConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = SacPendulumConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone());