mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[chat] fix gemini strategy (#4698)
* [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * g# This is a combination of 2 commits. [chat] fix gemini strategy fox * [chat] fix gemini strategy update llama2 example [chat] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * fix * fix * fix * fix * fix * Update train_prompts.py
This commit is contained in:
@@ -30,3 +30,4 @@ class Actor(LoRAModule):
|
||||
"""Returns model output."""
|
||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||
return output
|
||||
|
||||
|
@@ -71,11 +71,11 @@ def get_strategy_from_args(strategy: str):
|
||||
if strategy == "ddp":
|
||||
strategy_ = DDPStrategy()
|
||||
elif strategy == "colossalai_gemini":
|
||||
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif strategy == "colossalai_gemini_cpu":
|
||||
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2_cpu":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
|
@@ -110,8 +110,8 @@ class Strategy(ABC):
|
||||
"""
|
||||
return model
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
|
||||
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
|
||||
def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
|
||||
self.booster.save_model(model, path, shard=shard, **kwargs)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
|
||||
self.booster.load_model(model, path, strict)
|
||||
|
@@ -6,7 +6,6 @@ import torch.nn as nn
|
||||
import colossalai
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
|
||||
@@ -130,6 +129,9 @@ class GeminiStrategy(DDPStrategy):
|
||||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = "auto",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_m: int = 32, # only for stage 3
|
||||
@@ -160,6 +162,9 @@ class GeminiStrategy(DDPStrategy):
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
chunk_init_device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
precision="fp16",
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
@@ -188,7 +193,7 @@ class GeminiStrategy(DDPStrategy):
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
|
||||
def model_init_context(self):
|
||||
return LazyInitContext(default_device=get_current_device())
|
||||
return super().model_init_context()
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
assert isinstance(model, GeminiDDP)
|
||||
|
@@ -87,9 +87,9 @@ class DDPStrategy(Strategy):
|
||||
return model.unwrap()
|
||||
|
||||
def save_pretrained(
|
||||
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
|
||||
self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
|
||||
) -> None:
|
||||
if not only_rank0 or dist.get_rank() == 0:
|
||||
if dist.get_rank() == 0:
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
|
||||
pretrained_model = unwrapped_model.model
|
||||
@@ -98,19 +98,19 @@ class DDPStrategy(Strategy):
|
||||
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
model_path = os.path.join(path, "pytorch_model.bin")
|
||||
self.save_model(model, model_path, only_rank0=only_rank0)
|
||||
|
||||
model_path = os.path.join(path, "pytorch_model.bin")
|
||||
self.save_model(model, model_path, shard=shard)
|
||||
def _replace_keys(model_path: str, replace_fn: Callable):
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
||||
# HACK: rename keys of pytorch_model.bin
|
||||
if dist.get_rank() == 0:
|
||||
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
|
||||
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
model = self.unwrap_model(model)
|
||||
|
Reference in New Issue
Block a user