mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[tensor] a shorter shard and replicate spec (#1245)
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
from .process_group import ProcessGroup
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
from .distspec import shard as ShardSpec
|
||||
from .distspec import replicate as ReplicaSpec
|
||||
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
from .colo_tensor import ColoTensor
|
||||
from .colo_parameter import ColoParameter
|
||||
@@ -11,5 +14,5 @@ from . import distspec
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
|
||||
'ColoTensorSpec', 'TensorSpec'
|
||||
'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
|
||||
]
|
||||
|
@@ -5,7 +5,7 @@ import torch
|
||||
from functools import lru_cache
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor import distspec, ProcessGroup
|
||||
from colossalai.tensor import ProcessGroup, ReplicaSpec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from typing import Optional, Set, Callable
|
||||
@@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
|
||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||
The class should be initialized with a torch tensor in the following ways.
|
||||
1. directly init.
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
2. use static method from_torch_tensor
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
@@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
|
||||
# If not set spec, use a DP process group and replicate dist spec
|
||||
if spec is None:
|
||||
self.has_initialized = False
|
||||
self.dist_spec = distspec.replicate()
|
||||
self.dist_spec = ReplicaSpec()
|
||||
self.compute_spec = None
|
||||
self.process_group = ProcessGroup()
|
||||
else:
|
||||
@@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
|
||||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self._redistribute(dist_spec=distspec.replicate())
|
||||
self._redistribute(dist_spec=ReplicaSpec())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
return self.redistribute(distspec.replicate())
|
||||
return self.redistribute(ReplicaSpec())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
@@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
|
||||
"""
|
||||
if self.is_replicate():
|
||||
return super().view(*args)
|
||||
replicated_t = self.redistribute(dist_spec=distspec.replicate())
|
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
|
Reference in New Issue
Block a user