[Tensor] init ColoParameter (#914)

This commit is contained in:
Jiarui Fang
2022-05-06 12:57:14 +08:00
committed by GitHub
parent 193d629311
commit ab95ec9aea
6 changed files with 77 additions and 44 deletions

View File

@@ -1,6 +1,6 @@
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor
from colossalai.tensor import ColoTensor, ColoParameter
import types
from torch import nn
@@ -100,10 +100,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
tensor_detached = param.to(self._device).detach()
tensor_detached.requires_grad = requires_grad
setattr(
module, name,
ColoTensor.init_from_torch_tensor(tensor=tensor_detached,
save_payload=save_torch_payload,
is_model_data=True))
setattr(module, name,
ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload))
ColoModulize(module)