mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[chatgpt] fix lora save bug (#3099)
* fix colo-stratergy * polish * fix lora * fix ddp * polish * polish
This commit is contained in:
parent
018936a3f3
commit
c9dd036592
@ -74,6 +74,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
delattr(self, 'lora_A')
|
||||
delattr(self, 'lora_B')
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -125,3 +127,4 @@ class LoRAModule(nn.Module):
|
||||
return
|
||||
convert_to_lora_recursively(self, self.lora_rank)
|
||||
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
|
||||
|
||||
|
@ -6,11 +6,13 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from chatgpt.models.base import Actor
|
||||
from chatgpt.models.lora import LoraLinear
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
@ -143,6 +145,20 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
# TODO : better way to get torch model from gemini model
|
||||
# to get torch model from gemini model
|
||||
if isinstance(unwrapped_model, ZeroDDP):
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
unwrapped_model = get_static_torch_model(unwrapped_model)
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
unwrapped_model.load_state_dict(state_dict)
|
||||
# merge lora_weights into weights
|
||||
for module in unwrapped_model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights=True
|
||||
module.eval()
|
||||
# get state_dict and save
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.models.base import Actor
|
||||
from chatgpt.models.lora import LoraLinear
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
@ -72,10 +73,17 @@ class DDPStrategy(NaiveStrategy):
|
||||
return model.module
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights=True
|
||||
module.eval()
|
||||
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
super().save_model(model, path, only_rank0)
|
||||
|
||||
model = model.model.module
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user