colo init context add device attr. (#866)

This commit is contained in:
Jiarui Fang
2022-04-25 14:24:26 +08:00
committed by GitHub
parent 2238758c2e
commit d01d3b8cb0
3 changed files with 36 additions and 12 deletions

View File

@@ -1,3 +1,4 @@
from colossalai.utils.cuda import get_current_device
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
# from colossalai.logging import get_dist_logger
@@ -8,9 +9,15 @@ from colossalai.tensor import ColoTensor
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate=False):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
"""
Args:
lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False.
device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu').
"""
super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
@@ -26,4 +33,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload = True if not self._lazy_memory_allocate else False
for name, param in name_list:
delattr(module, name)
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=save_torch_payload))
setattr(module, name,
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), save_payload=save_torch_payload))