mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-06 12:07:00 +00:00
parent
039b7ed3bc
commit
c9427a323f
@ -7,7 +7,7 @@ from .cache_mgr import CachedParamMgr
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from .._utils import dual_all_to_all
|
from .._utils import dual_all_to_all
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup
|
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup, ColoTensorSpec
|
||||||
|
|
||||||
|
|
||||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||||
@ -57,13 +57,15 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||||||
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
||||||
|
|
||||||
if _weight is None:
|
if _weight is None:
|
||||||
self._weight.process_group = ProcessGroup(tp_degree=self.world_size)
|
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
|
||||||
|
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
||||||
|
compute_attr=ComputePattern.TP1D)
|
||||||
self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings,
|
self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings,
|
||||||
self.embedding_dim_per_partition,
|
self.embedding_dim_per_partition,
|
||||||
device='cpu',
|
device='cpu',
|
||||||
dtype=dtype),
|
dtype=dtype),
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
spec=ShardSpec(dims=[-1], num_partitions=[self.world_size]))
|
spec=colo_tensor_spec)
|
||||||
self.init_parameters()
|
self.init_parameters()
|
||||||
else:
|
else:
|
||||||
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
|
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
|
||||||
|
Loading…
Reference in New Issue
Block a user