[Tensor] overriding paramters() for Module using ColoTensor (#889)

This commit is contained in:
Jiarui Fang
2022-04-27 15:28:59 +08:00
committed by GitHub
parent daf59ff72e
commit 26c49639d8
3 changed files with 74 additions and 6 deletions

View File

@@ -165,7 +165,12 @@ class ColoTensor(object):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph) self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
def __add__(self, o) -> "ColoTensor": def __add__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor()) if isinstance(o, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
elif isinstance(o, torch.Tensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
else:
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
def __truediv__(self, o) -> "ColoTensor": def __truediv__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o) return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)

View File

@@ -1,10 +1,68 @@
from colossalai.utils.cuda import get_current_device
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
# from colossalai.logging import get_dist_logger
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
import types
# _orig_torch_empty = torch.empty from torch import nn
from typing import Iterator, Tuple, Union
def ColoModulize(module):
"""
Replacing the parameters() and named_parameters() with our customized ones
"""
def named_params_with_colotensor(
module: nn.Module,
prefix: str = '',
recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
memo = set()
for mod_prefix, mod in modules:
# find all colotensors tensor params
for name, val in vars(mod).items():
if isinstance(val, ColoTensor) and val not in memo:
memo.add(val)
name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val
# find all nn.Parameters
for name, val in module.old_named_parameters(recurse=recurse):
yield name, val
def fake_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
if isinstance(p, ColoTensor):
yield p.torch_tensor()
elif isinstance(p, torch.Tensor):
yield p
def fake_named_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
if isinstance(p, ColoTensor):
yield name, p.torch_tensor()
elif isinstance(p, torch.Tensor):
yield name, p
def colo_parameters(self, *args, **kargs):
for _, p in named_params_with_colotensor(self, *args, **kargs):
yield p
def colo_named_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
yield name, p
module.old_named_parameters = module.named_parameters
module.old_parameters = module.parameters
funcType = types.MethodType
module.parameters = funcType(fake_parameters, module)
module.named_parameters = funcType(fake_named_parameters, module)
module.colo_parameters = funcType(colo_parameters, module)
module.colo_named_parameters = funcType(colo_named_parameters, module)
module._colo_visited = True
class ColoInitContext(InsertPostInitMethodToModuleSubClasses): class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
@@ -24,8 +82,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times? FIXME(fjr) The module may be passed to this function multiple times?
""" """
if hasattr(module, '_colo_visited'):
return
name_list = [] name_list = []
for name, param in module.named_parameters(): for name, param in module.named_parameters(recurse=False):
if isinstance(param, ColoTensor): if isinstance(param, ColoTensor):
continue continue
name_list.append((name, param)) name_list.append((name, param))
@@ -35,3 +96,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr(module, name) delattr(module, name)
setattr(module, name, setattr(module, name,
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), save_payload=save_torch_payload)) ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), save_payload=save_torch_payload))
ColoModulize(module)

View File

@@ -48,7 +48,7 @@ def run_1d_row_tp():
model_torch = model_torch.cuda() model_torch = model_torch.cuda()
# A naive way to set spec for all weights in Linear # A naive way to set spec for all weights in Linear
for name, p in named_params_with_colotensor(model): for name, p in model.colo_named_parameters():
if not isinstance(p, ColoTensor): if not isinstance(p, ColoTensor):
continue continue
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name: if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name: