[pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224)

* * fix typehint & docstring in sharder.py

* * update pipeline forward for GPT2Model

* * add test for pipeline forward of GPT2Model

* * add cache cleaning in gpt2 test

* * change assert to raise command
This commit is contained in:
Baizhou Zhang
2023-07-13 15:34:06 +08:00
committed by Hongxin Liu
parent 37d22f6878
commit 208ac8f2ba
5 changed files with 357 additions and 9 deletions

View File

@@ -129,7 +129,7 @@ class Linear1D_Col(ParallelModule):
**kwargs)
with torch.no_grad():
# the weigh to the linear layer is a transpose
# the weight to the linear layer is a transpose
# thus shard on row is equal to shard on column
sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)