[chat] refactor strategy class with booster api (#3987)

* refactor: adapt boost API in base and naive strategies

* fix: initialize plugin after setup_distributed

* fix: fix save_pretrained fn

* refactor: adapt boost API in DDPStrategy

* to: add _post_init check

* to: fix ddp backward, modify ddp dataloader and unwrap

* feat: adapt boost API in ColossalAIStrategy

* fix: call setup_distributed before use get_current_device

* fix: fix save_model and save_optimizer

* test: remove save_sharded_optimizer test

* style: apply formatter

* fix: fix stage check and add comments

* feat: allow dict type arg in strategy.prepare

* to: temporarily remove lr_scheduler for testing

* style: simplify init of ColossalAIStrategy

* fix: fix lr_scheduler in sft and rm

* style: modify comments

* test: add train_prompts tests

* fix: fix inference only case and use in train_prompts

* test: skip failed tests in ci

* style: fix CodeFactor check

* fix: do not use model.to('cpu') with GeminiPlugin

* test: enable colossalai_gemini tests

* test: set CUDA_VISIBLE_DEVICES in ci

* docs: add note
This commit is contained in:
Wenhao Chen
2023-06-25 17:36:21 +08:00
committed by GitHub
parent b463651f3e
commit 153b957a1b
13 changed files with 350 additions and 290 deletions

View File

@@ -19,8 +19,10 @@ from colossalai.nn.optimizer import HybridAdam
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
numel = sum(p.numel() for p in model.parameters())
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
numel *= dist.get_world_size()
if isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
if isinstance(strategy.plugin, GeminiPlugin) and strategy.shard_init:
numel *= dist.get_world_size()
return numel