[ENH] Add pretrain/fine-tune lifecycle to BaseModel v2#2220
[ENH] Add pretrain/fine-tune lifecycle to BaseModel v2#2220echo-xiao wants to merge 5 commits intosktime:mainfrom
Conversation
af0c8e0 to
e79e0b1
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2220 +/- ##
=======================================
Coverage ? 86.68%
=======================================
Files ? 165
Lines ? 9782
Branches ? 0
=======================================
Hits ? 8480
Misses ? 1302
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hi @phoeenniixx the PR implements the basemodel v2 core, and I'd love to get your feed back. Does this pretrain on_fit_start() workflow look reasonable to you?
|
phoeenniixx
left a comment
There was a problem hiding this comment.
Thanks! This looks good!
Sorry, I don't have the complete mental model ready of how pretraining should look like... But I have added some doubts (some of which may feel redundant :))
- How the inference mode would look here? After freezing the backbone?
- Also, if we load weights from hf, how should we finetune the model? I think there can be multiple strats?
- I think we should think of how to integrate this with the FMs, I am not sure if the current
BaseModelis a good place for this? - Should the
pretrainetc methods be inBasepkgsimilar to the wayfit,predictis?
FYI @agobbifbk, @PranavBhatP, @fkiraly what are your thoughts on this?
Also, please dont share the GSoC proposal here, I think you should follow the official channels and apply at the GSoC portal and the sktime form...
| for param in self.parameters(): | ||
| param.requires_grad_(True) | ||
|
|
||
| def _post_init_load_pretrained(self) -> None: |
There was a problem hiding this comment.
Where (and how) is this method used? I mean why not directly use load_pretrained_weights?
There was a problem hiding this comment.
When Basemodel.init runs, self.model doesen't exists yet. _post_init_load_pretrained() is called after the model is fully built.
| >>> model.pretrain(dm_large, trainer_kwargs={"max_epochs": 20}) # doctest: +SKIP | ||
| >>> from lightning import Trainer # doctest: +SKIP | ||
| >>> trainer = Trainer(max_epochs=5) # doctest: +SKIP | ||
| >>> trainer.fit(model, dm_target) # backbone frozen automatically # doctest: +SKIP |
There was a problem hiding this comment.
Would it be better to use Basepkg here?
| ... pretrained_weights="hf://org/my-pretrained-model/weights.pt", | ||
| ... ) | ||
| >>> trainer = Trainer(max_epochs=5) # doctest: +SKIP | ||
| >>> trainer.fit(model, dm_target) # fine-tunes from pretrained # doctest: +SKIP |
There was a problem hiding this comment.
Would it be better to use Basepkg here?
| else: | ||
| raise ValueError(f"Scheduler {self.lr_scheduler} not supported.") | ||
|
|
||
| def pretrain( |
There was a problem hiding this comment.
I see, you are right. pretrain() and _pretrain() should move to Base pkg to be consistant with how fit() and train() is constructed. Basemodel should only keep the lifecycle state and the freeze/ewight-loading hooks.
| """ | ||
| target = getattr(self, "model", self) | ||
| for param in target.parameters(): | ||
| param.requires_grad_(False) |
There was a problem hiding this comment.
What if I just want to freeze just a part of the model? Like first few layers?
There was a problem hiding this comment.
Subclasses can override it.
|
Thanks for @phoeenniixx 's feedback. For the problems below:
|
|
I recieved the email about my application is not completed. I have just completed and sumitted it. I kindly ask if you could reconsider my application given the forms are now complete. I am a Data Scientist who has logn relied on sktime in my professional work. And my proposal outlines my deep architectural vision for the project. Please give my proposal a chance if you could. |

Reference Issues/PRs
close #2105
What does this implement/fix? Explain your changes.
Implements the pretrain -> fine-tune -> predict lifecycle for basemodel v2.
Changes:
_base_model_v2.py
What should a reviewer concentrate their feedback on?
Did you add any tests for the change?
tests/test_models/test_basemodel_pretrain.py
Any other comments?
NHiTS_v2 POC integration will be added as follow-up after #2186 is merged.
PR checklist
pre-commit install.