mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[tensor] fix some unittests (#1234)
This commit is contained in:
@@ -72,7 +72,7 @@ class ColoTensor(torch.Tensor):
|
||||
|
||||
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
|
||||
# If not set spec, use a DP process group and replicate dist spec
|
||||
if not spec:
|
||||
if spec is None:
|
||||
self.has_initialized = False
|
||||
self.dist_spec = distspec.replicate()
|
||||
self.compute_spec = None
|
||||
@@ -81,7 +81,10 @@ class ColoTensor(torch.Tensor):
|
||||
self.has_initialized = True
|
||||
self.dist_spec = spec.dist_attr
|
||||
self.compute_spec = spec.compute_attr
|
||||
self.process_group = spec.pg
|
||||
if spec.pg is None:
|
||||
self.process_group = ProcessGroup()
|
||||
else:
|
||||
self.process_group = spec.pg
|
||||
|
||||
self._type = TensorType.NONMODEL
|
||||
self._graph_node = None
|
||||
@@ -125,7 +128,7 @@ class ColoTensor(torch.Tensor):
|
||||
dist_spec (_DistSpec): target dist spec.
|
||||
"""
|
||||
assert isinstance(dist_spec, _DistSpec)
|
||||
assert self.process_group
|
||||
assert self.process_group is not None
|
||||
self._convert_to_dist_spec(dist_spec)
|
||||
|
||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||
|
Reference in New Issue
Block a user