From 7bae9a20066e8bb9b74aad7661cbc1d05a750e76 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 06:38:38 +0000 Subject: [PATCH 01/23] Bump versions for development --- CHANGELOG.md | 4 +++- Cargo.toml | 2 +- border-async-trainer/Cargo.toml | 4 ++-- border-atari-env/Cargo.toml | 6 +++--- border-candle-agent/Cargo.toml | 4 ++-- border-minari/Cargo.toml | 2 +- border-mlflow-tracking/Cargo.toml | 2 +- border-policy-no-backend/Cargo.toml | 4 ++-- border-py-gym-env/Cargo.toml | 2 +- border-tch-agent/Cargo.toml | 4 ++-- border-tensorboard/Cargo.toml | 2 +- examples/atari/dqn_atari/Cargo.toml | 10 +++++----- examples/atari/dqn_atari_async_tch/Cargo.toml | 12 ++++++------ examples/atari/dqn_atari_tch/Cargo.toml | 10 +++++----- examples/d4rl/awac_pen/Cargo.toml | 10 +++++----- examples/d4rl/bc_pen/Cargo.toml | 10 +++++----- examples/d4rl/iql_pen/Cargo.toml | 10 +++++----- examples/gym/awac_pendulum/Cargo.toml | 10 +++++----- examples/gym/convert_policy/Cargo.toml | 6 +++--- examples/gym/dqn_cartpole/Cargo.toml | 10 +++++----- examples/gym/dqn_cartpole_tch/Cargo.toml | 10 +++++----- examples/gym/pendulum_std/Cargo.toml | 6 +++--- examples/gym/sac_fetch_reach/Cargo.toml | 10 +++++----- examples/gym/sac_pendulum/Cargo.toml | 10 +++++----- examples/gym/sac_pendulum_tch/Cargo.toml | 10 +++++----- 25 files changed, 86 insertions(+), 84 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10f14e07..ef731086 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Changelog -## v0.0.8 (2025-??-??) +## v0.0.9 (2025-??-??) + +## v0.0.8 (2025-05-17) ### Added diff --git a/Cargo.toml b/Cargo.toml index 830e133e..498c2b07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ exclude = ["docker/", "examples/"] [workspace.package] -version = "0.0.8" +version = "0.0.9" edition = "2018" rust-version = "1.84" description = "Reinforcement learning library" diff --git a/border-async-trainer/Cargo.toml b/border-async-trainer/Cargo.toml index 42d5772f..edbcbbdc 100644 --- a/border-async-trainer/Cargo.toml +++ b/border-async-trainer/Cargo.toml @@ -12,8 +12,8 @@ readme = "README.md" [dependencies] anyhow = { workspace = true } aquamarine = { workspace = true } -border-core = { version = "0.0.8", path = "../border-core" } -border-tensorboard = { version = "0.0.8", path = "../border-tensorboard" } +border-core = { version = "0.0.9", path = "../border-core" } +border-tensorboard = { version = "0.0.9", path = "../border-tensorboard" } serde = { workspace = true, features = ["derive"] } log = { workspace = true } tokio = { version = "1.14.0", features = ["full"] } diff --git a/border-atari-env/Cargo.toml b/border-atari-env/Cargo.toml index 682995ef..6b1a1273 100644 --- a/border-atari-env/Cargo.toml +++ b/border-atari-env/Cargo.toml @@ -14,17 +14,17 @@ anyhow = { workspace = true } pixels = { version = "0.2.0", optional = true } winit = { version = "0.24.0", optional = true } dirs = { workspace = true } -border-core = { version = "0.0.8", path = "../border-core" } +border-core = { version = "0.0.9", path = "../border-core" } image = { workspace = true } tch = { workspace = true, optional = true } -border-tch-agent = { version = "0.0.8", path = "../border-tch-agent", optional = true } +border-tch-agent = { version = "0.0.9", path = "../border-tch-agent", optional = true } candle-core = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } itertools = "0.10.1" fastrand = { workspace = true } pollster = "=0.2.4" rand = { workspace = true } -border-candle-agent = { version = "0.0.8", path = "../border-candle-agent", optional = true } +border-candle-agent = { version = "0.0.9", path = "../border-candle-agent", optional = true } # The following crates are required by the code adapted from atari-env atari-env-sys = { version = "0.1.0", optional = true } diff --git a/border-candle-agent/Cargo.toml b/border-candle-agent/Cargo.toml index ba19c14c..f1a2306a 100644 --- a/border-candle-agent/Cargo.toml +++ b/border-candle-agent/Cargo.toml @@ -10,8 +10,8 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.8", path = "../border-core" } -border-async-trainer = { version = "0.0.8", path = "../border-async-trainer", optional = true } +border-core = { version = "0.0.9", path = "../border-core" } +border-async-trainer = { version = "0.0.9", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } tensorboard-rs = { workspace = true } diff --git a/border-minari/Cargo.toml b/border-minari/Cargo.toml index db6705bd..362e5965 100644 --- a/border-minari/Cargo.toml +++ b/border-minari/Cargo.toml @@ -10,7 +10,7 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.8", path = "../border-core" } +border-core = { version = "0.0.9", path = "../border-core" } numpy = { workspace = true } pyo3 = { workspace = true, default-features = false, features = [ "auto-initialize", "macros" diff --git a/border-mlflow-tracking/Cargo.toml b/border-mlflow-tracking/Cargo.toml index eaa52308..66829d0c 100644 --- a/border-mlflow-tracking/Cargo.toml +++ b/border-mlflow-tracking/Cargo.toml @@ -10,7 +10,7 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.8", path = "../border-core" } +border-core = { version = "0.0.9", path = "../border-core" } reqwest = { workspace = true } anyhow = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/border-policy-no-backend/Cargo.toml b/border-policy-no-backend/Cargo.toml index e8a9088e..ecef4b24 100644 --- a/border-policy-no-backend/Cargo.toml +++ b/border-policy-no-backend/Cargo.toml @@ -10,8 +10,8 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.8", path = "../border-core" } -border-tch-agent = { version = "0.0.8", path = "../border-tch-agent", optional = true } +border-core = { version = "0.0.9", path = "../border-core" } +border-tch-agent = { version = "0.0.9", path = "../border-tch-agent", optional = true } serde = { workspace = true, features = ["derive"] } log = { workspace = true } anyhow = { workspace = true } diff --git a/border-py-gym-env/Cargo.toml b/border-py-gym-env/Cargo.toml index 54e48fc7..5516304c 100644 --- a/border-py-gym-env/Cargo.toml +++ b/border-py-gym-env/Cargo.toml @@ -10,7 +10,7 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.8", path = "../border-core" } +border-core = { version = "0.0.9", path = "../border-core" } numpy = { workspace = true } pyo3 = { workspace = true, default-features = false, features = [ "auto-initialize", diff --git a/border-tch-agent/Cargo.toml b/border-tch-agent/Cargo.toml index 955a5cc5..e763004a 100644 --- a/border-tch-agent/Cargo.toml +++ b/border-tch-agent/Cargo.toml @@ -10,8 +10,8 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.8", path = "../border-core" } -border-async-trainer = { version = "0.0.8", path = "../border-async-trainer", optional = true } +border-core = { version = "0.0.9", path = "../border-core" } +border-async-trainer = { version = "0.0.9", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } tensorboard-rs = { workspace = true } diff --git a/border-tensorboard/Cargo.toml b/border-tensorboard/Cargo.toml index 507704d6..c91628aa 100644 --- a/border-tensorboard/Cargo.toml +++ b/border-tensorboard/Cargo.toml @@ -10,6 +10,6 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.8", path = "../border-core" } +border-core = { version = "0.0.9", path = "../border-core" } tensorboard-rs = { workspace = true } anyhow = { workspace = true } \ No newline at end of file diff --git a/examples/atari/dqn_atari/Cargo.toml b/examples/atari/dqn_atari/Cargo.toml index 9bc6de5b..48f2397e 100644 --- a/examples/atari/dqn_atari/Cargo.toml +++ b/examples/atari/dqn_atari/Cargo.toml @@ -11,11 +11,11 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" candle-core = { version = "0.8.4", feature = ["cuda", "cudnn"] } -border-candle-agent = { version = "0.0.8", path = "../../../border-candle-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } -border-atari-env = { version = "0.0.8", path = "../../../border-atari-env", features = ["candle"]} +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } +border-atari-env = { version = "0.0.9", path = "../../../border-atari-env", features = ["candle"]} serde = "1.0.194" serde_yaml = "0.8.7" diff --git a/examples/atari/dqn_atari_async_tch/Cargo.toml b/examples/atari/dqn_atari_async_tch/Cargo.toml index f3726629..5d2927a9 100644 --- a/examples/atari/dqn_atari_async_tch/Cargo.toml +++ b/examples/atari/dqn_atari_async_tch/Cargo.toml @@ -11,12 +11,12 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" tch = { version = "0.16.1" } -border-tch-agent = { version = "0.0.8", path = "../../../border-tch-agent", features = ["border-async-trainer"]} -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } -border-atari-env = { version = "0.0.8", path = "../../../border-atari-env", features = ["tch"]} -border-async-trainer = { version = "0.0.8", path = "../../../border-async-trainer" } +border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent", features = ["border-async-trainer"]} +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } +border-atari-env = { version = "0.0.9", path = "../../../border-atari-env", features = ["tch"]} +border-async-trainer = { version = "0.0.9", path = "../../../border-async-trainer" } serde = "1.0.194" serde_yaml = "0.8.7" diff --git a/examples/atari/dqn_atari_tch/Cargo.toml b/examples/atari/dqn_atari_tch/Cargo.toml index f017b9e5..aba78bd9 100644 --- a/examples/atari/dqn_atari_tch/Cargo.toml +++ b/examples/atari/dqn_atari_tch/Cargo.toml @@ -11,11 +11,11 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" tch = { version = "0.16.1" } -border-tch-agent = { version = "0.0.8", path = "../../../border-tch-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } -border-atari-env = { version = "0.0.8", path = "../../../border-atari-env", features = ["tch"]} +border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } +border-atari-env = { version = "0.0.9", path = "../../../border-atari-env", features = ["tch"]} serde = "1.0.194" serde_yaml = "0.8.7" diff --git a/examples/d4rl/awac_pen/Cargo.toml b/examples/d4rl/awac_pen/Cargo.toml index cf4532ed..04ac03a8 100644 --- a/examples/d4rl/awac_pen/Cargo.toml +++ b/examples/d4rl/awac_pen/Cargo.toml @@ -11,13 +11,13 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" candle-core = { version = "0.8.4", feature = ["cuda", "cudnn"] } -border-minari = { version = "0.0.8", path = "../../../border-minari", features = [ +border-minari = { version = "0.0.9", path = "../../../border-minari", features = [ "candle", ] } -border-candle-agent = { version = "0.0.8", path = "../../../border-candle-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] diff --git a/examples/d4rl/bc_pen/Cargo.toml b/examples/d4rl/bc_pen/Cargo.toml index 761dd397..ea3b4030 100644 --- a/examples/d4rl/bc_pen/Cargo.toml +++ b/examples/d4rl/bc_pen/Cargo.toml @@ -12,13 +12,13 @@ env_logger = "0.8.2" numpy = "0.14.1" candle-core = { version = "0.8.4", feature = ["cuda", "cudnn"] } candle-nn = { version = "0.8.4" } -border-minari = { version = "0.0.8", path = "../../../border-minari", features = [ +border-minari = { version = "0.0.9", path = "../../../border-minari", features = [ "candle", ] } -border-candle-agent = { version = "0.0.8", path = "../../../border-candle-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] diff --git a/examples/d4rl/iql_pen/Cargo.toml b/examples/d4rl/iql_pen/Cargo.toml index cf4532ed..04ac03a8 100644 --- a/examples/d4rl/iql_pen/Cargo.toml +++ b/examples/d4rl/iql_pen/Cargo.toml @@ -11,13 +11,13 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" candle-core = { version = "0.8.4", feature = ["cuda", "cudnn"] } -border-minari = { version = "0.0.8", path = "../../../border-minari", features = [ +border-minari = { version = "0.0.9", path = "../../../border-minari", features = [ "candle", ] } -border-candle-agent = { version = "0.0.8", path = "../../../border-candle-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] diff --git a/examples/gym/awac_pendulum/Cargo.toml b/examples/gym/awac_pendulum/Cargo.toml index 5e8e382e..1a332bba 100644 --- a/examples/gym/awac_pendulum/Cargo.toml +++ b/examples/gym/awac_pendulum/Cargo.toml @@ -10,13 +10,13 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" candle-core = { version = "0.8.4", feature = ["cuda", "cudnn"] } -border-py-gym-env = { version = "0.0.8", path = "../../../border-py-gym-env", features = [ +border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", features = [ "candle", ] } -border-candle-agent = { version = "0.0.8", path = "../../../border-candle-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] diff --git a/examples/gym/convert_policy/Cargo.toml b/examples/gym/convert_policy/Cargo.toml index 94ae8dbe..fe632f2f 100644 --- a/examples/gym/convert_policy/Cargo.toml +++ b/examples/gym/convert_policy/Cargo.toml @@ -9,11 +9,11 @@ anyhow = "1.0.38" clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" -border-policy-no-backend = { version = "0.0.8", path = "../../../border-policy-no-backend", features = [ +border-policy-no-backend = { version = "0.0.9", path = "../../../border-policy-no-backend", features = [ "tch", ] } -border-tch-agent = { version = "0.0.8", path = "../../../border-tch-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } +border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } serde = "1.0.194" tch = "0.16.0" bincode = "1.3.3" diff --git a/examples/gym/dqn_cartpole/Cargo.toml b/examples/gym/dqn_cartpole/Cargo.toml index 3083dd16..1b8d0887 100644 --- a/examples/gym/dqn_cartpole/Cargo.toml +++ b/examples/gym/dqn_cartpole/Cargo.toml @@ -10,13 +10,13 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" candle-core = { version = "0.8.4", feature = ["cuda", "cudnn"] } -border-py-gym-env = { version = "0.0.8", path = "../../../border-py-gym-env", features = [ +border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", features = [ "candle", ] } -border-candle-agent = { version = "0.0.8", path = "../../../border-candle-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] diff --git a/examples/gym/dqn_cartpole_tch/Cargo.toml b/examples/gym/dqn_cartpole_tch/Cargo.toml index 5f186bfd..7517abfe 100644 --- a/examples/gym/dqn_cartpole_tch/Cargo.toml +++ b/examples/gym/dqn_cartpole_tch/Cargo.toml @@ -10,13 +10,13 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" tch = "0.16.0" -border-py-gym-env = { version = "0.0.8", path = "../../../border-py-gym-env", features = [ +border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", features = [ "tch", ] } -border-tch-agent = { version = "0.0.8", path = "../../../border-tch-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] diff --git a/examples/gym/pendulum_std/Cargo.toml b/examples/gym/pendulum_std/Cargo.toml index 9b41c442..b91ec096 100644 --- a/examples/gym/pendulum_std/Cargo.toml +++ b/examples/gym/pendulum_std/Cargo.toml @@ -9,11 +9,11 @@ anyhow = "1.0.38" clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" -border-policy-no-backend = { version = "0.0.8", path = "../../../border-policy-no-backend", features = [ +border-policy-no-backend = { version = "0.0.9", path = "../../../border-policy-no-backend", features = [ "tch", ] } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-py-gym-env = { version = "0.0.8", path = "../../../border-py-gym-env" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env" } serde = "1.0.194" tch = "0.16.0" bincode = "1.3.3" diff --git a/examples/gym/sac_fetch_reach/Cargo.toml b/examples/gym/sac_fetch_reach/Cargo.toml index 54bb6d51..7b38c67e 100644 --- a/examples/gym/sac_fetch_reach/Cargo.toml +++ b/examples/gym/sac_fetch_reach/Cargo.toml @@ -10,13 +10,13 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" candle-core = { version = "0.8.4", feature = ["cuda", "cudnn"] } -border-py-gym-env = { version = "0.0.8", path = "../../../border-py-gym-env", features = [ +border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", features = [ "candle", ] } -border-candle-agent = { version = "0.0.8", path = "../../../border-candle-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] diff --git a/examples/gym/sac_pendulum/Cargo.toml b/examples/gym/sac_pendulum/Cargo.toml index 80fa2f9c..a8dd6faf 100644 --- a/examples/gym/sac_pendulum/Cargo.toml +++ b/examples/gym/sac_pendulum/Cargo.toml @@ -10,13 +10,13 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" candle-core = { version = "0.8.4", feature = ["cuda", "cudnn"] } -border-py-gym-env = { version = "0.0.8", path = "../../../border-py-gym-env", features = [ +border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", features = [ "candle", ] } -border-candle-agent = { version = "0.0.8", path = "../../../border-candle-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] diff --git a/examples/gym/sac_pendulum_tch/Cargo.toml b/examples/gym/sac_pendulum_tch/Cargo.toml index 5f186bfd..7517abfe 100644 --- a/examples/gym/sac_pendulum_tch/Cargo.toml +++ b/examples/gym/sac_pendulum_tch/Cargo.toml @@ -10,13 +10,13 @@ clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" tch = "0.16.0" -border-py-gym-env = { version = "0.0.8", path = "../../../border-py-gym-env", features = [ +border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", features = [ "tch", ] } -border-tch-agent = { version = "0.0.8", path = "../../../border-tch-agent" } -border-core = { version = "0.0.8", path = "../../../border-core" } -border-tensorboard = { version = "0.0.8", path = "../../../border-tensorboard" } -border-mlflow-tracking = { version = "0.0.8", path = "../../../border-mlflow-tracking" } +border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } +border-core = { version = "0.0.9", path = "../../../border-core" } +border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } +border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" [dev-dependencies] From 397e9a9be7325a3031b9bb63c36fd86d876cb724 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 16:01:18 +0900 Subject: [PATCH 02/23] Bump the rust version --- .github/workflows/ci.yml | 2 +- Cargo.toml | 2 +- examples/atari/dqn_atari/Cargo.toml | 2 +- examples/atari/dqn_atari_async_tch/Cargo.toml | 2 +- examples/atari/dqn_atari_tch/Cargo.toml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 647205b0..cdb8b8d6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest, macOS-latest] - rust: [1.84.0] + rust: [1.85.0] python-version: ["3.11"] steps: - uses: actions/checkout@v2 diff --git a/Cargo.toml b/Cargo.toml index 498c2b07..cc3696e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ exclude = ["docker/", "examples/"] [workspace.package] version = "0.0.9" edition = "2018" -rust-version = "1.84" +rust-version = "1.85" description = "Reinforcement learning library" repository = "https://github.com/laboroai/border" keywords = ["reinforcement", "learning", "rl"] diff --git a/examples/atari/dqn_atari/Cargo.toml b/examples/atari/dqn_atari/Cargo.toml index 48f2397e..041ab00a 100644 --- a/examples/atari/dqn_atari/Cargo.toml +++ b/examples/atari/dqn_atari/Cargo.toml @@ -2,7 +2,7 @@ name = "dqn_atari" version = "0.1.0" edition = "2018" -rust-version = "1.84" +rust-version = "1.85" [dependencies] log = "0.4" diff --git a/examples/atari/dqn_atari_async_tch/Cargo.toml b/examples/atari/dqn_atari_async_tch/Cargo.toml index 5d2927a9..b5b4b22f 100644 --- a/examples/atari/dqn_atari_async_tch/Cargo.toml +++ b/examples/atari/dqn_atari_async_tch/Cargo.toml @@ -2,7 +2,7 @@ name = "dqn_atari_async_tch" version = "0.1.0" edition = "2018" -rust-version = "1.84" +rust-version = "1.85" [dependencies] log = "0.4" diff --git a/examples/atari/dqn_atari_tch/Cargo.toml b/examples/atari/dqn_atari_tch/Cargo.toml index aba78bd9..3ec1381e 100644 --- a/examples/atari/dqn_atari_tch/Cargo.toml +++ b/examples/atari/dqn_atari_tch/Cargo.toml @@ -2,7 +2,7 @@ name = "dqn_atari_tch" version = "0.1.0" edition = "2018" -rust-version = "1.84" +rust-version = "1.85" [dependencies] log = "0.4" From 01ea3bcc70f512140b7d53dde58438a79c208573 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 08:34:14 +0000 Subject: [PATCH 03/23] Change names of traits RepalyBufferBase and ExperienceBufferBase --- border-async-trainer/src/actor/base.rs | 8 ++--- .../src/actor_manager/base.rs | 6 ++-- .../src/async_trainer/base.rs | 34 +++++++++---------- border-async-trainer/src/messages.rs | 2 +- .../src/replay_buffer_proxy.rs | 10 +++--- border-async-trainer/src/util.rs | 4 +-- border-atari-env/src/util/test.rs | 4 +-- border-candle-agent/src/awac/base.rs | 6 ++-- border-candle-agent/src/bc/base.rs | 6 ++-- border-candle-agent/src/dqn/base.rs | 8 ++--- border-candle-agent/src/iql/base.rs | 6 ++-- border-candle-agent/src/sac/base.rs | 6 ++-- border-core/src/base.rs | 2 +- border-core/src/base/agent.rs | 4 +-- border-core/src/base/replay_buffer.rs | 10 +++--- border-core/src/dummy.rs | 2 +- border-core/src/evaluator.rs | 6 ++-- .../src/evaluator/default_evaluator.rs | 4 +-- border-core/src/generic_replay_buffer/base.rs | 6 ++-- border-core/src/lib.rs | 22 ++++++------ border-core/src/record/base.rs | 2 +- border-core/src/record/buffered_recorder.rs | 12 +++---- border-core/src/record/null_recorder.rs | 10 +++--- border-core/src/record/recorder.rs | 6 ++-- border-core/src/trainer.rs | 10 +++--- border-core/src/trainer/sampler.rs | 6 ++-- border-minari/src/dataset.rs | 2 +- border-minari/src/evaluator.rs | 4 +-- border-mlflow-tracking/src/client.rs | 4 +-- border-mlflow-tracking/src/recorder.rs | 10 +++--- border-tch-agent/src/dqn/base.rs | 8 ++--- border-tch-agent/src/iqn/base.rs | 6 ++-- border-tch-agent/src/sac/base.rs | 8 ++--- border-tensorboard/src/lib.rs | 8 ++--- examples/atari/dqn_atari/src/main.rs | 2 +- examples/atari/dqn_atari_tch/src/main.rs | 2 +- examples/d4rl/awac_pen/src/main.rs | 6 ++-- examples/d4rl/bc_pen/src/main.rs | 6 ++-- examples/d4rl/iql_pen/src/main.rs | 6 ++-- examples/gym/awac_pendulum/src/main.rs | 2 +- examples/gym/convert_policy/src/main.rs | 2 +- examples/gym/dqn_cartpole/src/main.rs | 2 +- examples/gym/dqn_cartpole_tch/src/main.rs | 2 +- examples/gym/sac_fetch_reach/src/main.rs | 2 +- examples/gym/sac_pendulum/src/main.rs | 2 +- examples/gym/sac_pendulum_tch/src/main.rs | 2 +- 46 files changed, 144 insertions(+), 144 deletions(-) diff --git a/border-async-trainer/src/actor/base.rs b/border-async-trainer/src/actor/base.rs index cae4412a..8c2fc432 100644 --- a/border-async-trainer/src/actor/base.rs +++ b/border-async-trainer/src/actor/base.rs @@ -1,6 +1,6 @@ use crate::{ActorStat, PushedItemMessage, ReplayBufferProxy, ReplayBufferProxyConfig, SyncModel}; use border_core::{ - Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, Sampler, StepProcessor, + Agent, Configurable, Env, ExperienceBuffer, ReplayBuffer, Sampler, StepProcessor, }; use crossbeam_channel::Sender; use log::{debug, info}; @@ -21,7 +21,7 @@ use std::{ /// B-->|Env::Obs|A /// B-->|Step<E: Env>|C[StepProcessor] /// end -/// C-->|ReplayBufferBase::PushedItem|F[ReplayBufferProxy] +/// C-->|ReplayBuffer::PushedItem|F[ReplayBufferProxy] /// ``` /// /// In [`Actor`], an [`Agent`] runs on an [`Env`] and generates [`Step`] objects. @@ -41,7 +41,7 @@ where A: Agent + Configurable + SyncModel + 'static, E: Env, P: StepProcessor, - R: ExperienceBufferBase + ReplayBufferBase, + R: ExperienceBuffer + ReplayBuffer, { /// Stops sampling process if this field is set to `true`. id: usize, @@ -60,7 +60,7 @@ where A: Agent + Configurable + SyncModel + 'static, E: Env, P: StepProcessor, - R: ExperienceBufferBase + ReplayBufferBase, + R: ExperienceBuffer + ReplayBuffer, { pub fn build( id: usize, diff --git a/border-async-trainer/src/actor_manager/base.rs b/border-async-trainer/src/actor_manager/base.rs index 7e6ef59c..f1d6ad20 100644 --- a/border-async-trainer/src/actor_manager/base.rs +++ b/border-async-trainer/src/actor_manager/base.rs @@ -2,7 +2,7 @@ use crate::{ Actor, ActorManagerConfig, ActorStat, PushedItemMessage, ReplayBufferProxyConfig, SyncModel, }; use border_core::{ - Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor, + Agent, Configurable, Env, ExperienceBuffer, ReplayBuffer, StepProcessor, }; use crossbeam_channel::{bounded, /*unbounded,*/ Receiver, Sender}; use log::info; @@ -25,7 +25,7 @@ where A: Agent + Configurable + SyncModel, E: Env, P: StepProcessor, - R: ExperienceBufferBase + ReplayBufferBase, + R: ExperienceBuffer + ReplayBuffer, { /// Configurations of [`Agent`]s. agent_configs: Vec, @@ -72,7 +72,7 @@ where A: Agent + Configurable + SyncModel + 'static, E: Env, P: StepProcessor, - R: ExperienceBufferBase + Send + 'static + ReplayBufferBase, + R: ExperienceBuffer + Send + 'static + ReplayBuffer, A::Config: Send + 'static, E::Config: Send + 'static, P::Config: Send + 'static, diff --git a/border-async-trainer/src/async_trainer/base.rs b/border-async-trainer/src/async_trainer/base.rs index 49f5ef89..b0e15deb 100644 --- a/border-async-trainer/src/async_trainer/base.rs +++ b/border-async-trainer/src/async_trainer/base.rs @@ -2,7 +2,7 @@ use crate::{AsyncTrainStat, AsyncTrainerConfig, PushedItemMessage, SyncModel}; use anyhow::Result; use border_core::{ record::{Record, RecordValue::Scalar, Recorder}, - Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, + Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, }; use crossbeam_channel::{Receiver, Sender}; use log::{debug, info}; @@ -21,7 +21,7 @@ use std::{ /// ```mermaid /// flowchart LR /// subgraph ActorManager -/// E[Actor]-->|ReplayBufferBase::PushedItem|H[ReplayBufferProxy] +/// E[Actor]-->|ReplayBuffer::PushedItem|H[ReplayBufferProxy] /// F[Actor]-->H /// G[Actor]-->H /// end @@ -31,36 +31,36 @@ use std::{ /// /// subgraph I[AsyncTrainer] /// H-->|PushedItemMessage|J[ReplayBuffer] -/// J-->|ReplayBufferBase::Batch|K[Agent] +/// J-->|ReplayBuffer::Batch|K[Agent] /// end /// ``` /// /// * The [`Agent`] in [`AsyncTrainer`] (left) is trained with batches -/// of type [`ReplayBufferBase::Batch`], which are taken from the replay buffer. +/// of type [`ReplayBuffer::Batch`], which are taken from the replay buffer. /// * The model parameters of the [`Agent`] in [`AsyncTrainer`] are wrapped in /// [`SyncModel::ModelInfo`] and periodically sent to the [`Agent`]s in [`Actor`]s. /// [`Agent`] must implement [`SyncModel`] to synchronize the model parameters. /// * In [`ActorManager`] (right), [`Actor`]s sample transitions, which have type -/// [`ReplayBufferBase::Item`], and push the transitions into +/// [`ReplayBuffer::Item`], and push the transitions into /// [`ReplayBufferProxy`]. -/// * [`ReplayBufferProxy`] has a type parameter of [`ReplayBufferBase`] and the proxy accepts -/// [`ReplayBufferBase::Item`]. +/// * [`ReplayBufferProxy`] has a type parameter of [`ReplayBuffer`] and the proxy accepts +/// [`ReplayBuffer::Item`]. /// * The proxy sends the transitions into the replay buffer in the [`AsyncTrainer`]. /// /// [`ActorManager`]: crate::ActorManager /// [`Actor`]: crate::Actor -/// [`ReplayBufferBase::Item`]: border_core::ReplayBufferBase::PushedItem -/// [`ReplayBufferBase::Batch`]: border_core::ReplayBufferBase::PushedBatch +/// [`ReplayBuffer::Item`]: border_core::ReplayBuffer::PushedItem +/// [`ReplayBuffer::Batch`]: border_core::ReplayBuffer::PushedBatch /// [`ReplayBufferProxy`]: crate::ReplayBufferProxy -/// [`ReplayBufferBase`]: border_core::ReplayBufferBase +/// [`ReplayBuffer`]: border_core::ReplayBuffer /// [`SyncModel::ModelInfo`]: crate::SyncModel::ModelInfo /// [`Agent`]: border_core::Agent pub struct AsyncTrainer where A: Agent + Configurable + SyncModel, E: Env, - // R: ReplayBufferBase + Sync + Send + 'static, - R: ExperienceBufferBase + ReplayBufferBase, + // R: ReplayBuffer + Sync + Send + 'static, + R: ExperienceBuffer + ReplayBuffer, R::Item: Send + 'static, { /// Configuration of [`Env`]. Note that it is used only for evaluation, not for training. @@ -130,8 +130,8 @@ impl AsyncTrainer where A: Agent + Configurable + SyncModel + 'static, E: Env, - // R: ReplayBufferBase + Sync + Send + 'static, - R: ExperienceBufferBase + ReplayBufferBase, + // R: ReplayBuffer + Sync + Send + 'static, + R: ExperienceBuffer + ReplayBuffer, R::Item: Send + 'static, { /// Creates [`AsyncTrainer`]. @@ -231,7 +231,7 @@ where ) -> Result<()> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, D: Evaluator, { // Evaluation @@ -288,14 +288,14 @@ where /// In the training loop, the following values will be pushed into the given recorder: /// /// * `samples_total` - Total number of samples pushed into the replay buffer. - /// Here, a "sample" is an item in [`ExperienceBufferBase::Item`]. + /// Here, a "sample" is an item in [`ExperienceBuffer::Item`]. /// * `opt_steps_per_sec` - The number of optimization steps per second. /// * `samples_per_sec` - The number of samples per second. /// * `samples_per_opt_steps` - The number of samples per optimization step. /// /// These values will typically be monitored with tensorboard. /// - /// [`ExperienceBufferBase::Item`]: border_core::ExperienceBufferBase::Item + /// [`ExperienceBuffer::Item`]: border_core::ExperienceBuffer::Item pub fn train( &mut self, recorder: &mut Box>, diff --git a/border-async-trainer/src/messages.rs b/border-async-trainer/src/messages.rs index 32070a91..dd26a6ed 100644 --- a/border-async-trainer/src/messages.rs +++ b/border-async-trainer/src/messages.rs @@ -1,4 +1,4 @@ -/// Message containing a [`ReplayBufferBase`](border_core::ReplayBufferBase)`::Item`. +/// Message containing a [`ReplayBuffer`](border_core::ReplayBuffer)`::Item`. /// /// It will be sent from [`Actor`](crate::Actor) to [`ActorManager`](crate::ActorManager). pub struct PushedItemMessage { diff --git a/border-async-trainer/src/replay_buffer_proxy.rs b/border-async-trainer/src/replay_buffer_proxy.rs index 263c5beb..1a34758e 100644 --- a/border-async-trainer/src/replay_buffer_proxy.rs +++ b/border-async-trainer/src/replay_buffer_proxy.rs @@ -1,6 +1,6 @@ use crate::PushedItemMessage; use anyhow::Result; -use border_core::{ExperienceBufferBase, ReplayBufferBase}; +use border_core::{ExperienceBuffer, ReplayBuffer}; use crossbeam_channel::Sender; use std::marker::PhantomData; @@ -14,7 +14,7 @@ pub struct ReplayBufferProxyConfig { } /// A wrapper of replay buffer for asynchronous trainer. -pub struct ReplayBufferProxy { +pub struct ReplayBufferProxy { id: usize, /// Sender of [PushedItemMessage]. @@ -29,7 +29,7 @@ pub struct ReplayBufferProxy { phantom: PhantomData, } -impl ReplayBufferProxy { +impl ReplayBufferProxy { pub fn build_with_sender( id: usize, config: &ReplayBufferProxyConfig, @@ -46,7 +46,7 @@ impl ReplayBufferProxy { } } -impl ExperienceBufferBase for ReplayBufferProxy { +impl ExperienceBuffer for ReplayBufferProxy { type Item = R::Item; fn push(&mut self, tr: Self::Item) -> Result<()> { @@ -76,7 +76,7 @@ impl ExperienceBufferBase for ReplayBufferProxy { } } -impl ReplayBufferBase for ReplayBufferProxy { +impl ReplayBuffer for ReplayBufferProxy { type Config = ReplayBufferProxyConfig; type Batch = R::Batch; diff --git a/border-async-trainer/src/util.rs b/border-async-trainer/src/util.rs index eff40041..5c4a6d6e 100644 --- a/border-async-trainer/src/util.rs +++ b/border-async-trainer/src/util.rs @@ -3,7 +3,7 @@ use crate::{ actor_stats_fmt, ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel, }; use border_core::{ - record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, + record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, StepProcessor, }; use crossbeam_channel::unbounded; @@ -42,7 +42,7 @@ pub fn train_async( ) where A: Agent + Configurable + SyncModel + 'static, E: Env, - R: ExperienceBufferBase + Send + 'static + ReplayBufferBase, + R: ExperienceBuffer + Send + 'static + ReplayBuffer, S: StepProcessor, A::Config: Send + 'static, E::Config: Send + 'static, diff --git a/border-atari-env/src/util/test.rs b/border-atari-env/src/util/test.rs index 350d9af0..2efc83a2 100644 --- a/border-atari-env/src/util/test.rs +++ b/border-atari-env/src/util/test.rs @@ -7,7 +7,7 @@ use anyhow::Result; use border_core::{ generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, record::Record, - Agent as Agent_, Configurable, Policy, ReplayBufferBase, + Agent as Agent_, Configurable, Policy, ReplayBuffer as ReplayBuffer_, }; use serde::Deserialize; use std::ptr::copy; @@ -164,7 +164,7 @@ impl Configurable for RandomAgent { } } -impl Agent_ for RandomAgent { +impl Agent_ for RandomAgent { fn train(&mut self) { self.train = true; } diff --git a/border-candle-agent/src/awac/base.rs b/border-candle-agent/src/awac/base.rs index 23dc7241..564938ca 100644 --- a/border-candle-agent/src/awac/base.rs +++ b/border-candle-agent/src/awac/base.rs @@ -9,7 +9,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{Device, Tensor, D}; use candle_nn::{loss::mse, ops::softmax}; @@ -53,7 +53,7 @@ where E: Env, Q: SubModel2, P: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into + Into, E::Act: Into + Into, Q::Input2: From + Into, @@ -282,7 +282,7 @@ where E: Env + 'static, Q: SubModel2 + 'static, P: SubModel1 + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into + Into, E::Act: Into + Into + From, Q::Input2: From + Into, diff --git a/border-candle-agent/src/bc/base.rs b/border-candle-agent/src/bc/base.rs index 93f60b6d..3060c8ce 100644 --- a/border-candle-agent/src/bc/base.rs +++ b/border-candle-agent/src/bc/base.rs @@ -4,7 +4,7 @@ use crate::{model::SubModel1, util::OutDim}; use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{shape::D, DType, Device, Tensor}; use candle_nn::loss::mse; @@ -92,7 +92,7 @@ impl Agent for Bc where E: Env, P: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into, E::Act: From, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, @@ -157,7 +157,7 @@ impl Bc where E: Env, P: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, R::Batch: TransitionBatch, ::ObsBatch: Into, diff --git a/border-candle-agent/src/dqn/base.rs b/border-candle-agent/src/dqn/base.rs index 894a061e..9186c0c2 100644 --- a/border-candle-agent/src/dqn/base.rs +++ b/border-candle-agent/src/dqn/base.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{shape::D, DType, Device, Tensor}; use candle_nn::loss::mse; @@ -50,7 +50,7 @@ impl Dqn where E: Env, Q: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, R::Batch: TransitionBatch, ::ObsBatch: Into, @@ -280,7 +280,7 @@ impl Agent for Dqn where E: Env + 'static, Q: SubModel1 + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, @@ -367,7 +367,7 @@ impl SyncModel for Dqn where E: Env, Q: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, diff --git a/border-candle-agent/src/iql/base.rs b/border-candle-agent/src/iql/base.rs index 0738a3ab..85b13eff 100644 --- a/border-candle-agent/src/iql/base.rs +++ b/border-candle-agent/src/iql/base.rs @@ -9,7 +9,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{Device, Tensor, D}; use candle_nn::{loss::mse, ops::softmax}; @@ -59,7 +59,7 @@ where Q: SubModel2, P: SubModel1, V: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into, E::Act: Into, A: Clone, @@ -259,7 +259,7 @@ where Q: SubModel2 + 'static, P: SubModel1 + 'static, V: SubModel1 + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into + Into, E::Act: Into + From, O: 'static, diff --git a/border-candle-agent/src/sac/base.rs b/border-candle-agent/src/sac/base.rs index 214f7190..51b2be61 100644 --- a/border-candle-agent/src/sac/base.rs +++ b/border-candle-agent/src/sac/base.rs @@ -9,7 +9,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use candle_core::{Device, Tensor, D}; use candle_nn::loss::mse; @@ -50,7 +50,7 @@ where E: Env, Q: SubModel2, P: SubModel1, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into + Into, E::Act: Into, Q::Input2: From, @@ -215,7 +215,7 @@ where E: Env + 'static, Q: SubModel2 + 'static, P: SubModel1 + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into + Into, E::Act: Into + From, Q::Input2: From, diff --git a/border-core/src/base.rs b/border-core/src/base.rs index c73ed89f..f4585ca6 100644 --- a/border-core/src/base.rs +++ b/border-core/src/base.rs @@ -15,7 +15,7 @@ pub use agent::Agent; pub use batch::TransitionBatch; pub use env::Env; pub use policy::{Configurable, Policy}; -pub use replay_buffer::{ExperienceBufferBase, NullReplayBuffer, ReplayBufferBase}; +pub use replay_buffer::{ExperienceBuffer, NullReplayBuffer, ReplayBuffer}; use std::fmt::Debug; pub use step::{Info, Step, StepProcessor}; diff --git a/border-core/src/base/agent.rs b/border-core/src/base/agent.rs index ad2a8ecc..c1e54bb8 100644 --- a/border-core/src/base/agent.rs +++ b/border-core/src/base/agent.rs @@ -3,7 +3,7 @@ //! The [`Agent`] trait extends [`Policy`] with training capabilities, allowing the policy to //! learn from interactions with the environment. It provides methods for training, evaluation, //! parameter optimization, and model persistence. -use super::{Env, Policy, ReplayBufferBase}; +use super::{Env, Policy, ReplayBuffer}; use crate::record::Record; use anyhow::Result; use std::path::{Path, PathBuf}; @@ -21,7 +21,7 @@ use std::path::{Path, PathBuf}; /// /// During training, the agent uses a replay buffer to store and sample experiences, /// which are then used to update the policy's parameters through optimization steps. -pub trait Agent: Policy { +pub trait Agent: Policy { /// Switches the agent to training mode. /// /// In training mode, the policy may become stochastic to facilitate exploration. diff --git a/border-core/src/base/replay_buffer.rs b/border-core/src/base/replay_buffer.rs index 680be7a6..72aa9005 100644 --- a/border-core/src/base/replay_buffer.rs +++ b/border-core/src/base/replay_buffer.rs @@ -22,7 +22,7 @@ use anyhow::Result; /// items: Vec, /// } /// -/// impl ExperienceBufferBase for SimpleBuffer { +/// impl ExperienceBuffer for SimpleBuffer { /// type Item = T; /// /// fn push(&mut self, tr: T) -> Result<()> { @@ -35,7 +35,7 @@ use anyhow::Result; /// } /// } /// ``` -pub trait ExperienceBufferBase { +pub trait ExperienceBuffer { /// The type of items stored in the buffer. /// /// This can be any type that represents an experience or transition @@ -64,14 +64,14 @@ pub trait ExperienceBufferBase { /// Interface for replay buffers that generate batches for training. /// /// This trait provides functionality for sampling batches of experiences -/// for training agents. It is independent of [`ExperienceBufferBase`] and +/// for training agents. It is independent of [`ExperienceBuffer`] and /// focuses solely on the batch generation process. /// /// # Associated Types /// /// * `Config` - Configuration parameters for the buffer /// * `Batch` - The type of batch generated for training -pub trait ReplayBufferBase { +pub trait ReplayBuffer { /// Configuration parameters for the replay buffer. /// /// This type must implement `Clone` to support building multiple instances @@ -131,7 +131,7 @@ pub trait ReplayBufferBase { /// This struct is used as a placeholder when a replay buffer is not needed. pub struct NullReplayBuffer; -impl ReplayBufferBase for NullReplayBuffer { +impl ReplayBuffer for NullReplayBuffer { type Batch = (); type Config = (); diff --git a/border-core/src/dummy.rs b/border-core/src/dummy.rs index 9de88b23..67b55b17 100644 --- a/border-core/src/dummy.rs +++ b/border-core/src/dummy.rs @@ -90,7 +90,7 @@ impl crate::TransitionBatch for DummyBatch { /// Dummy replay buffer. pub struct DummyReplayBuffer; -impl crate::ReplayBufferBase for DummyReplayBuffer { +impl crate::ReplayBuffer for DummyReplayBuffer { type Batch = DummyBatch; type Config = usize; diff --git a/border-core/src/evaluator.rs b/border-core/src/evaluator.rs index f241830c..95f4ab79 100644 --- a/border-core/src/evaluator.rs +++ b/border-core/src/evaluator.rs @@ -8,7 +8,7 @@ //! - Monitor training progress //! - Validate the generalization of learned policies -use crate::{record::Record, Agent, Env, ReplayBufferBase}; +use crate::{record::Record, Agent, Env, ReplayBuffer}; use anyhow::Result; mod default_evaluator; pub use default_evaluator::DefaultEvaluator; @@ -36,7 +36,7 @@ pub use default_evaluator::DefaultEvaluator; /// impl Evaluator for CustomEvaluator { /// fn evaluate(&mut self, agent: &mut Box>) -> Result /// where -/// R: ReplayBufferBase, +/// R: ReplayBuffer, /// { /// // Custom evaluation logic /// // ... @@ -79,5 +79,5 @@ pub trait Evaluator { /// [`Trainer`]: crate::Trainer fn evaluate(&mut self, agent: &mut Box>) -> Result<(f32, Record)> where - R: ReplayBufferBase; + R: ReplayBuffer; } diff --git a/border-core/src/evaluator/default_evaluator.rs b/border-core/src/evaluator/default_evaluator.rs index 10b41719..ac1fd0e9 100644 --- a/border-core/src/evaluator/default_evaluator.rs +++ b/border-core/src/evaluator/default_evaluator.rs @@ -4,7 +4,7 @@ //! and calculates the average return across all episodes. use super::Evaluator; -use crate::{record::Record, Agent, Env, ReplayBufferBase}; +use crate::{record::Record, Agent, Env, ReplayBuffer}; use anyhow::Result; /// A default implementation of the [`Evaluator`] trait. @@ -63,7 +63,7 @@ impl Evaluator for DefaultEvaluator { /// - The environment fails to step fn evaluate(&mut self, policy: &mut Box>) -> Result<(f32, Record)> where - R: ReplayBufferBase, + R: ReplayBuffer, { let mut r_total = 0f32; diff --git a/border-core/src/generic_replay_buffer/base.rs b/border-core/src/generic_replay_buffer/base.rs index 14b3a623..d9dda3dc 100644 --- a/border-core/src/generic_replay_buffer/base.rs +++ b/border-core/src/generic_replay_buffer/base.rs @@ -9,7 +9,7 @@ mod iw_scheduler; mod sum_tree; use super::{config::PerConfig, BatchBase, GenericTransitionBatch, SimpleReplayBufferConfig}; -use crate::{ExperienceBufferBase, ReplayBufferBase, TransitionBatch}; +use crate::{ExperienceBuffer, ReplayBuffer, TransitionBatch}; use anyhow::Result; pub use iw_scheduler::IwScheduler; use rand::{rngs::StdRng, RngCore, SeedableRng}; @@ -267,7 +267,7 @@ where } } -impl ExperienceBufferBase for SimpleReplayBuffer +impl ExperienceBuffer for SimpleReplayBuffer where O: BatchBase, A: BatchBase, @@ -316,7 +316,7 @@ where } } -impl ReplayBufferBase for SimpleReplayBuffer +impl ReplayBuffer for SimpleReplayBuffer where O: BatchBase, A: BatchBase, diff --git a/border-core/src/lib.rs b/border-core/src/lib.rs index e374d79e..098b68c7 100644 --- a/border-core/src/lib.rs +++ b/border-core/src/lib.rs @@ -23,13 +23,13 @@ //! //! # Agent //! -//! In this crate, an [`Agent`] is defined as a trainable [`Policy`]. +//! In this crate, an [`Agent`] is defined as a trainable [`Policy`]. //! Agents operate in either training or evaluation mode. During training, the agent's policy may be probabilistic //! to facilitate exploration, while in evaluation mode, it typically becomes deterministic. //! //! The [`Agent::opt()`] method executes a single optimization step. The specific implementation of an optimization //! step varies between agents and may include multiple stochastic gradient descent steps. Training samples are -//! obtained from the [`ReplayBufferBase`]. +//! obtained from the [`ReplayBuffer`]. //! //! This trait also provides methods for saving and loading trained policy parameters to and from a directory. //! @@ -40,18 +40,18 @@ //! //! # Replay Buffer and Experience Buffer //! -//! The [`ReplayBufferBase`] trait provides an abstraction for replay buffers. Its associated type -//! [`ReplayBufferBase::Batch`] represents samples used for training [`Agent`]s. Agents must implement the -//! [`Agent::opt()`] method, where [`ReplayBufferBase::Batch`] must have appropriate type or trait bounds +//! The [`ReplayBuffer`] trait provides an abstraction for replay buffers. Its associated type +//! [`ReplayBuffer::Batch`] represents samples used for training [`Agent`]s. Agents must implement the +//! [`Agent::opt()`] method, where [`ReplayBuffer::Batch`] must have appropriate type or trait bounds //! for training the agent. //! -//! While [`ReplayBufferBase`] focuses on generating training batches, the [`ExperienceBufferBase`] trait -//! handles sample storage. The [`ExperienceBufferBase::push()`] method stores samples of type -//! [`ExperienceBufferBase::Item`], typically obtained through environment interactions. +//! While [`ReplayBuffer`] focuses on generating training batches, the [`ExperienceBuffer`] trait +//! handles sample storage. The [`ExperienceBuffer::push()`] method stores samples of type +//! [`ExperienceBuffer::Item`], typically obtained through environment interactions. //! //! ## Reference Implementation //! -//! [`SimpleReplayBuffer`] implements both [`ReplayBufferBase`] and [`ExperienceBufferBase`]. +//! [`SimpleReplayBuffer`] implements both [`ReplayBuffer`] and [`ExperienceBuffer`]. //! This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer. //! Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec`. //! The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of @@ -108,8 +108,8 @@ pub mod record; mod base; pub use base::{ - Act, Agent, Configurable, Env, ExperienceBufferBase, Info, NullReplayBuffer, Obs, Policy, - ReplayBufferBase, Step, StepProcessor, TransitionBatch, + Act, Agent, Configurable, Env, ExperienceBuffer, Info, NullReplayBuffer, Obs, Policy, + ReplayBuffer, Step, StepProcessor, TransitionBatch, }; mod trainer; diff --git a/border-core/src/record/base.rs b/border-core/src/record/base.rs index 5e97d0b9..849c6743 100644 --- a/border-core/src/record/base.rs +++ b/border-core/src/record/base.rs @@ -106,7 +106,7 @@ impl Record { /// # Returns /// /// An iterator over the record's keys - pub fn keys(&self) -> Keys { + pub fn keys(&self) -> Keys<'_, String, RecordValue> { self.0.keys() } diff --git a/border-core/src/record/buffered_recorder.rs b/border-core/src/record/buffered_recorder.rs index 0357b365..267c78cb 100644 --- a/border-core/src/record/buffered_recorder.rs +++ b/border-core/src/record/buffered_recorder.rs @@ -6,7 +6,7 @@ //! learning environments. use super::{Record, Recorder}; -use crate::{Env, ReplayBufferBase}; +use crate::{Env, ReplayBuffer}; use std::marker::PhantomData; /// A recorder that buffers sequences of observations and actions in memory. @@ -19,12 +19,12 @@ use std::marker::PhantomData; /// # Type Parameters /// /// * `E` - The environment type that implements the [`Env`] trait -/// * `R` - The replay buffer type that implements the [`ReplayBufferBase`] trait +/// * `R` - The replay buffer type that implements the [`ReplayBuffer`] trait #[derive(Default)] pub struct BufferedRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// The internal buffer storing the sequence of records buf: Vec, @@ -35,7 +35,7 @@ where impl BufferedRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Creates a new empty buffered recorder. /// @@ -57,7 +57,7 @@ where /// # Returns /// /// An iterator over references to the [`Record`]s in the buffer. - pub fn iter(&self) -> std::slice::Iter { + pub fn iter(&self) -> std::slice::Iter<'_, Record> { self.buf.iter() } } @@ -65,7 +65,7 @@ where impl Recorder for BufferedRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Writes a [`Record`] to the internal buffer. /// diff --git a/border-core/src/record/null_recorder.rs b/border-core/src/record/null_recorder.rs index b78a2929..53d3ab26 100644 --- a/border-core/src/record/null_recorder.rs +++ b/border-core/src/record/null_recorder.rs @@ -7,7 +7,7 @@ use std::marker::PhantomData; use super::{Record, Recorder}; -use crate::{Env, ReplayBufferBase}; +use crate::{Env, ReplayBuffer}; /// A recorder that discards all records without storing them. /// @@ -22,11 +22,11 @@ use crate::{Env, ReplayBufferBase}; /// # Type Parameters /// /// * `E` - The environment type that implements the [`Env`] trait -/// * `R` - The replay buffer type that implements the [`ReplayBufferBase`] trait +/// * `R` - The replay buffer type that implements the [`ReplayBuffer`] trait pub struct NullRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Phantom data to hold the type parameters phantom: PhantomData<(E, R)>, @@ -35,7 +35,7 @@ where impl NullRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Creates a new null recorder. /// @@ -52,7 +52,7 @@ where impl Recorder for NullRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Discards the given record without storing it. /// diff --git a/border-core/src/record/recorder.rs b/border-core/src/record/recorder.rs index 768a93e2..3aa49e35 100644 --- a/border-core/src/record/recorder.rs +++ b/border-core/src/record/recorder.rs @@ -6,7 +6,7 @@ //! recording strategies. use super::Record; -use crate::{Agent, Env, ReplayBufferBase}; +use crate::{Agent, Env, ReplayBuffer}; use anyhow::Result; use std::path::Path; @@ -22,11 +22,11 @@ use std::path::Path; /// # Type Parameters /// /// * `E` - The environment type that implements the [`Env`] trait -/// * `R` - The replay buffer type that implements the [`ReplayBufferBase`] trait +/// * `R` - The replay buffer type that implements the [`ReplayBuffer`] trait pub trait Recorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Writes a record to the recorder's output destination. /// diff --git a/border-core/src/trainer.rs b/border-core/src/trainer.rs index 87e403bf..3385ce42 100644 --- a/border-core/src/trainer.rs +++ b/border-core/src/trainer.rs @@ -10,7 +10,7 @@ use std::time::{Duration, SystemTime}; use crate::{ record::{Record, RecordValue::Scalar, Recorder}, - Agent, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, StepProcessor, + Agent, Env, Evaluator, ExperienceBuffer, ReplayBuffer, StepProcessor, }; use anyhow::Result; pub use config::TrainerConfig; @@ -201,7 +201,7 @@ impl Trainer { ) -> Result<(Record, bool)> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { if self.env_steps < self.warmup_period { Ok((Record::empty(), false)) @@ -237,7 +237,7 @@ impl Trainer { ) -> Result<()> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, D: Evaluator, { // Evaluation @@ -276,7 +276,7 @@ impl Trainer { where E: Env, P: StepProcessor, - R: ExperienceBufferBase + ReplayBufferBase, + R: ExperienceBuffer + ReplayBuffer, D: Evaluator, { let mut sampler = Sampler::new(env, step_proc); @@ -336,7 +336,7 @@ impl Trainer { ) -> Result<()> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, D: Evaluator, { // Return empty record diff --git a/border-core/src/trainer/sampler.rs b/border-core/src/trainer/sampler.rs index 4f96cbae..6c7a9203 100644 --- a/border-core/src/trainer/sampler.rs +++ b/border-core/src/trainer/sampler.rs @@ -21,7 +21,7 @@ //! 3. Performance Monitoring: //! * Monitor episode length //! * Record environment metrics -use crate::{record::Record, Agent, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor}; +use crate::{record::Record, Agent, Env, ExperienceBuffer, ReplayBuffer, StepProcessor}; use anyhow::Result; /// Manages the sampling of experiences from the environment. @@ -102,8 +102,8 @@ where buffer: &mut R_, ) -> Result where - R: ExperienceBufferBase + ReplayBufferBase, - R_: ExperienceBufferBase, + R: ExperienceBuffer + ReplayBuffer, + R_: ExperienceBuffer, { // Reset environment(s) if required if self.prev_obs.is_none() { diff --git a/border-minari/src/dataset.rs b/border-minari/src/dataset.rs index 11861ba6..745a1904 100644 --- a/border-minari/src/dataset.rs +++ b/border-minari/src/dataset.rs @@ -2,7 +2,7 @@ use crate::{util, MinariConverter, MinariEnv}; use anyhow::Result; use border_core::{ generic_replay_buffer::{GenericTransitionBatch, SimpleReplayBuffer, SimpleReplayBufferConfig}, - ExperienceBufferBase, ReplayBufferBase, + ExperienceBuffer, ReplayBuffer, }; use pyo3::{ types::{IntoPyDict, PyIterator}, diff --git a/border-minari/src/evaluator.rs b/border-minari/src/evaluator.rs index 5ad28436..5021abd9 100644 --- a/border-minari/src/evaluator.rs +++ b/border-minari/src/evaluator.rs @@ -1,7 +1,7 @@ //! Evaluator for Minari environments. use crate::{MinariConverter, MinariEnv}; use anyhow::Result; -use border_core::{record::Record, Agent, Env, Evaluator, ReplayBufferBase}; +use border_core::{record::Record, Agent, Env, Evaluator, ReplayBuffer}; /// An evaluator for Minari environments. /// @@ -22,7 +22,7 @@ impl Evaluator> for MinariEvaluator { /// The average return over episodes is returned. /// If the environment has ref_min_score and ref_max_score, the normalized score is also returned /// in the record. - fn evaluate( + fn evaluate( &mut self, policy: &mut Box, R>>, ) -> Result<(f32, Record)> { diff --git a/border-mlflow-tracking/src/client.rs b/border-mlflow-tracking/src/client.rs index dc2c2ccb..f11c450f 100644 --- a/border-mlflow-tracking/src/client.rs +++ b/border-mlflow-tracking/src/client.rs @@ -1,6 +1,6 @@ use crate::{system_time_as_millis, Experiment, MlflowTrackingRecorder, Run}; use anyhow::Result; -use border_core::{Env, ReplayBufferBase}; +use border_core::{Env, ReplayBuffer}; use log::info; use reqwest::blocking::Client; use serde::{Deserialize, Serialize}; @@ -212,7 +212,7 @@ impl MlflowTrackingClient { ) -> Result> where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { let run_name = run_name.as_ref(); let run = { diff --git a/border-mlflow-tracking/src/recorder.rs b/border-mlflow-tracking/src/recorder.rs index 61b4db1d..38d84424 100644 --- a/border-mlflow-tracking/src/recorder.rs +++ b/border-mlflow-tracking/src/recorder.rs @@ -2,7 +2,7 @@ use crate::{system_time_as_millis, Run}; use anyhow::Result; use border_core::{ record::{RecordStorage, RecordValue, Recorder}, - Agent, Env, ReplayBufferBase, + Agent, Env, ReplayBuffer, }; use chrono::{DateTime, Duration, Local, SecondsFormat}; use reqwest::blocking::Client; @@ -64,7 +64,7 @@ struct SetTagParams<'a> { pub struct MlflowTrackingRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { client: Client, base_url: String, @@ -81,7 +81,7 @@ where impl MlflowTrackingRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Create a new instance of `MlflowTrackingRecorder`. /// @@ -190,7 +190,7 @@ where impl Recorder for MlflowTrackingRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { fn write(&mut self, record: border_core::record::Record) { let url = format!("{}/api/2.0/mlflow/runs/log-metric", self.base_url); @@ -284,7 +284,7 @@ where impl Drop for MlflowTrackingRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Update run's status to "FINISHED" when dropped. /// diff --git a/border-tch-agent/src/dqn/base.rs b/border-tch-agent/src/dqn/base.rs index 2ba66a46..1d350ea3 100644 --- a/border-tch-agent/src/dqn/base.rs +++ b/border-tch-agent/src/dqn/base.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use serde::{de::DeserializeOwned, Serialize}; use std::{ @@ -51,7 +51,7 @@ impl Dqn where E: Env, Q: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, R::Batch: TransitionBatch, ::ObsBatch: Into, @@ -289,7 +289,7 @@ impl Agent for Dqn where E: Env + 'static, Q: SubModel + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, @@ -378,7 +378,7 @@ impl SyncModel for Dqn where E: Env, Q: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, diff --git a/border-tch-agent/src/iqn/base.rs b/border-tch-agent/src/iqn/base.rs index 51c1a1d5..75c028fa 100644 --- a/border-tch-agent/src/iqn/base.rs +++ b/border-tch-agent/src/iqn/base.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use log::trace; use serde::{de::DeserializeOwned, Serialize}; @@ -53,7 +53,7 @@ where E: Env, F: SubModel, M: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize + OutDim, R::Batch: TransitionBatch, @@ -275,7 +275,7 @@ where E: Env + 'static, F: SubModel + 'static, M: SubModel + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into, E::Act: From, F::Config: DeserializeOwned + Serialize + Clone, diff --git a/border-tch-agent/src/sac/base.rs b/border-tch-agent/src/sac/base.rs index d78f0293..046fa3e1 100644 --- a/border-tch-agent/src/sac/base.rs +++ b/border-tch-agent/src/sac/base.rs @@ -6,7 +6,7 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, + Agent, Configurable, Env, Policy, ReplayBuffer, TransitionBatch, }; use serde::{de::DeserializeOwned, Serialize}; // use log::info; @@ -60,7 +60,7 @@ where E: Env, Q: SubModel2, P: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into + Into, E::Act: Into, Q::Input2: From, @@ -284,7 +284,7 @@ where E: Env + 'static, Q: SubModel2 + 'static, P: SubModel + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, E::Obs: Into + Into, E::Act: Into + From, Q::Input2: From, @@ -362,7 +362,7 @@ where E: Env, Q: SubModel2, P: SubModel, - R: ReplayBufferBase, + R: ReplayBuffer, E::Obs: Into + Into, E::Act: Into + From, Q::Input2: From, diff --git a/border-tensorboard/src/lib.rs b/border-tensorboard/src/lib.rs index 59243794..cfbedde2 100644 --- a/border-tensorboard/src/lib.rs +++ b/border-tensorboard/src/lib.rs @@ -5,7 +5,7 @@ use anyhow::Result; use border_core::{ record::{Record, RecordValue, Recorder}, - Env, ReplayBufferBase, + Env, ReplayBuffer, }; use std::{ marker::PhantomData, @@ -17,7 +17,7 @@ use tensorboard_rs::summary_writer::SummaryWriter; pub struct TensorboardRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { model_dir: PathBuf, writer: SummaryWriter, @@ -30,7 +30,7 @@ where impl TensorboardRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Construct a [`TensorboardRecorder`]. /// @@ -56,7 +56,7 @@ where impl Recorder for TensorboardRecorder where E: Env, - R: ReplayBufferBase, + R: ReplayBuffer, { /// Writes a given [`Record`] into a TFRecord. /// diff --git a/examples/atari/dqn_atari/src/main.rs b/examples/atari/dqn_atari/src/main.rs index 79d6f610..2100ff05 100644 --- a/examples/atari/dqn_atari/src/main.rs +++ b/examples/atari/dqn_atari/src/main.rs @@ -5,7 +5,7 @@ use anyhow::Result; use args::Args; use border_core::{ generic_replay_buffer::SimpleStepProcessorConfig, record::Recorder, Agent, Configurable, - Env as _, Evaluator as _, ReplayBufferBase, StepProcessor, Trainer, + Env as _, Evaluator as _, ReplayBuffer, StepProcessor, Trainer, }; use border_mlflow_tracking::MlflowTrackingClient; use border_tensorboard::TensorboardRecorder; diff --git a/examples/atari/dqn_atari_tch/src/main.rs b/examples/atari/dqn_atari_tch/src/main.rs index bc98d47d..54342f08 100644 --- a/examples/atari/dqn_atari_tch/src/main.rs +++ b/examples/atari/dqn_atari_tch/src/main.rs @@ -5,7 +5,7 @@ use anyhow::Result; use args::Args; use border_core::{ generic_replay_buffer::SimpleStepProcessorConfig, record::Recorder, Agent, Configurable, - Env as _, Evaluator as _, ReplayBufferBase, StepProcessor, Trainer, + Env as _, Evaluator as _, ReplayBuffer, StepProcessor, Trainer, }; use border_mlflow_tracking::MlflowTrackingClient; use border_tensorboard::TensorboardRecorder; diff --git a/examples/d4rl/awac_pen/src/main.rs b/examples/d4rl/awac_pen/src/main.rs index f32e7799..36ab97a9 100644 --- a/examples/d4rl/awac_pen/src/main.rs +++ b/examples/d4rl/awac_pen/src/main.rs @@ -12,7 +12,7 @@ use border_candle_agent::{ use border_core::{ generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, Trainer, + Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; use border_minari::{ @@ -188,7 +188,7 @@ where E: Env + 'static, E::Obs: Into, E::Act: From + Into, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, R::Batch: TransitionBatch, ::ObsBatch: Into + Clone, ::ActBatch: Into + Clone, @@ -215,7 +215,7 @@ where fn create_recorder(config: &PenConfig) -> Result>> where E: Env + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, { log::info!("Create recorder"); if let Some(mlflow_run_name) = &config.args.mlflow_run_name { diff --git a/examples/d4rl/bc_pen/src/main.rs b/examples/d4rl/bc_pen/src/main.rs index 944e7e63..148f1d3f 100644 --- a/examples/d4rl/bc_pen/src/main.rs +++ b/examples/d4rl/bc_pen/src/main.rs @@ -7,7 +7,7 @@ use border_candle_agent::{ use border_core::{ generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, Trainer, + Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; use border_minari::{ @@ -161,7 +161,7 @@ where E: Env + 'static, E::Obs: Into, E::Act: From + Into, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, R::Batch: TransitionBatch, ::ObsBatch: Into + Clone, ::ActBatch: Into + Clone, @@ -188,7 +188,7 @@ where fn create_recorder(config: &PenConfig) -> Result>> where E: Env + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, { log::info!("Create recorder"); if let Some(mlflow_run_name) = &config.args.mlflow_run_name { diff --git a/examples/d4rl/iql_pen/src/main.rs b/examples/d4rl/iql_pen/src/main.rs index 0106ddb2..df4908e9 100644 --- a/examples/d4rl/iql_pen/src/main.rs +++ b/examples/d4rl/iql_pen/src/main.rs @@ -12,7 +12,7 @@ use border_candle_agent::{ use border_core::{ generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, Trainer, + Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; use border_minari::{ @@ -197,7 +197,7 @@ where E: Env + 'static, E::Obs: Into, E::Act: From + Into, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, R::Batch: TransitionBatch, ::ObsBatch: Into + Clone, ::ActBatch: Into + Clone, @@ -224,7 +224,7 @@ where fn create_recorder(config: &PenConfig) -> Result>> where E: Env + 'static, - R: ReplayBufferBase + 'static, + R: ReplayBuffer + 'static, { log::info!("Create recorder"); if let Some(mlflow_run_name) = &config.args.mlflow_run_name { diff --git a/examples/gym/awac_pendulum/src/main.rs b/examples/gym/awac_pendulum/src/main.rs index 5cba63b5..96d04658 100644 --- a/examples/gym/awac_pendulum/src/main.rs +++ b/examples/gym/awac_pendulum/src/main.rs @@ -15,7 +15,7 @@ use border_core::{ SimpleStepProcessorConfig, }, record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer, StepProcessor, Trainer, TrainerConfig, }; use border_mlflow_tracking::MlflowTrackingClient; diff --git a/examples/gym/convert_policy/src/main.rs b/examples/gym/convert_policy/src/main.rs index 259b19c3..97379191 100644 --- a/examples/gym/convert_policy/src/main.rs +++ b/examples/gym/convert_policy/src/main.rs @@ -101,7 +101,7 @@ mod dummy { pub struct DummyReplayBuffer; - impl border_core::ReplayBufferBase for DummyReplayBuffer { + impl border_core::ReplayBuffer for DummyReplayBuffer { type Batch = DummyBatch; type Config = usize; diff --git a/examples/gym/dqn_cartpole/src/main.rs b/examples/gym/dqn_cartpole/src/main.rs index d032fb8f..1357ac7e 100644 --- a/examples/gym/dqn_cartpole/src/main.rs +++ b/examples/gym/dqn_cartpole/src/main.rs @@ -12,7 +12,7 @@ use border_core::{ SimpleStepProcessorConfig, }, record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; use border_mlflow_tracking::MlflowTrackingClient; diff --git a/examples/gym/dqn_cartpole_tch/src/main.rs b/examples/gym/dqn_cartpole_tch/src/main.rs index 078eca0f..918a2bc7 100644 --- a/examples/gym/dqn_cartpole_tch/src/main.rs +++ b/examples/gym/dqn_cartpole_tch/src/main.rs @@ -5,7 +5,7 @@ use border_core::{ SimpleStepProcessorConfig, }, record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; use border_mlflow_tracking::MlflowTrackingClient; diff --git a/examples/gym/sac_fetch_reach/src/main.rs b/examples/gym/sac_fetch_reach/src/main.rs index b0b3a370..1e4be090 100644 --- a/examples/gym/sac_fetch_reach/src/main.rs +++ b/examples/gym/sac_fetch_reach/src/main.rs @@ -12,7 +12,7 @@ use border_core::{ SimpleStepProcessorConfig, }, record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; use border_mlflow_tracking::MlflowTrackingClient; diff --git a/examples/gym/sac_pendulum/src/main.rs b/examples/gym/sac_pendulum/src/main.rs index 838443eb..9c78b292 100644 --- a/examples/gym/sac_pendulum/src/main.rs +++ b/examples/gym/sac_pendulum/src/main.rs @@ -15,7 +15,7 @@ use border_core::{ SimpleStepProcessorConfig, }, record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; use border_mlflow_tracking::MlflowTrackingClient; diff --git a/examples/gym/sac_pendulum_tch/src/main.rs b/examples/gym/sac_pendulum_tch/src/main.rs index 5ec5b0b1..4f7f2749 100644 --- a/examples/gym/sac_pendulum_tch/src/main.rs +++ b/examples/gym/sac_pendulum_tch/src/main.rs @@ -5,7 +5,7 @@ use border_core::{ SimpleStepProcessorConfig, }, record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; use border_mlflow_tracking::MlflowTrackingClient; From 2ea84ff0f36ddc531ce277700ad5c08a5ebb836a Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 09:53:16 +0000 Subject: [PATCH 04/23] WIP: move generic replay buffer into another crate --- Cargo.toml | 1 + border-generic-replay-buffer/Cargo.toml | 28 ++ border-generic-replay-buffer/src/batch.rs | 206 +++++++++ border-generic-replay-buffer/src/config.rs | 294 ++++++++++++ .../src/iw_scheduler.rs | 46 ++ border-generic-replay-buffer/src/lib.rs | 34 ++ .../src/replay_buffer.rs | 427 ++++++++++++++++++ border-generic-replay-buffer/src/step_proc.rs | 138 ++++++ border-generic-replay-buffer/src/sum_tree.rs | 217 +++++++++ 9 files changed, 1391 insertions(+) create mode 100644 border-generic-replay-buffer/Cargo.toml create mode 100644 border-generic-replay-buffer/src/batch.rs create mode 100644 border-generic-replay-buffer/src/config.rs create mode 100644 border-generic-replay-buffer/src/iw_scheduler.rs create mode 100644 border-generic-replay-buffer/src/lib.rs create mode 100644 border-generic-replay-buffer/src/replay_buffer.rs create mode 100644 border-generic-replay-buffer/src/step_proc.rs create mode 100644 border-generic-replay-buffer/src/sum_tree.rs diff --git a/Cargo.toml b/Cargo.toml index cc3696e1..69d9c001 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "border-core", + "border-generic-replay-buffer", "border-tensorboard", "border-mlflow-tracking", "border-py-gym-env", diff --git a/border-generic-replay-buffer/Cargo.toml b/border-generic-replay-buffer/Cargo.toml new file mode 100644 index 00000000..fc828362 --- /dev/null +++ b/border-generic-replay-buffer/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "border-generic-replay-buffer" +description = "A generic implementation of replaybuffer for Border" +version.workspace = true +edition.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +border-core = { version = "0.0.9", path = "../border-core" } +serde = { workspace = true, features = ["derive"] } +serde_yaml = { workspace = true } +# log = { workspace = true } +# thiserror = { workspace = true } +anyhow = { workspace = true } +# chrono = { workspace = true } +# aquamarine = { workspace = true } +fastrand = { workspace = true } +segment-tree = { workspace = true } +# xxhash-rust = { workspace = true } +# Consider to replace with fastrand +rand = { workspace = true } + +[dev-dependencies] +tempdir = { workspace = true } diff --git a/border-generic-replay-buffer/src/batch.rs b/border-generic-replay-buffer/src/batch.rs new file mode 100644 index 00000000..a115fb76 --- /dev/null +++ b/border-generic-replay-buffer/src/batch.rs @@ -0,0 +1,206 @@ +//! Generic implementation of transition batches for reinforcement learning. +//! +//! This module provides a generic implementation of transition batches that can handle +//! arbitrary observation and action types. It supports the following features: +//! - Efficient batch processing +//! - Weighting for prioritized experience replay +//! - Transition sampling and management + +use border_core::TransitionBatch; + +/// A trait defining basic batch operations. +/// +/// This trait provides fundamental operations for efficiently managing batches of +/// observations and actions. +/// +/// # Type Parameters +/// +/// * `Self` - The batch type, representing batches of observations or actions. +/// +/// # Examples +/// +/// ```ignore +/// struct TensorBatch { +/// data: Vec, +/// shape: Vec, +/// } +/// +/// impl BatchBase for TensorBatch { +/// fn new(capacity: usize) -> Self { +/// Self { +/// data: Vec::with_capacity(capacity), +/// shape: vec![], +/// } +/// } +/// +/// fn push(&mut self, ix: usize, data: Self) { +/// // Data addition logic +/// } +/// +/// fn sample(&self, ixs: &Vec) -> Self { +/// // Sampling logic +/// } +/// } +/// ``` +pub trait BatchBase { + /// Creates a new batch with the specified capacity. + /// + /// # Arguments + /// + /// * `capacity` - Initial capacity of the batch + fn new(capacity: usize) -> Self; + + /// Adds data at the specified index. + /// + /// # Arguments + /// + /// * `ix` - Index where data should be added + /// * `data` - Data to be added + fn push(&mut self, ix: usize, data: Self); + + /// Retrieves samples from the specified indices. + /// + /// # Arguments + /// + /// * `ixs` - List of indices to sample from + /// + /// # Returns + /// + /// A new batch containing the sampled data + fn sample(&self, ixs: &Vec) -> Self; +} + +/// A generic structure representing transitions in reinforcement learning. +/// +/// This structure efficiently manages reinforcement learning transitions +/// (observations, actions, rewards, etc.). It also includes support for +/// prioritized experience replay (PER). +/// +/// # Type Parameters +/// +/// * `O` - Observation type, must implement `BatchBase` +/// * `A` - Action type, must implement `BatchBase` +/// +/// # Examples +/// +/// ```ignore +/// let batch = GenericTransitionBatch::::with_capacity(32); +/// ``` +pub struct GenericTransitionBatch +where + O: BatchBase, + A: BatchBase, +{ + /// Current observations + pub obs: O, + + /// Selected actions + pub act: A, + + /// Next state observations + pub next_obs: O, + + /// Transition rewards + pub reward: Vec, + + /// Episode termination flags + pub is_terminated: Vec, + + /// Episode truncation flags + pub is_truncated: Vec, + + /// Weights for prioritized experience replay + pub weight: Option>, + + /// Indices of sampled transitions + pub ix_sample: Option>, +} + +impl TransitionBatch for GenericTransitionBatch +where + O: BatchBase, + A: BatchBase, +{ + type ObsBatch = O; + type ActBatch = A; + + /// Decomposes the batch into its individual components. + /// + /// # Returns + /// + /// A tuple containing the following elements: + /// 1. Observations + /// 2. Actions + /// 3. Next observations + /// 4. Rewards + /// 5. Termination flags + /// 6. Truncation flags + /// 7. Sample indices + /// 8. Weights + fn unpack( + self, + ) -> ( + Self::ObsBatch, + Self::ActBatch, + Self::ObsBatch, + Vec, + Vec, + Vec, + Option>, + Option>, + ) { + ( + self.obs, + self.act, + self.next_obs, + self.reward, + self.is_terminated, + self.is_truncated, + self.ix_sample, + self.weight, + ) + } + + /// Returns the number of transitions in the batch. + fn len(&self) -> usize { + self.reward.len() + } + + /// Returns a reference to the batch of observations. + fn obs(&self) -> &Self::ObsBatch { + &self.obs + } + + /// Returns a reference to the batch of actions. + fn act(&self) -> &Self::ActBatch { + &self.act + } +} + +impl GenericTransitionBatch +where + O: BatchBase, + A: BatchBase, +{ + /// Creates a new batch with the specified capacity. + /// + /// # Arguments + /// + /// * `capacity` - Initial capacity of the batch + /// + /// # Returns + /// + /// A new `GenericTransitionBatch` instance + pub fn with_capacity(capacity: usize) -> Self { + Self { + obs: O::new(capacity), + act: A::new(capacity), + next_obs: O::new(capacity), + reward: Vec::with_capacity(capacity), + is_terminated: Vec::with_capacity(capacity), + is_truncated: Vec::with_capacity(capacity), + weight: None, + ix_sample: None, + } + } +} diff --git a/border-generic-replay-buffer/src/config.rs b/border-generic-replay-buffer/src/config.rs new file mode 100644 index 00000000..ffad05d5 --- /dev/null +++ b/border-generic-replay-buffer/src/config.rs @@ -0,0 +1,294 @@ +//! Configuration for the replay buffer implementation. +//! +//! This module provides configuration structures for the replay buffer, including: +//! - Basic buffer configuration (capacity, seed) +//! - Prioritized Experience Replay (PER) configuration +//! - Serialization and deserialization support + +use super::{WeightNormalizer, WeightNormalizer::All}; +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::{ + default::Default, + fs::File, + io::{BufReader, Write}, + path::Path, +}; + +/// Configuration for Prioritized Experience Replay (PER). +/// +/// This structure defines the parameters for prioritized sampling in the replay buffer. +/// It controls how transitions are sampled based on their importance and how +/// importance weights are calculated and normalized. +/// +/// # Fields +/// +/// * `alpha` - Controls the degree of prioritization (0 = uniform sampling) +/// * `beta_0` - Initial value for importance sampling weights +/// * `beta_final` - Final value for importance sampling weights +/// * `n_opts_final` - Number of optimization steps to reach `beta_final` +/// * `normalize` - Method for normalizing importance weights +/// +/// # Examples +/// +/// ```rust +/// use border_core::generic_replay_buffer::{PerConfig, WeightNormalizer}; +/// +/// let config = PerConfig::default() +/// .alpha(0.6) +/// .beta_0(0.4) +/// .beta_final(1.0) +/// .n_opts_final(500_000) +/// .normalize(WeightNormalizer::All); +/// ``` +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub struct PerConfig { + /// Exponent for prioritization. Higher values increase the bias towards + /// high-priority transitions. A value of 0 results in uniform sampling. + pub alpha: f32, + + /// Initial value of the importance sampling exponent. Lower values reduce + /// the impact of importance sampling weights. + pub beta_0: f32, + + /// Final value of the importance sampling exponent. Typically set to 1.0 + /// to fully compensate for the non-uniform sampling. + pub beta_final: f32, + + /// Number of optimization steps after which `beta` reaches its final value. + /// This allows for a gradual increase in the impact of importance sampling. + pub n_opts_final: usize, + + /// Method for normalizing importance sampling weights. Controls how the + /// weights are scaled to prevent numerical instability. + pub normalize: WeightNormalizer, +} + +impl Default for PerConfig { + /// Creates a default PER configuration with commonly used values: + /// - `alpha = 0.6` (moderate prioritization) + /// - `beta_0 = 0.4` (initial importance sampling) + /// - `beta_final = 1.0` (full compensation) + /// - `n_opts_final = 500_000` (gradual increase) + /// - `normalize = All` (normalize all weights) + fn default() -> Self { + Self { + alpha: 0.6, + beta_0: 0.4, + beta_final: 1.0, + n_opts_final: 500_000, + normalize: All, + } + } +} + +impl PerConfig { + /// Sets the prioritization exponent `alpha`. + /// + /// # Arguments + /// + /// * `alpha` - The new value for the prioritization exponent + /// + /// # Returns + /// + /// The modified configuration + pub fn alpha(mut self, alpha: f32) -> Self { + self.alpha = alpha; + self + } + + /// Sets the initial importance sampling exponent `beta_0`. + /// + /// # Arguments + /// + /// * `beta_0` - The new initial value for the importance sampling exponent + /// + /// # Returns + /// + /// The modified configuration + pub fn beta_0(mut self, beta_0: f32) -> Self { + self.beta_0 = beta_0; + self + } + + /// Sets the final importance sampling exponent `beta_final`. + /// + /// # Arguments + /// + /// * `beta_final` - The new final value for the importance sampling exponent + /// + /// # Returns + /// + /// The modified configuration + pub fn beta_final(mut self, beta_final: f32) -> Self { + self.beta_final = beta_final; + self + } + + /// Sets the number of optimization steps to reach the final beta value. + /// + /// # Arguments + /// + /// * `n_opts_final` - The new number of optimization steps + /// + /// # Returns + /// + /// The modified configuration + pub fn n_opts_final(mut self, n_opts_final: usize) -> Self { + self.n_opts_final = n_opts_final; + self + } + + /// Sets the method for normalizing importance weights. + /// + /// # Arguments + /// + /// * `normalize` - The new normalization method + /// + /// # Returns + /// + /// The modified configuration + pub fn normalize(mut self, normalize: WeightNormalizer) -> Self { + self.normalize = normalize; + self + } +} + +/// Configuration for the replay buffer. +/// +/// This structure defines the basic parameters for the replay buffer, +/// including its capacity, random seed, and optional PER configuration. +/// +/// # Fields +/// +/// * `capacity` - Maximum number of transitions to store +/// * `seed` - Random seed for sampling +/// * `per_config` - Optional configuration for prioritized experience replay +/// +/// # Examples +/// +/// ```rust +/// use border_generic_replay_buffer::{GenericReplayBufferConfig, PerConfig}; +/// +/// // Basic configuration +/// let config = GenericReplayBufferConfig::default() +/// .capacity(10000) +/// .seed(42); +/// +/// // Configuration with PER +/// let config_with_per = GenericReplayBufferConfig::default() +/// .capacity(10000) +/// .seed(42) +/// .per_config(Some(PerConfig::default())); +/// ``` +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub struct GenericReplayBufferConfig { + /// Maximum number of transitions that can be stored in the buffer. + /// When the buffer is full, new transitions replace the oldest ones. + pub capacity: usize, + + /// Random seed used for sampling transitions. This ensures reproducibility + /// of the sampling process when the same seed is used. + pub seed: u64, + + /// Optional configuration for prioritized experience replay. If `None`, + /// transitions are sampled uniformly at random. + pub per_config: Option, +} + +impl Default for GenericReplayBufferConfig { + /// Creates a default replay buffer configuration with commonly used values: + /// - `capacity = 10000` (moderate buffer size) + /// - `seed = 42` (fixed random seed) + /// - `per_config = None` (uniform sampling) + fn default() -> Self { + Self { + capacity: 10000, + seed: 42, + per_config: None, + } + } +} + +impl GenericReplayBufferConfig { + /// Sets the capacity of the replay buffer. + /// + /// # Arguments + /// + /// * `capacity` - The new capacity for the buffer + /// + /// # Returns + /// + /// The modified configuration + pub fn capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + + /// Sets the random seed for sampling. + /// + /// # Arguments + /// + /// * `seed` - The new random seed + /// + /// # Returns + /// + /// The modified configuration + pub fn seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } + + /// Sets the configuration for prioritized experience replay. + /// + /// # Arguments + /// + /// * `per_config` - The new PER configuration + /// + /// # Returns + /// + /// The modified configuration + pub fn per_config(mut self, per_config: Option) -> Self { + self.per_config = per_config; + self + } + + /// Loads the configuration from a YAML file. + /// + /// # Arguments + /// + /// * `path` - Path to the configuration file + /// + /// # Returns + /// + /// The loaded configuration + /// + /// # Errors + /// + /// Returns an error if the file cannot be read or parsed + pub fn load(path: impl AsRef) -> Result { + let file = File::open(path)?; + let rdr = BufReader::new(file); + let b = serde_yaml::from_reader(rdr)?; + Ok(b) + } + + /// Saves the configuration to a YAML file. + /// + /// # Arguments + /// + /// * `path` - Path where the configuration should be saved + /// + /// # Returns + /// + /// `Ok(())` if the configuration was saved successfully + /// + /// # Errors + /// + /// Returns an error if the file cannot be written + pub fn save(&self, path: impl AsRef) -> Result<()> { + let mut file = File::create(path)?; + file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; + Ok(()) + } +} diff --git a/border-generic-replay-buffer/src/iw_scheduler.rs b/border-generic-replay-buffer/src/iw_scheduler.rs new file mode 100644 index 00000000..13c02ce0 --- /dev/null +++ b/border-generic-replay-buffer/src/iw_scheduler.rs @@ -0,0 +1,46 @@ +//! Scheduling the exponent of importance weight for PER. +use serde::{Deserialize, Serialize}; + +/// Scheduler of the exponent of importance weight for PER. +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +pub struct IwScheduler { + /// Initial value of $\beta$. + pub beta_0: f32, + + /// Final value of $\beta$. + pub beta_final: f32, + + /// Optimization steps when beta reaches its final value. + pub n_opts_final: usize, + + /// Current optimizatioin steps. + pub n_opts: usize, +} + +impl IwScheduler { + /// Creates a scheduler. + pub fn new(beta_0: f32, beta_final: f32, n_opts_final: usize) -> Self { + Self { + beta_0, + beta_final, + n_opts_final, + n_opts: 0, + } + } + + /// Gets the exponents of importance sampling weight. + pub fn beta(&self) -> f32 { + let n_opts = self.n_opts; + if n_opts >= self.n_opts_final { + self.beta_final + } else { + let d = self.beta_final - self.beta_0; + self.beta_0 + d * (n_opts as f32 / self.n_opts_final as f32) + } + } + + /// Add optimization steps for scheduling beta through training. + pub fn add_n_opts(&mut self) { + self.n_opts += 1; + } +} diff --git a/border-generic-replay-buffer/src/lib.rs b/border-generic-replay-buffer/src/lib.rs new file mode 100644 index 00000000..e9d567b2 --- /dev/null +++ b/border-generic-replay-buffer/src/lib.rs @@ -0,0 +1,34 @@ +//! Generic implementation of replay buffers for reinforcement learning. +//! +//! This module provides a flexible implementation of replay buffers +//! that can handle arbitrary observation and action types. It supports both +//! standard experience replay and prioritized experience replay (PER). +//! +//! # Key Components +//! +//! - [`GenericReplayBuffer`]: A generic replay buffer implementation +//! - [`GenericTransitionBatch`]: A generic batch structure for transitions +//! - [`SimpleStepProcessor`]: A processor for converting environment steps to transitions +//! - [`PerConfig`]: Configuration for prioritized experience replay +//! +//! # Features +//! +//! - Generic type support for observations and actions +//! - Efficient batch processing +//! - Prioritized experience replay with importance sampling +//! - Configurable weight normalization +//! - Step processing for non-vectorized environments + +// mod base; +mod batch; +mod iw_scheduler; +mod sum_tree; +mod config; +mod step_proc; +mod replay_buffer; +pub use sum_tree::WeightNormalizer; +pub use iw_scheduler::IwScheduler; +pub use batch::{BatchBase, GenericTransitionBatch}; +pub use config::{PerConfig, GenericReplayBufferConfig}; +pub use step_proc::{SimpleStepProcessor, SimpleStepProcessorConfig}; +pub use replay_buffer::GenericReplayBuffer; diff --git a/border-generic-replay-buffer/src/replay_buffer.rs b/border-generic-replay-buffer/src/replay_buffer.rs new file mode 100644 index 00000000..6a434291 --- /dev/null +++ b/border-generic-replay-buffer/src/replay_buffer.rs @@ -0,0 +1,427 @@ +//! Generic implementation of replay buffers for reinforcement learning. +//! +//! This module provides a generic implementation of replay buffers that can store +//! and sample transitions of arbitrary observation and action types. It supports: +//! - Standard experience replay +//! - Prioritized experience replay (PER) +//! - Importance sampling weights for off-policy learning + +use super::{ + config::{GenericReplayBufferConfig, PerConfig}, + BatchBase, GenericTransitionBatch, +}; +use crate::iw_scheduler::IwScheduler; +use crate::sum_tree::SumTree; +use anyhow::Result; +use border_core::{ExperienceBuffer, ReplayBuffer, TransitionBatch}; +use rand::{rngs::StdRng, RngCore, SeedableRng}; + +/// State management for Prioritized Experience Replay (PER). +/// +/// This struct maintains the necessary state for PER, including: +/// - A sum tree for efficient priority sampling +/// - An importance weight scheduler for adjusting sample weights +struct PerState { + /// A sum tree data structure for efficient priority sampling. + sum_tree: SumTree, + + /// Scheduler for importance sampling weights. + iw_scheduler: IwScheduler, +} + +impl PerState { + /// Creates a new PER state with the given configuration. + /// + /// # Arguments + /// + /// * `capacity` - Maximum number of transitions to store + /// * `per_config` - Configuration for prioritized experience replay + fn new(capacity: usize, per_config: &PerConfig) -> Self { + Self { + sum_tree: SumTree::new(capacity, per_config.alpha, per_config.normalize), + iw_scheduler: IwScheduler::new( + per_config.beta_0, + per_config.beta_final, + per_config.n_opts_final, + ), + } + } +} + +/// A generic implementation of a replay buffer for reinforcement learning. +/// +/// This buffer can store transitions of arbitrary observation and action types, +/// making it suitable for a wide range of reinforcement learning tasks. It supports: +/// - Standard experience replay +/// - Prioritized experience replay (optional) +/// - Efficient sampling and storage +/// +/// # Type Parameters +/// +/// * `O` - The type of observations, must implement [`BatchBase`] +/// * `A` - The type of actions, must implement [`BatchBase`] +/// +/// # Examples +/// +/// ```ignore +/// let config = GenericReplayBufferConfig { +/// capacity: 10000, +/// per_config: Some(PerConfig { +/// alpha: 0.6, +/// beta_0: 0.4, +/// beta_final: 1.0, +/// n_opts_final: 100000, +/// normalize: true, +/// }), +/// }; +/// +/// let mut buffer = GenericReplayBuffer::::build(&config); +/// +/// // Add transitions +/// buffer.push(transition)?; +/// +/// // Sample a batch +/// let batch = buffer.batch(32)?; +/// ``` +pub struct GenericReplayBuffer +where + O: BatchBase, + A: BatchBase, +{ + /// Maximum number of transitions that can be stored. + capacity: usize, + + /// Current insertion index. + i: usize, + + /// Current number of stored transitions. + size: usize, + + /// Storage for observations. + obs: O, + + /// Storage for actions. + act: A, + + /// Storage for next observations. + next_obs: O, + + /// Storage for rewards. + reward: Vec, + + /// Storage for termination flags. + is_terminated: Vec, + + /// Storage for truncation flags. + is_truncated: Vec, + + /// Random number generator for sampling. + rng: StdRng, + + /// State for prioritized experience replay, if enabled. + per_state: Option, +} + +impl GenericReplayBuffer +where + O: BatchBase, + A: BatchBase, +{ + /// Pushes rewards into the buffer at the specified index. + /// + /// # Arguments + /// + /// * `i` - Starting index for insertion + /// * `b` - Vector of rewards to insert + #[inline] + fn push_reward(&mut self, i: usize, b: &Vec) { + let mut j = i; + for r in b.iter() { + self.reward[j] = *r; + j += 1; + if j == self.capacity { + j = 0; + } + } + } + + /// Pushes termination flags into the buffer at the specified index. + /// + /// # Arguments + /// + /// * `i` - Starting index for insertion + /// * `b` - Vector of termination flags to insert + #[inline] + fn push_is_terminated(&mut self, i: usize, b: &Vec) { + let mut j = i; + for d in b.iter() { + self.is_terminated[j] = *d; + j += 1; + if j == self.capacity { + j = 0; + } + } + } + + /// Pushes truncation flags into the buffer at the specified index. + /// + /// # Arguments + /// + /// * `i` - Starting index for insertion + /// * `b` - Vector of truncation flags to insert + fn push_is_truncated(&mut self, i: usize, b: &Vec) { + let mut j = i; + for d in b.iter() { + self.is_truncated[j] = *d; + j += 1; + if j == self.capacity { + j = 0; + } + } + } + + /// Samples rewards for the given indices. + /// + /// # Arguments + /// + /// * `ixs` - Indices to sample from + /// + /// # Returns + /// + /// Vector of sampled rewards + fn sample_reward(&self, ixs: &Vec) -> Vec { + ixs.iter().map(|ix| self.reward[*ix]).collect() + } + + /// Samples termination flags for the given indices. + /// + /// # Arguments + /// + /// * `ixs` - Indices to sample from + /// + /// # Returns + /// + /// Vector of sampled termination flags + fn sample_is_terminated(&self, ixs: &Vec) -> Vec { + ixs.iter().map(|ix| self.is_terminated[*ix]).collect() + } + + /// Samples truncation flags for the given indices. + /// + /// # Arguments + /// + /// * `ixs` - Indices to sample from + /// + /// # Returns + /// + /// Vector of sampled truncation flags + fn sample_is_truncated(&self, ixs: &Vec) -> Vec { + ixs.iter().map(|ix| self.is_truncated[*ix]).collect() + } + + /// Sets priorities for newly added samples in prioritized experience replay. + /// + /// # Arguments + /// + /// * `batch_size` - Number of new samples to prioritize + fn set_priority(&mut self, batch_size: usize) { + let sum_tree = &mut self.per_state.as_mut().unwrap().sum_tree; + let max_p = sum_tree.max(); + + for j in 0..batch_size { + let i = (self.i + j) % self.capacity; + sum_tree.add(i, max_p); + } + } + + /// Returns a batch containing all actions in the buffer. + /// + /// # Warning + /// + /// This method should be used with caution on large replay buffers + /// as it may consume significant memory. + pub fn whole_actions(&self) -> A { + let ixs = (0..self.size).collect::>(); + self.act.sample(&ixs) + } + + /// Returns the number of terminated episodes in the buffer. + pub fn num_terminated_flags(&self) -> usize { + self.is_terminated + .iter() + .map(|is_terminated| *is_terminated as usize) + .sum() + } + + /// Returns the number of truncated episodes in the buffer. + pub fn num_truncated_flags(&self) -> usize { + self.is_truncated + .iter() + .map(|is_truncated| *is_truncated as usize) + .sum() + } + + /// Returns the sum of all rewards in the buffer. + pub fn sum_rewards(&self) -> f32 { + self.reward.iter().sum() + } +} + +impl ExperienceBuffer for GenericReplayBuffer +where + O: BatchBase, + A: BatchBase, +{ + type Item = GenericTransitionBatch; + + /// Returns the current number of transitions in the buffer. + fn len(&self) -> usize { + self.size + } + + /// Adds a new transition to the buffer. + /// + /// # Arguments + /// + /// * `tr` - The transition to add + /// + /// # Returns + /// + /// `Ok(())` if the transition was added successfully + /// + /// # Errors + /// + /// Returns an error if the buffer is full and cannot accept more transitions + fn push(&mut self, tr: Self::Item) -> Result<()> { + let len = tr.len(); // batch size + let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = tr.unpack(); + self.obs.push(self.i, obs); + self.act.push(self.i, act); + self.next_obs.push(self.i, next_obs); + self.push_reward(self.i, &reward); + self.push_is_terminated(self.i, &is_terminated); + self.push_is_truncated(self.i, &is_truncated); + + if self.per_state.is_some() { + self.set_priority(len) + }; + + self.i = (self.i + len) % self.capacity; + self.size += len; + if self.size >= self.capacity { + self.size = self.capacity; + } + + Ok(()) + } +} + +impl ReplayBuffer for GenericReplayBuffer +where + O: BatchBase, + A: BatchBase, +{ + type Config = GenericReplayBufferConfig; + type Batch = GenericTransitionBatch; + + /// Creates a new replay buffer with the given configuration. + /// + /// # Arguments + /// + /// * `config` - Configuration for the replay buffer + /// + /// # Returns + /// + /// A new instance of the replay buffer + fn build(config: &Self::Config) -> Self { + let capacity = config.capacity; + let per_state = match &config.per_config { + Some(per_config) => Some(PerState::new(capacity, per_config)), + None => None, + }; + + Self { + capacity, + i: 0, + size: 0, + obs: O::new(capacity), + act: A::new(capacity), + next_obs: O::new(capacity), + reward: vec![0.; capacity], + is_terminated: vec![0; capacity], + is_truncated: vec![0; capacity], + rng: StdRng::seed_from_u64(config.seed as _), + per_state, + } + } + + /// Samples a batch of transitions from the buffer. + /// + /// If prioritized experience replay is enabled, samples are selected + /// according to their priorities. Otherwise, uniform random sampling is used. + /// + /// # Arguments + /// + /// * `size` - Number of transitions to sample + /// + /// # Returns + /// + /// A batch of sampled transitions + /// + /// # Errors + /// + /// Returns an error if: + /// - The buffer is empty + /// - The requested batch size is larger than the buffer size + fn batch(&mut self, size: usize) -> Result { + let (ixs, weight) = if let Some(per_state) = &self.per_state { + let sum_tree = &per_state.sum_tree; + let beta = per_state.iw_scheduler.beta(); + let (ixs, weight) = sum_tree.sample(size, beta); + let ixs = ixs.iter().map(|&ix| ix as usize).collect(); + (ixs, Some(weight)) + } else { + let ixs = (0..size) + // .map(|_| self.rng.usize(..self.size)) + .map(|_| (self.rng.next_u32() as usize) % self.size) + .collect::>(); + let weight = None; + (ixs, weight) + }; + + Ok(Self::Batch { + obs: self.obs.sample(&ixs), + act: self.act.sample(&ixs), + next_obs: self.next_obs.sample(&ixs), + reward: self.sample_reward(&ixs), + is_terminated: self.sample_is_terminated(&ixs), + is_truncated: self.sample_is_truncated(&ixs), + ix_sample: Some(ixs), + weight, + }) + } + + /// Updates the priorities of transitions in the buffer. + /// + /// This method is used in prioritized experience replay to adjust + /// the sampling probabilities based on TD errors. + /// + /// # Arguments + /// + /// * `ixs` - Optional indices of transitions to update + /// * `td_errs` - Optional TD errors for the transitions + fn update_priority(&mut self, ixs: &Option>, td_errs: &Option>) { + if let Some(per_state) = &mut self.per_state { + let ixs = ixs + .as_ref() + .expect("ixs should be Some(_) in update_priority()."); + let td_errs = td_errs + .as_ref() + .expect("td_errs should be Some(_) in update_priority()."); + for (&ix, &td_err) in ixs.iter().zip(td_errs.iter()) { + per_state.sum_tree.update(ix, td_err); + } + per_state.iw_scheduler.add_n_opts(); + } + } +} diff --git a/border-generic-replay-buffer/src/step_proc.rs b/border-generic-replay-buffer/src/step_proc.rs new file mode 100644 index 00000000..eb83e567 --- /dev/null +++ b/border-generic-replay-buffer/src/step_proc.rs @@ -0,0 +1,138 @@ +//! Generic implementation of step processing for reinforcement learning. +//! +//! This module provides a generic implementation of the `StepProcessor` trait, +//! which handles the conversion of environment steps into transitions suitable +//! for training. It supports: +//! - 1-step TD backup for non-vectorized environments +//! - Generic observation and action types +//! - Efficient batch processing + +use super::{BatchBase, GenericTransitionBatch}; +use border_core::{Env, Obs, StepProcessor, Step}; +use std::{default::Default, marker::PhantomData}; + +/// Configuration for the simple step processor. +#[derive(Clone, Debug)] +pub struct SimpleStepProcessorConfig {} + +impl Default for SimpleStepProcessorConfig { + /// Creates a new default configuration. + fn default() -> Self { + Self {} + } +} + +/// A generic implementation of the `StepProcessor` trait. +/// +/// This processor converts environment steps into transitions suitable for +/// training reinforcement learning agents. It supports 1-step TD backup +/// for non-vectorized environments, meaning that each step contains exactly +/// one observation. +/// +/// # Type Parameters +/// +/// * `E` - The environment type, must implement `Env` +/// * `O` - The observation batch type, must implement `BatchBase` and `From` +/// * `A` - The action batch type, must implement `BatchBase` and `From` +pub struct SimpleStepProcessor { + /// The previous observation, used to construct transitions. + prev_obs: Option, + /// Phantom data to hold the generic type parameters. + phantom: PhantomData<(E, A)>, +} + +impl StepProcessor for SimpleStepProcessor +where + E: Env, + O: BatchBase + From, + A: BatchBase + From, +{ + type Config = SimpleStepProcessorConfig; + type Output = GenericTransitionBatch; + + /// Creates a new step processor with the given configuration. + /// + /// # Arguments + /// + /// * `_config` - The configuration for the processor + /// + /// # Returns + /// + /// A new instance of the step processor + fn build(_config: &Self::Config) -> Self { + Self { + prev_obs: None, + phantom: PhantomData, + } + } + + /// Resets the processor with an initial observation. + /// + /// This method must be called before processing any steps to initialize + /// the processor with the starting state of the environment. + /// + /// # Arguments + /// + /// * `init_obs` - The initial observation from the environment + fn reset(&mut self, init_obs: E::Obs) { + self.prev_obs = Some(init_obs.into()); + } + + /// Processes a step from the environment into a transition. + /// + /// This method converts an environment step into a transition suitable + /// for training. It handles: + /// - Converting observations and actions to the appropriate batch types + /// - Managing the previous observation for constructing transitions + /// - Handling episode termination and truncation + /// + /// # Arguments + /// + /// * `step` - The step to process + /// + /// # Returns + /// + /// A transition batch containing the processed step + /// + /// # Panics + /// + /// This method will panic if: + /// - The step contains more than one observation + /// - `reset()` has not been called before processing steps + /// - The step is terminal but does not contain an initial observation + fn process(&mut self, step: Step) -> Self::Output { + assert_eq!(step.obs.len(), 1); + + let batch = if self.prev_obs.is_none() { + panic!("prev_obs is not set. Forgot to call reset()?"); + } else { + let is_done = step.is_done(); + let next_obs = step.obs.clone().into(); + let obs = self.prev_obs.replace(step.obs.into()).unwrap(); + let act = step.act.into(); + let reward = step.reward; + let is_terminated = step.is_terminated; + let is_truncated = step.is_truncated; + let ix_sample = None; + let weight = None; + + if is_done { + self.prev_obs + .replace(step.init_obs.expect("Failed to unwrap init_obs").into()); + } + + GenericTransitionBatch { + obs, + act, + next_obs, + reward, + is_terminated, + is_truncated, + ix_sample, + weight, + } + }; + + batch + } +} diff --git a/border-generic-replay-buffer/src/sum_tree.rs b/border-generic-replay-buffer/src/sum_tree.rs new file mode 100644 index 00000000..39afaf4f --- /dev/null +++ b/border-generic-replay-buffer/src/sum_tree.rs @@ -0,0 +1,217 @@ +//! Sum tree for prioritized sampling. +//! +//! Code is adapted from and +/// +use segment_tree::{ + ops::{MaxIgnoreNaN, MinIgnoreNaN}, + SegmentPoint, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Copy, Debug, Clone, Deserialize, Serialize, PartialEq)] +/// Specifies how to normalize the importance weights in a prioritized batch. +pub enum WeightNormalizer { + /// Normalize weights by the maximum weight of all samples in the buffer. + All, + /// Normalize weights by the maximum weight of samples in the batch. + Batch, +} + +#[derive(Debug)] +pub struct SumTree { + eps: f32, + alpha: f32, + capacity: usize, + n_samples: usize, + tree: Vec, + min_tree: SegmentPoint, + max_tree: SegmentPoint, + normalize: WeightNormalizer, +} + +impl SumTree { + pub fn new(capacity: usize, alpha: f32, normalize: WeightNormalizer) -> Self { + Self { + eps: 1e-8, + alpha, + capacity, + n_samples: 0, + tree: vec![0f32; 2 * capacity - 1], + min_tree: SegmentPoint::build(vec![f32::MAX; capacity], MinIgnoreNaN), + max_tree: SegmentPoint::build(vec![1e-8f32; capacity], MaxIgnoreNaN), + normalize, + } + } + + fn propagate(&mut self, ix: usize, change: f32) { + let parent = (ix - 1) / 2; + self.tree[parent] += change; + if parent != 0 { + self.propagate(parent, change); + } + } + + fn retrieve(&self, ix: usize, s: f32) -> usize { + let left = 2 * ix + 1; + let right = left + 1; + + if left >= self.tree.len() { + return ix; + } + + if s <= self.tree[left] || self.tree[right] == 0f32 { + return self.retrieve(left, s); + } else { + return self.retrieve(right, s - self.tree[left]); + } + } + + pub fn total(&self) -> f32 { + return self.tree[0]; + } + + pub fn max(&self) -> f32 { + self.max_tree + .query(0, self.max_tree.len()) + .powf(1.0 / self.alpha) + } + + /// Add priority value at `ix`-th element in the sum tree. + /// + /// The alpha-th power of the priority value is taken when addition. + pub fn add(&mut self, ix: usize, p: f32) { + debug_assert!(ix <= self.n_samples); + + self.update(ix, p); + + if self.n_samples < self.capacity { + self.n_samples += 1; + } + } + + /// Update priority value at `ix`-th element in the sum tree. + pub fn update(&mut self, ix: usize, p: f32) { + debug_assert!(ix < self.capacity); + + let p = (p + self.eps).powf(self.alpha); + self.min_tree.modify(ix, p); + self.max_tree.modify(ix, p); + let ix = ix + self.capacity - 1; + let change = p - self.tree[ix]; + if change.is_nan() { + println!("{:?}, {:?}", p, self.tree[ix]); + panic!(); + } + self.tree[ix] = p; + self.propagate(ix, change); + } + + /// Get the maximal index of the sum tree where the sum of priority values is less than `s`. + pub fn get(&self, s: f32) -> usize { + let ix = self.retrieve(0, s); + debug_assert!(ix >= (self.capacity - 1)); + ix + 1 - self.capacity + } + + /// Samples indices for batch and returns normalized weights. + /// + /// The weight is $w_i=\left(N^{-1}P(i)^{-1}\right)^{\beta}$ + /// and it will be normalized by $max_i w_i$. + pub fn sample(&self, batch_size: usize, beta: f32) -> (Vec, Vec) { + let p_sum = &self.total(); + let ps = (0..batch_size) + .map(|_| p_sum * fastrand::f32()) + .collect::>(); + let indices = ps.iter().map(|&p| self.get(p)).collect::>(); + // let indices = (0..batch_size) + // .map(|_| self.get(p_sum * fastrand::f32())) + // .collect::>(); + + let n = self.n_samples as f32 / p_sum; + let ws = indices + .iter() + .map(|ix| self.tree[ix + self.capacity - 1]) + .map(|p| (n * p).powf(-beta)) + .collect::>(); + + // normalizer within all samples + let w_max_inv = match self.normalize { + WeightNormalizer::All => (n * self.min_tree.query(0, self.n_samples)).powf(beta), + WeightNormalizer::Batch => 1f32 / ws.iter().fold(0.0 / 0.0, |m, v| v.max(m)), + }; + let ws = ws.iter().map(|w| w * w_max_inv).collect::>(); + + if p_sum.is_nan() || w_max_inv.is_nan() || ws.iter().sum::().is_nan() { + println!("self.n_samples: {:?}", self.n_samples); + println!("p_sum: {:?}", p_sum); + println!("w_max_inv: {:?}", w_max_inv); + println!("ps: {:?}", ps); + println!("indices: {:?}", indices); + println!("{:?}", ws); + panic!(); + } + + let ixs = indices.iter().map(|&ix| ix as i64).collect(); + + (ixs, ws) + } + + #[allow(dead_code)] + pub fn print_tree(&self) { + let mut nl = 1; + + for i in 0..self.tree.len() { + print!("{} ", self.tree[i]); + if i == 2 * nl - 2 { + println!(); + nl *= 2; + } + } + println!("max = {}", self.max()); + // println!("min = {}", self.min()); + println!("total = {}", self.total()); + } +} + +#[cfg(test)] +mod tests { + use super::{SumTree, WeightNormalizer::Batch}; + + #[test] + fn test_sum_tree_odd() { + let data = vec![0.5f32, 0.2, 0.8, 0.3, 1.1, 2.5, 3.9]; + let mut sum_tree = SumTree::new(8, 1.0, Batch); + for ix in 0..data.len() { + sum_tree.add(ix, data[ix]); + } + sum_tree.print_tree(); + println!(); + + assert_eq!(sum_tree.get(0.0), 0); + assert_eq!(sum_tree.get(0.4), 0); + assert_eq!(sum_tree.get(0.5), 0); + assert_eq!(sum_tree.get(0.6), 1); + assert_eq!(sum_tree.get(1.2), 2); + assert_eq!(sum_tree.get(1.6), 3); + assert_eq!(sum_tree.get(2.0), 4); + assert_eq!(sum_tree.get(2.8), 4); + + sum_tree.update(7, 2.0); + sum_tree.print_tree(); + println!(); + + // let (ixs, ws) = sum_tree.sample(10, 1.0); + // println!("{:?}", ixs); + // println!("{:?}", ws); + // println!(); + + // let n_samples = 1000000; + // let (ixs, _) = sum_tree.sample(n_samples, 1.0); + // debug_assert!(ixs.iter().all(|&ix| ix < data.len() as i64)); + // (0..5).for_each(|ix| { + // let p = data[ix] / sum_tree.total() * (n_samples as f32); + // let n = ixs.iter().filter(|&&e| e == ix as i64).collect::>().len(); + // println!("ix={:?}: {:?} (p={:?})", ix, n, p); + // }) + } +} From 146df2471a64bab0bd758af453ffe7ede43e6696 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 10:05:04 +0000 Subject: [PATCH 05/23] WIP: remove the original module --- border-core/src/lib.rs | 213 ----------------- border-generic-replay-buffer/src/config.rs | 2 +- border-generic-replay-buffer/src/lib.rs | 226 +++++++++++++++++- border-generic-replay-buffer/src/step_proc.rs | 2 +- 4 files changed, 221 insertions(+), 222 deletions(-) diff --git a/border-core/src/lib.rs b/border-core/src/lib.rs index 098b68c7..e8a3058c 100644 --- a/border-core/src/lib.rs +++ b/border-core/src/lib.rs @@ -103,7 +103,6 @@ pub mod dummy; pub mod error; mod evaluator; -pub mod generic_replay_buffer; pub mod record; mod base; @@ -115,215 +114,3 @@ pub use base::{ mod trainer; pub use evaluator::{DefaultEvaluator, Evaluator}; pub use trainer::{Sampler, Trainer, TrainerConfig}; - -// TODO: Consider to compile this module only for tests. -/// Agent and Env for testing. -pub mod test { - use serde::{Deserialize, Serialize}; - - /// Obs for testing. - #[derive(Clone, Debug)] - pub struct TestObs { - obs: usize, - } - - impl crate::Obs for TestObs { - fn len(&self) -> usize { - 1 - } - } - - /// Batch of obs for testing. - pub struct TestObsBatch { - obs: Vec, - } - - impl crate::generic_replay_buffer::BatchBase for TestObsBatch { - fn new(capacity: usize) -> Self { - Self { - obs: vec![0; capacity], - } - } - - fn push(&mut self, i: usize, data: Self) { - self.obs[i] = data.obs[0]; - } - - fn sample(&self, ixs: &Vec) -> Self { - let obs = ixs.iter().map(|ix| self.obs[*ix]).collect(); - Self { obs } - } - } - - impl From for TestObsBatch { - fn from(obs: TestObs) -> Self { - Self { obs: vec![obs.obs] } - } - } - - /// Act for testing. - #[derive(Clone, Debug)] - pub struct TestAct { - act: usize, - } - - impl crate::Act for TestAct {} - - /// Batch of act for testing. - pub struct TestActBatch { - act: Vec, - } - - impl From for TestActBatch { - fn from(act: TestAct) -> Self { - Self { act: vec![act.act] } - } - } - - impl crate::generic_replay_buffer::BatchBase for TestActBatch { - fn new(capacity: usize) -> Self { - Self { - act: vec![0; capacity], - } - } - - fn push(&mut self, i: usize, data: Self) { - self.act[i] = data.act[0]; - } - - fn sample(&self, ixs: &Vec) -> Self { - let act = ixs.iter().map(|ix| self.act[*ix]).collect(); - Self { act } - } - } - - /// Info for testing. - pub struct TestInfo {} - - impl crate::Info for TestInfo {} - - /// Environment for testing. - pub struct TestEnv { - state_init: usize, - state: usize, - } - - impl crate::Env for TestEnv { - type Config = usize; - type Obs = TestObs; - type Act = TestAct; - type Info = TestInfo; - - fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { - self.state = self.state_init; - Ok(TestObs { obs: self.state }) - } - - fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { - self.state = self.state_init; - Ok(TestObs { obs: self.state }) - } - - fn step_with_reset(&mut self, a: &Self::Act) -> (crate::Step, crate::record::Record) - where - Self: Sized, - { - self.state = self.state + a.act; - let step = crate::Step { - obs: TestObs { obs: self.state }, - act: a.clone(), - reward: vec![0.0], - is_terminated: vec![0], - is_truncated: vec![0], - info: TestInfo {}, - init_obs: Some(TestObs { - obs: self.state_init, - }), - }; - return (step, crate::record::Record::empty()); - } - - fn step(&mut self, a: &Self::Act) -> (crate::Step, crate::record::Record) - where - Self: Sized, - { - self.state = self.state + a.act; - let step = crate::Step { - obs: TestObs { obs: self.state }, - act: a.clone(), - reward: vec![0.0], - is_terminated: vec![0], - is_truncated: vec![0], - info: TestInfo {}, - init_obs: Some(TestObs { - obs: self.state_init, - }), - }; - return (step, crate::record::Record::empty()); - } - - fn build(config: &Self::Config, _seed: i64) -> anyhow::Result - where - Self: Sized, - { - Ok(Self { - state_init: *config, - state: 0, - }) - } - } - - type ReplayBuffer = - crate::generic_replay_buffer::SimpleReplayBuffer; - - /// Agent for testing. - pub struct TestAgent {} - - #[derive(Clone, Deserialize, Serialize)] - /// Config of agent for testing. - pub struct TestAgentConfig; - - impl crate::Agent for TestAgent { - fn train(&mut self) {} - - fn is_train(&self) -> bool { - false - } - - fn eval(&mut self) {} - - fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record { - crate::record::Record::empty() - } - - fn save_params(&self, _path: &std::path::Path) -> anyhow::Result> { - Ok(vec![]) - } - - fn load_params(&mut self, _path: &std::path::Path) -> anyhow::Result<()> { - Ok(()) - } - - fn as_any_ref(&self) -> &dyn std::any::Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } - } - - impl crate::Policy for TestAgent { - fn sample(&mut self, _obs: &TestObs) -> TestAct { - TestAct { act: 1 } - } - } - - impl crate::Configurable for TestAgent { - type Config = TestAgentConfig; - - fn build(_config: Self::Config) -> Self { - Self {} - } - } -} diff --git a/border-generic-replay-buffer/src/config.rs b/border-generic-replay-buffer/src/config.rs index ffad05d5..8e41614a 100644 --- a/border-generic-replay-buffer/src/config.rs +++ b/border-generic-replay-buffer/src/config.rs @@ -32,7 +32,7 @@ use std::{ /// # Examples /// /// ```rust -/// use border_core::generic_replay_buffer::{PerConfig, WeightNormalizer}; +/// use border_generic_replay_buffer::{PerConfig, WeightNormalizer}; /// /// let config = PerConfig::default() /// .alpha(0.6) diff --git a/border-generic-replay-buffer/src/lib.rs b/border-generic-replay-buffer/src/lib.rs index e9d567b2..c4ada74b 100644 --- a/border-generic-replay-buffer/src/lib.rs +++ b/border-generic-replay-buffer/src/lib.rs @@ -21,14 +21,226 @@ // mod base; mod batch; -mod iw_scheduler; -mod sum_tree; mod config; -mod step_proc; +mod iw_scheduler; mod replay_buffer; -pub use sum_tree::WeightNormalizer; -pub use iw_scheduler::IwScheduler; +mod step_proc; +mod sum_tree; pub use batch::{BatchBase, GenericTransitionBatch}; -pub use config::{PerConfig, GenericReplayBufferConfig}; -pub use step_proc::{SimpleStepProcessor, SimpleStepProcessorConfig}; +pub use config::{GenericReplayBufferConfig, PerConfig}; +pub use iw_scheduler::IwScheduler; pub use replay_buffer::GenericReplayBuffer; +pub use step_proc::{SimpleStepProcessor, SimpleStepProcessorConfig}; +pub use sum_tree::WeightNormalizer; + +// TODO: Consider to compile this module only for tests. +/// Agent and Env for testing. +pub mod test { + use border_core::{record::Record, Act, Agent, Configurable, Env, Info, Obs, Policy, Step}; + use serde::{Deserialize, Serialize}; + + /// Obs for testing. + #[derive(Clone, Debug)] + pub struct TestObs { + obs: usize, + } + + impl Obs for TestObs { + fn len(&self) -> usize { + 1 + } + } + + /// Batch of obs for testing. + pub struct TestObsBatch { + obs: Vec, + } + + impl crate::BatchBase for TestObsBatch { + fn new(capacity: usize) -> Self { + Self { + obs: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.obs[i] = data.obs[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let obs = ixs.iter().map(|ix| self.obs[*ix]).collect(); + Self { obs } + } + } + + impl From for TestObsBatch { + fn from(obs: TestObs) -> Self { + Self { obs: vec![obs.obs] } + } + } + + /// Act for testing. + #[derive(Clone, Debug)] + pub struct TestAct { + act: usize, + } + + impl Act for TestAct {} + + /// Batch of act for testing. + pub struct TestActBatch { + act: Vec, + } + + impl From for TestActBatch { + fn from(act: TestAct) -> Self { + Self { act: vec![act.act] } + } + } + + impl crate::BatchBase for TestActBatch { + fn new(capacity: usize) -> Self { + Self { + act: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.act[i] = data.act[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let act = ixs.iter().map(|ix| self.act[*ix]).collect(); + Self { act } + } + } + + /// Info for testing. + pub struct TestInfo {} + + impl Info for TestInfo {} + + /// Environment for testing. + pub struct TestEnv { + state_init: usize, + state: usize, + } + + impl Env for TestEnv { + type Config = usize; + type Obs = TestObs; + type Act = TestAct; + type Info = TestInfo; + + fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn step_with_reset(&mut self, a: &Self::Act) -> (Step, Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: Some(TestObs { + obs: self.state_init, + }), + }; + return (step, Record::empty()); + } + + fn step(&mut self, a: &TestAct) -> (Step, Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: Some(TestObs { + obs: self.state_init, + }), + }; + return (step, Record::empty()); + } + + fn build(config: &Self::Config, _seed: i64) -> anyhow::Result + where + Self: Sized, + { + Ok(Self { + state_init: *config, + state: 0, + }) + } + } + + type ReplayBuffer = crate::GenericReplayBuffer; + + /// Agent for testing. + pub struct TestAgent {} + + #[derive(Clone, Deserialize, Serialize)] + /// Config of agent for testing. + pub struct TestAgentConfig; + + impl Agent for TestAgent { + fn train(&mut self) {} + + fn is_train(&self) -> bool { + false + } + + fn eval(&mut self) {} + + fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> Record { + Record::empty() + } + + fn save_params(&self, _path: &std::path::Path) -> anyhow::Result> { + Ok(vec![]) + } + + fn load_params(&mut self, _path: &std::path::Path) -> anyhow::Result<()> { + Ok(()) + } + + fn as_any_ref(&self) -> &dyn std::any::Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } + } + + impl Policy for TestAgent { + fn sample(&mut self, _obs: &TestObs) -> TestAct { + TestAct { act: 1 } + } + } + + impl Configurable for TestAgent { + type Config = TestAgentConfig; + + fn build(_config: Self::Config) -> Self { + Self {} + } + } +} diff --git a/border-generic-replay-buffer/src/step_proc.rs b/border-generic-replay-buffer/src/step_proc.rs index eb83e567..e6ff6eb8 100644 --- a/border-generic-replay-buffer/src/step_proc.rs +++ b/border-generic-replay-buffer/src/step_proc.rs @@ -8,7 +8,7 @@ //! - Efficient batch processing use super::{BatchBase, GenericTransitionBatch}; -use border_core::{Env, Obs, StepProcessor, Step}; +use border_core::{Env, Obs, Step, StepProcessor}; use std::{default::Default, marker::PhantomData}; /// Configuration for the simple step processor. From 6a24b1072b7585b8cb9d1ff3215a190df1cf87cc Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 11:24:22 +0000 Subject: [PATCH 06/23] Tweaks on other crates and examples --- README.md | 1 + border-async-trainer/Cargo.toml | 1 + border-async-trainer/src/lib.rs | 20 +++++------ border-atari-env/Cargo.toml | 1 + border-atari-env/src/util/test.rs | 4 +-- border-candle-agent/Cargo.toml | 1 + border-candle-agent/src/tensor_batch.rs | 2 +- border-generic-replay-buffer/src/lib.rs | 1 - border-minari/Cargo.toml | 1 + border-minari/src/converter.rs | 3 +- border-minari/src/d4rl/antmaze/candle.rs | 2 +- border-minari/src/d4rl/antmaze/ndarray.rs | 2 +- border-minari/src/d4rl/kitchen/candle.rs | 2 +- border-minari/src/d4rl/kitchen/ndarray.rs | 2 +- border-minari/src/d4rl/pointmaze/ndarray.rs | 2 +- border-minari/src/dataset.rs | 10 +++--- border-minari/src/util/candle/tensor_batch.rs | 2 +- border-py-gym-env/Cargo.toml | 4 ++- border-py-gym-env/src/candle/tensor_batch.rs | 2 +- border-py-gym-env/src/tch/tensor_batch.rs | 2 +- border-tch-agent/Cargo.toml | 1 + border-tch-agent/src/sac.rs | 14 ++++---- border-tch-agent/src/tensor_batch.rs | 2 +- examples/d4rl/awac_pen/Cargo.toml | 1 + examples/d4rl/awac_pen/src/main.rs | 9 +++-- examples/d4rl/bc_pen/Cargo.toml | 1 + examples/d4rl/bc_pen/src/main.rs | 11 +++---- examples/d4rl/iql_pen/Cargo.toml | 1 + examples/d4rl/iql_pen/src/main.rs | 11 +++---- examples/gym/dqn_cartpole/Cargo.toml | 1 + examples/gym/dqn_cartpole/src/main.rs | 16 ++++----- examples/gym/dqn_cartpole_tch/Cargo.toml | 1 + examples/gym/dqn_cartpole_tch/src/main.rs | 19 +++++------ examples/gym/sac_fetch_reach/Cargo.toml | 1 + examples/gym/sac_fetch_reach/MUJOCO_LOG.TXT | 33 +++++++++++++++++++ examples/gym/sac_fetch_reach/src/main.rs | 17 +++++----- examples/gym/sac_pendulum/Cargo.toml | 1 + examples/gym/sac_pendulum/src/main.rs | 17 +++++----- examples/gym/sac_pendulum_tch/Cargo.toml | 1 + examples/gym/sac_pendulum_tch/src/main.rs | 17 +++++----- 40 files changed, 140 insertions(+), 100 deletions(-) create mode 100644 examples/gym/sac_fetch_reach/MUJOCO_LOG.TXT diff --git a/README.md b/README.md index 1425e022..6d661c82 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Border consists of the following crates: * Core and utility * [border-core](https://crates.io/crates/border-core) ([doc](https://docs.rs/border-core/latest/border_core/)) provides basic traits and functions for environments and reinforcement learning (RL) agents. + * [border-generic-replay-buffer](https://crates.io/crates/border-generic-replay-buffer) ([doc](https://docs.rs/border-generic-replay-buffer/latest/border_generic_replay_buffer/)) provides a generic implementation of replay buffer. * [border-tensorboard](https://crates.io/crates/border-tensorboard) ([doc](https://docs.rs/border-core/latest/border_tensorboard/)) implements the `TensorboardRecorder` struct for writing records that can be visualized in Tensorboard, based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). * [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) ([doc](https://docs.rs/border-core/latest/border_mlflow_tracking/)) provides MLflow tracking support for logging metrics during training via REST API. * [border-async-trainer](https://crates.io/crates/border-async-trainer) ([doc](https://docs.rs/border-core/latest/border_async_trainer/)) defines traits and functions for asynchronous training of RL agents using multiple actors. Each actor runs a sampling process in parallel, where an agent interacts with an environment to collect samples for a shared replay buffer. diff --git a/border-async-trainer/Cargo.toml b/border-async-trainer/Cargo.toml index edbcbbdc..8eb6386a 100644 --- a/border-async-trainer/Cargo.toml +++ b/border-async-trainer/Cargo.toml @@ -24,4 +24,5 @@ thiserror = { workspace = true } [dev-dependencies] env_logger = { workspace = true } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } test-log = "0.2.8" \ No newline at end of file diff --git a/border-async-trainer/src/lib.rs b/border-async-trainer/src/lib.rs index 52a8a7eb..29db9e8c 100644 --- a/border-async-trainer/src/lib.rs +++ b/border-async-trainer/src/lib.rs @@ -4,7 +4,7 @@ //! //! ``` //! # use serde::{Deserialize, Serialize}; -//! # use border_core::test::{ +//! # use border_generic_replay_buffer::test::{ //! # TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch, //! # TestAct, TestActBatch //! # }; @@ -13,11 +13,11 @@ //! # //test::{TestAgent, TestAgentConfig, TestEnv}, //! # ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, //! # }; +//! # use border_generic_replay_buffer::{ +//! # GenericReplayBuffer, GenericReplayBufferConfig, +//! # SimpleStepProcessorConfig, SimpleStepProcessor +//! # }; //! # use border_core::{ -//! # generic_replay_buffer::{ -//! # SimpleReplayBuffer, SimpleReplayBufferConfig, -//! # SimpleStepProcessorConfig, SimpleStepProcessor -//! # }, //! # record::{Recorder, NullRecorder}, DefaultEvaluator, //! # }; //! # @@ -34,7 +34,7 @@ //! type Env = TestEnv; //! type ObsBatch = TestObsBatch; //! type ActBatch = TestActBatch; -//! type ReplayBuffer = SimpleReplayBuffer; +//! type ReplayBuffer = GenericReplayBuffer; //! type StepProcessor = SimpleStepProcessor; //! //! // Create a new agent by wrapping the existing agent in order to implement SyncModel. @@ -115,7 +115,7 @@ //! let agent_configs: Vec<_> = vec![agent_config()]; //! let env_config_train = env_config(); //! let env_config_eval = env_config(); -//! let replay_buffer_config = SimpleReplayBufferConfig::default(); +//! let replay_buffer_config = GenericReplayBufferConfig::default(); //! let step_proc_config = SimpleStepProcessorConfig::default(); //! let actor_man_config = ActorManagerConfig::default(); //! let async_trainer_config = AsyncTrainerConfig::default(); @@ -198,7 +198,7 @@ pub mod test { obs: Vec, } - impl border_core::generic_replay_buffer::BatchBase for TestObsBatch { + impl border_generic_replay_buffer::BatchBase for TestObsBatch { fn new(capacity: usize) -> Self { Self { obs: vec![0; capacity], @@ -240,7 +240,7 @@ pub mod test { } } - impl border_core::generic_replay_buffer::BatchBase for TestActBatch { + impl border_generic_replay_buffer::BatchBase for TestActBatch { fn new(capacity: usize) -> Self { Self { act: vec![0; capacity], @@ -337,7 +337,7 @@ pub mod test { } type ReplayBuffer = - border_core::generic_replay_buffer::SimpleReplayBuffer; + border_generic_replay_buffer::GenericReplayBuffer; /// Agent for testing. pub struct TestAgent {} diff --git a/border-atari-env/Cargo.toml b/border-atari-env/Cargo.toml index 6b1a1273..26ea4e66 100644 --- a/border-atari-env/Cargo.toml +++ b/border-atari-env/Cargo.toml @@ -18,6 +18,7 @@ border-core = { version = "0.0.9", path = "../border-core" } image = { workspace = true } tch = { workspace = true, optional = true } border-tch-agent = { version = "0.0.9", path = "../border-tch-agent", optional = true } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } candle-core = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } itertools = "0.10.1" diff --git a/border-atari-env/src/util/test.rs b/border-atari-env/src/util/test.rs index 2efc83a2..3534239f 100644 --- a/border-atari-env/src/util/test.rs +++ b/border-atari-env/src/util/test.rs @@ -5,10 +5,10 @@ use crate::{ }; use anyhow::Result; use border_core::{ - generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, record::Record, Agent as Agent_, Configurable, Policy, ReplayBuffer as ReplayBuffer_, }; +use border_generic_replay_buffer::{BatchBase, GenericReplayBuffer}; use serde::Deserialize; use std::ptr::copy; @@ -17,7 +17,7 @@ pub type Act = BorderAtariAct; pub type ObsFilter = BorderAtariObsRawFilter; pub type ActFilter = BorderAtariActRawFilter; pub type EnvConfig = BorderAtariEnvConfig; -pub type ReplayBuffer = SimpleReplayBuffer; +pub type ReplayBuffer = GenericReplayBuffer; pub type Env = BorderAtariEnv; pub type Agent = RandomAgent; diff --git a/border-candle-agent/Cargo.toml b/border-candle-agent/Cargo.toml index f1a2306a..4a09755f 100644 --- a/border-candle-agent/Cargo.toml +++ b/border-candle-agent/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } border-async-trainer = { version = "0.0.9", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } diff --git a/border-candle-agent/src/tensor_batch.rs b/border-candle-agent/src/tensor_batch.rs index ef17e183..c5038038 100644 --- a/border-candle-agent/src/tensor_batch.rs +++ b/border-candle-agent/src/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{error::Result, DType, Device, IndexOp, Tensor}; /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-generic-replay-buffer/src/lib.rs b/border-generic-replay-buffer/src/lib.rs index c4ada74b..cdd7cf68 100644 --- a/border-generic-replay-buffer/src/lib.rs +++ b/border-generic-replay-buffer/src/lib.rs @@ -19,7 +19,6 @@ //! - Configurable weight normalization //! - Step processing for non-vectorized environments -// mod base; mod batch; mod config; mod iw_scheduler; diff --git a/border-minari/Cargo.toml b/border-minari/Cargo.toml index 362e5965..35991c5d 100644 --- a/border-minari/Cargo.toml +++ b/border-minari/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } numpy = { workspace = true } pyo3 = { workspace = true, default-features = false, features = [ "auto-initialize", "macros" diff --git a/border-minari/src/converter.rs b/border-minari/src/converter.rs index b27e78c6..f5141c0c 100644 --- a/border-minari/src/converter.rs +++ b/border-minari/src/converter.rs @@ -1,5 +1,6 @@ use anyhow::Result; -use border_core::{generic_replay_buffer::BatchBase, Act, Obs}; +use border_core::{Act, Obs}; +use border_generic_replay_buffer::BatchBase; use pyo3::{PyAny, PyObject, Python}; /// Conversion trait for observation and action. diff --git a/border-minari/src/d4rl/antmaze/candle.rs b/border-minari/src/d4rl/antmaze/candle.rs index faa15fdb..2847ae9e 100644 --- a/border-minari/src/d4rl/antmaze/candle.rs +++ b/border-minari/src/d4rl/antmaze/candle.rs @@ -4,7 +4,7 @@ use crate::{ MinariConverter, MinariDataset, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{DType, Device, Tensor}; use ndarray::{ArrayBase, ArrayD, Axis, Slice}; use pyo3::{types::PyIterator, PyAny, PyObject, Python}; diff --git a/border-minari/src/d4rl/antmaze/ndarray.rs b/border-minari/src/d4rl/antmaze/ndarray.rs index c8cd7e40..d549e2de 100644 --- a/border-minari/src/d4rl/antmaze/ndarray.rs +++ b/border-minari/src/d4rl/antmaze/ndarray.rs @@ -6,7 +6,7 @@ use crate::{ MinariConverter, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use ndarray::{s, ArrayD, Axis, IxDyn, Slice}; use pyo3::{PyAny, PyObject, Python}; diff --git a/border-minari/src/d4rl/kitchen/candle.rs b/border-minari/src/d4rl/kitchen/candle.rs index 925021ef..b1dd2b7e 100644 --- a/border-minari/src/d4rl/kitchen/candle.rs +++ b/border-minari/src/d4rl/kitchen/candle.rs @@ -9,7 +9,7 @@ use crate::{ MinariConverter, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{DType, Device, Tensor}; use ndarray::{ArrayBase, ArrayD, Axis, Slice}; use pyo3::{PyAny, PyObject, Python}; diff --git a/border-minari/src/d4rl/kitchen/ndarray.rs b/border-minari/src/d4rl/kitchen/ndarray.rs index 0a7a3a3e..e305acdc 100644 --- a/border-minari/src/d4rl/kitchen/ndarray.rs +++ b/border-minari/src/d4rl/kitchen/ndarray.rs @@ -6,7 +6,7 @@ use crate::{ MinariConverter, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use ndarray::{s, ArrayD, Axis, IxDyn, Slice}; use pyo3::{PyAny, PyObject, Python}; diff --git a/border-minari/src/d4rl/pointmaze/ndarray.rs b/border-minari/src/d4rl/pointmaze/ndarray.rs index 10d5a87e..336571a1 100644 --- a/border-minari/src/d4rl/pointmaze/ndarray.rs +++ b/border-minari/src/d4rl/pointmaze/ndarray.rs @@ -6,7 +6,7 @@ use crate::{ MinariConverter, }; use anyhow::Result; -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use ndarray::{s, ArrayD, Axis, IxDyn, Slice}; use pyo3::{PyAny, PyObject, Python}; diff --git a/border-minari/src/dataset.rs b/border-minari/src/dataset.rs index 745a1904..a60833d6 100644 --- a/border-minari/src/dataset.rs +++ b/border-minari/src/dataset.rs @@ -1,8 +1,8 @@ use crate::{util, MinariConverter, MinariEnv}; use anyhow::Result; -use border_core::{ - generic_replay_buffer::{GenericTransitionBatch, SimpleReplayBuffer, SimpleReplayBufferConfig}, - ExperienceBuffer, ReplayBuffer, +use border_core::{ExperienceBuffer, ReplayBuffer}; +use border_generic_replay_buffer::{ + GenericTransitionBatch, GenericReplayBuffer, GenericReplayBufferConfig, }; use pyo3::{ types::{IntoPyDict, PyIterator}, @@ -65,7 +65,7 @@ impl MinariDataset { &self, converter: &mut T, episode_indices: Option>, - ) -> Result> + ) -> Result> where T::ObsBatch: std::fmt::Debug, T::ActBatch: std::fmt::Debug, @@ -75,7 +75,7 @@ impl MinariDataset { let num_transitions = self.get_num_transitions(episode_indices.clone())?; // Prepare replay buffer - let mut replay_buffer = SimpleReplayBuffer::build(&SimpleReplayBufferConfig { + let mut replay_buffer = GenericReplayBuffer::build(&GenericReplayBufferConfig { capacity: num_transitions, seed: 0, per_config: None, diff --git a/border-minari/src/util/candle/tensor_batch.rs b/border-minari/src/util/candle/tensor_batch.rs index 6051657c..0680a88d 100644 --- a/border-minari/src/util/candle/tensor_batch.rs +++ b/border-minari/src/util/candle/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{/*error::Result, DType,*/ Device, Tensor}; // /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-py-gym-env/Cargo.toml b/border-py-gym-env/Cargo.toml index 5516304c..bd63fe6b 100644 --- a/border-py-gym-env/Cargo.toml +++ b/border-py-gym-env/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer", optional = true } numpy = { workspace = true } pyo3 = { workspace = true, default-features = false, features = [ "auto-initialize", @@ -42,4 +43,5 @@ features = ["candle"] no-default-features = true [features] -candle = [ "candle-core" ] +candle = [ "candle-core", "border-generic-replay-buffer" ] +tch = ["dep:tch", "border-generic-replay-buffer"] \ No newline at end of file diff --git a/border-py-gym-env/src/candle/tensor_batch.rs b/border-py-gym-env/src/candle/tensor_batch.rs index 09bdc97e..5976922d 100644 --- a/border-py-gym-env/src/candle/tensor_batch.rs +++ b/border-py-gym-env/src/candle/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use candle_core::{error::Result, DType, Device, Tensor}; /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-py-gym-env/src/tch/tensor_batch.rs b/border-py-gym-env/src/tch/tensor_batch.rs index 43f05740..91c88c83 100644 --- a/border-py-gym-env/src/tch/tensor_batch.rs +++ b/border-py-gym-env/src/tch/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use tch::{Device, Tensor}; /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/border-tch-agent/Cargo.toml b/border-tch-agent/Cargo.toml index e763004a..63a51c38 100644 --- a/border-tch-agent/Cargo.toml +++ b/border-tch-agent/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../border-generic-replay-buffer" } border-async-trainer = { version = "0.0.9", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } diff --git a/border-tch-agent/src/sac.rs b/border-tch-agent/src/sac.rs index 8f99c1cb..aaf987a8 100644 --- a/border-tch-agent/src/sac.rs +++ b/border-tch-agent/src/sac.rs @@ -5,15 +5,15 @@ //! ```no_run //! # use anyhow::Result; //! use border_core::{ -//! # Env as Env_, Obs as Obs_, Act as Act_, Step, test::{ -//! # TestAct as TestAct_, TestActBatch as TestActBatch_, -//! # TestEnv as TestEnv_, -//! # TestObs as TestObs_, TestObsBatch as TestObsBatch_, -//! # }, +//! # Env as Env_, Obs as Obs_, Act as Act_, Step, //! # record::Record, -//! # generic_replay_buffer::{SimpleReplayBuffer, BatchBase}, //! Configurable, //! }; +//! use border_generic_replay_buffer::{GenericReplayBuffer, BatchBase, test::{ +//! # TestAct as TestAct_, TestActBatch as TestActBatch_, +//! # TestEnv as TestEnv_, +//! # TestObs as TestObs_, TestObsBatch as TestObsBatch_, +//! # }}; //! use border_tch_agent::{ //! sac::{ActorConfig, CriticConfig, Sac, SacConfig}, //! mlp::{Mlp, Mlp2, MlpConfig}, @@ -136,7 +136,7 @@ //! # type Env = TestEnv; //! # type ObsBatch = TestObsBatch; //! # type ActBatch = TestActBatch; -//! # type ReplayBuffer = SimpleReplayBuffer; +//! # type ReplayBuffer = GenericReplayBuffer; //! # //! const DIM_OBS: i64 = 3; //! const DIM_ACT: i64 = 1; diff --git a/border-tch-agent/src/tensor_batch.rs b/border-tch-agent/src/tensor_batch.rs index 721e92be..719000ff 100644 --- a/border-tch-agent/src/tensor_batch.rs +++ b/border-tch-agent/src/tensor_batch.rs @@ -1,4 +1,4 @@ -use border_core::generic_replay_buffer::BatchBase; +use border_generic_replay_buffer::BatchBase; use tch::{Device, Tensor}; /// Adds capability of constructing [`Tensor`] with a static method. diff --git a/examples/d4rl/awac_pen/Cargo.toml b/examples/d4rl/awac_pen/Cargo.toml index 04ac03a8..7e9858c7 100644 --- a/examples/d4rl/awac_pen/Cargo.toml +++ b/examples/d4rl/awac_pen/Cargo.toml @@ -16,6 +16,7 @@ border-minari = { version = "0.0.9", path = "../../../border-minari", features = ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/d4rl/awac_pen/src/main.rs b/examples/d4rl/awac_pen/src/main.rs index 36ab97a9..e4f0c45f 100644 --- a/examples/d4rl/awac_pen/src/main.rs +++ b/examples/d4rl/awac_pen/src/main.rs @@ -10,11 +10,10 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, - record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, + record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; +use border_generic_replay_buffer::{BatchBase, GenericReplayBuffer}; use border_minari::{ d4rl::pen::candle::{PenConverter, PenConverterConfig}, MinariConverter, MinariDataset, MinariEnv, MinariEvaluator, @@ -200,7 +199,7 @@ where fn create_replay_buffer( converter: &mut T, dataset: &MinariDataset, -) -> Result> +) -> Result> where T: MinariConverter, T::ObsBatch: BatchBase + Debug + Into, @@ -282,7 +281,7 @@ where T::ObsBatch: std::fmt::Debug + Into + 'static + Clone, T::ActBatch: std::fmt::Debug + Into + 'static + Clone, { - let mut agent: Box, SimpleReplayBuffer>> = + let mut agent: Box, GenericReplayBuffer>> = create_agent(&config); let recorder = create_recorder(&config)?; // used for loading a trained model let mut evaluator = create_evaluator(&config.args, converter, &dataset, true)?; diff --git a/examples/d4rl/bc_pen/Cargo.toml b/examples/d4rl/bc_pen/Cargo.toml index ea3b4030..67bbfb98 100644 --- a/examples/d4rl/bc_pen/Cargo.toml +++ b/examples/d4rl/bc_pen/Cargo.toml @@ -17,6 +17,7 @@ border-minari = { version = "0.0.9", path = "../../../border-minari", features = ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/d4rl/bc_pen/src/main.rs b/examples/d4rl/bc_pen/src/main.rs index 148f1d3f..e8cd9dc3 100644 --- a/examples/d4rl/bc_pen/src/main.rs +++ b/examples/d4rl/bc_pen/src/main.rs @@ -5,11 +5,10 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, - record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, + record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; +use border_generic_replay_buffer::{BatchBase, GenericReplayBuffer}; use border_minari::{ d4rl::pen::candle::{PenConverter, PenConverterConfig}, MinariConverter, MinariDataset, MinariEnv, MinariEvaluator, @@ -173,7 +172,7 @@ where fn create_replay_buffer( converter: &mut T, dataset: &MinariDataset, -) -> Result> +) -> Result> where T: MinariConverter, T::ObsBatch: BatchBase + Debug + Into, @@ -203,7 +202,7 @@ where let model_dir = format!("{}/{}", MODEL_DIR, config.args.env); Ok(Box::new(TensorboardRecorder::new( &model_dir, &model_dir, false, - ))) + ))) } } @@ -255,7 +254,7 @@ where T::ObsBatch: std::fmt::Debug + Into + 'static + Clone, T::ActBatch: std::fmt::Debug + Into + 'static + Clone, { - let mut agent: Box, SimpleReplayBuffer>> = + let mut agent: Box, GenericReplayBuffer>> = create_agent(&config); let recorder = create_recorder(&config)?; // used for loading a trained model let mut evaluator = create_evaluator(&config.args, converter, &dataset, true)?; diff --git a/examples/d4rl/iql_pen/Cargo.toml b/examples/d4rl/iql_pen/Cargo.toml index 04ac03a8..7e9858c7 100644 --- a/examples/d4rl/iql_pen/Cargo.toml +++ b/examples/d4rl/iql_pen/Cargo.toml @@ -16,6 +16,7 @@ border-minari = { version = "0.0.9", path = "../../../border-minari", features = ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/d4rl/iql_pen/src/main.rs b/examples/d4rl/iql_pen/src/main.rs index df4908e9..3b3ffb8e 100644 --- a/examples/d4rl/iql_pen/src/main.rs +++ b/examples/d4rl/iql_pen/src/main.rs @@ -10,11 +10,10 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, - record::Recorder, - Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, + record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBuffer, ReplayBuffer, Trainer, TrainerConfig, TransitionBatch, }; +use border_generic_replay_buffer::{BatchBase, SimpleReplayBuffer}; use border_minari::{ d4rl::pen::candle::{PenConverter, PenConverterConfig}, MinariConverter, MinariDataset, MinariEnv, MinariEvaluator, @@ -209,7 +208,7 @@ where fn create_replay_buffer( converter: &mut T, dataset: &MinariDataset, -) -> Result> +) -> Result> where T: MinariConverter, T::ObsBatch: BatchBase + Debug + Into, @@ -239,7 +238,7 @@ where let model_dir = format!("{}/{}", MODEL_DIR, config.args.env); Ok(Box::new(TensorboardRecorder::new( &model_dir, &model_dir, false, - ))) + ))) } } @@ -291,7 +290,7 @@ where T::ObsBatch: std::fmt::Debug + Into + 'static + Clone, T::ActBatch: std::fmt::Debug + Into + 'static + Clone, { - let mut agent: Box, SimpleReplayBuffer>> = + let mut agent: Box, GenericReplayBuffer>> = create_agent(&config); let recorder = create_recorder(&config)?; // used for loading a trained model let mut evaluator = create_evaluator(&config.args, converter, &dataset, true)?; diff --git a/examples/gym/dqn_cartpole/Cargo.toml b/examples/gym/dqn_cartpole/Cargo.toml index 1b8d0887..881d77eb 100644 --- a/examples/gym/dqn_cartpole/Cargo.toml +++ b/examples/gym/dqn_cartpole/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/dqn_cartpole/src/main.rs b/examples/gym/dqn_cartpole/src/main.rs index 1357ac7e..17f8558f 100644 --- a/examples/gym/dqn_cartpole/src/main.rs +++ b/examples/gym/dqn_cartpole/src/main.rs @@ -7,13 +7,11 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, +}; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }; use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ @@ -31,7 +29,7 @@ use clap::Parser; use serde::Serialize; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -164,7 +162,7 @@ impl DqnCartpoleConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = DqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); diff --git a/examples/gym/dqn_cartpole_tch/Cargo.toml b/examples/gym/dqn_cartpole_tch/Cargo.toml index 7517abfe..0d19f8df 100644 --- a/examples/gym/dqn_cartpole_tch/Cargo.toml +++ b/examples/gym/dqn_cartpole_tch/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/dqn_cartpole_tch/src/main.rs b/examples/gym/dqn_cartpole_tch/src/main.rs index 918a2bc7..8d273722 100644 --- a/examples/gym/dqn_cartpole_tch/src/main.rs +++ b/examples/gym/dqn_cartpole_tch/src/main.rs @@ -1,13 +1,12 @@ use anyhow::Result; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, +}; + use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ tch::{NdarrayConverter, NdarrayConverterConfig, TensorBatch}, @@ -19,12 +18,12 @@ use border_tch_agent::{ util::CriticLoss, }; use border_tensorboard::TensorboardRecorder; -use tch::Device; use clap::Parser; use serde::Serialize; +use tch::Device; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -157,7 +156,7 @@ impl DqnCartpoleConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = DqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); diff --git a/examples/gym/sac_fetch_reach/Cargo.toml b/examples/gym/sac_fetch_reach/Cargo.toml index 7b38c67e..6bb20fbc 100644 --- a/examples/gym/sac_fetch_reach/Cargo.toml +++ b/examples/gym/sac_fetch_reach/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/sac_fetch_reach/MUJOCO_LOG.TXT b/examples/gym/sac_fetch_reach/MUJOCO_LOG.TXT new file mode 100644 index 00000000..af0d24c6 --- /dev/null +++ b/examples/gym/sac_fetch_reach/MUJOCO_LOG.TXT @@ -0,0 +1,33 @@ +Sun Aug 10 11:01:19 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 1.6800. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:44 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + +Sun Aug 10 11:01:45 2025 +WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.4000. + diff --git a/examples/gym/sac_fetch_reach/src/main.rs b/examples/gym/sac_fetch_reach/src/main.rs index 1e4be090..6ba9995c 100644 --- a/examples/gym/sac_fetch_reach/src/main.rs +++ b/examples/gym/sac_fetch_reach/src/main.rs @@ -7,14 +7,13 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, +}; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }; + use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ candle::{ @@ -31,7 +30,7 @@ use clap::Parser; use serde::Serialize; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -222,7 +221,7 @@ impl SacFetchReachConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = SacFetchReachConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); diff --git a/examples/gym/sac_pendulum/Cargo.toml b/examples/gym/sac_pendulum/Cargo.toml index a8dd6faf..8e83e6bc 100644 --- a/examples/gym/sac_pendulum/Cargo.toml +++ b/examples/gym/sac_pendulum/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/sac_pendulum/src/main.rs b/examples/gym/sac_pendulum/src/main.rs index 9c78b292..7f6bd690 100644 --- a/examples/gym/sac_pendulum/src/main.rs +++ b/examples/gym/sac_pendulum/src/main.rs @@ -10,14 +10,13 @@ use border_candle_agent::{ Activation, }; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, +}; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }; + use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ candle::{ @@ -34,7 +33,7 @@ use clap::Parser; use serde::Serialize; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -190,7 +189,7 @@ impl SacPendulumConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = SacPendulumConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); diff --git a/examples/gym/sac_pendulum_tch/Cargo.toml b/examples/gym/sac_pendulum_tch/Cargo.toml index 7517abfe..0d19f8df 100644 --- a/examples/gym/sac_pendulum_tch/Cargo.toml +++ b/examples/gym/sac_pendulum_tch/Cargo.toml @@ -15,6 +15,7 @@ border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env", fe ] } border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } border-core = { version = "0.0.9", path = "../../../border-core" } +border-generic-replay-buffer = { version = "0.0.9", path = "../../../border-generic-replay-buffer" } border-tensorboard = { version = "0.0.9", path = "../../../border-tensorboard" } border-mlflow-tracking = { version = "0.0.9", path = "../../../border-mlflow-tracking" } serde = "1.0.194" diff --git a/examples/gym/sac_pendulum_tch/src/main.rs b/examples/gym/sac_pendulum_tch/src/main.rs index 4f7f2749..fc23b0dc 100644 --- a/examples/gym/sac_pendulum_tch/src/main.rs +++ b/examples/gym/sac_pendulum_tch/src/main.rs @@ -1,13 +1,12 @@ use anyhow::Result; use border_core::{ - generic_replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - record::Recorder, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBuffer as _, - StepProcessor, Trainer, TrainerConfig, + record::Recorder, Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, + ReplayBuffer as _, StepProcessor, Trainer, TrainerConfig, }; +use border_generic_replay_buffer::{ + GenericReplayBuffer, GenericReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, +}; + use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ tch::{NdarrayConverter, NdarrayConverterConfig, TensorBatch}, @@ -24,7 +23,7 @@ use serde::Serialize; use tch::Device; type Env = GymEnv; -type ReplayBuffer = SimpleReplayBuffer; +type ReplayBuffer = GenericReplayBuffer; type StepProc = SimpleStepProcessor; type Evaluator = DefaultEvaluator; @@ -153,7 +152,7 @@ impl SacPendulumConfig { fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = SacPendulumConfig::new(DIM_OBS, DIM_ACT, max_opts, eval_interval)?; let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let replay_buffer_config = GenericReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); let mut recorder = create_recorder(&args, model_dir, Some(&config))?; let mut trainer = Trainer::build(config.trainer_config.clone()); From caf6798065981c1df887c9108f70afd43fabab8d Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 12:09:50 +0000 Subject: [PATCH 07/23] Update README Update README --- README.md | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 6d661c82..de03769e 100644 --- a/README.md +++ b/README.md @@ -12,18 +12,18 @@ Border consists of the following crates: * Core and utility * [border-core](https://crates.io/crates/border-core) ([doc](https://docs.rs/border-core/latest/border_core/)) provides basic traits and functions for environments and reinforcement learning (RL) agents. * [border-generic-replay-buffer](https://crates.io/crates/border-generic-replay-buffer) ([doc](https://docs.rs/border-generic-replay-buffer/latest/border_generic_replay_buffer/)) provides a generic implementation of replay buffer. - * [border-tensorboard](https://crates.io/crates/border-tensorboard) ([doc](https://docs.rs/border-core/latest/border_tensorboard/)) implements the `TensorboardRecorder` struct for writing records that can be visualized in Tensorboard, based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). - * [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) ([doc](https://docs.rs/border-core/latest/border_mlflow_tracking/)) provides MLflow tracking support for logging metrics during training via REST API. - * [border-async-trainer](https://crates.io/crates/border-async-trainer) ([doc](https://docs.rs/border-core/latest/border_async_trainer/)) defines traits and functions for asynchronous training of RL agents using multiple actors. Each actor runs a sampling process in parallel, where an agent interacts with an environment to collect samples for a shared replay buffer. + * [border-tensorboard](https://crates.io/crates/border-tensorboard) ([doc](https://docs.rs/border-tensorboard/latest/border_tensorboard/)) implements the `TensorboardRecorder` struct for writing records that can be visualized in Tensorboard, based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). + * [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) ([doc](https://docs.rs/border-mlflow-tracking/latest/border_mlflow_tracking/)) provides MLflow tracking support for logging metrics during training via REST API. + * [border-async-trainer](https://crates.io/crates/border-async-trainer) ([doc](https://docs.rs/border-async-trainer/latest/border_async_trainer/)) defines traits and functions for asynchronous training of RL agents using multiple actors. Each actor runs a sampling process in parallel, where an agent interacts with an environment to collect samples for a shared replay buffer. * [border](https://crates.io/crates/border) serves as a collection of examples. * Environment - * [border-py-gym-env](https://crates.io/crates/border-py-gym-env) ([doc](https://docs.rs/border-core/latest/border_py_gym_env/)) provides a wrapper for [Gymnasium](https://gymnasium.farama.org) environments written in Python. - * [border-atari-env](https://crates.io/crates/border-atari-env) ([doc](https://docs.rs/border-core/latest/border_atari_env/)) implements a wrapper for [atari-env](https://crates.io/crates/atari-env), which is part of [gym-rs](https://crates.io/crates/gym-rs). - * [border-minari](https://crates.io/crates/border-minari) ([doc](https://docs.rs/border-core/latest/border_minari/)) provides a wrapper for [Minari](https://minari.farama.org). + * [border-py-gym-env](https://crates.io/crates/border-py-gym-env) ([doc](https://docs.rs/border-py-gym-env/latest/border_py_gym_env/)) provides a wrapper for [Gymnasium](https://gymnasium.farama.org) environments written in Python. + * [border-atari-env](https://crates.io/crates/border-atari-env) ([doc](https://docs.rs/border-atari-env/latest/border_atari_env/)) implements a wrapper for [atari-env](https://crates.io/crates/atari-env), which is part of [gym-rs](https://crates.io/crates/gym-rs). + * [border-minari](https://crates.io/crates/border-minari) ([doc](https://docs.rs/border-minari/latest/border_minari/)) provides a wrapper for [Minari](https://minari.farama.org). * Agent - * [border-tch-agent](https://crates.io/crates/border-tch-agent) ([doc](https://docs.rs/border-core/latest/border_tch_agent/)) implements RL agents based on [tch](https://crates.io/crates/tch), including Deep Q Network (DQN), Implicit Quantile Network (IQN), and Soft Actor-Critic (SAC). - * [border-candle-agent](https://crates.io/crates/border-candle-agent) ([doc](https://docs.rs/border-core/latest/border_candle_agent/)) implements RL agents based on [candle](https://crates.io/crates/candle-core). - * [border-policy-no-backend](https://crates.io/crates/border-policy-no-backend) ([doc](https://docs.rs/border-core/latest/border_policy_no_backend/)) implements policies that are independent of any deep learning backend, such as Torch. + * [border-tch-agent](https://crates.io/crates/border-tch-agent) ([doc](https://docs.rs/border-tch-agent/latest/border_tch_agent/)) implements RL agents based on [tch](https://crates.io/crates/tch), including Deep Q Network (DQN), Implicit Quantile Network (IQN), and Soft Actor-Critic (SAC). + * [border-candle-agent](https://crates.io/crates/border-candle-agent) ([doc](https://docs.rs/border-candle-agent/latest/border_candle_agent/)) implements RL agents based on [candle](https://crates.io/crates/candle-core). + * [border-policy-no-backend](https://crates.io/crates/border-policy-no-backend) ([doc](https://docs.rs/border-policy-no-backend/latest/border_policy_no_backend/)) implements policies that are independent of any deep learning backend, such as Torch. ## Status @@ -39,16 +39,17 @@ Docker configuration files for development and testing are available in the [dev ## License -Crates | License ---------------------------|------------------ -`border-core` | MIT OR Apache-2.0 -`border-tensorboard` | MIT OR Apache-2.0 -`border-mlflow-tracking` | MIT OR Apache-2.0 -`border-async-trainer` | MIT OR Apache-2.0 -`border-py-gym-env` | MIT OR Apache-2.0 -`border-atari-env` | GPL-2.0-or-later -`border-minari` | MIT OR Apache-2.0 -`border-tch-agent` | MIT OR Apache-2.0 -`border-candle-agent` | MIT OR Apache-2.0 -`border-policy-no-backend`| MIT OR Apache-2.0 -`border` | GPL-2.0-or-later +Crates | License +------------------------------|------------------ +`border-core` | MIT OR Apache-2.0 +`border-generic-replay-buffer`| MIT OR Apache-2.0 +`border-tensorboard` | MIT OR Apache-2.0 +`border-mlflow-tracking` | MIT OR Apache-2.0 +`border-async-trainer` | MIT OR Apache-2.0 +`border-py-gym-env` | MIT OR Apache-2.0 +`border-atari-env` | GPL-2.0-or-later +`border-minari` | MIT OR Apache-2.0 +`border-tch-agent` | MIT OR Apache-2.0 +`border-candle-agent` | MIT OR Apache-2.0 +`border-policy-no-backend` | MIT OR Apache-2.0 +`border` | GPL-2.0-or-later From 6dab21fd70c60f86fdc5e4bc03c1a31f1fff0ca5 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 13:40:34 +0000 Subject: [PATCH 08/23] Use include_str for documentation Use include_str for documentation --- Cargo.toml | 1 + border-async-trainer/README.md | 160 +++++++++++++++++++++++ border-async-trainer/src/lib.rs | 161 +----------------------- border-atari-env/README.md | 86 +++++++++++++ border-atari-env/src/lib.rs | 87 +------------ border-atari-env/src/util/test.rs | 4 +- border-core/README.md | 71 +++++++++++ border-core/src/lib.rs | 102 +-------------- border-generic-replay-buffer/README.md | 63 ++++++++++ border-generic-replay-buffer/src/lib.rs | 22 +--- border-minari/README.md | 76 +++++++++++ border-minari/src/lib.rs | 78 +----------- border-mlflow-tracking/README.md | 30 ++++- border-mlflow-tracking/src/lib.rs | 124 +----------------- border-policy-no-backend/src/lib.rs | 2 +- border-py-gym-env/README.md | 44 +++++++ border-py-gym-env/src/lib.rs | 46 +------ 17 files changed, 536 insertions(+), 621 deletions(-) create mode 100644 border-generic-replay-buffer/README.md diff --git a/Cargo.toml b/Cargo.toml index 69d9c001..c78dd7d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ repository = "https://github.com/laboroai/border" keywords = ["reinforcement", "learning", "rl"] categories = ["science"] license = "MIT OR Apache-2.0" +readme = "README.md" [workspace.dependencies] clap = { version = "4.5.8", features = ["derive"] } diff --git a/border-async-trainer/README.md b/border-async-trainer/README.md index e69de29b..923d2d71 100644 --- a/border-async-trainer/README.md +++ b/border-async-trainer/README.md @@ -0,0 +1,160 @@ +Asynchronous trainer with parallel sampling processes. + +The code might look like below. + +``` +# use serde::{Deserialize, Serialize}; +# use border_generic_replay_buffer::test::{ +# TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch, +# TestAct, TestActBatch +# }; +# use border_core::Env as _; +# use border_async_trainer::{ +# //test::{TestAgent, TestAgentConfig, TestEnv}, +# ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, +# }; +# use border_generic_replay_buffer::{ +# GenericReplayBuffer, GenericReplayBufferConfig, +# SimpleStepProcessorConfig, SimpleStepProcessor +# }; +# use border_core::{ +# record::{Recorder, NullRecorder}, DefaultEvaluator, +# }; +# +# use std::path::{Path, PathBuf}; +# +# fn agent_config() -> TestAgentConfig { +# TestAgentConfig +# } +# +# fn env_config() -> usize { +# 0 +# } + +type Env = TestEnv; +type ObsBatch = TestObsBatch; +type ActBatch = TestActBatch; +type ReplayBuffer = GenericReplayBuffer; +type StepProcessor = SimpleStepProcessor; + +// Create a new agent by wrapping the existing agent in order to implement SyncModel. +struct TestAgent2(TestAgent); + +impl border_core::Configurable for TestAgent2 { + type Config = TestAgentConfig; + + fn build(config: Self::Config) -> Self { + Self(TestAgent::build(config)) + } +} + +impl border_core::Agent for TestAgent2 { + // Boilerplate code to delegate the method calls to the inner agent. + fn train(&mut self) { + self.0.train(); + } + + // For other methods ... +# fn is_train(&self) -> bool { +# self.0.is_train() +# } +# +# fn eval(&mut self) { +# self.0.eval(); +# } +# +# fn opt_with_record(&mut self, buffer: &mut ReplayBuffer) -> border_core::record::Record { +# self.0.opt_with_record(buffer) +# } +# +# fn save_params(&self, path: &Path) -> anyhow::Result> { +# self.0.save_params(path) +# } +# +# fn load_params(&mut self, path: &Path) -> anyhow::Result<()> { +# self.0.load_params(path) +# } +# +# fn opt(&mut self, buffer: &mut ReplayBuffer) { +# self.0.opt_with_record(buffer); +# } +# +# fn as_any_ref(&self) -> &dyn std::any::Any { +# self +# } +# +# fn as_any_mut(&mut self) -> &mut dyn std::any::Any { +# self +# } +} + +impl border_core::Policy for TestAgent2 { + // Boilerplate code to delegate the method calls to the inner agent. + // ... +# fn sample(&mut self, obs: &TestObs) -> TestAct { +# self.0.sample(obs) +# } +} + +impl border_async_trainer::SyncModel for TestAgent2{ + // Self::ModelInfo shold include the model parameters. + type ModelInfo = usize; + + + fn model_info(&self) -> (usize, Self::ModelInfo) { + // Extracts the model parameters and returns them as Self::ModelInfo. + // The first element of the tuple is the number of optimization steps. + (0, 0) + } + + fn sync_model(&mut self, _model_info: &Self::ModelInfo) { + // implements synchronization of the model based on the _model_info + } +} + +let agent_configs: Vec<_> = vec![agent_config()]; +let env_config_train = env_config(); +let env_config_eval = env_config(); +let replay_buffer_config = GenericReplayBufferConfig::default(); +let step_proc_config = SimpleStepProcessorConfig::default(); +let actor_man_config = ActorManagerConfig::default(); +let async_trainer_config = AsyncTrainerConfig::default(); +let mut recorder: Box> = Box::new(NullRecorder::new()); +let mut evaluator = DefaultEvaluator::::new(&env_config_eval, 0, 1).unwrap(); + +border_async_trainer::util::train_async::( + &agent_config(), + &agent_configs, + &env_config_train, + &env_config_eval, + &step_proc_config, + &replay_buffer_config, + &actor_man_config, + &async_trainer_config, + &mut recorder, + &mut evaluator, +); +``` + +Training process consists of the following two components: + +* [`ActorManager`] manages [`Actor`]s, each of which runs a thread for interacting + [`Agent`] and [`Env`] and taking samples. Those samples will be sent to + the replay buffer in [`AsyncTrainer`]. +* [`AsyncTrainer`] is responsible for training of an agent. It also runs a thread + for pushing samples from [`ActorManager`] into a replay buffer. + +The `Agent` must implement [`SyncModel`] trait in order to synchronize the model of +the agent in [`Actor`] with the trained agent in [`AsyncTrainer`]. The trait has +the ability to import and export the information of the model as +[`SyncModel`]`::ModelInfo`. + +The `Agent` in [`AsyncTrainer`] is responsible for training, typically with a GPU, +while the `Agent`s in [`Actor`]s in [`ActorManager`] is responsible for sampling +using CPU. + +Both [`AsyncTrainer`] and [`ActorManager`] are running in the same machine and +communicate by channels. + +[`Agent`]: border_core::Agent +[`Env`]: border_core::Env \ No newline at end of file diff --git a/border-async-trainer/src/lib.rs b/border-async-trainer/src/lib.rs index 29db9e8c..e7ebdeaf 100644 --- a/border-async-trainer/src/lib.rs +++ b/border-async-trainer/src/lib.rs @@ -1,163 +1,4 @@ -//! Asynchronous trainer with parallel sampling processes. -//! -//! The code might look like below. -//! -//! ``` -//! # use serde::{Deserialize, Serialize}; -//! # use border_generic_replay_buffer::test::{ -//! # TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch, -//! # TestAct, TestActBatch -//! # }; -//! # use border_core::Env as _; -//! # use border_async_trainer::{ -//! # //test::{TestAgent, TestAgentConfig, TestEnv}, -//! # ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, -//! # }; -//! # use border_generic_replay_buffer::{ -//! # GenericReplayBuffer, GenericReplayBufferConfig, -//! # SimpleStepProcessorConfig, SimpleStepProcessor -//! # }; -//! # use border_core::{ -//! # record::{Recorder, NullRecorder}, DefaultEvaluator, -//! # }; -//! # -//! # use std::path::{Path, PathBuf}; -//! # -//! # fn agent_config() -> TestAgentConfig { -//! # TestAgentConfig -//! # } -//! # -//! # fn env_config() -> usize { -//! # 0 -//! # } -//! -//! type Env = TestEnv; -//! type ObsBatch = TestObsBatch; -//! type ActBatch = TestActBatch; -//! type ReplayBuffer = GenericReplayBuffer; -//! type StepProcessor = SimpleStepProcessor; -//! -//! // Create a new agent by wrapping the existing agent in order to implement SyncModel. -//! struct TestAgent2(TestAgent); -//! -//! impl border_core::Configurable for TestAgent2 { -//! type Config = TestAgentConfig; -//! -//! fn build(config: Self::Config) -> Self { -//! Self(TestAgent::build(config)) -//! } -//! } -//! -//! impl border_core::Agent for TestAgent2 { -//! // Boilerplate code to delegate the method calls to the inner agent. -//! fn train(&mut self) { -//! self.0.train(); -//! } -//! -//! // For other methods ... -//! # fn is_train(&self) -> bool { -//! # self.0.is_train() -//! # } -//! # -//! # fn eval(&mut self) { -//! # self.0.eval(); -//! # } -//! # -//! # fn opt_with_record(&mut self, buffer: &mut ReplayBuffer) -> border_core::record::Record { -//! # self.0.opt_with_record(buffer) -//! # } -//! # -//! # fn save_params(&self, path: &Path) -> anyhow::Result> { -//! # self.0.save_params(path) -//! # } -//! # -//! # fn load_params(&mut self, path: &Path) -> anyhow::Result<()> { -//! # self.0.load_params(path) -//! # } -//! # -//! # fn opt(&mut self, buffer: &mut ReplayBuffer) { -//! # self.0.opt_with_record(buffer); -//! # } -//! # -//! # fn as_any_ref(&self) -> &dyn std::any::Any { -//! # self -//! # } -//! # -//! # fn as_any_mut(&mut self) -> &mut dyn std::any::Any { -//! # self -//! # } -//! } -//! -//! impl border_core::Policy for TestAgent2 { -//! // Boilerplate code to delegate the method calls to the inner agent. -//! // ... -//! # fn sample(&mut self, obs: &TestObs) -> TestAct { -//! # self.0.sample(obs) -//! # } -//! } -//! -//! impl border_async_trainer::SyncModel for TestAgent2{ -//! // Self::ModelInfo shold include the model parameters. -//! type ModelInfo = usize; -//! -//! -//! fn model_info(&self) -> (usize, Self::ModelInfo) { -//! // Extracts the model parameters and returns them as Self::ModelInfo. -//! // The first element of the tuple is the number of optimization steps. -//! (0, 0) -//! } -//! -//! fn sync_model(&mut self, _model_info: &Self::ModelInfo) { -//! // implements synchronization of the model based on the _model_info -//! } -//! } -//! -//! let agent_configs: Vec<_> = vec![agent_config()]; -//! let env_config_train = env_config(); -//! let env_config_eval = env_config(); -//! let replay_buffer_config = GenericReplayBufferConfig::default(); -//! let step_proc_config = SimpleStepProcessorConfig::default(); -//! let actor_man_config = ActorManagerConfig::default(); -//! let async_trainer_config = AsyncTrainerConfig::default(); -//! let mut recorder: Box> = Box::new(NullRecorder::new()); -//! let mut evaluator = DefaultEvaluator::::new(&env_config_eval, 0, 1).unwrap(); -//! -//! border_async_trainer::util::train_async::( -//! &agent_config(), -//! &agent_configs, -//! &env_config_train, -//! &env_config_eval, -//! &step_proc_config, -//! &replay_buffer_config, -//! &actor_man_config, -//! &async_trainer_config, -//! &mut recorder, -//! &mut evaluator, -//! ); -//! ``` -//! -//! Training process consists of the following two components: -//! -//! * [`ActorManager`] manages [`Actor`]s, each of which runs a thread for interacting -//! [`Agent`] and [`Env`] and taking samples. Those samples will be sent to -//! the replay buffer in [`AsyncTrainer`]. -//! * [`AsyncTrainer`] is responsible for training of an agent. It also runs a thread -//! for pushing samples from [`ActorManager`] into a replay buffer. -//! -//! The `Agent` must implement [`SyncModel`] trait in order to synchronize the model of -//! the agent in [`Actor`] with the trained agent in [`AsyncTrainer`]. The trait has -//! the ability to import and export the information of the model as -//! [`SyncModel`]`::ModelInfo`. -//! -//! The `Agent` in [`AsyncTrainer`] is responsible for training, typically with a GPU, -//! while the `Agent`s in [`Actor`]s in [`ActorManager`] is responsible for sampling -//! using CPU. -//! -//! Both [`AsyncTrainer`] and [`ActorManager`] are running in the same machine and -//! communicate by channels. -//! -//! [`Agent`]: border_core::Agent -//! [`Env`]: border_core::Env +#![doc = include_str!("../README.md")] mod actor; mod actor_manager; mod async_trainer; diff --git a/border-atari-env/README.md b/border-atari-env/README.md index e69de29b..70e864a4 100644 --- a/border-atari-env/README.md +++ b/border-atari-env/README.md @@ -0,0 +1,86 @@ +A thin wrapper of [`atari-env`](https://crates.io/crates/atari-env) for [`Border`](https://crates.io/crates/border). + +The code under [atari_env] is adapted from the +[`atari-env`](https://crates.io/crates/atari-env) crate +(rev = `0ef0422f953d79e96b32ad14284c9600bd34f335`), +because the crate registered in crates.io does not implement +[`atari_env::AtariEnv::lives()`] method, which is required for episodic life environments. + +This environment applies some preprocessing to observation as in +[`atari_wrapper.py`](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py). + +You need to place Atari Rom directories under the directory specified by environment variable +`ATARI_ROM_DIR`. An easy way to do this is to use [AutoROM](https://pypi.org/project/AutoROM/) +Python package. + +```bash +pip install autorom +mkdir $HOME/atari_rom +AutoROM --install-dir $HOME/atari_rom +export ATARI_ROM_DIR=$HOME/atari_rom +``` + +Here is an example of running Pong environment with a random policy. + +```no_run +use anyhow::Result; +use border_atari_env::{ + BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, + BorderAtariObs, BorderAtariObsRawFilter, +}; +use border_core::{Env as _, Policy, DefaultEvaluator, Evaluator as _, NullReplayBuffer, Agent}; + +# type Obs = BorderAtariObs; +# type Act = BorderAtariAct; +# type ObsFilter = BorderAtariObsRawFilter; +# type ActFilter = BorderAtariActRawFilter; +# type EnvConfig = BorderAtariEnvConfig; +# type Env = BorderAtariEnv; +# +# #[derive(Clone)] +# struct RandomPolicyConfig { +# pub n_acts: usize, +# } +# +# struct RandomPolicy { +# n_acts: usize, +# } +# +# impl RandomPolicy { +# pub fn build(n_acts: usize) -> Self { +# Self { n_acts } +# } +# } +# +# impl Policy for RandomPolicy { +# fn sample(&mut self, _: &Obs) -> Act { +# fastrand::u8(..self.n_acts as u8).into() +# } +# } +# +# impl Agent for RandomPolicy {} +# +# fn env_config(name: String) -> EnvConfig { +# EnvConfig::default().name(name) +# } +# +fn main() -> Result<()> { +# env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); +# fastrand::seed(42); +# + // Creates Pong environment + let env_config = env_config("pong".to_string()); + + // Creates a random policy + let n_acts = 4; + let mut policy = Box::new(RandomPolicy::build(n_acts)) as _; + + // Runs evaluation + let env_config = env_config.render(true); + let _ = DefaultEvaluator::new(&env_config, 42, 5)? + .evaluate(&mut policy); + + Ok(()) +} +``` +[`atari_env::AtariEnv::lives()`]: atari_env::AtariEnv::lives diff --git a/border-atari-env/src/lib.rs b/border-atari-env/src/lib.rs index 7a71c57d..6fbbaa92 100644 --- a/border-atari-env/src/lib.rs +++ b/border-atari-env/src/lib.rs @@ -1,89 +1,4 @@ -//! A thin wrapper of [`atari-env`](https://crates.io/crates/atari-env) for [`Border`](https://crates.io/crates/border). -//! -//! The code under [atari_env] is adapted from the -//! [`atari-env`](https://crates.io/crates/atari-env) crate -//! (rev = `0ef0422f953d79e96b32ad14284c9600bd34f335`), -//! because the crate registered in crates.io does not implement -//! [`atari_env::AtariEnv::lives()`] method, which is required for episodic life environments. -//! -//! This environment applies some preprocessing to observation as in -//! [`atari_wrapper.py`](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py). -//! -//! You need to place Atari Rom directories under the directory specified by environment variable -//! `ATARI_ROM_DIR`. An easy way to do this is to use [AutoROM](https://pypi.org/project/AutoROM/) -//! Python package. -//! -//! ```bash -//! pip install autorom -//! mkdir $HOME/atari_rom -//! AutoROM --install-dir $HOME/atari_rom -//! export ATARI_ROM_DIR=$HOME/atari_rom -//! ``` -//! -//! Here is an example of running Pong environment with a random policy. -//! -//! ```no_run -//! use anyhow::Result; -//! use border_atari_env::{ -//! BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, -//! BorderAtariObs, BorderAtariObsRawFilter, -//! }; -//! use border_core::{Env as _, Policy, DefaultEvaluator, Evaluator as _, NullReplayBuffer, Agent}; -//! -//! # type Obs = BorderAtariObs; -//! # type Act = BorderAtariAct; -//! # type ObsFilter = BorderAtariObsRawFilter; -//! # type ActFilter = BorderAtariActRawFilter; -//! # type EnvConfig = BorderAtariEnvConfig; -//! # type Env = BorderAtariEnv; -//! # -//! # #[derive(Clone)] -//! # struct RandomPolicyConfig { -//! # pub n_acts: usize, -//! # } -//! # -//! # struct RandomPolicy { -//! # n_acts: usize, -//! # } -//! # -//! # impl RandomPolicy { -//! # pub fn build(n_acts: usize) -> Self { -//! # Self { n_acts } -//! # } -//! # } -//! # -//! # impl Policy for RandomPolicy { -//! # fn sample(&mut self, _: &Obs) -> Act { -//! # fastrand::u8(..self.n_acts as u8).into() -//! # } -//! # } -//! # -//! # impl Agent for RandomPolicy {} -//! # -//! # fn env_config(name: String) -> EnvConfig { -//! # EnvConfig::default().name(name) -//! # } -//! # -//! fn main() -> Result<()> { -//! # env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); -//! # fastrand::seed(42); -//! # -//! // Creates Pong environment -//! let env_config = env_config("pong".to_string()); -//! -//! // Creates a random policy -//! let n_acts = 4; -//! let mut policy = Box::new(RandomPolicy::build(n_acts)) as _; -//! -//! // Runs evaluation -//! let env_config = env_config.render(true); -//! let _ = DefaultEvaluator::new(&env_config, 42, 5)? -//! .evaluate(&mut policy); -//! -//! Ok(()) -//! } -//! ``` -//! [`atari_env::AtariEnv::lives()`]: atari_env::AtariEnv::lives +#![doc = include_str!("../README.md")] mod act; pub mod atari_env; mod env; diff --git a/border-atari-env/src/util/test.rs b/border-atari-env/src/util/test.rs index 3534239f..e8b7d5c7 100644 --- a/border-atari-env/src/util/test.rs +++ b/border-atari-env/src/util/test.rs @@ -23,7 +23,7 @@ pub type Agent = RandomAgent; const FRAME_IN_BYTES: usize = 84 * 84; -/// Consists the observation part of a batch in [SimpleReplayBuffer]. +/// Consists the observation part of a batch in [`GenericReplayBuffer`]. pub struct ObsBatch { /// The number of samples in the batch. pub n: usize, @@ -78,7 +78,7 @@ impl From for ObsBatch { } } -/// Consists the action part of a batch in [SimpleReplayBuffer]. +/// Consists the action part of a batch in [`GenericReplayBuffer`]. pub struct ActBatch { /// The number of samples in the batch. pub n: usize, diff --git a/border-core/README.md b/border-core/README.md index e69de29b..c010bb21 100644 --- a/border-core/README.md +++ b/border-core/README.md @@ -0,0 +1,71 @@ +Core components for reinforcement learning. + +# Observation and Action + +The [`Obs`] and [`Act`] traits provide abstractions for observations and actions in environments. + +# Environment + +The [`Env`] trait serves as the fundamental abstraction for environments. It defines four associated types: +`Config`, `Obs`, `Act`, and `Info`. The `Obs` and `Act` types represent concrete implementations of +environment observations and actions, respectively. These types must implement the [`Obs`] and [`Act`] traits. +Environments implementing [`Env`] generate [`Step`] objects at each interaction step through the +[`Env::step()`] method. The [`Info`] type stores additional information from each agent-environment interaction, +which may be empty (implemented as a zero-sized struct). The `Config` type represents environment configurations +and is used during environment construction. + +# Policy + +The [`Policy`] trait represents a decision-making policy. The [`Policy::sample()`] method takes an +`E::Obs` and generates an `E::Act`. Policies can be either probabilistic or deterministic, depending on the +implementation. + +# Agent + +In this crate, an [`Agent`] is defined as a trainable [`Policy`]. +Agents operate in either training or evaluation mode. During training, the agent's policy may be probabilistic +to facilitate exploration, while in evaluation mode, it typically becomes deterministic. + +The [`Agent::opt()`] method executes a single optimization step. The specific implementation of an optimization +step varies between agents and may include multiple stochastic gradient descent steps. Training samples are +obtained from the [`ReplayBuffer`]. + +This trait also provides methods for saving and loading trained policy parameters to and from a directory. + +# Batch + +The [`TransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`. +This trait is used for training [`Agent`]s using reinforcement learning algorithms. + +# Replay Buffer and Experience Buffer + +The [`ReplayBuffer`] trait provides an abstraction for replay buffers. Its associated type +[`ReplayBuffer::Batch`] represents samples used for training [`Agent`]s. Agents must implement the +[`Agent::opt()`] method, where [`ReplayBuffer::Batch`] must have appropriate type or trait bounds +for training the agent. + +While [`ReplayBuffer`] focuses on generating training batches, the [`ExperienceBuffer`] trait +handles sample storage. The [`ExperienceBuffer::push()`] method stores samples of type +[`ExperienceBuffer::Item`], typically obtained through environment interactions. + +# Step Processor + +The [`StepProcessor`] trait plays a crucial role in the training pipeline by transforming environment +interactions into training samples. It processes [`Step`] objects, which contain the current +observation, action, reward, and next observation, into a format suitable for the replay buffer. + +# Trainer + +The [`Trainer`] manages the training loop and related objects. A [`Trainer`] instance is configured with +training parameters such as the maximum number of optimization steps and the directory for saving agent +parameters during training. The [`Trainer::train`] method executes online training of an agent in an environment. +During the training loop, the agent interacts with the environment to collect samples and perform optimization +steps, while simultaneously recording various metrics. + +# Evaluator + +The [`Evaluator`] trait is used to evaluate a policy's (`P`) performance in an environment (`E`). +An instance of this type is provided to the [`Trainer`] for policy evaluation during training. +[`DefaultEvaluator`] serves as the default implementation of [`Evaluator`]. This evaluator +runs the policy in the environment for a specified number of episodes. At the start of each episode, +the environment is reset using [`Env::reset_with_index()`] to control specific evaluation conditions. diff --git a/border-core/src/lib.rs b/border-core/src/lib.rs index e8a3058c..b924f2f3 100644 --- a/border-core/src/lib.rs +++ b/border-core/src/lib.rs @@ -1,105 +1,5 @@ +#![doc = include_str!("../README.md")] #![warn(missing_docs)] -//! Core components for reinforcement learning. -//! -//! # Observation and Action -//! -//! The [`Obs`] and [`Act`] traits provide abstractions for observations and actions in environments. -//! -//! # Environment -//! -//! The [`Env`] trait serves as the fundamental abstraction for environments. It defines four associated types: -//! `Config`, `Obs`, `Act`, and `Info`. The `Obs` and `Act` types represent concrete implementations of -//! environment observations and actions, respectively. These types must implement the [`Obs`] and [`Act`] traits. -//! Environments implementing [`Env`] generate [`Step`] objects at each interaction step through the -//! [`Env::step()`] method. The [`Info`] type stores additional information from each agent-environment interaction, -//! which may be empty (implemented as a zero-sized struct). The `Config` type represents environment configurations -//! and is used during environment construction. -//! -//! # Policy -//! -//! The [`Policy`] trait represents a decision-making policy. The [`Policy::sample()`] method takes an -//! `E::Obs` and generates an `E::Act`. Policies can be either probabilistic or deterministic, depending on the -//! implementation. -//! -//! # Agent -//! -//! In this crate, an [`Agent`] is defined as a trainable [`Policy`]. -//! Agents operate in either training or evaluation mode. During training, the agent's policy may be probabilistic -//! to facilitate exploration, while in evaluation mode, it typically becomes deterministic. -//! -//! The [`Agent::opt()`] method executes a single optimization step. The specific implementation of an optimization -//! step varies between agents and may include multiple stochastic gradient descent steps. Training samples are -//! obtained from the [`ReplayBuffer`]. -//! -//! This trait also provides methods for saving and loading trained policy parameters to and from a directory. -//! -//! # Batch -//! -//! The [`TransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`. -//! This trait is used for training [`Agent`]s using reinforcement learning algorithms. -//! -//! # Replay Buffer and Experience Buffer -//! -//! The [`ReplayBuffer`] trait provides an abstraction for replay buffers. Its associated type -//! [`ReplayBuffer::Batch`] represents samples used for training [`Agent`]s. Agents must implement the -//! [`Agent::opt()`] method, where [`ReplayBuffer::Batch`] must have appropriate type or trait bounds -//! for training the agent. -//! -//! While [`ReplayBuffer`] focuses on generating training batches, the [`ExperienceBuffer`] trait -//! handles sample storage. The [`ExperienceBuffer::push()`] method stores samples of type -//! [`ExperienceBuffer::Item`], typically obtained through environment interactions. -//! -//! ## Reference Implementation -//! -//! [`SimpleReplayBuffer`] implements both [`ReplayBuffer`] and [`ExperienceBuffer`]. -//! This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer. -//! Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec`. -//! The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of -//! `(o_t, r_t, a_t, o_t+1)` transitions. -//! -//! # Step Processor -//! -//! The [`StepProcessor`] trait plays a crucial role in the training pipeline by transforming environment -//! interactions into training samples. It processes [`Step`] objects, which contain the current -//! observation, action, reward, and next observation, into a format suitable for the replay buffer. -//! -//! The [`SimpleStepProcessor`] is a concrete implementation that: -//! 1. Maintains the previous observation to construct complete transitions -//! 2. Converts environment-specific observations and actions (`E::Obs` and `E::Act`) into batch-compatible -//! types (`O` and `A`) using the `From` trait -//! 3. Generates [`GenericTransitionBatch`] objects containing the complete transition -//! `(o_t, a_t, o_t+1, r_t, is_terminated, is_truncated)` -//! 4. Handles episode termination by properly resetting the previous observation -//! -//! This processor is essential for implementing temporal difference learning algorithms, as it ensures -//! that transitions are properly formatted and stored in the replay buffer for training. -//! -//! [`SimpleStepProcessor`] can be used with [`SimpleReplayBuffer`]. It converts `E::Obs` and -//! `E::Act` into their respective [`BatchBase`] types and generates [`GenericTransitionBatch`]. This conversion -//! relies on the trait bounds `O: From` and `A: From`. -//! -//! # Trainer -//! -//! The [`Trainer`] manages the training loop and related objects. A [`Trainer`] instance is configured with -//! training parameters such as the maximum number of optimization steps and the directory for saving agent -//! parameters during training. The [`Trainer::train`] method executes online training of an agent in an environment. -//! During the training loop, the agent interacts with the environment to collect samples and perform optimization -//! steps, while simultaneously recording various metrics. -//! -//! # Evaluator -//! -//! The [`Evaluator`] trait is used to evaluate a policy's (`P`) performance in an environment (`E`). -//! An instance of this type is provided to the [`Trainer`] for policy evaluation during training. -//! [`DefaultEvaluator`] serves as the default implementation of [`Evaluator`]. This evaluator -//! runs the policy in the environment for a specified number of episodes. At the start of each episode, -//! the environment is reset using [`Env::reset_with_index()`] to control specific evaluation conditions. -//! -//! [`SimpleReplayBuffer`]: generic_replay_buffer::SimpleReplayBuffer -//! [`SimpleReplayBuffer`]: generic_replay_buffer::SimpleReplayBuffer -//! [`BatchBase`]: generic_replay_buffer::BatchBase -//! [`GenericTransitionBatch`]: generic_replay_buffer::GenericTransitionBatch -//! [`SimpleStepProcessor`]: generic_replay_buffer::SimpleStepProcessor -//! [`SimpleStepProcessor`]: generic_replay_buffer::SimpleStepProcessor pub mod dummy; pub mod error; mod evaluator; diff --git a/border-generic-replay-buffer/README.md b/border-generic-replay-buffer/README.md new file mode 100644 index 00000000..edebbb8b --- /dev/null +++ b/border-generic-replay-buffer/README.md @@ -0,0 +1,63 @@ +Generic implementation of replay buffers for reinforcement learning. + +This module provides a flexible implementation of replay buffers +that can handle arbitrary observation and action types. It supports both +standard experience replay and prioritized experience replay (PER). + +# Key Components + +- [`GenericReplayBuffer`]: A generic replay buffer implementation +- [`GenericTransitionBatch`]: A generic batch structure for transitions +- [`SimpleStepProcessor`]: A processor for converting environment steps to transitions +- [`PerConfig`]: Configuration for prioritized experience replay + +# Features + +- Generic type support for observations and actions +- Efficient batch processing +- Prioritized experience replay with importance sampling +- Configurable weight normalization +- Step processing for non-vectorized environments + +# [`BatchBase`] + +# [`GenericTransitionBatch`] + +The [`TransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`. +This trait is used for training [`Agent`]s using reinforcement learning algorithms. + +# [`GenericReplayBuffer`] + +[`GenericReplayBuffer`] implements both [`ReplayBuffer`] and [`ExperienceBuffer`]. +This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer. +Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec`. +The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of +`(o_t, r_t, a_t, o_t+1)` transitions. + +# [`SimpleStepProcessor`] + +The [`SimpleStepProcessor`] is a concrete implementation that: +1. Maintains the previous observation to construct complete transitions +2. Converts environment-specific observations and actions (`E::Obs` and `E::Act`) into batch-compatible + types (`O` and `A`) using the `From` trait +3. Generates [`GenericTransitionBatch`] objects containing the complete transition + `(o_t, a_t, o_t+1, r_t, is_terminated, is_truncated)` +4. Handles episode termination by properly resetting the previous observation + +This processor is essential for implementing temporal difference learning algorithms, as it ensures +that transitions are properly formatted and stored in the replay buffer for training. + +[`SimpleStepProcessor`] can be used with [`GenericReplayBuffer`]. It converts `E::Obs` and +`E::Act` into their respective [`BatchBase`] types and generates [`GenericTransitionBatch`]. This conversion +relies on the trait bounds `O: From` and `A: From`. + +[`GenericReplayBuffer`]: crate::GenericReplayBuffer +[`GenericReplayBuffer`]: crate::GenericReplayBuffer +[`BatchBase`]: crate::BatchBase +[`GenericTransitionBatch`]: crate::GenericTransitionBatch +[`SimpleStepProcessor`]: crate::SimpleStepProcessor +[`SimpleStepProcessor`]: crate::SimpleStepProcessor +[`BatchBase`]: crate::BatchBase +[`ReplayBuffer`]: border_core::ReplayBuffer +[`ExperienceBuffer`]: border_core::ExperienceBuffer +[`Agent`]: border_core::Agent diff --git a/border-generic-replay-buffer/src/lib.rs b/border-generic-replay-buffer/src/lib.rs index cdd7cf68..9dba02f2 100644 --- a/border-generic-replay-buffer/src/lib.rs +++ b/border-generic-replay-buffer/src/lib.rs @@ -1,24 +1,4 @@ -//! Generic implementation of replay buffers for reinforcement learning. -//! -//! This module provides a flexible implementation of replay buffers -//! that can handle arbitrary observation and action types. It supports both -//! standard experience replay and prioritized experience replay (PER). -//! -//! # Key Components -//! -//! - [`GenericReplayBuffer`]: A generic replay buffer implementation -//! - [`GenericTransitionBatch`]: A generic batch structure for transitions -//! - [`SimpleStepProcessor`]: A processor for converting environment steps to transitions -//! - [`PerConfig`]: Configuration for prioritized experience replay -//! -//! # Features -//! -//! - Generic type support for observations and actions -//! - Efficient batch processing -//! - Prioritized experience replay with importance sampling -//! - Configurable weight normalization -//! - Step processing for non-vectorized environments - +#![doc = include_str!("../README.md")] mod batch; mod config; mod iw_scheduler; diff --git a/border-minari/README.md b/border-minari/README.md index e69de29b..181c7ab4 100644 --- a/border-minari/README.md +++ b/border-minari/README.md @@ -0,0 +1,76 @@ +A wrapper for [Minari](https://minari.farama.org) environments. + +This crate provides a Rust interface for Minari datasets, which are collections of offline reinforcement learning data. +It allows users to load and interact with Minari datasets in a way that is compatible with the Border framework. + +# Features + +- **Dataset Loading**: Load Minari datasets from disk or from the Minari registry. +- **Environment Interaction**: Interact with the loaded datasets using the Border environment interface. +- **Data Access**: Access observations, actions, rewards, and other data from the datasets. + +# Example + +The following example demonstrates how to: +1. Load a D4RL Kitchen dataset +2. Create a replay buffer from a specific episode +3. Recover the environment state +4. Replay the actions from the dataset + +This is particularly useful for: +- Analyzing expert demonstrations +- Testing environment behavior +- Validating dataset quality +- Reproducing recorded trajectories + +```no_run +# use anyhow::Result; +use border_core::Env; +use border_minari::{d4rl::kitchen::ndarray::KitchenConverter, MinariDataset}; +# use numpy::convert; +# use std::num; + +fn main() -> Result<()> { + // Load the D4RL Kitchen dataset + let dataset = MinariDataset::load_dataset("D4RL/kitchen/complete-v1", true)?; + + // Create a converter for handling observation and action types + let mut converter = KitchenConverter {}; + + // Create a replay buffer containing only the sixth episode + let replay_buffer = dataset.create_replay_buffer(&mut converter, Some(vec![5]))?; + + // Recover the environment state from the dataset + // The 'false' parameter indicates not to use the initial state + // 'human' indicates the agent type + let mut env = dataset.recover_environment(converter, false, "human")?; + + // Get the sequence of actions from the replay buffer + let actions = replay_buffer.whole_actions(); + + // Reset the environment and replay the actions + env.reset(None)?; + for ix in 0..actions.action.shape()[0] { + let act = actions.get(ix); + let _ = env.step(&act); + } + + Ok(()) +} +``` + +The example uses the following key components: +- [`KitchenConverter`]: Handles conversion between Python and Rust types for the Kitchen environment +- [`MinariDataset`]: Manages the dataset and provides methods for data access +- [`Env`]: The Border environment interface for interaction + +[`KitchenConverter`]: crate::d4rl::kitchen::ndarray::KitchenConverter +[`MinariDataset`]: crate::MinariDataset +[`Env`]: border_core::Env + +# Integration with Border + +This crate implements the [`Env`] trait from `border-core`, making it compatible with other Border components +such as agents, policies, and trainers. It can be used in both online and offline reinforcement learning scenarios. + +[`Env`]: border_core::Env diff --git a/border-minari/src/lib.rs b/border-minari/src/lib.rs index ffec0690..3d3cbbaa 100644 --- a/border-minari/src/lib.rs +++ b/border-minari/src/lib.rs @@ -1,80 +1,4 @@ -//! A wrapper for [Minari](https://minari.farama.org) environments. -//! -//! This crate provides a Rust interface for Minari datasets, which are collections of offline reinforcement learning data. -//! It allows users to load and interact with Minari datasets in a way that is compatible with the Border framework. -//! -//! # Features -//! -//! - **Dataset Loading**: Load Minari datasets from disk or from the Minari registry. -//! - **Environment Interaction**: Interact with the loaded datasets using the Border environment interface. -//! - **Data Access**: Access observations, actions, rewards, and other data from the datasets. -//! -//! # Example -//! -//! The following example demonstrates how to: -//! 1. Load a D4RL Kitchen dataset -//! 2. Create a replay buffer from a specific episode -//! 3. Recover the environment state -//! 4. Replay the actions from the dataset -//! -//! This is particularly useful for: -//! - Analyzing expert demonstrations -//! - Testing environment behavior -//! - Validating dataset quality -//! - Reproducing recorded trajectories -//! -//! ```no_run -//! # use anyhow::Result; -//! use border_core::Env; -//! use border_minari::{d4rl::kitchen::ndarray::KitchenConverter, MinariDataset}; -//! # use numpy::convert; -//! # use std::num; -//! -//! fn main() -> Result<()> { -//! // Load the D4RL Kitchen dataset -//! let dataset = MinariDataset::load_dataset("D4RL/kitchen/complete-v1", true)?; -//! -//! // Create a converter for handling observation and action types -//! let mut converter = KitchenConverter {}; -//! -//! // Create a replay buffer containing only the sixth episode -//! let replay_buffer = dataset.create_replay_buffer(&mut converter, Some(vec![5]))?; -//! -//! // Recover the environment state from the dataset -//! // The 'false' parameter indicates not to use the initial state -//! // 'human' indicates the agent type -//! let mut env = dataset.recover_environment(converter, false, "human")?; -//! -//! // Get the sequence of actions from the replay buffer -//! let actions = replay_buffer.whole_actions(); -//! -//! // Reset the environment and replay the actions -//! env.reset(None)?; -//! for ix in 0..actions.action.shape()[0] { -//! let act = actions.get(ix); -//! let _ = env.step(&act); -//! } -//! -//! Ok(()) -//! } -//! ``` -//! -//! The example uses the following key components: -//! - [`KitchenConverter`]: Handles conversion between Python and Rust types for the Kitchen environment -//! - [`MinariDataset`]: Manages the dataset and provides methods for data access -//! - [`Env`]: The Border environment interface for interaction -//! -//! [`KitchenConverter`]: crate::d4rl::kitchen::ndarray::KitchenConverter -//! [`MinariDataset`]: crate::MinariDataset -//! [`Env`]: border_core::Env -//! -//! # Integration with Border -//! -//! This crate implements the [`Env`] trait from `border-core`, making it compatible with other Border components -//! such as agents, policies, and trainers. It can be used in both online and offline reinforcement learning scenarios. -//! -//! [`Env`]: border_core::Env - +#![doc = include_str!("../README.md")] mod converter; pub mod d4rl; mod dataset; diff --git a/border-mlflow-tracking/README.md b/border-mlflow-tracking/README.md index dab7cc1d..62b737c3 100644 --- a/border-mlflow-tracking/README.md +++ b/border-mlflow-tracking/README.md @@ -1,16 +1,25 @@ -Support [MLflow](https://mlflow.org) tracking to manage experiments. +A logger for border-core crate. -Before running the program using this crate, run a tracking server with the following command: +This crate is based on [MLflow](https://mlflow.org) tracking. + +# Setup + +To use this crate, you need to start an MLflow tracking server first. You can do this by running: ```bash mlflow server --host 127.0.0.1 --port 8080 ``` -Then, training configurations and metrices can be logged to the tracking server. -The following code provides an example. Nested configuration parameters will be flattened, +Before running the program using this crate, you need to set the `MLFLOW_DEFAULT_ARTIFACT_ROOT` +environment variable to specify where model parameters and artifacts will be saved during training. +Typically, you should set this to the `mlruns` directory of your MLflow installation. + +# Example + +The following code is an example. Nested configuration parameters will be flattened, logged like `hyper_params.param1`, `hyper_params.param2`. -```rust +```no_run use anyhow::Result; use border_core::record::{Record, RecordValue, Recorder}; use border_mlflow_tracking::MlflowTrackingClient; @@ -66,7 +75,8 @@ fn main() -> Result<()> { }; // Set experiment for runs - let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; + let client = MlflowTrackingClient::new("http://localhost:8080") + .set_experiment("Default")?; // Create recorders for logging let mut recorder_run1 = client.create_recorder("")?; @@ -103,3 +113,11 @@ fn main() -> Result<()> { Ok(()) } ``` + +## Save model parameters during training + +[`MlflowTrackingClient`] relies on the `MLFLOW_DEFAULT_ARTIFACT_ROOT` environment variable +to locate where model parameters are saved during training. Note that this environment variable +should be set for the program using this crate, not for the tracking server program. +Currently, only saving to the local file system is supported. + diff --git a/border-mlflow-tracking/src/lib.rs b/border-mlflow-tracking/src/lib.rs index 698e181f..bf119554 100644 --- a/border-mlflow-tracking/src/lib.rs +++ b/border-mlflow-tracking/src/lib.rs @@ -1,126 +1,4 @@ -//! A logger for border-core crate. -//! -//! This crate is based on [MLflow](https://mlflow.org) tracking. -//! -//! # Setup -//! -//! To use this crate, you need to start an MLflow tracking server first. You can do this by running: -//! -//! ```bash -//! mlflow server --host 127.0.0.1 --port 8080 -//! ``` -//! -//! Before running the program using this crate, you need to set the `MLFLOW_DEFAULT_ARTIFACT_ROOT` -//! environment variable to specify where model parameters and artifacts will be saved during training. -//! Typically, you should set this to the `mlruns` directory of your MLflow installation. -//! -//! # Example -//! -//! The following code is an example. Nested configuration parameters will be flattened, -//! logged like `hyper_params.param1`, `hyper_params.param2`. -//! -//! ```no_run -//! use anyhow::Result; -//! use border_core::record::{Record, RecordValue, Recorder}; -//! use border_mlflow_tracking::MlflowTrackingClient; -//! use serde::Serialize; -//! -//! // Nested Configuration struct -//! #[derive(Debug, Serialize)] -//! struct Config { -//! env_params: String, -//! hyper_params: HyperParameters, -//! } -//! -//! #[derive(Debug, Serialize)] -//! struct HyperParameters { -//! param1: i64, -//! param2: Param2, -//! param3: Param3, -//! } -//! -//! #[derive(Debug, Serialize)] -//! enum Param2 { -//! Variant1, -//! Variant2(f32), -//! } -//! -//! #[derive(Debug, Serialize)] -//! struct Param3 { -//! dataset_name: String, -//! } -//! -//! fn main() -> Result<()> { -//! env_logger::init(); -//! -//! let config1 = Config { -//! env_params: "env1".to_string(), -//! hyper_params: HyperParameters { -//! param1: 0, -//! param2: Param2::Variant1, -//! param3: Param3 { -//! dataset_name: "a".to_string(), -//! }, -//! }, -//! }; -//! let config2 = Config { -//! env_params: "env2".to_string(), -//! hyper_params: HyperParameters { -//! param1: 0, -//! param2: Param2::Variant2(3.0), -//! param3: Param3 { -//! dataset_name: "a".to_string(), -//! }, -//! }, -//! }; -//! -//! // Set experiment for runs -//! let client = MlflowTrackingClient::new("http://localhost:8080") -//! .set_experiment("Default")?; -//! -//! // Create recorders for logging -//! let mut recorder_run1 = client.create_recorder("")?; -//! let mut recorder_run2 = client.create_recorder("")?; -//! recorder_run1.log_params(&config1)?; -//! recorder_run2.log_params(&config2)?; -//! -//! // Logging while training -//! for opt_steps in 0..100 { -//! let opt_steps = opt_steps as f32; -//! -//! // Create a record -//! let mut record = Record::empty(); -//! record.insert("opt_steps", RecordValue::Scalar(opt_steps)); -//! record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); -//! -//! // Log metrices in the record -//! recorder_run1.write(record); -//! } -//! -//! // Logging while training -//! for opt_steps in 0..100 { -//! let opt_steps = opt_steps as f32; -//! -//! // Create a record -//! let mut record = Record::empty(); -//! record.insert("opt_steps", RecordValue::Scalar(opt_steps)); -//! record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); -//! -//! // Log metrices in the record -//! recorder_run2.write(record); -//! } -//! -//! Ok(()) -//! } -//! ``` -//! -//! ## Save model parameters during training -//! -//! [`MlflowTrackingClient`] relies on the `MLFLOW_DEFAULT_ARTIFACT_ROOT` environment variable -//! to locate where model parameters are saved during training. Note that this environment variable -//! should be set for the program using this crate, not for the tracking server program. -//! Currently, only saving to the local file system is supported. -//! +#![doc = include_str!("../README.md")] mod client; mod experiment; mod recorder; diff --git a/border-policy-no-backend/src/lib.rs b/border-policy-no-backend/src/lib.rs index 93053528..a4a49260 100644 --- a/border-policy-no-backend/src/lib.rs +++ b/border-policy-no-backend/src/lib.rs @@ -1,4 +1,4 @@ -//! Policy with no backend. +#![doc = include_str!("../README.md")] mod mat; mod mlp; diff --git a/border-py-gym-env/README.md b/border-py-gym-env/README.md index e69de29b..085e6e7f 100644 --- a/border-py-gym-env/README.md +++ b/border-py-gym-env/README.md @@ -0,0 +1,44 @@ +A wrapper of [Gymnasium](https://gymnasium.farama.org) environments on Python. + +[`GymEnv`] is a wrapper of [Gymnasium](https://gymnasium.farama.org) based on [`PyO3`](https://github.com/PyO3/pyo3). +It has been tested on some of [classic control](https://gymnasium.farama.org/environments/classic_control/) and +[Gymnasium-Robotics](https://robotics.farama.org) environments. + +In order to bridge Python and Rust, we need to convert Python objects to Rust objects and vice versa. +This crate provides the [`GymEnvConverter`] trait to handle these conversions. + +# Type Conversion + +The [`GymEnvConverter`] trait provides a unified interface for converting between Python and Rust types: + +* `filt_obs`: Converts Python observations to Rust types +* `filt_act`: Converts Rust actions to Python types + +# Implementations + +This crate provides several implementations of [`GymEnvConverter`]: + +* [`ndarray::NdarrayConverter`]: Handles conversions for environments using ndarray types +* [`candle::CandleConverter`]: Handles conversions for environments using Candle tensor types + (requires `candle` feature flag) +* [`tch::TchConverter`]: Handles conversions for environments using Tch tensor types + (requires `tch` feature flag) + +To use Candle or Tch converters, enable the corresponding feature in your `Cargo.toml`: + +```toml +[dependencies] +border-py-gym-env = { version = "0.1.0", features = ["candle"] } # For Candle support +# or +border-py-gym-env = { version = "0.1.0", features = ["tch"] } # For Tch support +``` + +Each implementation supports different types of observations and actions: + +* Array observations (e.g., CartPole) +* Dictionary observations (e.g., FetchPickAndPlace) +* Discrete actions (e.g., CartPole) +* Continuous actions (e.g., Pendulum) + +[`Policy`]: border_core::Policy +[`ArrayD`]: https://docs.rs/ndarray/0.15.1/ndarray/type.ArrayD.html diff --git a/border-py-gym-env/src/lib.rs b/border-py-gym-env/src/lib.rs index bd34abb0..9965bb4e 100644 --- a/border-py-gym-env/src/lib.rs +++ b/border-py-gym-env/src/lib.rs @@ -1,48 +1,6 @@ #![allow(rustdoc::broken_intra_doc_links)] -//! A wrapper of [Gymnasium](https://gymnasium.farama.org) environments on Python. -//! -//! [`GymEnv`] is a wrapper of [Gymnasium](https://gymnasium.farama.org) based on [`PyO3`](https://github.com/PyO3/pyo3). -//! It has been tested on some of [classic control](https://gymnasium.farama.org/environments/classic_control/) and -//! [Gymnasium-Robotics](https://robotics.farama.org) environments. -//! -//! In order to bridge Python and Rust, we need to convert Python objects to Rust objects and vice versa. -//! This crate provides the [`GymEnvConverter`] trait to handle these conversions. -//! -//! # Type Conversion -//! -//! The [`GymEnvConverter`] trait provides a unified interface for converting between Python and Rust types: -//! -//! * `filt_obs`: Converts Python observations to Rust types -//! * `filt_act`: Converts Rust actions to Python types -//! -//! # Implementations -//! -//! This crate provides several implementations of [`GymEnvConverter`]: -//! -//! * [`ndarray::NdarrayConverter`]: Handles conversions for environments using ndarray types -//! * [`candle::CandleConverter`]: Handles conversions for environments using Candle tensor types -//! (requires `candle` feature flag) -//! * [`tch::TchConverter`]: Handles conversions for environments using Tch tensor types -//! (requires `tch` feature flag) -//! -//! To use Candle or Tch converters, enable the corresponding feature in your `Cargo.toml`: -//! -//! ```toml -//! [dependencies] -//! border-py-gym-env = { version = "0.1.0", features = ["candle"] } # For Candle support -//! # or -//! border-py-gym-env = { version = "0.1.0", features = ["tch"] } # For Tch support -//! ``` -//! -//! Each implementation supports different types of observations and actions: -//! -//! * Array observations (e.g., CartPole) -//! * Dictionary observations (e.g., FetchPickAndPlace) -//! * Discrete actions (e.g., CartPole) -//! * Continuous actions (e.g., Pendulum) -//! -//! [`Policy`]: border_core::Policy -//! [`ArrayD`]: https://docs.rs/ndarray/0.15.1/ndarray/type.ArrayD.html +#![doc = include_str!("../README.md")] + mod base; #[cfg(feature = "candle")] pub mod candle; From 54d4ac9744702442a825c9425fcab9cb3de8f710 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 10 Aug 2025 14:24:20 +0000 Subject: [PATCH 09/23] Tweak --- border-generic-replay-buffer/README.md | 32 ++++++++++++++++++-------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/border-generic-replay-buffer/README.md b/border-generic-replay-buffer/README.md index edebbb8b..a0a204e5 100644 --- a/border-generic-replay-buffer/README.md +++ b/border-generic-replay-buffer/README.md @@ -21,22 +21,21 @@ standard experience replay and prioritized experience replay (PER). # [`BatchBase`] -# [`GenericTransitionBatch`] +The [`BatchBase`] trait represents batches of observations and actions, serving dual purposes in the reinforcement learning pipeline: -The [`TransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`. -This trait is used for training [`Agent`]s using reinforcement learning algorithms. +1. **Training Batch Component**: Forms the building blocks for training batches used by agents during optimization +2. **Storage Container**: Acts as the internal storage mechanism within [`GenericReplayBuffer`] for efficiently managing observation and action data -# [`GenericReplayBuffer`] +# [`GenericTransitionBatch`] -[`GenericReplayBuffer`] implements both [`ReplayBuffer`] and [`ExperienceBuffer`]. -This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer. -Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec`. -The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of -`(o_t, r_t, a_t, o_t+1)` transitions. +The [`GenericTransitionBatch`] trait represents a batch of transitions in the form `(o_t, r_t, a_t, o_t+1)`. +This is composed of structs that implement the [`BatchBase`] trait and +used for training [`Agent`]s using reinforcement learning algorithms. # [`SimpleStepProcessor`] -The [`SimpleStepProcessor`] is a concrete implementation that: +The [`SimpleStepProcessor`] is an implementation of [`StepProcessor`] that: + 1. Maintains the previous observation to construct complete transitions 2. Converts environment-specific observations and actions (`E::Obs` and `E::Act`) into batch-compatible types (`O` and `A`) using the `From` trait @@ -51,6 +50,18 @@ that transitions are properly formatted and stored in the replay buffer for trai `E::Act` into their respective [`BatchBase`] types and generates [`GenericTransitionBatch`]. This conversion relies on the trait bounds `O: From` and `A: From`. +# [`GenericReplayBuffer`] + +[`GenericReplayBuffer`] implements both [`ReplayBuffer`] and [`ExperienceBuffer`]. +This type takes two parameters, `O` and `A`, representing observation and action types in the replay buffer. +Both `O` and `A` must implement [`BatchBase`], which provides sample storage functionality similar to `Vec`. +The associated types `Item` and `Batch` are both [`GenericTransitionBatch`], representing sets of +`(o_t, r_t, a_t, o_t+1)` transitions. + +## Type Compatibility + +Typically, the associated types `SimpleStepProcessor::Output` and `GenericReplayBuffer::Item` need to be matched. This ensures that the step data from environment interactions matches the samples expected by the replay buffer, guaranteeing type compatibility throughout the reinforcement learning pipeline. + [`GenericReplayBuffer`]: crate::GenericReplayBuffer [`GenericReplayBuffer`]: crate::GenericReplayBuffer [`BatchBase`]: crate::BatchBase @@ -61,3 +72,4 @@ relies on the trait bounds `O: From` and `A: From`. [`ReplayBuffer`]: border_core::ReplayBuffer [`ExperienceBuffer`]: border_core::ExperienceBuffer [`Agent`]: border_core::Agent +[`StepProcessor`]: border_core::StepProcessor From 97c9eefafa29a7caf7f0b80cc811952f640cb6c5 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 11 Aug 2025 01:46:45 +0000 Subject: [PATCH 10/23] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef731086..54d5079e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## v0.0.9 (2025-??-??) +### Changed + +* Separate the generic replaybuffer into a separate crate (`border-generic-replay-buffer`). + ## v0.0.8 (2025-05-17) ### Added From 9d340163e441b9225ec3117bd65dd91c44d0d263 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 11 Aug 2025 02:52:08 +0000 Subject: [PATCH 11/23] Simpify example of converting mlp model --- border-core/Cargo.toml | 4 + border-core/src/dummy.rs | 48 +++---- examples/gym/convert_policy/Cargo.toml | 2 +- examples/gym/convert_policy/src/main.rs | 179 ++---------------------- 4 files changed, 42 insertions(+), 191 deletions(-) diff --git a/border-core/Cargo.toml b/border-core/Cargo.toml index f4c59d2d..78ba8ff4 100644 --- a/border-core/Cargo.toml +++ b/border-core/Cargo.toml @@ -20,8 +20,12 @@ aquamarine = { workspace = true } fastrand = { workspace = true } segment-tree = { workspace = true } xxhash-rust = { workspace = true } +tch = { workspace = true, optional = true } # Consider to replace with fastrand rand = { workspace = true } [dev-dependencies] tempdir = { workspace = true } + +[features] +tch = ["dep:tch"] \ No newline at end of file diff --git a/border-core/src/dummy.rs b/border-core/src/dummy.rs index 67b55b17..c0b987d7 100644 --- a/border-core/src/dummy.rs +++ b/border-core/src/dummy.rs @@ -10,12 +10,12 @@ impl crate::Obs for DummyObs { } } -// TODO: Consider to make this work with feature flag tch. -// impl Into for DummyObs { -// fn into(self) -> tch::Tensor { -// unimplemented!(); -// } -// } +#[cfg(feature = "tch")] +impl Into for DummyObs { + fn into(self) -> tch::Tensor { + unimplemented!(); + } +} #[derive(Clone, Debug)] /// Dummy action. @@ -27,30 +27,30 @@ impl crate::Act for DummyAct { } } -// TODO: Consider to make this work with feature flag tch. -// impl Into for DummyAct { -// fn into(self) -> tch::Tensor { -// unimplemented!(); -// } -// } +#[cfg(feature = "tch")] +impl Into for DummyAct { + fn into(self) -> tch::Tensor { + unimplemented!(); + } +} -// TODO: Consider to make this work with feature flag tch. -// impl From for DummyAct { -// fn from(_value: tch::Tensor) -> Self { -// unimplemented!(); -// } -// } +#[cfg(feature = "tch")] +impl From for DummyAct { + fn from(_value: tch::Tensor) -> Self { + unimplemented!(); + } +} #[derive(Clone)] /// Dummy inner batch. pub struct DummyInnerBatch; -// TODO: Consider to make this work with feature flag tch. -// impl Into for DummyInnerBatch { -// fn into(self) -> tch::Tensor { -// unimplemented!(); -// } -// } +#[cfg(feature = "tch")] +impl Into for DummyInnerBatch { + fn into(self) -> tch::Tensor { + unimplemented!(); + } +} /// Dummy batch. pub struct DummyBatch; diff --git a/examples/gym/convert_policy/Cargo.toml b/examples/gym/convert_policy/Cargo.toml index fe632f2f..4beea915 100644 --- a/examples/gym/convert_policy/Cargo.toml +++ b/examples/gym/convert_policy/Cargo.toml @@ -13,7 +13,7 @@ border-policy-no-backend = { version = "0.0.9", path = "../../../border-policy-n "tch", ] } border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } -border-core = { version = "0.0.9", path = "../../../border-core" } +border-core = { version = "0.0.9", path = "../../../border-core", features = ["tch"] } serde = "1.0.194" tch = "0.16.0" bincode = "1.3.3" diff --git a/examples/gym/convert_policy/src/main.rs b/examples/gym/convert_policy/src/main.rs index 97379191..18119ba7 100644 --- a/examples/gym/convert_policy/src/main.rs +++ b/examples/gym/convert_policy/src/main.rs @@ -3,179 +3,26 @@ //! You need to prepare the model parameter files by `sac_pendulum_tch.rs` in advance. //! use anyhow::Result; -use border_core::{Agent, Configurable}; -use border_policy_no_backend::Mlp; +use border_core::{Agent, Configurable, dummy::*}; +use border_policy_no_backend::Mlp as MlpNoBackend; use border_tch_agent::{ - mlp, model::ModelBase, - sac::{ActorConfig, CriticConfig, SacConfig}, + mlp::{Mlp, Mlp2, MlpConfig}, + sac::{ActorConfig, CriticConfig, SacConfig, Sac}, }; use std::{fs, io::Write}; const DIM_OBS: i64 = 3; const DIM_ACT: i64 = 1; -// Dummy types -mod dummy { - use super::mlp::{Mlp, Mlp2}; - use border_tch_agent::sac::Sac as Sac_; +type Sac_ = Sac; - #[derive(Clone, Debug)] - pub struct DummyObs; - - impl border_core::Obs for DummyObs { - fn len(&self) -> usize { - unimplemented!(); - } - } - - impl Into for DummyObs { - fn into(self) -> tch::Tensor { - unimplemented!(); - } - } - - #[derive(Clone, Debug)] - pub struct DummyAct; - - impl border_core::Act for DummyAct { - fn len(&self) -> usize { - unimplemented!(); - } - } - - impl Into for DummyAct { - fn into(self) -> tch::Tensor { - unimplemented!(); - } - } - - impl From for DummyAct { - fn from(_value: tch::Tensor) -> Self { - unimplemented!(); - } - } - - #[derive(Clone)] - pub struct DummyInnerBatch; - - impl Into for DummyInnerBatch { - fn into(self) -> tch::Tensor { - unimplemented!(); - } - } - - pub struct DummyBatch; - - impl border_core::TransitionBatch for DummyBatch { - type ObsBatch = DummyInnerBatch; - type ActBatch = DummyInnerBatch; - - fn len(&self) -> usize { - unimplemented!(); - } - - fn obs(&self) -> &Self::ObsBatch { - unimplemented!(); - } - - fn act(&self) -> &Self::ActBatch { - unimplemented!(); - } - - fn unpack( - self, - ) -> ( - Self::ObsBatch, - Self::ActBatch, - Self::ObsBatch, - Vec, - Vec, - Vec, - Option>, - Option>, - ) { - unimplemented!(); - } - } - - pub struct DummyReplayBuffer; - - impl border_core::ReplayBuffer for DummyReplayBuffer { - type Batch = DummyBatch; - type Config = usize; - - fn batch(&mut self, _size: usize) -> anyhow::Result { - unimplemented!(); - } - - fn build(_config: &Self::Config) -> Self { - unimplemented!(); - } - - fn update_priority(&mut self, _ixs: &Option>, _td_err: &Option>) { - unimplemented!(); - } - } - - #[derive(Clone, Debug)] - pub struct DummyInfo; - - impl border_core::Info for DummyInfo {} - - pub struct DummyEnv; - - impl border_core::Env for DummyEnv { - type Config = usize; - type Act = DummyAct; - type Obs = DummyObs; - type Info = DummyInfo; - - fn build(_config: &Self::Config, _seed: i64) -> anyhow::Result - where - Self: Sized, - { - unimplemented!(); - } - - fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { - unimplemented!(); - } - - fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { - unimplemented!(); - } - - fn step(&mut self, _a: &Self::Act) -> (border_core::Step, border_core::record::Record) - where - Self: Sized, - { - unimplemented!(); - } - - fn step_with_reset( - &mut self, - _a: &Self::Act, - ) -> (border_core::Step, border_core::record::Record) - where - Self: Sized, - { - unimplemented!(); - } - } - - pub type Env = DummyEnv; - pub type Sac = Sac_; -} - -use dummy::Sac; - -fn create_sac_config() -> SacConfig { +fn create_sac_config() -> SacConfig { // Omit learning related parameters let actor_config = ActorConfig::default() .out_dim(DIM_ACT) - .pi_config(mlp::MlpConfig::new(DIM_OBS, vec![64, 64], DIM_ACT, false)); - let critic_config = CriticConfig::default().q_config(mlp::MlpConfig::new( + .pi_config(MlpConfig::new(DIM_OBS, vec![64, 64], DIM_ACT, false)); + let critic_config = CriticConfig::default().q_config(MlpConfig::new( DIM_OBS + DIM_ACT, vec![64, 64], 1, @@ -187,21 +34,21 @@ fn create_sac_config() -> SacConfig { .device(tch::Device::Cpu) } -fn load_sac_model(src_path: &str) -> Result { +fn load_sac_model(src_path: &str) -> Result { let config = create_sac_config(); - let mut sac = Sac::build(config); + let mut sac = Sac_::build(config); sac.load_params(src_path.as_ref())?; Ok(sac) } -fn create_mlp(sac: &Sac) -> Mlp { +fn create_mlp(sac: &Sac_) -> MlpNoBackend { let vs = sac.get_policy_net().get_var_store(); let w_names = ["mlp.al0.weight", "mlp.al1.weight", "ml.weight"]; let b_names = ["mlp.al0.bias", "mlp.al1.bias", "ml.bias"]; - Mlp::from_varstore(vs, &w_names, &b_names) + MlpNoBackend::from_varstore(vs, &w_names, &b_names) } -fn serialize_to_file(mlp: &Mlp, dest_path: &str) -> Result<()> { +fn serialize_to_file(mlp: &MlpNoBackend, dest_path: &str) -> Result<()> { let encoded = bincode::serialize(mlp)?; let mut file = fs::OpenOptions::new() .create(true) From 3d165ce185cd5567bb0ca4072f3f05b4362e58b7 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 11 Aug 2025 02:52:20 +0000 Subject: [PATCH 12/23] Cargo fmt --- border-async-trainer/src/actor_manager/base.rs | 4 +--- border-atari-env/src/util/test.rs | 3 +-- border-minari/src/dataset.rs | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/border-async-trainer/src/actor_manager/base.rs b/border-async-trainer/src/actor_manager/base.rs index f1d6ad20..f24abac6 100644 --- a/border-async-trainer/src/actor_manager/base.rs +++ b/border-async-trainer/src/actor_manager/base.rs @@ -1,9 +1,7 @@ use crate::{ Actor, ActorManagerConfig, ActorStat, PushedItemMessage, ReplayBufferProxyConfig, SyncModel, }; -use border_core::{ - Agent, Configurable, Env, ExperienceBuffer, ReplayBuffer, StepProcessor, -}; +use border_core::{Agent, Configurable, Env, ExperienceBuffer, ReplayBuffer, StepProcessor}; use crossbeam_channel::{bounded, /*unbounded,*/ Receiver, Sender}; use log::info; use std::{ diff --git a/border-atari-env/src/util/test.rs b/border-atari-env/src/util/test.rs index e8b7d5c7..6e284958 100644 --- a/border-atari-env/src/util/test.rs +++ b/border-atari-env/src/util/test.rs @@ -5,8 +5,7 @@ use crate::{ }; use anyhow::Result; use border_core::{ - record::Record, - Agent as Agent_, Configurable, Policy, ReplayBuffer as ReplayBuffer_, + record::Record, Agent as Agent_, Configurable, Policy, ReplayBuffer as ReplayBuffer_, }; use border_generic_replay_buffer::{BatchBase, GenericReplayBuffer}; use serde::Deserialize; diff --git a/border-minari/src/dataset.rs b/border-minari/src/dataset.rs index a60833d6..e979c002 100644 --- a/border-minari/src/dataset.rs +++ b/border-minari/src/dataset.rs @@ -2,7 +2,7 @@ use crate::{util, MinariConverter, MinariEnv}; use anyhow::Result; use border_core::{ExperienceBuffer, ReplayBuffer}; use border_generic_replay_buffer::{ - GenericTransitionBatch, GenericReplayBuffer, GenericReplayBufferConfig, + GenericReplayBuffer, GenericReplayBufferConfig, GenericTransitionBatch, }; use pyo3::{ types::{IntoPyDict, PyIterator}, From acda184c229ba7e5b0c7434dc1fe6f68d3a07533 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 11 Aug 2025 02:58:32 +0000 Subject: [PATCH 13/23] Support candle for dummy env --- border-core/Cargo.toml | 4 +++- border-core/src/dummy.rs | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/border-core/Cargo.toml b/border-core/Cargo.toml index 78ba8ff4..b7bc2973 100644 --- a/border-core/Cargo.toml +++ b/border-core/Cargo.toml @@ -21,6 +21,7 @@ fastrand = { workspace = true } segment-tree = { workspace = true } xxhash-rust = { workspace = true } tch = { workspace = true, optional = true } +candle-core = { workspace = true, optional = true } # Consider to replace with fastrand rand = { workspace = true } @@ -28,4 +29,5 @@ rand = { workspace = true } tempdir = { workspace = true } [features] -tch = ["dep:tch"] \ No newline at end of file +tch = ["dep:tch"] +candle = ["dep:candle-core"] \ No newline at end of file diff --git a/border-core/src/dummy.rs b/border-core/src/dummy.rs index c0b987d7..2cc5e686 100644 --- a/border-core/src/dummy.rs +++ b/border-core/src/dummy.rs @@ -17,6 +17,13 @@ impl Into for DummyObs { } } +#[cfg(feature = "candle")] +impl Into for DummyObs { + fn into(self) -> candle_core::Tensor { + unimplemented!(); + } +} + #[derive(Clone, Debug)] /// Dummy action. pub struct DummyAct; @@ -41,6 +48,20 @@ impl From for DummyAct { } } +#[cfg(feature = "candle")] +impl Into for DummyAct { + fn into(self) -> candle_core::Tensor { + unimplemented!(); + } +} + +#[cfg(feature = "candle")] +impl From for DummyAct { + fn from(_value: candle_core::Tensor) -> Self { + unimplemented!(); + } +} + #[derive(Clone)] /// Dummy inner batch. pub struct DummyInnerBatch; @@ -52,6 +73,13 @@ impl Into for DummyInnerBatch { } } +#[cfg(feature = "candle")] +impl Into for DummyInnerBatch { + fn into(self) -> candle_core::Tensor { + unimplemented!(); + } +} + /// Dummy batch. pub struct DummyBatch; From b8b30dd0f4ee87c2b37c03e368e297cb71f325ba Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 11 Aug 2025 03:39:30 +0000 Subject: [PATCH 14/23] WIP: tweaks on border-policy-no-backend --- border-policy-no-backend/src/mat.rs | 27 +++++++++++++------------ border-policy-no-backend/src/mat/tch.rs | 13 ++++++++++++ border-policy-no-backend/src/mlp.rs | 17 +--------------- border-policy-no-backend/src/mlp/tch.rs | 18 +++++++++++++++++ examples/gym/convert_policy/Cargo.toml | 2 +- 5 files changed, 47 insertions(+), 30 deletions(-) create mode 100644 border-policy-no-backend/src/mat/tch.rs create mode 100644 border-policy-no-backend/src/mlp/tch.rs diff --git a/border-policy-no-backend/src/mat.rs b/border-policy-no-backend/src/mat.rs index 1e9eb8d0..a1dff391 100644 --- a/border-policy-no-backend/src/mat.rs +++ b/border-policy-no-backend/src/mat.rs @@ -8,19 +8,20 @@ pub struct Mat { } #[cfg(feature = "tch")] -impl From for Mat { - fn from(x: tch::Tensor) -> Self { - let shape: Vec = x.size().iter().map(|e| *e as i32).collect(); - let (n, shape) = match shape.len() { - 1 => (shape[0] as usize, vec![shape[0], 1]), - 2 => ((shape[0] * shape[1]) as usize, shape), - _ => panic!("Invalid matrix size: {:?}", shape), - }; - let mut data: Vec = vec![0f32; n]; - x.f_copy_data(&mut data, n).unwrap(); - Self { data, shape } - } -} +mod tch; +// impl From for Mat { +// fn from(x: tch::Tensor) -> Self { +// let shape: Vec = x.size().iter().map(|e| *e as i32).collect(); +// let (n, shape) = match shape.len() { +// 1 => (shape[0] as usize, vec![shape[0], 1]), +// 2 => ((shape[0] * shape[1]) as usize, shape), +// _ => panic!("Invalid matrix size: {:?}", shape), +// }; +// let mut data: Vec = vec![0f32; n]; +// x.f_copy_data(&mut data, n).unwrap(); +// Self { data, shape } +// } +// } impl Mat { pub fn matmul(&self, x: &Mat) -> Self { diff --git a/border-policy-no-backend/src/mat/tch.rs b/border-policy-no-backend/src/mat/tch.rs new file mode 100644 index 00000000..e379ab9e --- /dev/null +++ b/border-policy-no-backend/src/mat/tch.rs @@ -0,0 +1,13 @@ +impl From for super::Mat { + fn from(x: tch::Tensor) -> Self { + let shape: Vec = x.size().iter().map(|e| *e as i32).collect(); + let (n, shape) = match shape.len() { + 1 => (shape[0] as usize, vec![shape[0], 1]), + 2 => ((shape[0] * shape[1]) as usize, shape), + _ => panic!("Invalid matrix size: {:?}", shape), + }; + let mut data: Vec = vec![0f32; n]; + x.f_copy_data(&mut data, n).unwrap(); + Self { data, shape } + } +} diff --git a/border-policy-no-backend/src/mlp.rs b/border-policy-no-backend/src/mlp.rs index d78f47a3..49c02ff5 100644 --- a/border-policy-no-backend/src/mlp.rs +++ b/border-policy-no-backend/src/mlp.rs @@ -2,7 +2,7 @@ use crate::Mat; use serde::{Deserialize, Serialize}; #[cfg(feature = "tch")] -use tch::nn::VarStore; +mod tch; #[derive(Clone, Debug, Deserialize, Serialize)] /// Multilayer perceptron with ReLU activation function. @@ -26,19 +26,4 @@ impl Mlp { } x.tanh() } - - #[cfg(feature = "tch")] - pub fn from_varstore(vs: &VarStore, w_names: &[&str], b_names: &[&str]) -> Self { - let vars = vs.variables(); - let ws: Vec = w_names - .iter() - .map(|name| vars[&name.to_string()].copy().into()) - .collect(); - let bs: Vec = b_names - .iter() - .map(|name| vars[&name.to_string()].copy().into()) - .collect(); - - Self { ws, bs } - } } diff --git a/border-policy-no-backend/src/mlp/tch.rs b/border-policy-no-backend/src/mlp/tch.rs new file mode 100644 index 00000000..f2bf236f --- /dev/null +++ b/border-policy-no-backend/src/mlp/tch.rs @@ -0,0 +1,18 @@ +use crate::Mat; +use tch::nn::VarStore; + +impl super::Mlp { + pub fn from_varstore(vs: &VarStore, w_names: &[&str], b_names: &[&str]) -> Self { + let vars = vs.variables(); + let ws: Vec = w_names + .iter() + .map(|name| vars[&name.to_string()].copy().into()) + .collect(); + let bs: Vec = b_names + .iter() + .map(|name| vars[&name.to_string()].copy().into()) + .collect(); + + Self { ws, bs } + } +} diff --git a/examples/gym/convert_policy/Cargo.toml b/examples/gym/convert_policy/Cargo.toml index 4beea915..a34d50cc 100644 --- a/examples/gym/convert_policy/Cargo.toml +++ b/examples/gym/convert_policy/Cargo.toml @@ -13,7 +13,7 @@ border-policy-no-backend = { version = "0.0.9", path = "../../../border-policy-n "tch", ] } border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } -border-core = { version = "0.0.9", path = "../../../border-core", features = ["tch"] } +border-core = { version = "0.0.9", path = "../../../border-core", features = ["tch", "candle"] } serde = "1.0.194" tch = "0.16.0" bincode = "1.3.3" From 89f8ab8e808897f8cf4de27948be47e5788261d6 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 11 Aug 2025 07:39:33 +0000 Subject: [PATCH 15/23] Add directory --- examples/gym/convert_policy/model/from_candle/.gitkeep | 0 examples/gym/convert_policy/model/from_tch/.gitkeep | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 examples/gym/convert_policy/model/from_candle/.gitkeep create mode 100644 examples/gym/convert_policy/model/from_tch/.gitkeep diff --git a/examples/gym/convert_policy/model/from_candle/.gitkeep b/examples/gym/convert_policy/model/from_candle/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/examples/gym/convert_policy/model/from_tch/.gitkeep b/examples/gym/convert_policy/model/from_tch/.gitkeep new file mode 100644 index 00000000..e69de29b From 3034928f33a448fe9e44b60704c6d49c4ba162f3 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 11 Aug 2025 20:22:22 +0900 Subject: [PATCH 16/23] Support candle for policy without backend --- border-candle-agent/src/sac/base.rs | 4 ++++ border-candle-agent/src/util/actor.rs | 4 ++++ border-policy-no-backend/Cargo.toml | 6 +++++- border-policy-no-backend/src/mat.rs | 16 +++------------- border-policy-no-backend/src/mat/candle.rs | 16 ++++++++++++++++ border-policy-no-backend/src/mlp.rs | 2 ++ border-policy-no-backend/src/mlp/candle.rs | 17 +++++++++++++++++ 7 files changed, 51 insertions(+), 14 deletions(-) create mode 100644 border-policy-no-backend/src/mat/candle.rs create mode 100644 border-policy-no-backend/src/mlp/candle.rs diff --git a/border-candle-agent/src/sac/base.rs b/border-candle-agent/src/sac/base.rs index 51b2be61..1dd0071f 100644 --- a/border-candle-agent/src/sac/base.rs +++ b/border-candle-agent/src/sac/base.rs @@ -147,6 +147,10 @@ where Ok(record) } + + pub fn get_policy_net(&self) -> &GaussianActor

