fix save_model inin naive and ddp strategy (#3436)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen 2023-04-04 15:30:01 +08:00 committed by GitHub
parent 1beb85cc25
commit 773955abfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 12 deletions

View File

@ -1,3 +1,5 @@
from typing import Optional
import os import os
import random import random
@ -5,12 +7,13 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from coati.models.base import Actor from coati.models.base import LM, Actor, RewardModel
from coati.models.lora import LoraLinear from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy from .base import Strategy
from .naive import NaiveStrategy from .naive import NaiveStrategy
@ -72,17 +75,32 @@ class DDPStrategy(NaiveStrategy):
model: DDP = Strategy._unwrap_actor(actor) model: DDP = Strategy._unwrap_actor(actor)
return model.module return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return None
for module in model.modules(): for module in model.modules():
if isinstance(module, LoraLinear): if isinstance(module, LoraLinear):
module.merge_weights = True module.merge_weights = True
module.eval() module.eval()
if only_rank0 and dist.get_rank() != 0: if isinstance(model, RewardModel):
return state_dict = model.state_dict()
model = model.model.module if only_rank0 and dist.get_rank() != 0:
state_dict = model.state_dict() return
torch.save(state_dict, path) torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0: if only_rank0 and dist.get_rank() != 0:

View File

@ -1,11 +1,14 @@
from typing import Any from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from coati.models.base import LM, RewardModel
from coati.models.lora import LoraLinear
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy from .base import Strategy
@ -38,9 +41,25 @@ class NaiveStrategy(Strategy):
pin_memory=pin_memory, pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn) collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self._unwrap_model(model) for module in model.modules():
torch.save(unwrapped_model.state_dict(), path) if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
if isinstance(model, RewardModel):
state_dict = model.state_dict()
torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self._unwrap_model(model) unwrapped_model = self._unwrap_model(model)