mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 21:51:57 +00:00
[Tensor] init ColoParameter (#914)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.tensor import ColoTensor, ColoParameter
|
||||
import types
|
||||
|
||||
from torch import nn
|
||||
@@ -100,10 +100,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
tensor_detached = param.to(self._device).detach()
|
||||
tensor_detached.requires_grad = requires_grad
|
||||
|
||||
setattr(
|
||||
module, name,
|
||||
ColoTensor.init_from_torch_tensor(tensor=tensor_detached,
|
||||
save_payload=save_torch_payload,
|
||||
is_model_data=True))
|
||||
setattr(module, name,
|
||||
ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload))
|
||||
|
||||
ColoModulize(module)
|
||||
|
Reference in New Issue
Block a user