mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[shardformer] support lazy init (#4202)
* [shardformer] support lazy init * [shardformer] linear support lazy init * [shardformer] embedding support lazy init * [shardformer] norm support lazy init * [shardformer] fused linear support lazy init * [test] update shardformer test layer * [test] shardformer with lazy init fit ddp * [lazy] hotfix deepcopy of param * [shardformer] fix bert policy and update test * [shardformer] fix bloom policy and update test * [shardformer] fix opt policy and update test * [shardformer] fix t5 policy and update test * [shardformer] fix gpt2 policy and update test * [shardformer] fix llama policy and update test
This commit is contained in:
@@ -1,15 +1,22 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import Embedding1D
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_embedding_1d():
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_embedding_1d(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
|
||||
|
||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||
|
Reference in New Issue
Block a user