No tutorial on finetuning procedure #1430
Asif-Iqbal-Bhatti
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I am interested in fine-tuning the UMA model for my specific use case but have been unable to find a clear tutorial on the process. I understand that a Hydra YAML configuration file is required, but the documentation I have found so far is difficult to interpret.
My dataset is currently in .extxyz format, which I converted from OUTCAR files using ASE. The stress tensor in my data is a 3×3 matrix. I have prepared the dataset following the MPRelaxSet settings to ensure compatibility with the Materials Project database on raw vasp data not processed.
Could you please provide:
An example YAML configuration file for fine-tuning. I am attaching it below.
'''
defaults:
job:
device_type: CUDA
scheduler:
mode: LOCAL
ranks_per_node: 1
num_nodes: 1
debug: true
run_dir: uma_finetune_runs/
run_name: uma_finetune
logger:
target: fairchem.core.common.logger.WandBSingletonLogger.init_wandb
partial: true
entity: example
project: uma_finetune
mode: online
base_model_name: uma-s-1p1
max_neighbors: 300
epochs: 100
steps: null
batch_size: 2
lr: 4e-4
weight_decay: 1e-3
evaluate_every_n_steps: 100
checkpoint_every_n_steps: 1000
train_dataset:
target: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset
dataset_configs:
DATASET_NAME: ${data.train_dataset}
combined_dataset_config:
sampling:
type: temperature
temperature: 1.0
train_dataloader:
target: fairchem.core.components.common.dataloader_builder.get_dataloader
dataset: ${train_dataset}
batch_sampler_fn:
target: fairchem.core.common.data_parallel.BalancedBatchSampler
partial: true
batch_size: ${batch_size}
shuffle: true
seed: 0
num_workers: 0
collate_fn:
target: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter
tasks: ${data.tasks_list}
val_dataset:
target: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset
dataset_configs:
DATASET_NAME: ${data.val_dataset}
combined_dataset_config:
sampling:
type: temperature
temperature: 1.0
eval_dataloader:
target: fairchem.core.components.common.dataloader_builder.get_dataloader
dataset: ${val_dataset}
batch_sampler_fn:
target: fairchem.core.common.data_parallel.BalancedBatchSampler
partial: true
batch_size: ${batch_size}
shuffle: false
seed: 0
num_workers: 0
collate_fn:
target: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter
tasks: ${data.tasks_list}
runner:
target: fairchem.core.components.train.train_runner.TrainEvalRunner
train_dataloader: ${train_dataloader}
eval_dataloader: ${eval_dataloader}
train_eval_unit:
target: fairchem.core.units.mlip_unit.mlip_unit.MLIPTrainEvalUnit
job_config: ${job}
tasks: ${data.tasks_list}
model:
target: fairchem.core.units.mlip_unit.mlip_unit.initialize_finetuning_model
checkpoint_location:
target: fairchem.core.calculate.pretrained_mlip.pretrained_checkpoint_path_from_name
model_name: ${base_model_name}
overrides:
backbone:
otf_graph: true
max_neighbors: ${max_neighbors}
regress_stress: ${data.regress_stress}
always_use_pbc: false
pass_through_head_outputs: ${data.pass_through_head_outputs}
heads: ${data.heads}
optimizer_fn:
target: torch.optim.AdamW
partial: true
lr: ${lr}
weight_decay: ${weight_decay}
cosine_lr_scheduler_fn:
target: fairchem.core.units.mlip_unit.mlip_unit._get_consine_lr_scheduler
partial: true
warmup_factor: 0.2
warmup_epochs: 0.01
lr_min_factor: 0.01
epochs: ${epochs}
steps: ${steps}
print_every: 10
clip_grad_norm: 100
max_epochs: ${epochs}
max_steps: ${steps}
evaluate_every_n_steps: ${evaluate_every_n_steps}
callbacks:
checkpoint_every_n_steps: ${checkpoint_every_n_steps}
max_saved_checkpoints: 5
'''
Your opinion on whether it would be advisable to fine-tune the uma-s1p1.pt model on my MP-format data, or whether differences between datasets would make this approach less effective.
During training I can't see any val or loss curve now on wandb. Where is the validation loss curve stored?
Beta Was this translation helpful? Give feedback.
All reactions