mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[hotfix] fix shape error in backward when using ColoTensor (#1298)
This commit is contained in:
@@ -204,12 +204,14 @@ class ColoTensor(torch.Tensor):
|
||||
ColoTensor: a redistributed colotensor
|
||||
"""
|
||||
if pg is not None and pg != self.get_process_group():
|
||||
print('here _redistribute')
|
||||
# if the pg is not equal, convert the current tensor to replicated
|
||||
self._redistribute(ReplicaSpec())
|
||||
self.process_group = pg
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
||||
handled = self.redistribute(ReplicaSpec())
|
||||
else:
|
||||
handled = self
|
||||
pg = self.process_group
|
||||
|
||||
ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
|
||||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
|
Reference in New Issue
Block a user