mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
@@ -61,3 +63,14 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
shard_loss = loss_fn(shard_output)
|
||||
return org_output, org_loss, shard_output, shard_loss
|
||||
|
||||
|
||||
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
|
||||
org_sd = org_model.state_dict()
|
||||
shard_sd = sharded_model.state_dict()
|
||||
for k, v in org_sd.items():
|
||||
assert k in shard_sd, f'{name} {k} not in sharded model'
|
||||
shard_v = shard_sd[k]
|
||||
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
|
||||
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
|
||||
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
|
||||
|
Reference in New Issue
Block a user