mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[refactor] move process group from _DistSpec to ColoTensor. (#1203)
This commit is contained in:
@@ -3,6 +3,7 @@ from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Any
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
@@ -129,7 +130,7 @@ def _get_colo_tensors_info(*args) -> list:
|
||||
info = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ColoTensor):
|
||||
info.append((arg.__class__, arg.tensor_spec))
|
||||
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
|
||||
else:
|
||||
info.append(None)
|
||||
return info
|
||||
|
Reference in New Issue
Block a user