mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[shardformer] support lazy init (#4202)
* [shardformer] support lazy init * [shardformer] linear support lazy init * [shardformer] embedding support lazy init * [shardformer] norm support lazy init * [shardformer] fused linear support lazy init * [test] update shardformer test layer * [test] shardformer with lazy init fit ddp * [lazy] hotfix deepcopy of param * [shardformer] fix bert policy and update test * [shardformer] fix bloom policy and update test * [shardformer] fix opt policy and update test * [shardformer] fix t5 policy and update test * [shardformer] fix gpt2 policy and update test * [shardformer] fix llama policy and update test
This commit is contained in:
@@ -9,8 +9,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
@@ -95,6 +95,7 @@ class Embedding1D(ParallelModule):
|
||||
r"""
|
||||
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
num_embedding = module.num_embeddings
|
||||
embedding_dim = module.embedding_dim
|
||||
@@ -223,6 +224,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
r"""
|
||||
Convert a native pytorch embedding module to a parallel module.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the origin attributes
|
||||
num_embeddings = module.num_embeddings
|
||||
embedding_dim = module.embedding_dim
|
||||
@@ -243,6 +245,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# shard and slice the weight along the vocabulary(num_embeddings) dimension
|
||||
# the shape of the weight is (num_embeddings, embedding_dim)
|
||||
|
Reference in New Issue
Block a user