mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)
This commit is contained in:
@@ -57,15 +57,15 @@ class ColoTensor(torch.Tensor):
|
||||
self._graph_node = None
|
||||
|
||||
@property
|
||||
def spec(self) -> TensorSpec:
|
||||
def tensor_spec(self) -> TensorSpec:
|
||||
return self._tensor_spec
|
||||
|
||||
def set_spec(self, spec: TensorSpec) -> None:
|
||||
def set_tensor_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
self._tensor_spec = spec
|
||||
|
||||
def has_spec(self) -> bool:
|
||||
def has_compute_spec(self) -> bool:
|
||||
return self._tensor_spec.compute_spec is not None
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
@@ -100,27 +100,27 @@ class ColoTensor(torch.Tensor):
|
||||
dist_spec (_DistSpec): the target dist. spec.
|
||||
"""
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
|
||||
self._tensor_spec.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
tensor_spec = copy(self._tensor_spec)
|
||||
tensor_spec.dist_spec = dist_spec
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
|
||||
return ColoTensor.from_torch_tensor(ret, tensor_spec)
|
||||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, distspec.replicate())
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, distspec.replicate())
|
||||
self._tensor_spec.dist_spec = distspec.replicate()
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
return self.convert_to_dist_spec(distspec.replicate(self.spec.get_process_group()))
|
||||
return self.convert_to_dist_spec(distspec.replicate(self.tensor_spec.get_process_group()))
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
@@ -134,16 +134,6 @@ class ColoTensor(torch.Tensor):
|
||||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoTensor(data, spec=copy(self.spec))
|
||||
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
# TODO(jiaruifang) a patch for gpt test.
|
||||
# We need to override the member function must operate on a replicated tensor
|
||||
# def view(self, *args, **kwargs):
|
||||
# self.data = DistSpecManager.handle_trans_spec(self,
|
||||
# self.spec.dist_spec,
|
||||
# distspec.replicate(self.spec.get_process_group()))
|
||||
# # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group())
|
||||
# self.data.view(*args, **kwargs)
|
||||
# return ColoTensor.from_torch_tensor(self.data)
|
||||
return tensor
|
Reference in New Issue
Block a user