[polish] polish ColoTensor and its submodules (#2537)

This commit is contained in:
HELSON
2023-02-03 11:44:10 +08:00
committed by GitHub
parent 51d4d6e718
commit 552183bb74
6 changed files with 75 additions and 65 deletions

View File

@@ -189,7 +189,12 @@ class ColoTensor(torch.Tensor):
return _convert_output(ret, colo_spec)
def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
output_list = [super(ColoTensor, self).__repr__()]
output_list.append(str(self.process_group))
output_list.append(str(self.dist_spec))
if self.compute_spec is not None:
output_list.append(str(self.compute_spec))
return "\n".join(output_list)
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute