mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-29 14:24:55 +00:00
[Gemini] fix the convert_to_torch_module bug (#2269)
This commit is contained in:
@@ -2,7 +2,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@@ -22,6 +21,7 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||
return total_temp
|
||||
|
||||
|
||||
# TODO() not work for module where two params share the same tensor.
|
||||
def _add_param(model, name, param):
|
||||
name_list = name.split('.')
|
||||
module = model._modules[name_list[0]]
|
||||
@@ -30,7 +30,7 @@ def _add_param(model, name, param):
|
||||
module._parameters[name_list[-1]] = param
|
||||
|
||||
|
||||
def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
|
||||
def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module:
|
||||
"""convert_to_torch_module
|
||||
|
||||
Args:
|
||||
@@ -39,11 +39,12 @@ def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
|
||||
Returns:
|
||||
torch.nn.Module: a torch model contains the params of gemini_ddp_model
|
||||
"""
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
assert isinstance(gemini_ddp_model, GeminiDDP)
|
||||
module = gemini_ddp_model.module
|
||||
|
||||
for n, p in module.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
p.to_replicate_()
|
||||
_add_param(module, n, p.data)
|
||||
# replace ColoTensor to torch.nn.Tensor in module
|
||||
for n, p in gemini_ddp_model.torch_named_parameters():
|
||||
_add_param(module, n, p)
|
||||
|
||||
return module
|
||||
|
||||
Reference in New Issue
Block a user