mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-24 06:51:14 +00:00
[chat] refactor strategy class with booster api (#3987)
* refactor: adapt boost API in base and naive strategies
* fix: initialize plugin after setup_distributed
* fix: fix save_pretrained fn
* refactor: adapt boost API in DDPStrategy
* to: add _post_init check
* to: fix ddp backward, modify ddp dataloader and unwrap
* feat: adapt boost API in ColossalAIStrategy
* fix: call setup_distributed before use get_current_device
* fix: fix save_model and save_optimizer
* test: remove save_sharded_optimizer test
* style: apply formatter
* fix: fix stage check and add comments
* feat: allow dict type arg in strategy.prepare
* to: temporarily remove lr_scheduler for testing
* style: simplify init of ColossalAIStrategy
* fix: fix lr_scheduler in sft and rm
* style: modify comments
* test: add train_prompts tests
* fix: fix inference only case and use in train_prompts
* test: skip failed tests in ci
* style: fix CodeFactor check
* fix: do not use model.to('cpu') with GeminiPlugin
* test: enable colossalai_gemini tests
* test: set CUDA_VISIBLE_DEVICES in ci
* docs: add note
This commit is contained in:
@@ -19,8 +19,10 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
if isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
if isinstance(strategy.plugin, GeminiPlugin) and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user