mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
[chatgpt] fix lora save bug (#3099)
* fix colo-stratergy * polish * fix lora * fix ddp * polish * polish
This commit is contained in:
parent
018936a3f3
commit
c9dd036592
@ -74,6 +74,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||||||
# Merge the weights and mark it
|
# Merge the weights and mark it
|
||||||
if self.r > 0:
|
if self.r > 0:
|
||||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||||
|
delattr(self, 'lora_A')
|
||||||
|
delattr(self, 'lora_B')
|
||||||
self.merged = True
|
self.merged = True
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
@ -125,3 +127,4 @@ class LoRAModule(nn.Module):
|
|||||||
return
|
return
|
||||||
convert_to_lora_recursively(self, self.lora_rank)
|
convert_to_lora_recursively(self, self.lora_rank)
|
||||||
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
|
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
|
||||||
|
|
||||||
|
@ -6,11 +6,13 @@ import torch.distributed as dist
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from chatgpt.models.base import Actor
|
from chatgpt.models.base import Actor
|
||||||
|
from chatgpt.models.lora import LoraLinear
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||||
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||||
|
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
@ -143,6 +145,20 @@ class ColossalAIStrategy(DDPStrategy):
|
|||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||||
unwrapped_model = self._unwrap_model(model)
|
unwrapped_model = self._unwrap_model(model)
|
||||||
|
# TODO : better way to get torch model from gemini model
|
||||||
|
# to get torch model from gemini model
|
||||||
|
if isinstance(unwrapped_model, ZeroDDP):
|
||||||
|
state_dict = unwrapped_model.state_dict()
|
||||||
|
unwrapped_model = get_static_torch_model(unwrapped_model)
|
||||||
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
|
return
|
||||||
|
unwrapped_model.load_state_dict(state_dict)
|
||||||
|
# merge lora_weights into weights
|
||||||
|
for module in unwrapped_model.modules():
|
||||||
|
if isinstance(module, LoraLinear):
|
||||||
|
module.merge_weights=True
|
||||||
|
module.eval()
|
||||||
|
# get state_dict and save
|
||||||
state_dict = unwrapped_model.state_dict()
|
state_dict = unwrapped_model.state_dict()
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
return
|
return
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from chatgpt.models.base import Actor
|
from chatgpt.models.base import Actor
|
||||||
|
from chatgpt.models.lora import LoraLinear
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
from chatgpt.replay_buffer import ReplayBuffer
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@ -72,10 +73,17 @@ class DDPStrategy(NaiveStrategy):
|
|||||||
return model.module
|
return model.module
|
||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, LoraLinear):
|
||||||
|
module.merge_weights=True
|
||||||
|
module.eval()
|
||||||
|
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
return
|
return
|
||||||
super().save_model(model, path, only_rank0)
|
model = model.model.module
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user