mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 03:21:47 +00:00
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
40 lines
1.0 KiB
Python
40 lines
1.0 KiB
Python
import torch.nn as nn
|
|
from peft import PeftModel
|
|
|
|
|
|
class ModelWrapper(nn.Module):
|
|
"""
|
|
A wrapper class to define the common interface used by booster.
|
|
|
|
Args:
|
|
module (nn.Module): The model to be wrapped.
|
|
"""
|
|
|
|
def __init__(self, module: nn.Module) -> None:
|
|
super().__init__()
|
|
self.module = module
|
|
|
|
def unwrap(self, unwrap_peft: bool = True):
|
|
"""
|
|
Unwrap the model to return the original model for checkpoint saving/loading.
|
|
"""
|
|
if isinstance(self.module, ModelWrapper):
|
|
model = self.module.unwrap()
|
|
else:
|
|
model = self.module
|
|
if unwrap_peft and isinstance(model, PeftModel):
|
|
model = model.get_base_model()
|
|
return model
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.module(*args, **kwargs)
|
|
|
|
|
|
class AMPModelMixin:
|
|
"""This mixin class defines the interface for AMP training."""
|
|
|
|
def update_master_params(self):
|
|
"""
|
|
Update the master parameters for AMP training.
|
|
"""
|