mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[tensor] a shorter shard and replicate spec (#1245)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from torch.fx.node import map_arg
|
||||
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
|
||||
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
|
||||
|
||||
|
||||
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
|
||||
@@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
||||
setattr(weight, "fx_attr", spec)
|
||||
return weight
|
||||
|
Reference in New Issue
Block a user