[chat] refactor model save/load logic (#3654)

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test
This commit is contained in:
Hongxin Liu
2023-04-27 18:41:49 +08:00
committed by GitHub
parent 6ef7011462
commit 842768a174
14 changed files with 155 additions and 181 deletions

View File

@@ -1,5 +1,24 @@
import torch.nn as nn
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel
__all__ = ['Actor', 'Critic', 'RewardModel']
def get_base_model(model: nn.Module) -> nn.Module:
"""Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
For Critic and RewardModel, it's base model is itself.
Args:
model (nn.Module): model to get base model from
Returns:
nn.Module: the base model
"""
if isinstance(model, Actor):
return model.get_base_model()
return model
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']