[refactor] remove gpc dependency in colotensor's _ops (#1189)

This commit is contained in:
Jiarui Fang
2022-07-04 18:54:37 +08:00
committed by GitHub
parent abf6a262dc
commit 060b917daf
33 changed files with 499 additions and 357 deletions

View File

@@ -78,6 +78,12 @@ class ColoTensor(torch.Tensor):
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self._tensor_spec.dist_spec.process_group
def get_tp_world_size(self) -> int:
return self._tensor_spec.dist_spec.process_group.tp_world_size()
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: