mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
[FAW] move coloparam setting in test code. (#1429)
This commit is contained in:
parent
cb98cf5558
commit
10b3df65c8
@ -67,9 +67,6 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||||||
self.init_parameters()
|
self.init_parameters()
|
||||||
else:
|
else:
|
||||||
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
|
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
|
||||||
_weight.process_group = ProcessGroup(tp_degree=self.world_size)
|
|
||||||
_weight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
|
||||||
ComputeSpec(ComputePattern.TP1D))
|
|
||||||
self._weight = _weight
|
self._weight = _weight
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -8,11 +8,9 @@ import random
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.tensor import ColoParameter
|
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
|
||||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||||
|
|
||||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag
|
|
||||||
|
|
||||||
NUM_EMBED, EMBED_DIM = 10, 8
|
NUM_EMBED, EMBED_DIM = 10, 8
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
|
|
||||||
@ -161,6 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size):
|
|||||||
weight = torch.rand(num_embed, embed_dim)
|
weight = torch.rand(num_embed, embed_dim)
|
||||||
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
|
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
|
||||||
|
|
||||||
|
# initialize the tensor spec for the embedding weight parameter,
|
||||||
|
# which is an ColoParameter.
|
||||||
|
coloweight.process_group = ProcessGroup(tp_degree=world_size)
|
||||||
|
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
|
||||||
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
|
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
|
||||||
include_last_offset=True,
|
include_last_offset=True,
|
||||||
freeze=False,
|
freeze=False,
|
||||||
|
Loading…
Reference in New Issue
Block a user