[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -1,8 +1,10 @@
import copy
from contextlib import nullcontext
import torch
from torch.nn import Module
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@@ -61,3 +63,14 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
org_sd = org_model.state_dict()
shard_sd = sharded_model.state_dict()
for k, v in org_sd.items():
assert k in shard_sd, f'{name} {k} not in sharded model'
shard_v = shard_sd[k]
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'