[tensor] revert local view back (#1178)

This commit is contained in:
Jiarui Fang
2022-06-27 18:38:34 +08:00
committed by GitHub
parent 0dd4e2bbfb
commit 1b657f9ce1
4 changed files with 10 additions and 20 deletions

View File

@@ -101,13 +101,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError
#### the ColoParameter should use the torch.Tensor's builtin methodes ###
def view(self, *args) -> 'ColoTensor':
return super().view_base(*args)
def size(self, *args, **kwargs) -> torch.Size:
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
return super().size_base(*args, **kwargs)

View File

@@ -147,13 +147,13 @@ class ColoTensor(torch.Tensor):
##### override builtin functions which must use tensor in replicate placement ####
def view_base(self, *args) -> 'ColoTensor':
def view_local(self, *args) -> 'ColoTensor':
return super().view(*args)
def size_base(self, *args, **kwargs) -> torch.Size:
def size_local(self, *args, **kwargs) -> torch.Size:
return super().size(*args, **kwargs)
def view(self, *args) -> 'ColoTensor':
def view_global(self, *args) -> 'ColoTensor':
"""override the torch buildin view()
the args passed in must be in a replicate placement.
Returns:
@@ -167,7 +167,7 @@ class ColoTensor(torch.Tensor):
self._tensor_spec.dist_spec = distspec.replicate()
return super().view(*args)
def size(self, args: Optional[int] = None):
def size_global(self, args: Optional[int] = None):
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns: