[refactor] move process group from _DistSpec to ColoTensor. (#1203)

This commit is contained in:
Jiarui Fang
2022-07-06 16:15:16 +08:00
committed by GitHub
parent 5da87ce35d
commit ae7d3f4927
34 changed files with 452 additions and 367 deletions

View File

@@ -3,6 +3,7 @@ from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple, Any
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor import ColoTensorSpec
class ParamOpHook(ABC):
@@ -129,7 +130,7 @@ def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, arg.tensor_spec))
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
else:
info.append(None)
return info