mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -14,7 +14,7 @@ from colossalai.testing import (
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
@@ -88,6 +88,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user