[hotfix] fix shape error in backward when using ColoTensor (#1298)

This commit is contained in:
HELSON
2022-07-13 23:06:12 +08:00
committed by GitHub
parent f83c4d6597
commit 260a55804a
4 changed files with 26 additions and 56 deletions

View File

@@ -204,12 +204,14 @@ class ColoTensor(torch.Tensor):
ColoTensor: a redistributed colotensor
"""
if pg is not None and pg != self.get_process_group():
print('here _redistribute')
# if the pg is not equal, convert the current tensor to replicated
self._redistribute(ReplicaSpec())
self.process_group = pg
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
handled = self.redistribute(ReplicaSpec())
else:
handled = self
pg = self.process_group
ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_