mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[NFC] polish doc style for ColoTensor (#1457)
This commit is contained in:
@@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
||||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
@@ -51,31 +51,31 @@ def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
|
||||
|
||||
class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
|
||||
The Colotensor can be initialized with a PyTorch tensor in the following ways.
|
||||
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
|
||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||
The class should be initialized with a torch tensor in the following ways.
|
||||
1. directly init.
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
2. use static method from_torch_tensor
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
"""__new__
|
||||
"""
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor wrappers the data.
|
||||
"""
|
||||
@@ -115,12 +115,10 @@ class ColoTensor(torch.Tensor):
|
||||
"""set_process_group
|
||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): target pg
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
RuntimeError:
|
||||
"""
|
||||
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
||||
# if the new pg is the same as the old pg, just returns
|
||||
@@ -139,6 +137,7 @@ class ColoTensor(torch.Tensor):
|
||||
def set_dist_spec(self, dist_spec: _DistSpec):
|
||||
"""set_dist_spec
|
||||
set dist spec and change the payloads.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): target dist spec.
|
||||
"""
|
||||
@@ -182,6 +181,7 @@ class ColoTensor(torch.Tensor):
|
||||
"""_redistribute
|
||||
Note the function will not handle the logic of backward propagation!
|
||||
It is used during model tensor initializations as an internal function.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): the target dist. spec.
|
||||
"""
|
||||
@@ -193,12 +193,14 @@ class ColoTensor(torch.Tensor):
|
||||
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||
"""redistribute
|
||||
Redistribute the tensor among processes. The rule is like this:
|
||||
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
|
||||
DP process group still as replicated.
|
||||
2. If the pg is not not None and not equal to the cureent process group.
|
||||
First, convert the tensor as replicated among TP process group.
|
||||
Second, reset the process group.
|
||||
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
|
||||
|
||||
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
|
||||
DP process group not changed.
|
||||
|
||||
2. If the pg is not not None and not equal to the current process group.
|
||||
First, convert the tensor as replicated among the TP process group.
|
||||
Second, reset the process group to the new pg.
|
||||
Third, conver the tensor (new replicated both among the tp process group) to the new dist_spec.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): the new dist spec.
|
||||
@@ -219,18 +221,31 @@ class ColoTensor(torch.Tensor):
|
||||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self._redistribute(dist_spec=ReplicaSpec())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
|
||||
converting dist spec of the tensor to ReplicaSpec()
|
||||
"""
|
||||
return self.redistribute(ReplicaSpec())
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
"""from_torch_tensor
|
||||
|
||||
A static method builds a `ColoTensor` from a PyTorch Tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
|
||||
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor
|
||||
"""
|
||||
tensor = tensor.as_subclass(ColoTensor)
|
||||
tensor.__init__(tensor, spec=spec)
|
||||
return tensor
|
||||
@@ -252,10 +267,13 @@ class ColoTensor(torch.Tensor):
|
||||
return super().size(*args)
|
||||
|
||||
def size_global(self, *args) -> torch.Size:
|
||||
"""override the torch buildin size()
|
||||
"""size_global
|
||||
|
||||
override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
torch.Size: the global tensor shape
|
||||
"""
|
||||
if self.is_replicate():
|
||||
return self.size_local(*args)
|
||||
|
Reference in New Issue
Block a user