[chat] fix gemini strategy (#4698)

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* g# This is a combination of 2 commits.

[chat] fix gemini strategy

fox

* [chat] fix gemini strategy

update llama2 example

[chat] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* fix

* fix

* fix

* fix

* fix

* Update train_prompts.py
This commit is contained in:
flybird11111
2023-09-27 13:15:32 +08:00
committed by GitHub
parent bbbcac26e8
commit be400a0936
16 changed files with 49 additions and 40 deletions

View File

@@ -30,3 +30,4 @@ class Actor(LoRAModule):
"""Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output

View File

@@ -71,11 +71,11 @@ def get_strategy_from_args(strategy: str):
if strategy == "ddp":
strategy_ = DDPStrategy()
elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:

View File

@@ -110,8 +110,8 @@ class Strategy(ABC):
"""
return model
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
self.booster.save_model(model, path, shard=shard, **kwargs)
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)

View File

@@ -6,7 +6,6 @@ import torch.nn as nn
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.utils import get_current_device
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
@@ -130,6 +129,9 @@ class GeminiStrategy(DDPStrategy):
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = "auto",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
@@ -160,6 +162,9 @@ class GeminiStrategy(DDPStrategy):
plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=get_current_device(),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
@@ -188,7 +193,7 @@ class GeminiStrategy(DDPStrategy):
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
return LazyInitContext(default_device=get_current_device())
return super().model_init_context()
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiDDP)

View File

@@ -87,9 +87,9 @@ class DDPStrategy(Strategy):
return model.unwrap()
def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
if not only_rank0 or dist.get_rank() == 0:
if dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
pretrained_model = unwrapped_model.model
@@ -98,19 +98,19 @@ class DDPStrategy(Strategy):
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, only_rank0=only_rank0)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, shard=shard)
def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)