mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
[chat] refactor actor class (#3968)
* refactor: separate log_probs fn from Actor forward fn * refactor: separate generate fn from Actor class * feat: update unwrap_model and get_base_model * unwrap_model returns model not wrapped by Strategy * get_base_model returns HF model for Actor, Critic and RewardModel * feat: simplify Strategy.prepare * style: remove get_base_model method of Actor * perf: tokenize text in batches * refactor: move calc_action_log_probs to utils of model * test: update test with new forward fn * style: rename forward fn args * fix: do not unwrap model in save_model fn of naive strategy * test: add gemini test for train_prompts * fix: fix _set_default_generate_kwargs
This commit is contained in:
@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor, get_base_model
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -69,21 +68,16 @@ class Strategy(ABC):
|
||||
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
||||
"""
|
||||
|
||||
def prepare_model(model: nn.Module):
|
||||
if isinstance(model, Actor):
|
||||
return Actor(self.setup_model(model.get_base_model()))
|
||||
return self.setup_model(model)
|
||||
|
||||
rets = []
|
||||
for arg in models_or_model_optim_pairs:
|
||||
if isinstance(arg, tuple):
|
||||
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
||||
model, optimizer = arg
|
||||
model = prepare_model(model)
|
||||
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
|
||||
model = self.setup_model(model)
|
||||
optimizer = self.setup_optimizer(optimizer, model)
|
||||
rets.append((model, optimizer))
|
||||
elif isinstance(arg, nn.Module):
|
||||
rets.append(prepare_model(arg))
|
||||
rets.append(self.setup_model(model))
|
||||
else:
|
||||
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
||||
|
||||
@@ -93,16 +87,15 @@ class Strategy(ABC):
|
||||
|
||||
@staticmethod
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
|
||||
For Actor, it will unwrap `actor.model`.
|
||||
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model to unwrap
|
||||
|
||||
Returns:
|
||||
nn.Module: the original model (usually a huggingface model)
|
||||
nn.Module: the original model
|
||||
"""
|
||||
return get_base_model(model)
|
||||
return model
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
@@ -133,4 +126,4 @@ class Strategy(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
pass
|
||||
pass
|
||||
|
@@ -5,7 +5,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import get_base_model
|
||||
from torch.optim import Optimizer
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
@@ -153,14 +152,13 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
|
||||
return
|
||||
base_model = get_base_model(model)
|
||||
if self.stage == 3:
|
||||
assert isinstance(base_model, ZeroDDP)
|
||||
assert isinstance(model, ZeroDDP)
|
||||
# for stage 3, state_dict() method should be called on every rank
|
||||
state_dict = base_model.state_dict(only_rank_0=only_rank0)
|
||||
state_dict = model.state_dict(only_rank_0=only_rank0)
|
||||
else:
|
||||
# only_rank0 is false or rank == 0
|
||||
state_dict = base_model.state_dict()
|
||||
state_dict = model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
@@ -172,11 +170,10 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
|
||||
if self.stage == 3:
|
||||
assert isinstance(base_model, ZeroDDP)
|
||||
return base_model.module
|
||||
return base_model
|
||||
assert isinstance(model, ZeroDDP)
|
||||
return model.module
|
||||
return model
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
@@ -196,5 +193,5 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
# if isinstance(module, LoraLinear):
|
||||
# module.merge_weights = True
|
||||
# module.eval()
|
||||
base_model: ZeroDDP = get_base_model(model)
|
||||
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||
assert isinstance(model, ZeroDDP)
|
||||
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||
|
@@ -69,8 +69,8 @@ class DDPStrategy(NaiveStrategy):
|
||||
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
base_model: DDP = super().unwrap_model(model)
|
||||
return base_model.module
|
||||
assert isinstance(model, DDP)
|
||||
return model.module
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
|
@@ -58,14 +58,13 @@ class NaiveStrategy(Strategy):
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
base_model = get_base_model(model)
|
||||
state_dict = base_model.state_dict()
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
base_model = get_base_model(model)
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
base_model.load_state_dict(state_dict, strict=strict)
|
||||
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
Reference in New Issue
Block a user