mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[tensor] torch function return colotensor (#1229)
This commit is contained in:
@@ -11,12 +11,30 @@ from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _check_output(output):
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise RuntimeError
|
||||
def _convert_output(output, pg: ProcessGroup):
|
||||
if type(output) == torch.Tensor:
|
||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
||||
elif isinstance(output, (list, tuple)):
|
||||
output = type(output)(_check_output(o) for o in output)
|
||||
return output
|
||||
return type(output)(_convert_output(o, pg) for o in output)
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
|
||||
for elem in args:
|
||||
if isinstance(elem, ColoTensor):
|
||||
pg = elem.get_process_group()
|
||||
return pg
|
||||
elif isinstance(elem, (list, tuple)):
|
||||
pg = _scan_for_pg_from_args(elem, {})
|
||||
if pg is not None:
|
||||
return pg
|
||||
print(type(elem), elem, isinstance(elem, (list, tuple)))
|
||||
for k, v in kwargs:
|
||||
if isinstance(v, ColoTensor):
|
||||
pg = v.get_process_group()
|
||||
return pg
|
||||
return None
|
||||
|
||||
|
||||
class ColoTensor(torch.Tensor):
|
||||
@@ -108,6 +126,7 @@ class ColoTensor(torch.Tensor):
|
||||
dist_spec (_DistSpec): target dist spec.
|
||||
"""
|
||||
assert isinstance(dist_spec, _DistSpec)
|
||||
assert self.process_group
|
||||
self._convert_to_dist_spec(dist_spec)
|
||||
|
||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||
@@ -136,12 +155,11 @@ class ColoTensor(torch.Tensor):
|
||||
if func in get_default_nowrap_functions():
|
||||
return ret
|
||||
else:
|
||||
# TODO(jiaruifang) its parallel Op's duty to convert output activations
|
||||
return ret
|
||||
# return _check_output(ret)
|
||||
pg = _scan_for_pg_from_args(args, kwargs)
|
||||
return _convert_output(ret, pg)
|
||||
|
||||
def __repr__(self):
|
||||
return f'ColoTensor: {super().__repr__()}'
|
||||
return f'ColoTensor: {super().__repr__()}\n dist spec: {self.dist_spec}\n process group: {self.process_group}'
|
||||
|
||||
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||
"""_convert_to_dist_spec
|
||||
|
Reference in New Issue
Block a user