mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[shardformer] support lazy init (#4202)
* [shardformer] support lazy init * [shardformer] linear support lazy init * [shardformer] embedding support lazy init * [shardformer] norm support lazy init * [shardformer] fused linear support lazy init * [test] update shardformer test layer * [test] shardformer with lazy init fit ddp * [lazy] hotfix deepcopy of param * [shardformer] fix bert policy and update test * [shardformer] fix bloom policy and update test * [shardformer] fix opt policy and update test * [shardformer] fix t5 policy and update test * [shardformer] fix gpt2 policy and update test * [shardformer] fix llama policy and update test
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -5,15 +7,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def check_shardformer_with_ddp(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@parameterize('lazy_init', [True, False])
|
||||
def check_shardformer_with_ddp(lazy_init: bool):
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
@@ -41,9 +43,12 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
||||
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
|
||||
shardformer = ShardFormer(shard_config=shard_config)
|
||||
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create and shard model
|
||||
model = model_fn().cuda()
|
||||
with ctx:
|
||||
model = model_fn().cuda()
|
||||
sharded_model, _ = shardformer.optimize(model)
|
||||
|
||||
# add ddp
|
||||
@@ -65,13 +70,18 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_shardformer_with_ddp()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2():
|
||||
spawn(check_shardformer_with_ddp, 4)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt2()
|
||||
test_gpt2()
|
||||
|
Reference in New Issue
Block a user