mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[colotensor] add Tensor.view op and its unit test (#1343)
[colotensor] add megatron initialization for gpt2
This commit is contained in:
@@ -22,28 +22,30 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
||||
}
|
||||
|
||||
|
||||
def _convert_output(output, pg: ProcessGroup):
|
||||
def _convert_output(output, colo_spec: ColoTensorSpec):
|
||||
if type(output) == torch.Tensor:
|
||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
||||
return ColoTensor.from_torch_tensor(output, colo_spec)
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return type(output)(_convert_output(o, pg) for o in output)
|
||||
return type(output)(_convert_output(o, colo_spec) for o in output)
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
|
||||
def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
|
||||
for elem in args:
|
||||
if isinstance(elem, ColoTensor):
|
||||
pg = elem.get_process_group()
|
||||
return pg
|
||||
dp = elem.dist_spec
|
||||
return ColoTensorSpec(pg, dp)
|
||||
elif isinstance(elem, (list, tuple)):
|
||||
pg = _scan_for_pg_from_args(elem, {})
|
||||
if pg is not None:
|
||||
return pg
|
||||
spec = _get_spec_from_args(elem, {})
|
||||
if spec is not None:
|
||||
return spec
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
pg = v.get_process_group()
|
||||
return pg
|
||||
dp = v.dist_spec
|
||||
return ColoTensorSpec(pg, dp)
|
||||
return None
|
||||
|
||||
|
||||
@@ -170,11 +172,11 @@ class ColoTensor(torch.Tensor):
|
||||
if func in _get_my_nowrap_functions():
|
||||
return ret
|
||||
else:
|
||||
pg = _scan_for_pg_from_args(args, kwargs)
|
||||
return _convert_output(ret, pg)
|
||||
colo_spec = _get_spec_from_args(args, kwargs)
|
||||
return _convert_output(ret, colo_spec)
|
||||
|
||||
def __repr__(self):
|
||||
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
|
||||
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
|
||||
|
||||
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
||||
"""_redistribute
|
||||
@@ -243,50 +245,32 @@ class ColoTensor(torch.Tensor):
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
##### override builtin functions which must use tensor in replicate placement ####
|
||||
# override builtin functions which must use tensor in replicate placement #
|
||||
|
||||
def view_local(self, *args) -> 'ColoTensor':
|
||||
return super().view(*args)
|
||||
def size_local(self, *args) -> torch.Size:
|
||||
with torch._C.DisableTorchFunction():
|
||||
return super().size(*args)
|
||||
|
||||
def size_local(self, *args, **kwargs) -> torch.Size:
|
||||
return super().size(*args, **kwargs)
|
||||
|
||||
def view_global(self, *args) -> 'ColoTensor':
|
||||
"""override the torch buildin view()
|
||||
the args passed in must be in a replicate placement.
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
"""
|
||||
if self.is_replicate():
|
||||
return super().view(*args)
|
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None) -> torch.Size:
|
||||
def size_global(self, *args) -> torch.Size:
|
||||
"""override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
"""
|
||||
if self.is_replicate():
|
||||
if args is not None:
|
||||
return super().size(args)
|
||||
else:
|
||||
return super().size()
|
||||
|
||||
return self.size_local(*args)
|
||||
spec = self.dist_spec
|
||||
dims = spec.dims
|
||||
num_partitions = spec.num_partitions
|
||||
# import inspect
|
||||
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
|
||||
|
||||
size_list = list(super().size())
|
||||
size_list = list(self.size_local())
|
||||
for dim, num_partition in zip(dims, num_partitions):
|
||||
size_list[dim] *= num_partition
|
||||
if args is not None:
|
||||
return size_list[args]
|
||||
else:
|
||||
if args == ():
|
||||
return torch.Size(size_list)
|
||||
else:
|
||||
return size_list[args[0]]
|
||||
|
||||
# Some API for dist spec check
|
||||
|
||||
|
Reference in New Issue
Block a user