[chat] refactor actor class (#3968)

* refactor: separate log_probs fn from Actor forward fn

* refactor: separate generate fn from Actor class

* feat: update unwrap_model and get_base_model
* unwrap_model returns model not wrapped by Strategy
* get_base_model returns HF model for Actor, Critic and RewardModel

* feat: simplify Strategy.prepare

* style: remove get_base_model method of Actor

* perf: tokenize text in batches

* refactor: move calc_action_log_probs to utils of model

* test: update test with new forward fn

* style: rename forward fn args

* fix: do not unwrap model in save_model fn of naive strategy

* test: add gemini test for train_prompts

* fix: fix _set_default_generate_kwargs
This commit is contained in:
Wenhao Chen
2023-06-13 13:31:56 +08:00
committed by GitHub
parent b3ab7fbabf
commit 9d02590c9a
14 changed files with 151 additions and 120 deletions

View File

@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from coati.models.base import Actor, get_base_model
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
@@ -69,21 +68,16 @@ class Strategy(ABC):
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
"""
def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(model.get_base_model()))
return self.setup_model(model)
rets = []
for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
model = self.setup_model(model)
optimizer = self.setup_optimizer(optimizer, model)
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg))
rets.append(self.setup_model(model))
else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
@@ -93,16 +87,15 @@ class Strategy(ABC):
@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
For Actor, it will unwrap `actor.model`.
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
Args:
model (nn.Module): the model to unwrap
Returns:
nn.Module: the original model (usually a huggingface model)
nn.Module: the original model
"""
return get_base_model(model)
return model
@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
@@ -133,4 +126,4 @@ class Strategy(ABC):
@abstractmethod
def get_model_state_dict_shard(self, model: nn.Module, **config):
pass
pass

View File

@@ -5,7 +5,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from coati.models.base import get_base_model
from torch.optim import Optimizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@@ -153,14 +152,13 @@ class ColossalAIStrategy(DDPStrategy):
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
return
base_model = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
assert isinstance(model, ZeroDDP)
# for stage 3, state_dict() method should be called on every rank
state_dict = base_model.state_dict(only_rank_0=only_rank0)
state_dict = model.state_dict(only_rank_0=only_rank0)
else:
# only_rank0 is false or rank == 0
state_dict = base_model.state_dict()
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
@@ -172,11 +170,10 @@ class ColossalAIStrategy(DDPStrategy):
torch.save(optimizer.state_dict(), path)
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
return base_model.module
return base_model
assert isinstance(model, ZeroDDP)
return model.module
return model
def save_pretrained(self,
model: nn.Module,
@@ -196,5 +193,5 @@ class ColossalAIStrategy(DDPStrategy):
# if isinstance(module, LoraLinear):
# module.merge_weights = True
# module.eval()
base_model: ZeroDDP = get_base_model(model)
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
assert isinstance(model, ZeroDDP)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)

View File

@@ -69,8 +69,8 @@ class DDPStrategy(NaiveStrategy):
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: DDP = super().unwrap_model(model)
return base_model.module
assert isinstance(model, DDP)
return model.module
def save_pretrained(self,
model: nn.Module,

View File

@@ -58,14 +58,13 @@ class NaiveStrategy(Strategy):
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
base_model = get_base_model(model)
state_dict = base_model.state_dict()
state_dict = model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
base_model = get_base_model(model)
unwrapped_model = self.unwrap_model(model)
state_dict = torch.load(path, map_location=map_location)
base_model.load_state_dict(state_dict, strict=strict)
unwrapped_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path)