mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[chat] remove naive strategy and split colossalai strategy (#4094)
* feat: remove on_learn_epoch fn as not used * revert: add _on_learn_epoch fn * to: remove the use of NaiveStrategy * test: remove NaiveStrategy tests * feat: remove NaiveStrategy * style: modify comments and params * feat: split ColossalAIStrategy into LowLevelZeroStrategy and GeminiStrategy * fix: remove naive * fix: align with modified colossal strategy * fix: fix ddp _try_init_dist arg
This commit is contained in:
@@ -8,7 +8,7 @@ from coati.models.base import RewardModel
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import PerformanceEvaluator
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
@@ -19,10 +19,8 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
if isinstance(strategy.plugin, GeminiPlugin) and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
if isinstance(strategy, GeminiStrategy) and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
@@ -78,17 +76,17 @@ def main(args):
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
|
@@ -83,8 +83,8 @@ def main(args):
|
||||
env_info=env_info_maker,
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
@@ -153,10 +153,10 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
|
@@ -87,8 +87,8 @@ def main(args):
|
||||
env_info=env_info_maker,
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
@@ -164,10 +164,10 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
|
Reference in New Issue
Block a user