{ + &self.actor + } } impl Policy for Sac diff --git a/border-candle-agent/src/util/actor.rs b/border-candle-agent/src/util/actor.rs index b6aeb199..111c3e07 100644 --- a/border-candle-agent/src/util/actor.rs +++ b/border-candle-agent/src/util/actor.rs @@ -264,6 +264,10 @@ where Ok(()) } + + pub fn get_var_map(&self) -> &VarMap { + &self.varmap + } } impl

Clone for GaussianActor

diff --git a/border-policy-no-backend/Cargo.toml b/border-policy-no-backend/Cargo.toml index ecef4b24..14fede9c 100644 --- a/border-policy-no-backend/Cargo.toml +++ b/border-policy-no-backend/Cargo.toml @@ -12,15 +12,19 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } border-tch-agent = { version = "0.0.9", path = "../border-tch-agent", optional = true } +border-candle-agent = { version = "0.0.9", path = "../border-candle-agent", optional = true } serde = { workspace = true, features = ["derive"] } log = { workspace = true } anyhow = { workspace = true } tch = { workspace = true, optional = true } +candle-core = { workspace = true, optional = true } +candle-nn = { workspace = true, optional = true } rand = { workspace = true } [dev-dependencies] tempdir = { workspace = true } -tch = { workspace = true } +# tch = { workspace = true } [features] tch = ["border-tch-agent", "dep:tch"] +candle = ["border-candle-agent", "dep:candle-core", "dep:candle-nn"] diff --git a/border-policy-no-backend/src/mat.rs b/border-policy-no-backend/src/mat.rs index a1dff391..b073c8d2 100644 --- a/border-policy-no-backend/src/mat.rs +++ b/border-policy-no-backend/src/mat.rs @@ -9,19 +9,9 @@ pub struct Mat { #[cfg(feature = "tch")] mod tch; -// impl From for Mat { -// fn from(x: tch::Tensor) -> Self { -// let shape: Vec = x.size().iter().map(|e| *e as i32).collect(); -// let (n, shape) = match shape.len() { -// 1 => (shape[0] as usize, vec![shape[0], 1]), -// 2 => ((shape[0] * shape[1]) as usize, shape), -// _ => panic!("Invalid matrix size: {:?}", shape), -// }; -// let mut data: Vec = vec![0f32; n]; -// x.f_copy_data(&mut data, n).unwrap(); -// Self { data, shape } -// } -// } + +#[cfg(feature = "candle")] +mod candle; impl Mat { pub fn matmul(&self, x: &Mat) -> Self { diff --git a/border-policy-no-backend/src/mat/candle.rs b/border-policy-no-backend/src/mat/candle.rs new file mode 100644 index 00000000..d468dbe9 --- /dev/null +++ b/border-policy-no-backend/src/mat/candle.rs @@ -0,0 +1,16 @@ +impl From for super::Mat { + fn from(x: candle_core::Tensor) -> Self { + let shape: Vec = x.dims().iter().map(|e| *e as i32).collect(); + let data = match shape.len() { + 1 => x.to_vec1::().unwrap(), + 2 => x.to_vec2::().unwrap().into_iter().flatten().collect(), + _ => panic!("Invalid matrix size: {:?}", shape), + }; + let shape = match shape.len() { + 1 => vec![shape[0], 1], + 2 => shape, + _ => panic!("Invalid matrix size: {:?}", shape), + }; + Self { data, shape } + } +} diff --git a/border-policy-no-backend/src/mlp.rs b/border-policy-no-backend/src/mlp.rs index 49c02ff5..04d850c8 100644 --- a/border-policy-no-backend/src/mlp.rs +++ b/border-policy-no-backend/src/mlp.rs @@ -1,6 +1,8 @@ use crate::Mat; use serde::{Deserialize, Serialize}; +#[cfg(feature = "candle")] +mod candle; #[cfg(feature = "tch")] mod tch; diff --git a/border-policy-no-backend/src/mlp/candle.rs b/border-policy-no-backend/src/mlp/candle.rs new file mode 100644 index 00000000..6ed4ed87 --- /dev/null +++ b/border-policy-no-backend/src/mlp/candle.rs @@ -0,0 +1,17 @@ +use crate::Mat; +use candle_nn::VarMap; + +impl super::Mlp { + pub fn from_varmap(vm: &VarMap, w_names: &[&str], b_names: &[&str]) -> Self { + let vars = vm.data().lock().unwrap(); + let ws: Vec = w_names + .iter() + .map(|name| vars.get(*name).unwrap().as_tensor().clone().into()) + .collect(); + let bs: Vec = b_names + .iter() + .map(|name| vars.get(*name).unwrap().as_tensor().clone().into()) + .collect(); + Self { ws, bs } + } +} From 2b67f11899793179216abef59063444215c43eca Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 11 Aug 2025 20:22:48 +0900 Subject: [PATCH 17/23] Fix examples --- examples/gym/convert_policy/Cargo.toml | 16 +++-- examples/gym/convert_policy/src/candle.rs | 52 ++++++++++++++++ examples/gym/convert_policy/src/main.rs | 76 ++++++----------------- examples/gym/convert_policy/src/tch.rs | 44 +++++++++++++ examples/gym/pendulum_std/Cargo.toml | 6 +- examples/gym/pendulum_std/src/main.rs | 12 +--- examples/gym/sac_pendulum/src/main.rs | 2 +- 7 files changed, 132 insertions(+), 76 deletions(-) create mode 100644 examples/gym/convert_policy/src/candle.rs create mode 100644 examples/gym/convert_policy/src/tch.rs diff --git a/examples/gym/convert_policy/Cargo.toml b/examples/gym/convert_policy/Cargo.toml index a34d50cc..22a48fab 100644 --- a/examples/gym/convert_policy/Cargo.toml +++ b/examples/gym/convert_policy/Cargo.toml @@ -9,14 +9,18 @@ anyhow = "1.0.38" clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" -border-policy-no-backend = { version = "0.0.9", path = "../../../border-policy-no-backend", features = [ - "tch", -] } -border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent" } -border-core = { version = "0.0.9", path = "../../../border-core", features = ["tch", "candle"] } +border-policy-no-backend = { version = "0.0.9", path = "../../../border-policy-no-backend" } +border-tch-agent = { version = "0.0.9", path = "../../../border-tch-agent", optional = true } +border-candle-agent = { version = "0.0.9", path = "../../../border-candle-agent", optional = true } +border-core = { version = "0.0.9", path = "../../../border-core" } serde = "1.0.194" -tch = "0.16.0" +tch = { version = "0.16.0", optional = true } +candle-core = { version = "=0.8.4", optional = true } bincode = "1.3.3" [dev-dependencies] tempdir = "0.3.7" + +[features] +tch = ["border-policy-no-backend/tch", "border-tch-agent", "border-core/tch", "dep:tch"] +candle = ["border-policy-no-backend/candle", "border-candle-agent", "border-core/candle", "dep:candle-core"] diff --git a/examples/gym/convert_policy/src/candle.rs b/examples/gym/convert_policy/src/candle.rs new file mode 100644 index 00000000..28e21787 --- /dev/null +++ b/examples/gym/convert_policy/src/candle.rs @@ -0,0 +1,52 @@ +use anyhow::Result; +use border_candle_agent::{ + mlp::{Mlp, Mlp2, MlpConfig}, + sac::{Sac, SacConfig}, + util::{actor::GaussianActorConfig, critic::MultiCriticConfig}, + Activation, +}; +use border_core::{dummy::*, Agent, Configurable}; +use border_policy_no_backend::Mlp as MlpNoBackend; + +const DIM_OBS: i64 = 3; +const DIM_ACT: i64 = 1; + +type Sac_ = Sac; + +fn create_sac_config() -> SacConfig { + // Omit learning related parameters + let actor_config = GaussianActorConfig::default() + .out_dim(DIM_ACT) + .policy_config(MlpConfig::new( + DIM_OBS, + vec![64, 64], + DIM_ACT, + Activation::None, + )); + let critic_config = MultiCriticConfig::default() + .q_config(MlpConfig::new( + DIM_OBS + DIM_ACT, + vec![64, 64], + 1, + Activation::None, + )) + .n_nets(1); + SacConfig::default() + .actor_config(actor_config) + .critic_config(critic_config) + .device(candle_core::Device::Cpu) +} + +pub fn load_sac_model(src_path: &str) -> Result { + let config = create_sac_config(); + let mut sac = Sac_::build(config); + sac.load_params(src_path.as_ref())?; + Ok(sac) +} + +pub fn create_mlp(sac: &Sac_) -> MlpNoBackend { + let vm = sac.get_policy_net().get_var_map(); + let w_names = ["actor.mlp.ln0.weight", "actor.mlp.ln1.weight", "actor.mean.weight"]; + let b_names = ["actor.mlp.ln0.bias", "actor.mlp.ln1.bias", "actor.mean.bias"]; + MlpNoBackend::from_varmap(vm, &w_names, &b_names) +} diff --git a/examples/gym/convert_policy/src/main.rs b/examples/gym/convert_policy/src/main.rs index 18119ba7..eabbdeeb 100644 --- a/examples/gym/convert_policy/src/main.rs +++ b/examples/gym/convert_policy/src/main.rs @@ -2,51 +2,20 @@ //! //! You need to prepare the model parameter files by `sac_pendulum_tch.rs` in advance. //! -use anyhow::Result; -use border_core::{Agent, Configurable, dummy::*}; -use border_policy_no_backend::Mlp as MlpNoBackend; -use border_tch_agent::{ - model::ModelBase, - mlp::{Mlp, Mlp2, MlpConfig}, - sac::{ActorConfig, CriticConfig, SacConfig, Sac}, -}; -use std::{fs, io::Write}; - -const DIM_OBS: i64 = 3; -const DIM_ACT: i64 = 1; -type Sac_ = Sac; +#[cfg(all(feature = "tch", not(feature = "candle")))] +mod tch; +#[cfg(all(feature = "tch", not(feature = "candle")))] +use tch::{create_mlp, load_sac_model}; -fn create_sac_config() -> SacConfig { - // Omit learning related parameters - let actor_config = ActorConfig::default() - .out_dim(DIM_ACT) - .pi_config(MlpConfig::new(DIM_OBS, vec![64, 64], DIM_ACT, false)); - let critic_config = CriticConfig::default().q_config(MlpConfig::new( - DIM_OBS + DIM_ACT, - vec![64, 64], - 1, - false, - )); - SacConfig::default() - .actor_config(actor_config) - .critic_config(critic_config) - .device(tch::Device::Cpu) -} - -fn load_sac_model(src_path: &str) -> Result { - let config = create_sac_config(); - let mut sac = Sac_::build(config); - sac.load_params(src_path.as_ref())?; - Ok(sac) -} +#[cfg(all(feature = "candle", not(feature = "tch")))] +mod candle; +#[cfg(all(feature = "candle", not(feature = "tch")))] +use candle::{create_mlp, load_sac_model}; -fn create_mlp(sac: &Sac_) -> MlpNoBackend { - let vs = sac.get_policy_net().get_var_store(); - let w_names = ["mlp.al0.weight", "mlp.al1.weight", "ml.weight"]; - let b_names = ["mlp.al0.bias", "mlp.al1.bias", "ml.bias"]; - MlpNoBackend::from_varstore(vs, &w_names, &b_names) -} +use anyhow::Result; +use border_policy_no_backend::Mlp as MlpNoBackend; +use std::{fs, io::Write}; fn serialize_to_file(mlp: &MlpNoBackend, dest_path: &str) -> Result<()> { let encoded = bincode::serialize(mlp)?; @@ -59,8 +28,15 @@ fn serialize_to_file(mlp: &MlpNoBackend, dest_path: &str) -> Result<()> { } fn main() -> Result<()> { - let src_path = "../sac_pendulum_tch/model/best"; - let dest_path = "./model/mlp.bincode"; + #[cfg(all(feature = "tch", not(feature = "candle")))] + let (src_path, dest_path) = { + ("../sac_pendulum_tch/model/best", "./model/from_tch/mlp.bincode") + }; + + #[cfg(all(feature = "candle", not(feature = "tch")))] + let (src_path, dest_path) = { + ("../sac_pendulum/model/best", "./model/from_candle/mlp.bincode") + }; let sac = load_sac_model(src_path)?; let mlp = create_mlp(&sac); @@ -68,15 +44,3 @@ fn main() -> Result<()> { Ok(()) } - -// #[test] -// fn test() -> Result<()> { -// let src_path = "/root/border/border/examples/gym/model/tch/sac_pendulum/best"; -// let dest_path = "/root/border/border/examples/gym/model/edge/sac_pendulum/best/mlp.bincode"; - -// let sac = load_sac_model(src_path)?; -// let mlp = create_mlp(&sac); -// serialize_to_file(&mlp, dest_path)?; - -// Ok(()) -// } diff --git a/examples/gym/convert_policy/src/tch.rs b/examples/gym/convert_policy/src/tch.rs new file mode 100644 index 00000000..81a27fd0 --- /dev/null +++ b/examples/gym/convert_policy/src/tch.rs @@ -0,0 +1,44 @@ +use anyhow::Result; +use border_core::{Agent, Configurable, dummy::*}; +use border_policy_no_backend::Mlp as MlpNoBackend; +use border_tch_agent::{ + model::ModelBase, + mlp::{Mlp, Mlp2, MlpConfig}, + sac::{ActorConfig, CriticConfig, SacConfig, Sac}, +}; + +const DIM_OBS: i64 = 3; +const DIM_ACT: i64 = 1; + +type Sac_ = Sac; + +fn create_sac_config() -> SacConfig { + // Omit learning related parameters + let actor_config = ActorConfig::default() + .out_dim(DIM_ACT) + .pi_config(MlpConfig::new(DIM_OBS, vec![64, 64], DIM_ACT, false)); + let critic_config = CriticConfig::default().q_config(MlpConfig::new( + DIM_OBS + DIM_ACT, + vec![64, 64], + 1, + false, + )); + SacConfig::default() + .actor_config(actor_config) + .critic_config(critic_config) + .device(tch::Device::Cpu) +} + +pub fn load_sac_model(src_path: &str) -> Result { + let config = create_sac_config(); + let mut sac = Sac_::build(config); + sac.load_params(src_path.as_ref())?; + Ok(sac) +} + +pub fn create_mlp(sac: &Sac_) -> MlpNoBackend { + let vs = sac.get_policy_net().get_var_store(); + let w_names = ["mlp.al0.weight", "mlp.al1.weight", "ml.weight"]; + let b_names = ["mlp.al0.bias", "mlp.al1.bias", "ml.bias"]; + MlpNoBackend::from_varstore(vs, &w_names, &b_names) +} diff --git a/examples/gym/pendulum_std/Cargo.toml b/examples/gym/pendulum_std/Cargo.toml index b91ec096..d43c2274 100644 --- a/examples/gym/pendulum_std/Cargo.toml +++ b/examples/gym/pendulum_std/Cargo.toml @@ -9,13 +9,11 @@ anyhow = "1.0.38" clap = { version = "4.5.8", features = ["derive"] } env_logger = "0.8.2" numpy = "0.14.1" -border-policy-no-backend = { version = "0.0.9", path = "../../../border-policy-no-backend", features = [ - "tch", -] } +border-policy-no-backend = { version = "0.0.9", path = "../../../border-policy-no-backend" } border-core = { version = "0.0.9", path = "../../../border-core" } border-py-gym-env = { version = "0.0.9", path = "../../../border-py-gym-env" } serde = "1.0.194" -tch = "0.16.0" +# tch = "0.16.0" bincode = "1.3.3" pyo3 = { version = "=0.14.5", default-features = false } ndarray = "0.15.1" diff --git a/examples/gym/pendulum_std/src/main.rs b/examples/gym/pendulum_std/src/main.rs index a6b26963..b209c5cc 100644 --- a/examples/gym/pendulum_std/src/main.rs +++ b/examples/gym/pendulum_std/src/main.rs @@ -159,15 +159,9 @@ fn eval(path: &str, n_episodes: usize, render: bool) -> Result<()> { fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - let path = "../convert_policy/model/mlp.bincode"; + // let path = "../convert_policy/model/mlp.bincode"; + // let path = "../convert_policy/model/from_tch/mlp.bincode"; + let path = "../convert_policy/model/from_candle/mlp.bincode"; let _ = eval(path, 5, true)?; Ok(()) } - -// #[test] -// fn test_pendulum_edge() -> Result<()> { -// env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); -// let path = "/root/border/examples/convert_policy/model/mlp.bincode"; -// let _ = eval(path, 1, false)?; -// Ok(()) -// } diff --git a/examples/gym/sac_pendulum/src/main.rs b/examples/gym/sac_pendulum/src/main.rs index 7f6bd690..9e1d43f6 100644 --- a/examples/gym/sac_pendulum/src/main.rs +++ b/examples/gym/sac_pendulum/src/main.rs @@ -49,7 +49,7 @@ const EVAL_INTERVAL: usize = 2_000; const REPLAY_BUFFER_CAPACITY: usize = 100_000; const N_EPISODES_PER_EVAL: usize = 5; const ENV_NAME: &str = "Pendulum-v1"; -const MODEL_DIR: &str = "./model/candle/sac_pendulum"; +const MODEL_DIR: &str = "./model"; const MLFLOW_EXPERIMENT_NAME: &str = "Gym"; const MLFLOW_RUN_NAME: &str = "sac-gym-pendulum-v1-candle"; const MLFLOW_TAGS: &[(&str, &str)] = &[("env", "pendulum"), ("algo", "sac"), ("backend", "candle")]; From 5e750937f791e9158c46f1bda45226af4f1d07c9 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sat, 16 Aug 2025 12:13:47 +0000 Subject: [PATCH 18/23] Loose trait bound --- border-generic-replay-buffer/README.md | 2 +- border-generic-replay-buffer/src/step_proc.rs | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/border-generic-replay-buffer/README.md b/border-generic-replay-buffer/README.md index a0a204e5..37fab578 100644 --- a/border-generic-replay-buffer/README.md +++ b/border-generic-replay-buffer/README.md @@ -48,7 +48,7 @@ that transitions are properly formatted and stored in the replay buffer for trai [`SimpleStepProcessor`] can be used with [`GenericReplayBuffer`]. It converts `E::Obs` and `E::Act` into their respective [`BatchBase`] types and generates [`GenericTransitionBatch`]. This conversion -relies on the trait bounds `O: From` and `A: From`. +relies on the trait bounds `E::Obs: Into` and `E::Act: Into`. # [`GenericReplayBuffer`] diff --git a/border-generic-replay-buffer/src/step_proc.rs b/border-generic-replay-buffer/src/step_proc.rs index e6ff6eb8..e8eaf0bb 100644 --- a/border-generic-replay-buffer/src/step_proc.rs +++ b/border-generic-replay-buffer/src/step_proc.rs @@ -32,8 +32,10 @@ impl Default for SimpleStepProcessorConfig { /// # Type Parameters /// /// * `E` - The environment type, must implement `Env` -/// * `O` - The observation batch type, must implement `BatchBase` and `From` -/// * `A` - The action batch type, must implement `BatchBase` and `From` +/// * `O` - The observation batch type, must implement `BatchBase`. +/// `E::Obs` must implement `Into`. +/// * `A` - The action batch type, must implement `BatchBase`. +/// `E::Act` must implement `Into`. pub struct SimpleStepProcessor { /// The previous observation, used to construct transitions. prev_obs: Option, @@ -44,8 +46,10 @@ pub struct SimpleStepProcessor { impl StepProcessor for SimpleStepProcessor where E: Env, - O: BatchBase + From, - A: BatchBase + From, + O: BatchBase, + A: BatchBase, + E::Obs: Into, + E::Act: Into, { type Config = SimpleStepProcessorConfig; type Output = GenericTransitionBatch; From 0d1e77447efb0bb3e25b6c2744eaed952f71c4a5 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sat, 16 Aug 2025 12:24:45 +0000 Subject: [PATCH 19/23] Remove obsolete files --- border-core/src/generic_replay_buffer.rs | 29 -- border-core/src/generic_replay_buffer/base.rs | 427 ------------------ .../base/iw_scheduler.rs | 46 -- .../generic_replay_buffer/base/sum_tree.rs | 217 --------- .../src/generic_replay_buffer/batch.rs | 206 --------- .../src/generic_replay_buffer/config.rs | 294 ------------ .../src/generic_replay_buffer/step_proc.rs | 138 ------ 7 files changed, 1357 deletions(-) delete mode 100644 border-core/src/generic_replay_buffer.rs delete mode 100644 border-core/src/generic_replay_buffer/base.rs delete mode 100644 border-core/src/generic_replay_buffer/base/iw_scheduler.rs delete mode 100644 border-core/src/generic_replay_buffer/base/sum_tree.rs delete mode 100644 border-core/src/generic_replay_buffer/batch.rs delete mode 100644 border-core/src/generic_replay_buffer/config.rs delete mode 100644 border-core/src/generic_replay_buffer/step_proc.rs diff --git a/border-core/src/generic_replay_buffer.rs b/border-core/src/generic_replay_buffer.rs deleted file mode 100644 index 6074161f..00000000 --- a/border-core/src/generic_replay_buffer.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Generic implementation of replay buffers for reinforcement learning. -//! -//! This module provides a flexible and efficient implementation of replay buffers -//! that can handle arbitrary observation and action types. It supports both -//! standard experience replay and prioritized experience replay (PER). -//! -//! # Key Components -//! -//! - [`SimpleReplayBuffer`]: A generic replay buffer implementation -//! - [`GenericTransitionBatch`]: A generic batch structure for transitions -//! - [`SimpleStepProcessor`]: A processor for converting environment steps to transitions -//! - [`PerConfig`]: Configuration for prioritized experience replay -//! -//! # Features -//! -//! - Generic type support for observations and actions -//! - Efficient batch processing -//! - Prioritized experience replay with importance sampling -//! - Configurable weight normalization -//! - Step processing for non-vectorized environments - -mod base; -mod batch; -mod config; -mod step_proc; -pub use base::{IwScheduler, SimpleReplayBuffer, WeightNormalizer}; -pub use batch::{BatchBase, GenericTransitionBatch}; -pub use config::{PerConfig, SimpleReplayBufferConfig}; -pub use step_proc::{SimpleStepProcessor, SimpleStepProcessorConfig}; diff --git a/border-core/src/generic_replay_buffer/base.rs b/border-core/src/generic_replay_buffer/base.rs deleted file mode 100644 index d9dda3dc..00000000 --- a/border-core/src/generic_replay_buffer/base.rs +++ /dev/null @@ -1,427 +0,0 @@ -//! Generic implementation of replay buffers for reinforcement learning. -//! -//! This module provides a generic implementation of replay buffers that can store -//! and sample transitions of arbitrary observation and action types. It supports: -//! - Standard experience replay -//! - Prioritized experience replay (PER) -//! - Importance sampling weights for off-policy learning - -mod iw_scheduler; -mod sum_tree; -use super::{config::PerConfig, BatchBase, GenericTransitionBatch, SimpleReplayBufferConfig}; -use crate::{ExperienceBuffer, ReplayBuffer, TransitionBatch}; -use anyhow::Result; -pub use iw_scheduler::IwScheduler; -use rand::{rngs::StdRng, RngCore, SeedableRng}; -use sum_tree::SumTree; -pub use sum_tree::WeightNormalizer; - -/// State management for Prioritized Experience Replay (PER). -/// -/// This struct maintains the necessary state for PER, including: -/// - A sum tree for efficient priority sampling -/// - An importance weight scheduler for adjusting sample weights -struct PerState { - /// A sum tree data structure for efficient priority sampling. - sum_tree: SumTree, - - /// Scheduler for importance sampling weights. - iw_scheduler: IwScheduler, -} - -impl PerState { - /// Creates a new PER state with the given configuration. - /// - /// # Arguments - /// - /// * `capacity` - Maximum number of transitions to store - /// * `per_config` - Configuration for prioritized experience replay - fn new(capacity: usize, per_config: &PerConfig) -> Self { - Self { - sum_tree: SumTree::new(capacity, per_config.alpha, per_config.normalize), - iw_scheduler: IwScheduler::new( - per_config.beta_0, - per_config.beta_final, - per_config.n_opts_final, - ), - } - } -} - -/// A generic implementation of a replay buffer for reinforcement learning. -/// -/// This buffer can store transitions of arbitrary observation and action types, -/// making it suitable for a wide range of reinforcement learning tasks. It supports: -/// - Standard experience replay -/// - Prioritized experience replay (optional) -/// - Efficient sampling and storage -/// -/// # Type Parameters -/// -/// * `O` - The type of observations, must implement [`BatchBase`] -/// * `A` - The type of actions, must implement [`BatchBase`] -/// -/// # Examples -/// -/// ```ignore -/// let config = SimpleReplayBufferConfig { -/// capacity: 10000, -/// per_config: Some(PerConfig { -/// alpha: 0.6, -/// beta_0: 0.4, -/// beta_final: 1.0, -/// n_opts_final: 100000, -/// normalize: true, -/// }), -/// }; -/// -/// let mut buffer = SimpleReplayBuffer::::build(&config); -/// -/// // Add transitions -/// buffer.push(transition)?; -/// -/// // Sample a batch -/// let batch = buffer.batch(32)?; -/// ``` -pub struct SimpleReplayBuffer -where - O: BatchBase, - A: BatchBase, -{ - /// Maximum number of transitions that can be stored. - capacity: usize, - - /// Current insertion index. - i: usize, - - /// Current number of stored transitions. - size: usize, - - /// Storage for observations. - obs: O, - - /// Storage for actions. - act: A, - - /// Storage for next observations. - next_obs: O, - - /// Storage for rewards. - reward: Vec, - - /// Storage for termination flags. - is_terminated: Vec, - - /// Storage for truncation flags. - is_truncated: Vec, - - /// Random number generator for sampling. - rng: StdRng, - - /// State for prioritized experience replay, if enabled. - per_state: Option, -} - -impl SimpleReplayBuffer -where - O: BatchBase, - A: BatchBase, -{ - /// Pushes rewards into the buffer at the specified index. - /// - /// # Arguments - /// - /// * `i` - Starting index for insertion - /// * `b` - Vector of rewards to insert - #[inline] - fn push_reward(&mut self, i: usize, b: &Vec) { - let mut j = i; - for r in b.iter() { - self.reward[j] = *r; - j += 1; - if j == self.capacity { - j = 0; - } - } - } - - /// Pushes termination flags into the buffer at the specified index. - /// - /// # Arguments - /// - /// * `i` - Starting index for insertion - /// * `b` - Vector of termination flags to insert - #[inline] - fn push_is_terminated(&mut self, i: usize, b: &Vec) { - let mut j = i; - for d in b.iter() { - self.is_terminated[j] = *d; - j += 1; - if j == self.capacity { - j = 0; - } - } - } - - /// Pushes truncation flags into the buffer at the specified index. - /// - /// # Arguments - /// - /// * `i` - Starting index for insertion - /// * `b` - Vector of truncation flags to insert - fn push_is_truncated(&mut self, i: usize, b: &Vec) { - let mut j = i; - for d in b.iter() { - self.is_truncated[j] = *d; - j += 1; - if j == self.capacity { - j = 0; - } - } - } - - /// Samples rewards for the given indices. - /// - /// # Arguments - /// - /// * `ixs` - Indices to sample from - /// - /// # Returns - /// - /// Vector of sampled rewards - fn sample_reward(&self, ixs: &Vec) -> Vec { - ixs.iter().map(|ix| self.reward[*ix]).collect() - } - - /// Samples termination flags for the given indices. - /// - /// # Arguments - /// - /// * `ixs` - Indices to sample from - /// - /// # Returns - /// - /// Vector of sampled termination flags - fn sample_is_terminated(&self, ixs: &Vec) -> Vec { - ixs.iter().map(|ix| self.is_terminated[*ix]).collect() - } - - /// Samples truncation flags for the given indices. - /// - /// # Arguments - /// - /// * `ixs` - Indices to sample from - /// - /// # Returns - /// - /// Vector of sampled truncation flags - fn sample_is_truncated(&self, ixs: &Vec) -> Vec { - ixs.iter().map(|ix| self.is_truncated[*ix]).collect() - } - - /// Sets priorities for newly added samples in prioritized experience replay. - /// - /// # Arguments - /// - /// * `batch_size` - Number of new samples to prioritize - fn set_priority(&mut self, batch_size: usize) { - let sum_tree = &mut self.per_state.as_mut().unwrap().sum_tree; - let max_p = sum_tree.max(); - - for j in 0..batch_size { - let i = (self.i + j) % self.capacity; - sum_tree.add(i, max_p); - } - } - - /// Returns a batch containing all actions in the buffer. - /// - /// # Warning - /// - /// This method should be used with caution on large replay buffers - /// as it may consume significant memory. - pub fn whole_actions(&self) -> A { - let ixs = (0..self.size).collect::>(); - self.act.sample(&ixs) - } - - /// Returns the number of terminated episodes in the buffer. - pub fn num_terminated_flags(&self) -> usize { - self.is_terminated - .iter() - .map(|is_terminated| *is_terminated as usize) - .sum() - } - - /// Returns the number of truncated episodes in the buffer. - pub fn num_truncated_flags(&self) -> usize { - self.is_truncated - .iter() - .map(|is_truncated| *is_truncated as usize) - .sum() - } - - /// Returns the sum of all rewards in the buffer. - pub fn sum_rewards(&self) -> f32 { - self.reward.iter().sum() - } -} - -impl ExperienceBuffer for SimpleReplayBuffer -where - O: BatchBase, - A: BatchBase, -{ - type Item = GenericTransitionBatch; - - /// Returns the current number of transitions in the buffer. - fn len(&self) -> usize { - self.size - } - - /// Adds a new transition to the buffer. - /// - /// # Arguments - /// - /// * `tr` - The transition to add - /// - /// # Returns - /// - /// `Ok(())` if the transition was added successfully - /// - /// # Errors - /// - /// Returns an error if the buffer is full and cannot accept more transitions - fn push(&mut self, tr: Self::Item) -> Result<()> { - let len = tr.len(); // batch size - let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = tr.unpack(); - self.obs.push(self.i, obs); - self.act.push(self.i, act); - self.next_obs.push(self.i, next_obs); - self.push_reward(self.i, &reward); - self.push_is_terminated(self.i, &is_terminated); - self.push_is_truncated(self.i, &is_truncated); - - if self.per_state.is_some() { - self.set_priority(len) - }; - - self.i = (self.i + len) % self.capacity; - self.size += len; - if self.size >= self.capacity { - self.size = self.capacity; - } - - Ok(()) - } -} - -impl ReplayBuffer for SimpleReplayBuffer -where - O: BatchBase, - A: BatchBase, -{ - type Config = SimpleReplayBufferConfig; - type Batch = GenericTransitionBatch; - - /// Creates a new replay buffer with the given configuration. - /// - /// # Arguments - /// - /// * `config` - Configuration for the replay buffer - /// - /// # Returns - /// - /// A new instance of the replay buffer - fn build(config: &Self::Config) -> Self { - let capacity = config.capacity; - let per_state = match &config.per_config { - Some(per_config) => Some(PerState::new(capacity, per_config)), - None => None, - }; - - Self { - capacity, - i: 0, - size: 0, - obs: O::new(capacity), - act: A::new(capacity), - next_obs: O::new(capacity), - reward: vec![0.; capacity], - is_terminated: vec![0; capacity], - is_truncated: vec![0; capacity], - rng: StdRng::seed_from_u64(config.seed as _), - per_state, - } - } - - /// Samples a batch of transitions from the buffer. - /// - /// If prioritized experience replay is enabled, samples are selected - /// according to their priorities. Otherwise, uniform random sampling is used. - /// - /// # Arguments - /// - /// * `size` - Number of transitions to sample - /// - /// # Returns - /// - /// A batch of sampled transitions - /// - /// # Errors - /// - /// Returns an error if: - /// - The buffer is empty - /// - The requested batch size is larger than the buffer size - fn batch(&mut self, size: usize) -> Result { - let (ixs, weight) = if let Some(per_state) = &self.per_state { - let sum_tree = &per_state.sum_tree; - let beta = per_state.iw_scheduler.beta(); - let (ixs, weight) = sum_tree.sample(size, beta); - let ixs = ixs.iter().map(|&ix| ix as usize).collect(); - (ixs, Some(weight)) - } else { - let ixs = (0..size) - // .map(|_| self.rng.usize(..self.size)) - .map(|_| (self.rng.next_u32() as usize) % self.size) - .collect::>(); - let weight = None; - (ixs, weight) - }; - - Ok(Self::Batch { - obs: self.obs.sample(&ixs), - act: self.act.sample(&ixs), - next_obs: self.next_obs.sample(&ixs), - reward: self.sample_reward(&ixs), - is_terminated: self.sample_is_terminated(&ixs), - is_truncated: self.sample_is_truncated(&ixs), - ix_sample: Some(ixs), - weight, - }) - } - - /// Updates the priorities of transitions in the buffer. - /// - /// This method is used in prioritized experience replay to adjust - /// the sampling probabilities based on TD errors. - /// - /// # Arguments - /// - /// * `ixs` - Optional indices of transitions to update - /// * `td_errs` - Optional TD errors for the transitions - fn update_priority(&mut self, ixs: &Option>, td_errs: &Option>) { - if let Some(per_state) = &mut self.per_state { - let ixs = ixs - .as_ref() - .expect("ixs should be Some(_) in update_priority()."); - let td_errs = td_errs - .as_ref() - .expect("td_errs should be Some(_) in update_priority()."); - for (&ix, &td_err) in ixs.iter().zip(td_errs.iter()) { - per_state.sum_tree.update(ix, td_err); - } - per_state.iw_scheduler.add_n_opts(); - } - } -} diff --git a/border-core/src/generic_replay_buffer/base/iw_scheduler.rs b/border-core/src/generic_replay_buffer/base/iw_scheduler.rs deleted file mode 100644 index 13c02ce0..00000000 --- a/border-core/src/generic_replay_buffer/base/iw_scheduler.rs +++ /dev/null @@ -1,46 +0,0 @@ -//! Scheduling the exponent of importance weight for PER. -use serde::{Deserialize, Serialize}; - -/// Scheduler of the exponent of importance weight for PER. -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] -pub struct IwScheduler { - /// Initial value of $\beta$. - pub beta_0: f32, - - /// Final value of $\beta$. - pub beta_final: f32, - - /// Optimization steps when beta reaches its final value. - pub n_opts_final: usize, - - /// Current optimizatioin steps. - pub n_opts: usize, -} - -impl IwScheduler { - /// Creates a scheduler. - pub fn new(beta_0: f32, beta_final: f32, n_opts_final: usize) -> Self { - Self { - beta_0, - beta_final, - n_opts_final, - n_opts: 0, - } - } - - /// Gets the exponents of importance sampling weight. - pub fn beta(&self) -> f32 { - let n_opts = self.n_opts; - if n_opts >= self.n_opts_final { - self.beta_final - } else { - let d = self.beta_final - self.beta_0; - self.beta_0 + d * (n_opts as f32 / self.n_opts_final as f32) - } - } - - /// Add optimization steps for scheduling beta through training. - pub fn add_n_opts(&mut self) { - self.n_opts += 1; - } -} diff --git a/border-core/src/generic_replay_buffer/base/sum_tree.rs b/border-core/src/generic_replay_buffer/base/sum_tree.rs deleted file mode 100644 index 39afaf4f..00000000 --- a/border-core/src/generic_replay_buffer/base/sum_tree.rs +++ /dev/null @@ -1,217 +0,0 @@ -//! Sum tree for prioritized sampling. -//! -//! Code is adapted from and -/// -use segment_tree::{ - ops::{MaxIgnoreNaN, MinIgnoreNaN}, - SegmentPoint, -}; -use serde::{Deserialize, Serialize}; - -#[derive(Copy, Debug, Clone, Deserialize, Serialize, PartialEq)] -/// Specifies how to normalize the importance weights in a prioritized batch. -pub enum WeightNormalizer { - /// Normalize weights by the maximum weight of all samples in the buffer. - All, - /// Normalize weights by the maximum weight of samples in the batch. - Batch, -} - -#[derive(Debug)] -pub struct SumTree { - eps: f32, - alpha: f32, - capacity: usize, - n_samples: usize, - tree: Vec, - min_tree: SegmentPoint, - max_tree: SegmentPoint, - normalize: WeightNormalizer, -} - -impl SumTree { - pub fn new(capacity: usize, alpha: f32, normalize: WeightNormalizer) -> Self { - Self { - eps: 1e-8, - alpha, - capacity, - n_samples: 0, - tree: vec![0f32; 2 * capacity - 1], - min_tree: SegmentPoint::build(vec![f32::MAX; capacity], MinIgnoreNaN), - max_tree: SegmentPoint::build(vec![1e-8f32; capacity], MaxIgnoreNaN), - normalize, - } - } - - fn propagate(&mut self, ix: usize, change: f32) { - let parent = (ix - 1) / 2; - self.tree[parent] += change; - if parent != 0 { - self.propagate(parent, change); - } - } - - fn retrieve(&self, ix: usize, s: f32) -> usize { - let left = 2 * ix + 1; - let right = left + 1; - - if left >= self.tree.len() { - return ix; - } - - if s <= self.tree[left] || self.tree[right] == 0f32 { - return self.retrieve(left, s); - } else { - return self.retrieve(right, s - self.tree[left]); - } - } - - pub fn total(&self) -> f32 { - return self.tree[0]; - } - - pub fn max(&self) -> f32 { - self.max_tree - .query(0, self.max_tree.len()) - .powf(1.0 / self.alpha) - } - - /// Add priority value at `ix`-th element in the sum tree. - /// - /// The alpha-th power of the priority value is taken when addition. - pub fn add(&mut self, ix: usize, p: f32) { - debug_assert!(ix <= self.n_samples); - - self.update(ix, p); - - if self.n_samples < self.capacity { - self.n_samples += 1; - } - } - - /// Update priority value at `ix`-th element in the sum tree. - pub fn update(&mut self, ix: usize, p: f32) { - debug_assert!(ix < self.capacity); - - let p = (p + self.eps).powf(self.alpha); - self.min_tree.modify(ix, p); - self.max_tree.modify(ix, p); - let ix = ix + self.capacity - 1; - let change = p - self.tree[ix]; - if change.is_nan() { - println!("{:?}, {:?}", p, self.tree[ix]); - panic!(); - } - self.tree[ix] = p; - self.propagate(ix, change); - } - - /// Get the maximal index of the sum tree where the sum of priority values is less than `s`. - pub fn get(&self, s: f32) -> usize { - let ix = self.retrieve(0, s); - debug_assert!(ix >= (self.capacity - 1)); - ix + 1 - self.capacity - } - - /// Samples indices for batch and returns normalized weights. - /// - /// The weight is $w_i=\left(N^{-1}P(i)^{-1}\right)^{\beta}$ - /// and it will be normalized by $max_i w_i$. - pub fn sample(&self, batch_size: usize, beta: f32) -> (Vec, Vec) { - let p_sum = &self.total(); - let ps = (0..batch_size) - .map(|_| p_sum * fastrand::f32()) - .collect::>(); - let indices = ps.iter().map(|&p| self.get(p)).collect::>(); - // let indices = (0..batch_size) - // .map(|_| self.get(p_sum * fastrand::f32())) - // .collect::>(); - - let n = self.n_samples as f32 / p_sum; - let ws = indices - .iter() - .map(|ix| self.tree[ix + self.capacity - 1]) - .map(|p| (n * p).powf(-beta)) - .collect::>(); - - // normalizer within all samples - let w_max_inv = match self.normalize { - WeightNormalizer::All => (n * self.min_tree.query(0, self.n_samples)).powf(beta), - WeightNormalizer::Batch => 1f32 / ws.iter().fold(0.0 / 0.0, |m, v| v.max(m)), - }; - let ws = ws.iter().map(|w| w * w_max_inv).collect::>(); - - if p_sum.is_nan() || w_max_inv.is_nan() || ws.iter().sum::().is_nan() { - println!("self.n_samples: {:?}", self.n_samples); - println!("p_sum: {:?}", p_sum); - println!("w_max_inv: {:?}", w_max_inv); - println!("ps: {:?}", ps); - println!("indices: {:?}", indices); - println!("{:?}", ws); - panic!(); - } - - let ixs = indices.iter().map(|&ix| ix as i64).collect(); - - (ixs, ws) - } - - #[allow(dead_code)] - pub fn print_tree(&self) { - let mut nl = 1; - - for i in 0..self.tree.len() { - print!("{} ", self.tree[i]); - if i == 2 * nl - 2 { - println!(); - nl *= 2; - } - } - println!("max = {}", self.max()); - // println!("min = {}", self.min()); - println!("total = {}", self.total()); - } -} - -#[cfg(test)] -mod tests { - use super::{SumTree, WeightNormalizer::Batch}; - - #[test] - fn test_sum_tree_odd() { - let data = vec![0.5f32, 0.2, 0.8, 0.3, 1.1, 2.5, 3.9]; - let mut sum_tree = SumTree::new(8, 1.0, Batch); - for ix in 0..data.len() { - sum_tree.add(ix, data[ix]); - } - sum_tree.print_tree(); - println!(); - - assert_eq!(sum_tree.get(0.0), 0); - assert_eq!(sum_tree.get(0.4), 0); - assert_eq!(sum_tree.get(0.5), 0); - assert_eq!(sum_tree.get(0.6), 1); - assert_eq!(sum_tree.get(1.2), 2); - assert_eq!(sum_tree.get(1.6), 3); - assert_eq!(sum_tree.get(2.0), 4); - assert_eq!(sum_tree.get(2.8), 4); - - sum_tree.update(7, 2.0); - sum_tree.print_tree(); - println!(); - - // let (ixs, ws) = sum_tree.sample(10, 1.0); - // println!("{:?}", ixs); - // println!("{:?}", ws); - // println!(); - - // let n_samples = 1000000; - // let (ixs, _) = sum_tree.sample(n_samples, 1.0); - // debug_assert!(ixs.iter().all(|&ix| ix < data.len() as i64)); - // (0..5).for_each(|ix| { - // let p = data[ix] / sum_tree.total() * (n_samples as f32); - // let n = ixs.iter().filter(|&&e| e == ix as i64).collect::>().len(); - // println!("ix={:?}: {:?} (p={:?})", ix, n, p); - // }) - } -} diff --git a/border-core/src/generic_replay_buffer/batch.rs b/border-core/src/generic_replay_buffer/batch.rs deleted file mode 100644 index 240bed10..00000000 --- a/border-core/src/generic_replay_buffer/batch.rs +++ /dev/null @@ -1,206 +0,0 @@ -//! Generic implementation of transition batches for reinforcement learning. -//! -//! This module provides a generic implementation of transition batches that can handle -//! arbitrary observation and action types. It supports the following features: -//! - Efficient batch processing -//! - Weighting for prioritized experience replay -//! - Transition sampling and management - -use crate::TransitionBatch; - -/// A trait defining basic batch operations. -/// -/// This trait provides fundamental operations for efficiently managing batches of -/// observations and actions. -/// -/// # Type Parameters -/// -/// * `Self` - The batch type, representing batches of observations or actions. -/// -/// # Examples -/// -/// ```ignore -/// struct TensorBatch { -/// data: Vec, -/// shape: Vec, -/// } -/// -/// impl BatchBase for TensorBatch { -/// fn new(capacity: usize) -> Self { -/// Self { -/// data: Vec::with_capacity(capacity), -/// shape: vec![], -/// } -/// } -/// -/// fn push(&mut self, ix: usize, data: Self) { -/// // Data addition logic -/// } -/// -/// fn sample(&self, ixs: &Vec) -> Self { -/// // Sampling logic -/// } -/// } -/// ``` -pub trait BatchBase { - /// Creates a new batch with the specified capacity. - /// - /// # Arguments - /// - /// * `capacity` - Initial capacity of the batch - fn new(capacity: usize) -> Self; - - /// Adds data at the specified index. - /// - /// # Arguments - /// - /// * `ix` - Index where data should be added - /// * `data` - Data to be added - fn push(&mut self, ix: usize, data: Self); - - /// Retrieves samples from the specified indices. - /// - /// # Arguments - /// - /// * `ixs` - List of indices to sample from - /// - /// # Returns - /// - /// A new batch containing the sampled data - fn sample(&self, ixs: &Vec) -> Self; -} - -/// A generic structure representing transitions in reinforcement learning. -/// -/// This structure efficiently manages reinforcement learning transitions -/// (observations, actions, rewards, etc.). It also includes support for -/// prioritized experience replay (PER). -/// -/// # Type Parameters -/// -/// * `O` - Observation type, must implement `BatchBase` -/// * `A` - Action type, must implement `BatchBase` -/// -/// # Examples -/// -/// ```ignore -/// let batch = GenericTransitionBatch::::with_capacity(32); -/// ``` -pub struct GenericTransitionBatch -where - O: BatchBase, - A: BatchBase, -{ - /// Current observations - pub obs: O, - - /// Selected actions - pub act: A, - - /// Next state observations - pub next_obs: O, - - /// Transition rewards - pub reward: Vec, - - /// Episode termination flags - pub is_terminated: Vec, - - /// Episode truncation flags - pub is_truncated: Vec, - - /// Weights for prioritized experience replay - pub weight: Option>, - - /// Indices of sampled transitions - pub ix_sample: Option>, -} - -impl TransitionBatch for GenericTransitionBatch -where - O: BatchBase, - A: BatchBase, -{ - type ObsBatch = O; - type ActBatch = A; - - /// Decomposes the batch into its individual components. - /// - /// # Returns - /// - /// A tuple containing the following elements: - /// 1. Observations - /// 2. Actions - /// 3. Next observations - /// 4. Rewards - /// 5. Termination flags - /// 6. Truncation flags - /// 7. Sample indices - /// 8. Weights - fn unpack( - self, - ) -> ( - Self::ObsBatch, - Self::ActBatch, - Self::ObsBatch, - Vec, - Vec, - Vec, - Option>, - Option>, - ) { - ( - self.obs, - self.act, - self.next_obs, - self.reward, - self.is_terminated, - self.is_truncated, - self.ix_sample, - self.weight, - ) - } - - /// Returns the number of transitions in the batch. - fn len(&self) -> usize { - self.reward.len() - } - - /// Returns a reference to the batch of observations. - fn obs(&self) -> &Self::ObsBatch { - &self.obs - } - - /// Returns a reference to the batch of actions. - fn act(&self) -> &Self::ActBatch { - &self.act - } -} - -impl GenericTransitionBatch -where - O: BatchBase, - A: BatchBase, -{ - /// Creates a new batch with the specified capacity. - /// - /// # Arguments - /// - /// * `capacity` - Initial capacity of the batch - /// - /// # Returns - /// - /// A new `GenericTransitionBatch` instance - pub fn with_capacity(capacity: usize) -> Self { - Self { - obs: O::new(capacity), - act: A::new(capacity), - next_obs: O::new(capacity), - reward: Vec::with_capacity(capacity), - is_terminated: Vec::with_capacity(capacity), - is_truncated: Vec::with_capacity(capacity), - weight: None, - ix_sample: None, - } - } -} diff --git a/border-core/src/generic_replay_buffer/config.rs b/border-core/src/generic_replay_buffer/config.rs deleted file mode 100644 index e3c9b50d..00000000 --- a/border-core/src/generic_replay_buffer/config.rs +++ /dev/null @@ -1,294 +0,0 @@ -//! Configuration for the replay buffer implementation. -//! -//! This module provides configuration structures for the replay buffer, including: -//! - Basic buffer configuration (capacity, seed) -//! - Prioritized Experience Replay (PER) configuration -//! - Serialization and deserialization support - -use super::{WeightNormalizer, WeightNormalizer::All}; -use anyhow::Result; -use serde::{Deserialize, Serialize}; -use std::{ - default::Default, - fs::File, - io::{BufReader, Write}, - path::Path, -}; - -/// Configuration for Prioritized Experience Replay (PER). -/// -/// This structure defines the parameters for prioritized sampling in the replay buffer. -/// It controls how transitions are sampled based on their importance and how -/// importance weights are calculated and normalized. -/// -/// # Fields -/// -/// * `alpha` - Controls the degree of prioritization (0 = uniform sampling) -/// * `beta_0` - Initial value for importance sampling weights -/// * `beta_final` - Final value for importance sampling weights -/// * `n_opts_final` - Number of optimization steps to reach `beta_final` -/// * `normalize` - Method for normalizing importance weights -/// -/// # Examples -/// -/// ```rust -/// use border_core::generic_replay_buffer::{PerConfig, WeightNormalizer}; -/// -/// let config = PerConfig::default() -/// .alpha(0.6) -/// .beta_0(0.4) -/// .beta_final(1.0) -/// .n_opts_final(500_000) -/// .normalize(WeightNormalizer::All); -/// ``` -#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -pub struct PerConfig { - /// Exponent for prioritization. Higher values increase the bias towards - /// high-priority transitions. A value of 0 results in uniform sampling. - pub alpha: f32, - - /// Initial value of the importance sampling exponent. Lower values reduce - /// the impact of importance sampling weights. - pub beta_0: f32, - - /// Final value of the importance sampling exponent. Typically set to 1.0 - /// to fully compensate for the non-uniform sampling. - pub beta_final: f32, - - /// Number of optimization steps after which `beta` reaches its final value. - /// This allows for a gradual increase in the impact of importance sampling. - pub n_opts_final: usize, - - /// Method for normalizing importance sampling weights. Controls how the - /// weights are scaled to prevent numerical instability. - pub normalize: WeightNormalizer, -} - -impl Default for PerConfig { - /// Creates a default PER configuration with commonly used values: - /// - `alpha = 0.6` (moderate prioritization) - /// - `beta_0 = 0.4` (initial importance sampling) - /// - `beta_final = 1.0` (full compensation) - /// - `n_opts_final = 500_000` (gradual increase) - /// - `normalize = All` (normalize all weights) - fn default() -> Self { - Self { - alpha: 0.6, - beta_0: 0.4, - beta_final: 1.0, - n_opts_final: 500_000, - normalize: All, - } - } -} - -impl PerConfig { - /// Sets the prioritization exponent `alpha`. - /// - /// # Arguments - /// - /// * `alpha` - The new value for the prioritization exponent - /// - /// # Returns - /// - /// The modified configuration - pub fn alpha(mut self, alpha: f32) -> Self { - self.alpha = alpha; - self - } - - /// Sets the initial importance sampling exponent `beta_0`. - /// - /// # Arguments - /// - /// * `beta_0` - The new initial value for the importance sampling exponent - /// - /// # Returns - /// - /// The modified configuration - pub fn beta_0(mut self, beta_0: f32) -> Self { - self.beta_0 = beta_0; - self - } - - /// Sets the final importance sampling exponent `beta_final`. - /// - /// # Arguments - /// - /// * `beta_final` - The new final value for the importance sampling exponent - /// - /// # Returns - /// - /// The modified configuration - pub fn beta_final(mut self, beta_final: f32) -> Self { - self.beta_final = beta_final; - self - } - - /// Sets the number of optimization steps to reach the final beta value. - /// - /// # Arguments - /// - /// * `n_opts_final` - The new number of optimization steps - /// - /// # Returns - /// - /// The modified configuration - pub fn n_opts_final(mut self, n_opts_final: usize) -> Self { - self.n_opts_final = n_opts_final; - self - } - - /// Sets the method for normalizing importance weights. - /// - /// # Arguments - /// - /// * `normalize` - The new normalization method - /// - /// # Returns - /// - /// The modified configuration - pub fn normalize(mut self, normalize: WeightNormalizer) -> Self { - self.normalize = normalize; - self - } -} - -/// Configuration for the replay buffer. -/// -/// This structure defines the basic parameters for the replay buffer, -/// including its capacity, random seed, and optional PER configuration. -/// -/// # Fields -/// -/// * `capacity` - Maximum number of transitions to store -/// * `seed` - Random seed for sampling -/// * `per_config` - Optional configuration for prioritized experience replay -/// -/// # Examples -/// -/// ```rust -/// use border_core::generic_replay_buffer::{SimpleReplayBufferConfig, PerConfig}; -/// -/// // Basic configuration -/// let config = SimpleReplayBufferConfig::default() -/// .capacity(10000) -/// .seed(42); -/// -/// // Configuration with PER -/// let config_with_per = SimpleReplayBufferConfig::default() -/// .capacity(10000) -/// .seed(42) -/// .per_config(Some(PerConfig::default())); -/// ``` -#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -pub struct SimpleReplayBufferConfig { - /// Maximum number of transitions that can be stored in the buffer. - /// When the buffer is full, new transitions replace the oldest ones. - pub capacity: usize, - - /// Random seed used for sampling transitions. This ensures reproducibility - /// of the sampling process when the same seed is used. - pub seed: u64, - - /// Optional configuration for prioritized experience replay. If `None`, - /// transitions are sampled uniformly at random. - pub per_config: Option, -} - -impl Default for SimpleReplayBufferConfig { - /// Creates a default replay buffer configuration with commonly used values: - /// - `capacity = 10000` (moderate buffer size) - /// - `seed = 42` (fixed random seed) - /// - `per_config = None` (uniform sampling) - fn default() -> Self { - Self { - capacity: 10000, - seed: 42, - per_config: None, - } - } -} - -impl SimpleReplayBufferConfig { - /// Sets the capacity of the replay buffer. - /// - /// # Arguments - /// - /// * `capacity` - The new capacity for the buffer - /// - /// # Returns - /// - /// The modified configuration - pub fn capacity(mut self, capacity: usize) -> Self { - self.capacity = capacity; - self - } - - /// Sets the random seed for sampling. - /// - /// # Arguments - /// - /// * `seed` - The new random seed - /// - /// # Returns - /// - /// The modified configuration - pub fn seed(mut self, seed: u64) -> Self { - self.seed = seed; - self - } - - /// Sets the configuration for prioritized experience replay. - /// - /// # Arguments - /// - /// * `per_config` - The new PER configuration - /// - /// # Returns - /// - /// The modified configuration - pub fn per_config(mut self, per_config: Option) -> Self { - self.per_config = per_config; - self - } - - /// Loads the configuration from a YAML file. - /// - /// # Arguments - /// - /// * `path` - Path to the configuration file - /// - /// # Returns - /// - /// The loaded configuration - /// - /// # Errors - /// - /// Returns an error if the file cannot be read or parsed - pub fn load(path: impl AsRef) -> Result { - let file = File::open(path)?; - let rdr = BufReader::new(file); - let b = serde_yaml::from_reader(rdr)?; - Ok(b) - } - - /// Saves the configuration to a YAML file. - /// - /// # Arguments - /// - /// * `path` - Path where the configuration should be saved - /// - /// # Returns - /// - /// `Ok(())` if the configuration was saved successfully - /// - /// # Errors - /// - /// Returns an error if the file cannot be written - pub fn save(&self, path: impl AsRef) -> Result<()> { - let mut file = File::create(path)?; - file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; - Ok(()) - } -} diff --git a/border-core/src/generic_replay_buffer/step_proc.rs b/border-core/src/generic_replay_buffer/step_proc.rs deleted file mode 100644 index 48776333..00000000 --- a/border-core/src/generic_replay_buffer/step_proc.rs +++ /dev/null @@ -1,138 +0,0 @@ -//! Generic implementation of step processing for reinforcement learning. -//! -//! This module provides a generic implementation of the `StepProcessor` trait, -//! which handles the conversion of environment steps into transitions suitable -//! for training. It supports: -//! - 1-step TD backup for non-vectorized environments -//! - Generic observation and action types -//! - Efficient batch processing - -use super::{BatchBase, GenericTransitionBatch}; -use crate::{Env, Obs, StepProcessor}; -use std::{default::Default, marker::PhantomData}; - -/// Configuration for the simple step processor. -#[derive(Clone, Debug)] -pub struct SimpleStepProcessorConfig {} - -impl Default for SimpleStepProcessorConfig { - /// Creates a new default configuration. - fn default() -> Self { - Self {} - } -} - -/// A generic implementation of the `StepProcessor` trait. -/// -/// This processor converts environment steps into transitions suitable for -/// training reinforcement learning agents. It supports 1-step TD backup -/// for non-vectorized environments, meaning that each step contains exactly -/// one observation. -/// -/// # Type Parameters -/// -/// * `E` - The environment type, must implement `Env` -/// * `O` - The observation batch type, must implement `BatchBase` and `From` -/// * `A` - The action batch type, must implement `BatchBase` and `From` -pub struct SimpleStepProcessor { - /// The previous observation, used to construct transitions. - prev_obs: Option, - /// Phantom data to hold the generic type parameters. - phantom: PhantomData<(E, A)>, -} - -impl StepProcessor for SimpleStepProcessor -where - E: Env, - O: BatchBase + From, - A: BatchBase + From, -{ - type Config = SimpleStepProcessorConfig; - type Output = GenericTransitionBatch; - - /// Creates a new step processor with the given configuration. - /// - /// # Arguments - /// - /// * `_config` - The configuration for the processor - /// - /// # Returns - /// - /// A new instance of the step processor - fn build(_config: &Self::Config) -> Self { - Self { - prev_obs: None, - phantom: PhantomData, - } - } - - /// Resets the processor with an initial observation. - /// - /// This method must be called before processing any steps to initialize - /// the processor with the starting state of the environment. - /// - /// # Arguments - /// - /// * `init_obs` - The initial observation from the environment - fn reset(&mut self, init_obs: E::Obs) { - self.prev_obs = Some(init_obs.into()); - } - - /// Processes a step from the environment into a transition. - /// - /// This method converts an environment step into a transition suitable - /// for training. It handles: - /// - Converting observations and actions to the appropriate batch types - /// - Managing the previous observation for constructing transitions - /// - Handling episode termination and truncation - /// - /// # Arguments - /// - /// * `step` - The step to process - /// - /// # Returns - /// - /// A transition batch containing the processed step - /// - /// # Panics - /// - /// This method will panic if: - /// - The step contains more than one observation - /// - `reset()` has not been called before processing steps - /// - The step is terminal but does not contain an initial observation - fn process(&mut self, step: crate::Step) -> Self::Output { - assert_eq!(step.obs.len(), 1); - - let batch = if self.prev_obs.is_none() { - panic!("prev_obs is not set. Forgot to call reset()?"); - } else { - let is_done = step.is_done(); - let next_obs = step.obs.clone().into(); - let obs = self.prev_obs.replace(step.obs.into()).unwrap(); - let act = step.act.into(); - let reward = step.reward; - let is_terminated = step.is_terminated; - let is_truncated = step.is_truncated; - let ix_sample = None; - let weight = None; - - if is_done { - self.prev_obs - .replace(step.init_obs.expect("Failed to unwrap init_obs").into()); - } - - GenericTransitionBatch { - obs, - act, - next_obs, - reward, - is_terminated, - is_truncated, - ix_sample, - weight, - } - }; - - batch - } -} From f622b1c116739f5652616f140f56e835bbfc2bae Mon Sep 17 00:00:00 2001 From: taku-y Date: Sat, 16 Aug 2025 15:01:48 +0000 Subject: [PATCH 20/23] Add TensorBatch in border_generic_replay_buffer crate Tweak --- border-generic-replay-buffer/Cargo.toml | 10 +- border-generic-replay-buffer/src/candle.rs | 119 ++++++++++++++++++ .../src/candle/tensor_batch.rs | 0 border-generic-replay-buffer/src/lib.rs | 2 + 4 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 border-generic-replay-buffer/src/candle.rs create mode 100644 border-generic-replay-buffer/src/candle/tensor_batch.rs diff --git a/border-generic-replay-buffer/Cargo.toml b/border-generic-replay-buffer/Cargo.toml index fc828362..e28573fb 100644 --- a/border-generic-replay-buffer/Cargo.toml +++ b/border-generic-replay-buffer/Cargo.toml @@ -11,18 +11,16 @@ readme = "README.md" [dependencies] border-core = { version = "0.0.9", path = "../border-core" } +candle-core = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } -# log = { workspace = true } -# thiserror = { workspace = true } anyhow = { workspace = true } -# chrono = { workspace = true } -# aquamarine = { workspace = true } fastrand = { workspace = true } segment-tree = { workspace = true } -# xxhash-rust = { workspace = true } -# Consider to replace with fastrand rand = { workspace = true } [dev-dependencies] tempdir = { workspace = true } + +[features] +candle = ["candle-core"] diff --git a/border-generic-replay-buffer/src/candle.rs b/border-generic-replay-buffer/src/candle.rs new file mode 100644 index 00000000..4146a0ac --- /dev/null +++ b/border-generic-replay-buffer/src/candle.rs @@ -0,0 +1,119 @@ +use crate::BatchBase; +use candle_core::{error::Result, DType, Device, Tensor}; + +/// Adds capability of constructing [`Tensor`] with a static method. +/// +/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html +pub trait ZeroTensor { + /// Constructs zero tensor. + fn zeros(shape: &[usize]) -> Result; +} + +impl ZeroTensor for u8 { + fn zeros(shape: &[usize]) -> Result { + Tensor::zeros(shape, DType::U8, &Device::Cpu) + } +} + +impl ZeroTensor for f32 { + fn zeros(shape: &[usize]) -> Result { + Tensor::zeros(shape, DType::F32, &Device::Cpu) + } +} + +impl ZeroTensor for i64 { + fn zeros(shape: &[usize]) -> Result { + Tensor::zeros(shape, DType::I64, &Device::Cpu) + } +} + +/// A buffer consisting of a [`Tensor`]. +/// +/// The internal buffer is `Vec`. +/// +/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html +#[derive(Clone, Debug)] +pub struct TensorBatch { + pub buf: Vec, + pub capacity: usize, +} + +impl TensorBatch { + pub fn from_tensor(t: Tensor) -> Self { + let capacity = t.dims()[0] as _; + assert_eq!(capacity, 1); + Self { + buf: vec![t], + capacity, + } + } +} + +impl BatchBase for TensorBatch { + fn new(capacity: usize) -> Self { + Self { + buf: Vec::with_capacity(capacity), + capacity: capacity, + } + } + + /// Pushes given data. + /// + /// if ix + data.buf.len() exceeds the self.capacity, + /// the tail samples in data is placed in the head of the buffer of self. + fn push(&mut self, ix: usize, data: Self) { + if self.buf.len() == self.capacity { + for (i, sample) in data.buf.into_iter().enumerate() { + let ix_ = (ix + i) % self.capacity; + self.buf[ix_] = sample; + } + } else if self.buf.len() < self.capacity { + for (i, sample) in data.buf.into_iter().enumerate() { + if self.buf.len() < self.capacity { + self.buf.push(sample); + } else { + let ix_ = (ix + i) % self.capacity; + self.buf[ix_] = sample; + } + } + } else { + panic!("The length of the buffer is SubBatch is larger than its capacity."); + } + } + + fn sample(&self, ixs: &Vec) -> Self { + let buf = ixs.iter().map(|&ix| self.buf[ix].clone()).collect(); + Self { + buf, + capacity: ixs.len(), + } + } +} + +impl From for Tensor { + fn from(b: TensorBatch) -> Self { + Tensor::cat(&b.buf[..], 0).unwrap() + } +} + +impl From for TensorBatch { + fn from(t: Tensor) -> TensorBatch { + if t.dims()[0] == 1 { + TensorBatch { + buf: vec![t], + capacity: 1, + } + } else { + let buf = (0..t.dims()[0]) + .map(|ix| { + let ix = Tensor::from_vec(vec![ix as i64], &[1], &Device::Cpu).unwrap(); + t.index_select(&ix, 0).unwrap() + }) + .collect(); + TensorBatch { + buf, + capacity: t.dims()[0], + } + } + } +} diff --git a/border-generic-replay-buffer/src/candle/tensor_batch.rs b/border-generic-replay-buffer/src/candle/tensor_batch.rs new file mode 100644 index 00000000..e69de29b diff --git a/border-generic-replay-buffer/src/lib.rs b/border-generic-replay-buffer/src/lib.rs index 9dba02f2..93083828 100644 --- a/border-generic-replay-buffer/src/lib.rs +++ b/border-generic-replay-buffer/src/lib.rs @@ -1,5 +1,7 @@ #![doc = include_str!("../README.md")] mod batch; +#[cfg(feature = "candle")] +pub mod candle; mod config; mod iw_scheduler; mod replay_buffer; From 7711bf1ac8987e4428b458f71345fb8997e963d0 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 18 Aug 2025 14:55:13 +0000 Subject: [PATCH 21/23] Add MlpAgent --- border-policy-no-backend/src/lib.rs | 2 + border-policy-no-backend/src/mlp.rs | 2 + border-policy-no-backend/src/mlp_agent.rs | 59 +++++++++++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 border-policy-no-backend/src/mlp_agent.rs diff --git a/border-policy-no-backend/src/lib.rs b/border-policy-no-backend/src/lib.rs index a4a49260..6f2db1a6 100644 --- a/border-policy-no-backend/src/lib.rs +++ b/border-policy-no-backend/src/lib.rs @@ -1,6 +1,8 @@ #![doc = include_str!("../README.md")] mod mat; mod mlp; +mod mlp_agent; pub use mat::Mat; pub use mlp::Mlp; +pub use mlp_agent::MlpAgent; diff --git a/border-policy-no-backend/src/mlp.rs b/border-policy-no-backend/src/mlp.rs index 04d850c8..6d4afdcf 100644 --- a/border-policy-no-backend/src/mlp.rs +++ b/border-policy-no-backend/src/mlp.rs @@ -8,6 +8,8 @@ mod tch; #[derive(Clone, Debug, Deserialize, Serialize)] /// Multilayer perceptron with ReLU activation function. +/// +/// The tanh() function is applied to the output layer to constrain the output values to the range [-1, 1]. pub struct Mlp { /// Weights of layers. ws: Vec, diff --git a/border-policy-no-backend/src/mlp_agent.rs b/border-policy-no-backend/src/mlp_agent.rs new file mode 100644 index 00000000..0a8da688 --- /dev/null +++ b/border-policy-no-backend/src/mlp_agent.rs @@ -0,0 +1,59 @@ +use border_core::{Agent, Env, NullReplayBuffer, Policy}; +use serde::{Deserialize, Serialize}; + +use crate::{Mat, Mlp}; + +/// MLP-based agent for reinforcement learning. +/// +/// This agent uses a multilayer perceptron (MLP) as its policy network. +/// The MLP outputs actions in the range [-1, 1] due to the tanh activation +/// function applied to the output layer. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct MlpAgent { + mlp: Mlp, +} + +impl Policy for MlpAgent +where + E: Env, + E::Obs: Into, + E::Act: From, +{ + fn sample(&mut self, obs: &E::Obs) -> E::Act { + let obs_mat: Mat = obs.clone().into(); + let act_mat = self.mlp.forward(&obs_mat); + act_mat.into() + } +} + +/// This Agent trait implementation does nothing, but is required when passing +/// this policy to border_core::Evaluator for evaluation. The Evaluator accepts +/// trait objects that implement the Agent trait, and trait objects cannot be +/// upcast to Policy trait objects, so the Agent trait object is used instead. +impl Agent for MlpAgent +where + E: Env, + E::Obs: Into, + E::Act: From, +{ +} + +impl MlpAgent { + /// Creates a new MlpAgent with the given MLP. + /// + /// # Arguments + /// + /// * `mlp` - The MLP network to use as the policy + /// + /// # Returns + /// + /// A new MlpAgent instance + pub fn new(mlp: Mlp) -> Self { + Self { mlp } + } + + /// Returns a reference to the underlying MLP. + pub fn mlp(&self) -> &Mlp { + &self.mlp + } +} From 4d5b81d9ef9a4a11b733701eef1a25b6fd4dcb81 Mon Sep 17 00:00:00 2001 From: taku-y Date: Tue, 19 Aug 2025 01:28:39 +0000 Subject: [PATCH 22/23] Add method --- border-candle-agent/src/bc/base.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/border-candle-agent/src/bc/base.rs b/border-candle-agent/src/bc/base.rs index 3060c8ce..07ce4a79 100644 --- a/border-candle-agent/src/bc/base.rs +++ b/border-candle-agent/src/bc/base.rs @@ -195,4 +195,8 @@ where ); record } + + pub fn get_policy_model(&self) -> &BcModel

{ + &self.policy_model + } } From bd6d72d39ff863a342e2d84e180c5d2532ae5a4b Mon Sep 17 00:00:00 2001 From: taku-y Date: Fri, 22 Aug 2025 14:57:37 +0000 Subject: [PATCH 23/23] Fix readme --- examples/gym/convert_policy/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gym/convert_policy/README.md b/examples/gym/convert_policy/README.md index 6102f7a3..d5ebd723 100644 --- a/examples/gym/convert_policy/README.md +++ b/examples/gym/convert_policy/README.md @@ -6,5 +6,5 @@ which is readable using Rust's standard library. The below command loads model parameters from `../sac_pendulum_tch/model/best`, converts its format, then saves as `./model/mlp.bincode`. It will be used in `pendulum_std` example. ```bash -cargo run +cargo run --features=tch ```