Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"

[compat]
DecisionFocusedLearningBenchmarks = "0.5.0, 0.6"
DecisionFocusedLearningBenchmarks = "0.6.1"
DocStringExtensions = "0.9.5"
Flux = "0.16.9"
InferOpt = "0.7.1"
Expand Down
3 changes: 2 additions & 1 deletion src/DecisionFocusedLearningAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ include("algorithms/abstract_algorithm.jl")
include("algorithms/supervised/fyl.jl")
include("algorithms/supervised/anticipative_imitation.jl")
include("algorithms/supervised/dagger.jl")
include("algorithms/mirror_descent/mirror_descent.jl")

export TrainingContext

Expand All @@ -41,7 +42,7 @@ export AbstractMetric,

export AbstractAlgorithm, AbstractImitationAlgorithm
export PerturbedFenchelYoungLossImitation,
DAgger, AnticipativeImitation, train_policy!, train_policy
DAgger, AnticipativeImitation, train_policy!, train_policy, MirrorDescent
export AbstractPolicy, DFLPolicy

end
239 changes: 239 additions & 0 deletions src/algorithms/mirror_descent/mirror_descent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
"""
$TYPEDEF

Mirror Descent algorithm for learning coordinated solutions.

This algorithm is designed for stochastic benchmarks.

Reference: <https://arxiv.org/abs/2505.04757>

# Fields
$TYPEDFIELDS
"""
@kwdef struct MirrorDescent{A<:PerturbedFenchelYoungLossImitation} <: AbstractAlgorithm
"inner imitation algorithm for supervised learning"
inner_algorithm::A = PerturbedFenchelYoungLossImitation()
end

# Helper function to augment a dataset with anticipative solutions
function _augment_with_anticipative(dataset, anticipative_solver)
return map(dataset) do sample
y = anticipative_solver(sample.scenario; sample.context...)
return DataSample(sample; y=y)
end
end

# Helper function to create a perturbed sample
function _perturbed_sample(sample, model, perturbed_solver, is_minimization, κ)
θ = model(sample.x)
signed_θ = is_minimization ? -κ * θ : κ * θ
y = perturbed_solver(signed_θ; scenario=sample.scenario, sample.context...)
return DataSample(sample; y=y)
end

# Helper function to augment a dataset with perturbed solutions
function _augment_with_perturbed(dataset, model, perturbed_solver, is_minimization; κ=1.0)
return map(dataset) do sample
return _perturbed_sample(sample, model, perturbed_solver, is_minimization, κ)
end
end

# Helper function to augment a dataset with perturbed solutions in-place
function _augment_with_perturbed!(dataset, model, perturbed_solver, is_minimization; κ=1.0)
for i in eachindex(dataset)
dataset[i] = _perturbed_sample(
dataset[i], model, perturbed_solver, is_minimization, κ
)
end
return dataset
end

# Helper function to run the mirror descent loop for a given number of iterations
function _mirror_descent_loop(
algorithm,
policy,
input_dataset,
perturbed_solver,
is_minimization;
md_iters,
epochs,
κ,
metrics,
verbose,
)
# Allocate the perturbed dataset once. Subsequent iterations mutate in place.
dataset = _augment_with_perturbed(
input_dataset, policy.statistical_model, perturbed_solver, is_minimization; κ
)
return map(1:md_iters) do n_it
verbose && println("Mirror descent iteration $n_it / $md_iters")
if n_it > 1
_augment_with_perturbed!(
dataset, policy.statistical_model, perturbed_solver, is_minimization; κ
)
end
return train_policy!(algorithm.inner_algorithm, policy, dataset; epochs, metrics)
end
end

"""
$TYPEDSIGNATURES

Train a DFLPolicy using the Mirror Descent algorithm on a provided training dataset.

When `imitation_start=true`, the first iteration is a pure imitation step using
`anticipative_solver`; subsequent iterations are the mirror descent loop using
`perturbed_anticipative_solver`.

# Arguments
- `iterations=10`: total number of mirror descent iterations (includes the imitation step
when `imitation_start=true`)
- `epochs=10`: number of inner training epochs per mirror descent iteration
- `κ=1.0`: scaling factor applied to `θ` before passing it to the perturbed solver
- `metrics::Tuple=()`: metrics forwarded to the inner training algorithm
- `verbose=false`: if true, prints progress at each iteration
- `imitation_start=true`: if true, run a pure imitation step against the
anticipative solver as the first iteration
- `is_minimization=true`: set to false if the objective is a maximization problem
"""
function train_policy!(
algorithm::MirrorDescent,
policy::DFLPolicy,
train_dataset,
anticipative_solver,
perturbed_anticipative_solver;
epochs=10,
iterations=10,
κ=1.0,
metrics::Tuple=(),
verbose::Bool=false,
imitation_start::Bool=true,
is_minimization::Bool=true,
)
if imitation_start
verbose && println("Imitation step")
dataset = _augment_with_anticipative(train_dataset, anticipative_solver)
h_imitation = train_policy!(
algorithm.inner_algorithm, policy, dataset; epochs, metrics
)
md_iters = iterations - 1
md_iters >= 1 || return [h_imitation]
rest = _mirror_descent_loop(
algorithm,
policy,
dataset,
perturbed_anticipative_solver,
is_minimization;
md_iters,
epochs,
κ,
metrics,
verbose,
)
return pushfirst!(rest, h_imitation)
end

