mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[chat] refactor model save/load logic (#3654)
* [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test
This commit is contained in:
@@ -1,5 +1,24 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from .actor import Actor
|
||||
from .critic import Critic
|
||||
from .reward_model import RewardModel
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel']
|
||||
|
||||
def get_base_model(model: nn.Module) -> nn.Module:
|
||||
"""Get the base model of our wrapper classes.
|
||||
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
|
||||
For Critic and RewardModel, it's base model is itself.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to get base model from
|
||||
|
||||
Returns:
|
||||
nn.Module: the base model
|
||||
"""
|
||||
if isinstance(model, Actor):
|
||||
return model.get_base_model()
|
||||
return model
|
||||
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|
||||
|
@@ -199,15 +199,9 @@ class PPOTrainer(Trainer):
|
||||
|
||||
return {'reward': experience.reward.mean().item()}
|
||||
|
||||
def save_model(self,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
|
||||
|
||||
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = strategy._unwrap_actor(actor)
|
||||
origin_model = strategy.unwrap_model(actor)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -9,8 +9,8 @@ from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .callbacks import Callback
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
|
||||
@@ -41,20 +41,18 @@ class RewardModelTrainer(Trainer):
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader,
|
||||
batch_size: int = 1,
|
||||
max_epochs: int = 1,
|
||||
callbacks: List[Callback] = [],
|
||||
) -> None:
|
||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
||||
train_sampler = None
|
||||
|
||||
self.train_dataloader = train_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
self.model = strategy.setup_model(model)
|
||||
self.model = model
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
||||
self.optimizer = optim
|
||||
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)
|
||||
|
||||
def eval_acc(self, dataloader):
|
||||
@@ -123,9 +121,3 @@ class RewardModelTrainer(Trainer):
|
||||
epoch_bar.update()
|
||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
||||
step_bar.close()
|
||||
|
||||
def save_model(self,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
|
||||
|
@@ -49,8 +49,8 @@ class SFTTrainer(Trainer):
|
||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
||||
self.train_dataloader = train_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
(self.model, self.optimizer) = strategy.prepare((model, optim))
|
||||
self.model = model
|
||||
self.optimizer = optim
|
||||
|
||||
self.accimulation_steps = accimulation_steps
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
|
||||
@@ -133,9 +133,3 @@ class SFTTrainer(Trainer):
|
||||
logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
|
||||
|
||||
# epoch_bar.update()
|
||||
|
||||
def save_model(self,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
|
||||
|
@@ -2,10 +2,9 @@ from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.models.base import Actor, get_base_model
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -72,8 +71,8 @@ class Strategy(ABC):
|
||||
|
||||
def prepare_model(model: nn.Module):
|
||||
if isinstance(model, Actor):
|
||||
return Actor(self.setup_model(self._unwrap_model(model)))
|
||||
return self.setup_model(self._unwrap_model(model))
|
||||
return Actor(self.setup_model(model.get_base_model()))
|
||||
return self.setup_model(model)
|
||||
|
||||
rets = []
|
||||
for arg in models_or_model_optim_pairs:
|
||||
@@ -81,7 +80,7 @@ class Strategy(ABC):
|
||||
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
||||
model, optimizer = arg
|
||||
model = prepare_model(model)
|
||||
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
|
||||
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
|
||||
rets.append((model, optimizer))
|
||||
elif isinstance(arg, nn.Module):
|
||||
rets.append(prepare_model(arg))
|
||||
@@ -93,31 +92,20 @@ class Strategy(ABC):
|
||||
return rets
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
|
||||
For Actor, it will unwrap `actor.model`.
|
||||
|
||||
Args:
|
||||
model (nn.Module): an actor or a critic
|
||||
"""
|
||||
if isinstance(model, Actor):
|
||||
return model.model
|
||||
return model
|
||||
model (nn.Module): the model to unwrap
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
|
||||
|
||||
Args:
|
||||
actor (Actor): a wrapped actor
|
||||
Returns:
|
||||
nn.Module: the original model (usually a huggingface model)
|
||||
"""
|
||||
return Strategy._unwrap_model(actor)
|
||||
return get_base_model(model)
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -134,3 +122,11 @@ class Strategy(ABC):
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
return DistributedSampler(dataset, 1, 0)
|
||||
|
||||
@abstractmethod
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
pass
|
||||
|
@@ -5,10 +5,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import Actor, RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from coati.models.base import get_base_model
|
||||
from torch.optim import Optimizer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
import colossalai
|
||||
@@ -17,9 +15,7 @@ from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini.utils import get_static_torch_model
|
||||
|
||||
from .base import Strategy
|
||||
from .ddp import DDPStrategy
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
@@ -141,7 +137,7 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
||||
|
||||
if self.stage != 3 and self.precision == 'fp16':
|
||||
model = model.half()
|
||||
model = model.half().cuda()
|
||||
return model
|
||||
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
@@ -154,47 +150,39 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.step()
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
|
||||
if isinstance(model, ZeroDDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return None
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
# TODO : better way to get torch model from gemini model
|
||||
# to get torch model from gemini model
|
||||
|
||||
if isinstance(unwrapped_model, RewardModel):
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
|
||||
return
|
||||
base_model = get_base_model(model)
|
||||
if self.stage == 3:
|
||||
assert isinstance(base_model, ZeroDDP)
|
||||
# for stage 3, state_dict() method should be called on every rank
|
||||
state_dict = base_model.state_dict(only_rank_0=only_rank0)
|
||||
else:
|
||||
try:
|
||||
logger.info(f'Saving model to {path}', ranks=[0])
|
||||
unwrapped_model.save_pretrained(path)
|
||||
logger.info(f'Model saved to {path} Successfully', ranks=[0])
|
||||
if tokenizer is not None:
|
||||
logger.info(f'Saving tokenizer to {path}', ranks=[0])
|
||||
tokenizer.save_pretrained(path)
|
||||
logger.info(f'Tokenizer saved to {path} Successfully', ranks=[0])
|
||||
except AttributeError:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
# only_rank0 is false or rank == 0
|
||||
state_dict = base_model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0:
|
||||
raise RuntimeError(
|
||||
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
|
||||
if self.stage == 3:
|
||||
assert isinstance(base_model, ZeroDDP)
|
||||
return base_model.module
|
||||
return base_model
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if self.stage == 3:
|
||||
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
|
||||
super().save_pretrained(model, path, only_rank0, tokenizer)
|
||||
|
@@ -6,14 +6,12 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor, RewardModel
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .base import Strategy
|
||||
from .naive import NaiveStrategy
|
||||
from .sampler import DistributedSampler
|
||||
|
||||
@@ -68,34 +66,10 @@ class DDPStrategy(NaiveStrategy):
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
model: DDP = Strategy._unwrap_actor(actor)
|
||||
return model.module
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return None
|
||||
|
||||
if isinstance(model, RewardModel):
|
||||
state_dict = model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
else:
|
||||
try:
|
||||
model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
except AttributeError:
|
||||
state_dict = model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
return
|
||||
super().save_model(model, path, only_rank0)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
@@ -104,3 +78,16 @@ class DDPStrategy(NaiveStrategy):
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
base_model: DDP = super().unwrap_model(model)
|
||||
return base_model.module
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
super().save_pretrained(model, path, only_rank0, tokenizer)
|
||||
|
@@ -3,10 +3,11 @@ from typing import Any, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.base import get_base_model
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .base import Strategy
|
||||
@@ -40,27 +41,15 @@ class NaiveStrategy(Strategy):
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if isinstance(model, RewardModel):
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
else:
|
||||
try:
|
||||
model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
except AttributeError:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
base_model = get_base_model(model)
|
||||
state_dict = base_model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
base_model = get_base_model(model)
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
||||
base_model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
@@ -68,3 +57,14 @@ class NaiveStrategy(Strategy):
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
optimizer.load_state_dict(state_dict)
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, PreTrainedModel)
|
||||
unwrapped_model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
Reference in New Issue
Block a user