[shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460)

* support gpt2 seq parallel with pp/dp/tp

* fix a bug when waiting for stream done

* delete unused gpt2_seq file
This commit is contained in:
Bin Jia
2023-08-18 11:21:53 +08:00
committed by GitHub
parent a78daf6180
commit 7c8be77081
6 changed files with 268 additions and 240 deletions

View File

@@ -235,6 +235,10 @@ class HybridParallelPlugin(PipelinePluginBase):
assert dist.get_world_size() % (
tp_size * pp_size
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
# TODO(ver217): support zero
assert zero_stage == 0, 'zero is not support yet'
self.tp_size = tp_size