[tensor] a shorter shard and replicate spec (#1245)

This commit is contained in:
Jiarui Fang
2022-07-11 15:51:48 +08:00
committed by GitHub
parent 2699dfbbfd
commit 9bcd2fd4af
25 changed files with 91 additions and 98 deletions

View File

@@ -1,5 +1,8 @@
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .distspec import shard as ShardSpec
from .distspec import replicate as ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter
@@ -11,5 +14,5 @@ from . import distspec
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
'ColoTensorSpec', 'TensorSpec'
'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
]

View File

@@ -5,7 +5,7 @@ import torch
from functools import lru_cache
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable
@@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
"""
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
@@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
# If not set spec, use a DP process group and replicate dist spec
if spec is None:
self.has_initialized = False
self.dist_spec = distspec.replicate()
self.dist_spec = ReplicaSpec()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
@@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._redistribute(dist_spec=distspec.replicate())
self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.redistribute(distspec.replicate())
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
@@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
"""
if self.is_replicate():
return super().view(*args)
replicated_t = self.redistribute(dist_spec=distspec.replicate())
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):