# else
return _mirror_descent_loop(
algorithm,
policy,
train_dataset,
perturbed_anticipative_solver,
is_minimization;
md_iters=iterations,
epochs,
κ,
metrics,
verbose,
)
end

"""
$TYPEDSIGNATURES

Generate a dataset for the provided benchmark and train a DFLPolicy using the Mirror Descent algorithm.

This high-level wrapper builds every component (`model`, `maximizer`,
`anticipative_solver`, `parametric_anticipative_solver`, `train_dataset`) from the
benchmark, each exposed as an optional keyword so callers can override any of them
without dropping to [`train_policy!`](@ref).

# Arguments
- `dataset_size=30`: number of samples in the training dataset
(used when `train_dataset` is not provided)
- `nb_scenarios=1`: number of scenarios per instance
(used when `train_dataset` is not provided)
- `context_per_instance=1`: number of contexts per instance
(used when `train_dataset` is not provided)
- `seed=nothing`: random seed for reproducibility
(used in `model` and `train_dataset` when not provided)
- `model`: statistical model to wrap in the policy
(defaults to `generate_statistical_model(benchmark; seed)`)
- `maximizer`: combinatorial oracle to wrap in the policy
(defaults to `generate_maximizer(benchmark)`)
- `anticipative_solver`: oracle used in pure-imitation iterations
(defaults to `generate_anticipative_solver(benchmark)`)
- `parametric_anticipative_solver`: parametric oracle wrapped in `PerturbedAdditive` for
mirror-descent iterations (defaults to `generate_parametric_anticipative_solver(benchmark)`)
- `train_dataset`: training dataset (defaults to `generate_dataset(benchmark, dataset_size; ...)`)
- `epochs=10`: number of inner training epochs per mirror descent iteration
- `iterations=10`: total number of mirror descent iterations
- `κ=1.0`: scaling factor applied to `θ` before passing it to the perturbed solver
- `metrics::Tuple=()`: metrics forwarded to the inner training algorithm
- `verbose=false`: if true, prints a banner at each iteration
- `imitation_start=true`: if true, run a pure imitation step against the anticipative solver as the
first iteration
"""
function train_policy(
algorithm::MirrorDescent,
benchmark::ExogenousStochasticBenchmark;
dataset_size=30,
nb_scenarios=1,
context_per_instance=1,
seed=nothing,
model=generate_statistical_model(benchmark; seed=seed),
maximizer=generate_maximizer(benchmark),
anticipative_solver=generate_anticipative_solver(benchmark),
parametric_anticipative_solver=generate_parametric_anticipative_solver(benchmark),
train_dataset=generate_dataset(
benchmark,
dataset_size;
nb_scenarios=nb_scenarios,
contexts_per_instance=context_per_instance,
seed=seed,
),
epochs=10,
iterations=10,
κ=1.0,
metrics::Tuple=(),
verbose::Bool=false,
imitation_start::Bool=true,
)
policy = DFLPolicy(model, maximizer)

(; nb_samples, ε, threaded) = algorithm.inner_algorithm
perturbed_anticipative_solver = PerturbedAdditive(
(θ; scenario, kwargs...) -> parametric_anticipative_solver(θ, scenario; kwargs...);
ε=κ * ε,
nb_samples=nb_samples,
seed=seed,
threaded=threaded,
)

histories_per_iteration = train_policy!(
algorithm,
policy,
train_dataset,
anticipative_solver,
perturbed_anticipative_solver;
epochs=epochs,
iterations=iterations,
κ=κ,
metrics=metrics,
verbose=verbose,
imitation_start=imitation_start,
is_minimization=is_minimization_problem(benchmark),
)

return histories_per_iteration, policy
end
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"

Expand All @@ -16,7 +17,7 @@ DecisionFocusedLearningAlgorithms = {path = ".."}
[compat]
Aqua = "0.8"
DecisionFocusedLearningAlgorithms = "0.2.0"
DecisionFocusedLearningBenchmarks = "0.5"
DecisionFocusedLearningBenchmarks = "0.6.1"
Documenter = "1"
JuliaFormatter = "2"
MLUtils = "0.4"
Expand Down
Loading