From b92313903f36e863c91b3e9bac621dba31cf52c0 Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Wed, 5 Apr 2023 09:45:42 +0800 Subject: [PATCH] fix save_model indent error in ppo trainer (#3450) Co-authored-by: Yuanchen Xu --- applications/Chat/coati/trainer/ppo.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 84254d50d..6b99855be 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -117,6 +117,9 @@ class PPOTrainer(Trainer): return {'reward': experience.reward.mean().item()} + def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) + def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: origin_model = strategy._unwrap_actor(actor) @@ -129,7 +132,3 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn return new_kwargs - - -def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)