[chatgpt] fix lora save bug (#3099)

* fix colo-stratergy

* polish

* fix lora

* fix ddp

* polish

* polish
This commit is contained in:
BlueRum 2023-03-10 17:58:10 +08:00 committed by GitHub
parent 018936a3f3
commit c9dd036592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 2 deletions

View File

@ -74,6 +74,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it # Merge the weights and mark it
if self.r > 0: if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
self.merged = True self.merged = True
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
@ -125,3 +127,4 @@ class LoRAModule(nn.Module):
return return
convert_to_lora_recursively(self, self.lora_rank) convert_to_lora_recursively(self, self.lora_rank)
lora.mark_only_lora_as_trainable(self, self.lora_train_bias) lora.mark_only_lora_as_trainable(self, self.lora_train_bias)

View File

@ -6,11 +6,13 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from chatgpt.models.base import Actor from chatgpt.models.base import Actor
from chatgpt.models.lora import LoraLinear
from torch.optim import Optimizer from torch.optim import Optimizer
import colossalai import colossalai
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
@ -143,6 +145,20 @@ class ColossalAIStrategy(DDPStrategy):
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model) unwrapped_model = self._unwrap_model(model)
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
if isinstance(unwrapped_model, ZeroDDP):
state_dict = unwrapped_model.state_dict()
unwrapped_model = get_static_torch_model(unwrapped_model)
if only_rank0 and dist.get_rank() != 0:
return
unwrapped_model.load_state_dict(state_dict)
# merge lora_weights into weights
for module in unwrapped_model.modules():
if isinstance(module, LoraLinear):
module.merge_weights=True
module.eval()
# get state_dict and save
state_dict = unwrapped_model.state_dict() state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0: if only_rank0 and dist.get_rank() != 0:
return return

View File

@ -6,6 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.models.base import Actor from chatgpt.models.base import Actor
from chatgpt.models.lora import LoraLinear
from chatgpt.replay_buffer import ReplayBuffer from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
@ -72,10 +73,17 @@ class DDPStrategy(NaiveStrategy):
return model.module return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights=True
module.eval()
if only_rank0 and dist.get_rank() != 0: if only_rank0 and dist.get_rank() != 0:
return return
super().save_model(model, path, only_rank0) model = model.model.module
state_dict = model.state_dict()
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0: if only_rank0 and dist.get_rank() != 0:
return return