ColossalAI/colossalai/interface/model.py
Hongxin Liu 014837e725
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
2025-02-14 14:48:54 +08:00

40 lines
1.0 KiB
Python

import torch.nn as nn
from peft import PeftModel
class ModelWrapper(nn.Module):
"""
A wrapper class to define the common interface used by booster.
Args:
module (nn.Module): The model to be wrapped.
"""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def unwrap(self, unwrap_peft: bool = True):
"""
Unwrap the model to return the original model for checkpoint saving/loading.
"""
if isinstance(self.module, ModelWrapper):
model = self.module.unwrap()
else:
model = self.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class AMPModelMixin:
"""This mixin class defines the interface for AMP training."""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""