[chatgpt] fix lora save bug (#3099)

* fix colo-stratergy

* polish

* fix lora

* fix ddp

* polish

* polish
This commit is contained in:
BlueRum 2023-03-10 17:58:10 +08:00 committed by GitHub
parent 018936a3f3
commit c9dd036592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 2 deletions

View File

@ -74,6 +74,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
self.merged = True
def forward(self, x: torch.Tensor):
@ -125,3 +127,4 @@ class LoRAModule(nn.Module):
return
convert_to_lora_recursively(self, self.lora_rank)
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)

View File

@ -6,11 +6,13 @@ import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from chatgpt.models.base import Actor
from chatgpt.models.lora import LoraLinear
from torch.optim import Optimizer
import colossalai
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
@ -143,6 +145,20 @@ class ColossalAIStrategy(DDPStrategy):
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model)
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
if isinstance(unwrapped_model, ZeroDDP):
state_dict = unwrapped_model.state_dict()
unwrapped_model = get_static_torch_model(unwrapped_model)
if only_rank0 and dist.get_rank() != 0:
return
unwrapped_model.load_state_dict(state_dict)
# merge lora_weights into weights
for module in unwrapped_model.modules():
if isinstance(module, LoraLinear):
module.merge_weights=True
module.eval()
# get state_dict and save
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return

View File

@ -6,6 +6,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.models.base import Actor
from chatgpt.models.lora import LoraLinear
from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@ -72,10 +73,17 @@ class DDPStrategy(NaiveStrategy):
return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights=True
module.eval()
if only_rank0 and dist.get_rank() != 0:
return
super().save_model(model, path, only_rank0)
model = model.model.module
state_dict = model.state_dict()
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return