ColossalAI/colossalai/legacy/nn/_ops/view.py
Hongxin Liu 554aa9592e
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
2023-09-11 16:24:28 +08:00

97 lines
3.1 KiB
Python

import operator
from functools import reduce
from typing import Optional, Union
import torch
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
from colossalai.tensor.op_wrapper import colo_op_impl
def _all_int(my_iter):
return all(isinstance(i, int) for i in my_iter)
def _get_valid_shape(shape):
if isinstance(shape, list):
if _all_int(shape):
return tuple(shape)
else:
raise RuntimeError("expects type(int) but finds an other type")
elif isinstance(shape, tuple):
if _all_int(shape):
return shape
else:
return _get_valid_shape(shape[0])
else:
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
def _shape_infer(org_sp, tgt_sp):
cnt = 0
pos = 0
for idx, dim in enumerate(tgt_sp):
if dim < -1:
raise RuntimeError("invalid shape dimension {}".format(dim))
elif dim == -1:
cnt += 1
pos = idx
if cnt > 1:
raise RuntimeError("only one dimension can be inferred")
org_prod = reduce(operator.mul, org_sp, 1)
tgt_prod = reduce(operator.mul, tgt_sp, 1)
if cnt == 0:
if org_prod != tgt_prod:
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
else:
return tgt_sp
elif org_prod % tgt_prod != 0:
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
infer_dim = -(org_prod // tgt_prod)
return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
@colo_op_impl(torch.Tensor.view)
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
Changes the shape of the current tensor.
"""
assert isinstance(self, ColoTensor)
# apply original `view` function for replicated colo tensors
if self.is_replicate():
return self.view(*shape)
cur_sp = self.size()
org_sp = self.size_global()
# parse the passed arguments
tgt_sp = _get_valid_shape(shape)
# get the correct shape from inference
inf_sp = _shape_infer(org_sp, tgt_sp)
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
new_shape = (cur_sp[0],) + tgt_sp[1:]
res = self.view(*new_shape)
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
new_shape = tgt_sp[:-1] + (cur_sp[-1],)
res = self.view(*new_shape)
else:
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
spec=ColoTensorSpec(self.get_process_group()))
return ColoTensor.from_torch_tensor(tensor=res,
spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
@colo_op_impl(torch.Tensor.size)
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
size = self.size_global()
if dim is None:
return size
else:
return size[dim]