mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[chat] refactor strategy class with booster api (#3987)
* refactor: adapt boost API in base and naive strategies * fix: initialize plugin after setup_distributed * fix: fix save_pretrained fn * refactor: adapt boost API in DDPStrategy * to: add _post_init check * to: fix ddp backward, modify ddp dataloader and unwrap * feat: adapt boost API in ColossalAIStrategy * fix: call setup_distributed before use get_current_device * fix: fix save_model and save_optimizer * test: remove save_sharded_optimizer test * style: apply formatter * fix: fix stage check and add comments * feat: allow dict type arg in strategy.prepare * to: temporarily remove lr_scheduler for testing * style: simplify init of ColossalAIStrategy * fix: fix lr_scheduler in sft and rm * style: modify comments * test: add train_prompts tests * fix: fix inference only case and use in train_prompts * test: skip failed tests in ci * style: fix CodeFactor check * fix: do not use model.to('cpu') with GeminiPlugin * test: enable colossalai_gemini tests * test: set CUDA_VISIBLE_DEVICES in ci * docs: add note
This commit is contained in:
@@ -19,8 +19,10 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
if isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
if isinstance(strategy.plugin, GeminiPlugin) and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user