[pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354)

* add naive optimizer for 3DPlugin/refactor gpt2 shardformer test

* merge tests of PP/DP/TP combinations into one test file

* fix bug when sync grad for dp in HybridPlugin

* update supported precisions for 3DPlugin/fix bug when shifting tp_degree

* improve the passing of lazy_init

* modify lazy_init/use sync_shared_params
This commit is contained in:
Baizhou Zhang
2023-08-01 17:29:09 +08:00
committed by Hongxin Liu
parent f13954cd58
commit 0ceec8f9a9
8 changed files with 187 additions and 142 deletions

View File

@@ -456,12 +456,12 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_.shape, self.weight.shape, self.weight.shape[0])
input_ = input_
else:
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
if self.stream_chunk_num > 1: