mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 08:17:57 +00:00
fix save_model inin naive and ddp strategy (#3436)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
parent
1beb85cc25
commit
773955abfa
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
@ -5,12 +7,13 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from coati.models.base import Actor
|
from coati.models.base import LM, Actor, RewardModel
|
||||||
from coati.models.lora import LoraLinear
|
from coati.models.lora import LoraLinear
|
||||||
from coati.replay_buffer import ReplayBuffer
|
from coati.replay_buffer import ReplayBuffer
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from .base import Strategy
|
from .base import Strategy
|
||||||
from .naive import NaiveStrategy
|
from .naive import NaiveStrategy
|
||||||
@ -72,17 +75,32 @@ class DDPStrategy(NaiveStrategy):
|
|||||||
model: DDP = Strategy._unwrap_actor(actor)
|
model: DDP = Strategy._unwrap_actor(actor)
|
||||||
return model.module
|
return model.module
|
||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||||
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
|
return None
|
||||||
|
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if isinstance(module, LoraLinear):
|
if isinstance(module, LoraLinear):
|
||||||
module.merge_weights = True
|
module.merge_weights = True
|
||||||
module.eval()
|
module.eval()
|
||||||
|
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
if isinstance(model, RewardModel):
|
||||||
return
|
state_dict = model.state_dict()
|
||||||
model = model.model.module
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
state_dict = model.state_dict()
|
return
|
||||||
torch.save(state_dict, path)
|
torch.save(state_dict, path)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
if isinstance(model, LM):
|
||||||
|
model = model.model
|
||||||
|
model.save_pretrained(path)
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer.save_pretrained(path)
|
||||||
|
except AttributeError:
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
|
return
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from coati.replay_buffer import ReplayBuffer
|
from coati.replay_buffer import ReplayBuffer
|
||||||
|
from coati.models.base import LM, RewardModel
|
||||||
|
from coati.models.lora import LoraLinear
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from .base import Strategy
|
from .base import Strategy
|
||||||
|
|
||||||
@ -38,9 +41,25 @@ class NaiveStrategy(Strategy):
|
|||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
collate_fn=replay_buffer.collate_fn)
|
collate_fn=replay_buffer.collate_fn)
|
||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||||
unwrapped_model = self._unwrap_model(model)
|
for module in model.modules():
|
||||||
torch.save(unwrapped_model.state_dict(), path)
|
if isinstance(module, LoraLinear):
|
||||||
|
module.merge_weights = True
|
||||||
|
module.eval()
|
||||||
|
|
||||||
|
if isinstance(model, RewardModel):
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
if isinstance(model, LM):
|
||||||
|
model = model.model
|
||||||
|
model.save_pretrained(path)
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer.save_pretrained(path)
|
||||||
|
except AttributeError:
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
|
||||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||||
unwrapped_model = self._unwrap_model(model)
|
unwrapped_model = self._unwrap_model(model)
|
||||||
|
Loading…
Reference in New Issue
Block a user