[Gemini] fix the convert_to_torch_module bug (#2269)

This commit is contained in:
Jiarui Fang
2023-01-03 15:55:35 +08:00
committed by GitHub
parent 879df8b943
commit af32022f74
4 changed files with 52 additions and 25 deletions

View File

@@ -2,7 +2,6 @@ import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk
from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device
@@ -22,6 +21,7 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
return total_temp
# TODO() not work for module where two params share the same tensor.
def _add_param(model, name, param):
name_list = name.split('.')
module = model._modules[name_list[0]]
@@ -30,7 +30,7 @@ def _add_param(model, name, param):
module._parameters[name_list[-1]] = param
def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module:
"""convert_to_torch_module
Args:
@@ -39,11 +39,12 @@ def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
Returns:
torch.nn.Module: a torch model contains the params of gemini_ddp_model
"""
from colossalai.nn.parallel import GeminiDDP
assert isinstance(gemini_ddp_model, GeminiDDP)
module = gemini_ddp_model.module
for n, p in module.named_parameters():
if isinstance(p, ColoTensor):
p.to_replicate_()
_add_param(module, n, p.data)
# replace ColoTensor to torch.nn.Tensor in module
for n, p in gemini_ddp_model.torch_named_parameters():
_add_param(module, n, p)
return module