From 773955abfaa3b5aef832ad4a33ce053183edee0e Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Tue, 4 Apr 2023 15:30:01 +0800 Subject: [PATCH] fix save_model inin naive and ddp strategy (#3436) Co-authored-by: Yuanchen Xu --- .../Chat/coati/trainer/strategies/ddp.py | 34 ++++++++++++++----- .../Chat/coati/trainer/strategies/naive.py | 27 ++++++++++++--- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index 83cbbe633..8a8c4b3c2 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -1,3 +1,5 @@ +from typing import Optional + import os import random @@ -5,12 +7,13 @@ import numpy as np import torch import torch.distributed as dist import torch.nn as nn -from coati.models.base import Actor +from coati.models.base import LM, Actor, RewardModel from coati.models.lora import LoraLinear from coati.replay_buffer import ReplayBuffer from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.data import DataLoader +from transformers.tokenization_utils_base import PreTrainedTokenizerBase from .base import Strategy from .naive import NaiveStrategy @@ -72,17 +75,32 @@ class DDPStrategy(NaiveStrategy): model: DDP = Strategy._unwrap_actor(actor) return model.module - def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + if only_rank0 and dist.get_rank() != 0: + return None + for module in model.modules(): if isinstance(module, LoraLinear): module.merge_weights = True module.eval() - - if only_rank0 and dist.get_rank() != 0: - return - model = model.model.module - state_dict = model.state_dict() - torch.save(state_dict, path) + + if isinstance(model, RewardModel): + state_dict = model.state_dict() + if only_rank0 and dist.get_rank() != 0: + return + torch.save(state_dict, path) + else: + try: + if isinstance(model, LM): + model = model.model + model.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + except AttributeError: + state_dict = model.state_dict() + if only_rank0 and dist.get_rank() != 0: + return + torch.save(state_dict, path) def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: if only_rank0 and dist.get_rank() != 0: diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py index 80768d7e6..bb47e5ab2 100644 --- a/applications/Chat/coati/trainer/strategies/naive.py +++ b/applications/Chat/coati/trainer/strategies/naive.py @@ -1,11 +1,14 @@ -from typing import Any +from typing import Any, Optional import torch import torch.nn as nn import torch.optim as optim from coati.replay_buffer import ReplayBuffer +from coati.models.base import LM, RewardModel +from coati.models.lora import LoraLinear from torch.optim import Optimizer from torch.utils.data import DataLoader +from transformers.tokenization_utils_base import PreTrainedTokenizerBase from .base import Strategy @@ -38,9 +41,25 @@ class NaiveStrategy(Strategy): pin_memory=pin_memory, collate_fn=replay_buffer.collate_fn) - def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: - unwrapped_model = self._unwrap_model(model) - torch.save(unwrapped_model.state_dict(), path) + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + for module in model.modules(): + if isinstance(module, LoraLinear): + module.merge_weights = True + module.eval() + + if isinstance(model, RewardModel): + state_dict = model.state_dict() + torch.save(state_dict, path) + else: + try: + if isinstance(model, LM): + model = model.model + model.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + except AttributeError: + state_dict = model.state_dict() + torch.save(state_dict, path) def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: unwrapped_model = self._unwrap_model(model)