Skip to content

[ENH] Add pretrain/fine-tune lifecycle to BaseModel v2#2220

Open
echo-xiao wants to merge 5 commits intosktime:mainfrom
echo-xiao:enh/basemodel-pretrain-hook
Open

[ENH] Add pretrain/fine-tune lifecycle to BaseModel v2#2220
echo-xiao wants to merge 5 commits intosktime:mainfrom
echo-xiao:enh/basemodel-pretrain-hook

Conversation

@echo-xiao
Copy link
Copy Markdown
Contributor

@echo-xiao echo-xiao commented Mar 20, 2026

Reference Issues/PRs

close #2105

What does this implement/fix? Explain your changes.

Implements the pretrain -> fine-tune -> predict lifecycle for basemodel v2.

Changes:

_base_model_v2.py

  • pretrained_weights init param
  • fine_tune_strategy init param
  • pretrain(): pre-trian on gloabl/panel data
  • _pretrain(): hook for subclasses to override
  • _post_init_laod_pretrained(): subclaases call
  • load_ptrtrained_weights(): load HuggingFace path
  • _freeze_backbone() / _unfreeze_backbone(): freeze_utilities
  • fine_tune_strategy param
  • three mode docstring examples (pretrain->finetune, cold-start, train from scratch)

What should a reviewer concentrate their feedback on?

  1. design of _pretrain() hook: most important thing
  2. find_tune_strategy default value
  3. laod_pretrained_weights(), is this pattern acceptable for all v2 subclasses?
  4. post_init_load_pretrain() pattern, is this pattern acceptable for all v2 subclasses?
  5. if all of these changes match Pre-training, global learning, and fine-tuning API enhancement-proposals#41

Did you add any tests for the change?

tests/test_models/test_basemodel_pretrain.py

  1. pretrain() sets and returns self
  2. _pretrain() subclass hook
  3. weights change after pretraining
  4. load_pretrained_weights() with difference format

Any other comments?

NHiTS_v2 POC integration will be added as follow-up after #2186 is merged.

PR checklist

  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - bugfix, [MNT] - CI, test framework, [ENH] - adding or improving code, [DOC] - writing or improving documentation or docstrings.
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.

@echo-xiao echo-xiao changed the title [ENH] Add pretrain/fine-tune lifecycle to BaseModel v2 with NHiTS_v2 … [ENH] Add pretrain/fine-tune lifecycle to BaseModel v2 Mar 20, 2026
@echo-xiao echo-xiao force-pushed the enh/basemodel-pretrain-hook branch from af0c8e0 to e79e0b1 Compare March 20, 2026 23:13
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 21, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (main@5600398). Learn more about missing BASE report.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2220   +/-   ##
=======================================
  Coverage        ?   86.68%           
=======================================
  Files           ?      165           
  Lines           ?     9782           
  Branches        ?        0           
=======================================
  Hits            ?     8480           
  Misses          ?     1302           
  Partials        ?        0           
Flag Coverage Δ
cpu 86.68% <100.00%> (?)
pytest 86.68% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@echo-xiao
Copy link
Copy Markdown
Contributor Author

echo-xiao commented Mar 26, 2026

Hi @phoeenniixx the PR implements the basemodel v2 core, and I'd love to get your feed back. Does this pretrain on_fit_start() workflow look reasonable to you?

Freeze Backbone Flow-2026-03-26-182545

Copy link
Copy Markdown
Member

@phoeenniixx phoeenniixx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks good!
Sorry, I don't have the complete mental model ready of how pretraining should look like... But I have added some doubts (some of which may feel redundant :))

  • How the inference mode would look here? After freezing the backbone?
  • Also, if we load weights from hf, how should we finetune the model? I think there can be multiple strats?
  • I think we should think of how to integrate this with the FMs, I am not sure if the current BaseModel is a good place for this?
  • Should the pretrain etc methods be in Basepkg similar to the way fit, predict is?

FYI @agobbifbk, @PranavBhatP, @fkiraly what are your thoughts on this?

Also, please dont share the GSoC proposal here, I think you should follow the official channels and apply at the GSoC portal and the sktime form...

for param in self.parameters():
param.requires_grad_(True)

def _post_init_load_pretrained(self) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where (and how) is this method used? I mean why not directly use load_pretrained_weights?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When Basemodel.init runs, self.model doesen't exists yet. _post_init_load_pretrained() is called after the model is fully built.

>>> model.pretrain(dm_large, trainer_kwargs={"max_epochs": 20}) # doctest: +SKIP
>>> from lightning import Trainer # doctest: +SKIP
>>> trainer = Trainer(max_epochs=5) # doctest: +SKIP
>>> trainer.fit(model, dm_target) # backbone frozen automatically # doctest: +SKIP
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to use Basepkg here?

... pretrained_weights="hf://org/my-pretrained-model/weights.pt",
... )
>>> trainer = Trainer(max_epochs=5) # doctest: +SKIP
>>> trainer.fit(model, dm_target) # fine-tunes from pretrained # doctest: +SKIP
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to use Basepkg here?

else:
raise ValueError(f"Scheduler {self.lr_scheduler} not supported.")

def pretrain(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be in Basepkg ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, you are right. pretrain() and _pretrain() should move to Base pkg to be consistant with how fit() and train() is constructed. Basemodel should only keep the lifecycle state and the freeze/ewight-loading hooks.

"""
target = getattr(self, "model", self)
for param in target.parameters():
param.requires_grad_(False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if I just want to freeze just a part of the model? Like first few layers?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subclasses can override it.

@phoeenniixx phoeenniixx added the enhancement New feature or request label Mar 29, 2026
@echo-xiao
Copy link
Copy Markdown
Contributor Author

Thanks for @phoeenniixx 's feedback. For the problems below:

  1. How the inference mode would look here? After freezing the backbone?
    -- Inference works as normal. Frozen parameters still participate in the forward pass. Training only updates unfrozen parts. Frozen parameters are not updated during training.

  2. Also, if we load weights from hf, how should we finetune the model? I think there can be multiple strats?
    -- Currently supports freeze_backbone and full. Other strategies can be added as follow-ups.

  3. I think we should think of how to integrate this with the FMs, I am not sure if the current BaseModel is a good place for this?
    -- This PR provides the base infrastructure (load_pretrained_weights, is_pretrained_, freeze_backbone) for FM integration to build on. Happy to align with the FM project contributors on interface design.

  4. Should the pretrain etc methods be in Basepkg similar to the way fit, predict is?
    -- Agreed

@echo-xiao
Copy link
Copy Markdown
Contributor Author

@phoeenniixx @fkiraly

I recieved the email about my application is not completed. I have just completed and sumitted it. I kindly ask if you could reconsider my application given the forms are now complete.

I am a Data Scientist who has logn relied on sktime in my professional work. And my proposal outlines my deep architectural vision for the project. Please give my proposal a chance if you could.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ENH] Extend TFT v2 with pretrain/fine-tune lifecycle as proof-of-concept for #41

2 participants