mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[tensor] revert local view back (#1178)
This commit is contained in:
@@ -101,13 +101,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
# TODO(jzy) we don't support object reflection now.
|
||||
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
|
||||
raise NotImplementedError
|
||||
|
||||
#### the ColoParameter should use the torch.Tensor's builtin methodes ###
|
||||
|
||||
def view(self, *args) -> 'ColoTensor':
|
||||
return super().view_base(*args)
|
||||
|
||||
def size(self, *args, **kwargs) -> torch.Size:
|
||||
# import inspect
|
||||
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
|
||||
return super().size_base(*args, **kwargs)
|
||||
|
@@ -147,13 +147,13 @@ class ColoTensor(torch.Tensor):
|
||||
|
||||
##### override builtin functions which must use tensor in replicate placement ####
|
||||
|
||||
def view_base(self, *args) -> 'ColoTensor':
|
||||
def view_local(self, *args) -> 'ColoTensor':
|
||||
return super().view(*args)
|
||||
|
||||
def size_base(self, *args, **kwargs) -> torch.Size:
|
||||
def size_local(self, *args, **kwargs) -> torch.Size:
|
||||
return super().size(*args, **kwargs)
|
||||
|
||||
def view(self, *args) -> 'ColoTensor':
|
||||
def view_global(self, *args) -> 'ColoTensor':
|
||||
"""override the torch buildin view()
|
||||
the args passed in must be in a replicate placement.
|
||||
Returns:
|
||||
@@ -167,7 +167,7 @@ class ColoTensor(torch.Tensor):
|
||||
self._tensor_spec.dist_spec = distspec.replicate()
|
||||
return super().view(*args)
|
||||
|
||||
def size(self, args: Optional[int] = None):
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
"""override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
Returns:
|
||||
|
Reference in New Issue
Block a user