[pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354)

* add naive optimizer for 3DPlugin/refactor gpt2 shardformer test

* merge tests of PP/DP/TP combinations into one test file

* fix bug when sync grad for dp in HybridPlugin

* update supported precisions for 3DPlugin/fix bug when shifting tp_degree

* improve the passing of lazy_init

* modify lazy_init/use sync_shared_params
This commit is contained in:
Baizhou Zhang
2023-08-01 17:29:09 +08:00
committed by Hongxin Liu
parent f13954cd58
commit 0ceec8f9a9
8 changed files with 187 additions and 142 deletions

View File

@@ -42,6 +42,8 @@ class HybridParallelModule(ModelWrapper):
module = module.half().cuda()
elif precision == 'bf16':
module = module.to(dtype=torch.bfloat16).cuda()
else:
module = module.cuda() # train without AMP
# TODO(ver217): support TP+DP
super().__init__(module)
@@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper):
for p in self.module.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size())
def init_pipeline_optimizer(optim: Optimizer, model: Module):
@@ -72,7 +75,15 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
optim.__setstate__({'param_groups': new_param_groups})
class HybridParallelOptimizer(MixedPrecisionOptimizer):
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim)
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(self,
optim: Optimizer,
@@ -192,7 +203,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return ['cuda']
def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16']
return ['fp16', 'bf16', 'fp32']
def control_device(self) -> bool:
return True
@@ -218,12 +229,17 @@ class HybridParallelPlugin(PipelinePluginBase):
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
optimizer = HybridParallelOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
if self.precision in ['fp16', 'bf16']:
optimizer = HybridParallelAMPOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
else:
optimizer = HybridParallelNaiveOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism)
else:
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
@@ -241,7 +257,8 @@ class HybridParallelPlugin(PipelinePluginBase):
data_iter: Iterator,
model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer],
optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
HybridParallelZeroOptimizer],
return_loss: bool = True,
return_outputs: bool = False) -> dict:
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
@@ -250,7 +267,7 @@ class HybridParallelPlugin(PipelinePluginBase):
with ctx:
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
return_outputs)
# model.sync_shared_params()
model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
else: