mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354)
* add naive optimizer for 3DPlugin/refactor gpt2 shardformer test * merge tests of PP/DP/TP combinations into one test file * fix bug when sync grad for dp in HybridPlugin * update supported precisions for 3DPlugin/fix bug when shifting tp_degree * improve the passing of lazy_init * modify lazy_init/use sync_shared_params
This commit is contained in:
committed by
Hongxin Liu
parent
f13954cd58
commit
0ceec8f9a9
@@ -42,6 +42,8 @@ class HybridParallelModule(ModelWrapper):
|
||||
module = module.half().cuda()
|
||||
elif precision == 'bf16':
|
||||
module = module.to(dtype=torch.bfloat16).cuda()
|
||||
else:
|
||||
module = module.cuda() # train without AMP
|
||||
# TODO(ver217): support TP+DP
|
||||
super().__init__(module)
|
||||
|
||||
@@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper):
|
||||
for p in self.module.parameters():
|
||||
if p.grad is not None:
|
||||
dist.all_reduce(p.grad, group=self.dp_group)
|
||||
p.grad.div_(self.dp_group.size())
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
@@ -72,7 +75,15 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
optim.__setstate__({'param_groups': new_param_groups})
|
||||
|
||||
|
||||
class HybridParallelOptimizer(MixedPrecisionOptimizer):
|
||||
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim)
|
||||
|
||||
|
||||
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
@@ -192,7 +203,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
return ['cuda']
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16', 'bf16']
|
||||
return ['fp16', 'bf16', 'fp32']
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
@@ -218,12 +229,17 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
optimizer = HybridParallelOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
if self.precision in ['fp16', 'bf16']:
|
||||
optimizer = HybridParallelAMPOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism)
|
||||
else:
|
||||
optimizer = HybridParallelZeroOptimizer(optimizer,
|
||||
model,
|
||||
@@ -241,7 +257,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
data_iter: Iterator,
|
||||
model: HybridParallelModule,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer],
|
||||
optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
|
||||
HybridParallelZeroOptimizer],
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
|
||||
@@ -250,7 +267,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
with ctx:
|
||||
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
|
||||
return_outputs)
|
||||
# model.sync_shared_params()
|
||||
model.sync_shared_params()
|
||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||
optimizer.sync_grad()
|
||||
else:
|
||||
|
Reference in New Issue
Block a user