[tensor] fix some unittests (#1234)

This commit is contained in:
Jiarui Fang
2022-07-08 14:18:30 +08:00
committed by GitHub
parent a45ddf2d5f
commit 3b500984b1
7 changed files with 27 additions and 11 deletions

View File

@@ -72,7 +72,7 @@ class ColoTensor(torch.Tensor):
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
# If not set spec, use a DP process group and replicate dist spec
if not spec:
if spec is None:
self.has_initialized = False
self.dist_spec = distspec.replicate()
self.compute_spec = None
@@ -81,7 +81,10 @@ class ColoTensor(torch.Tensor):
self.has_initialized = True
self.dist_spec = spec.dist_attr
self.compute_spec = spec.compute_attr
self.process_group = spec.pg
if spec.pg is None:
self.process_group = ProcessGroup()
else:
self.process_group = spec.pg
self._type = TensorType.NONMODEL
self._graph_node = None
@@ -125,7 +128,7 @@ class ColoTensor(torch.Tensor):
dist_spec (_DistSpec): target dist spec.
"""
assert isinstance(dist_spec, _DistSpec)
assert self.process_group
assert self.process_group is not None
self._convert_to_dist_spec(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):