import torch.nn as nn
from peft import PeftModel


class ModelWrapper(nn.Module):
    """
    A wrapper class to define the common interface used by booster.

    Args:
        module (nn.Module): The model to be wrapped.
    """

    def __init__(self, module: nn.Module) -> None:
        super().__init__()
        self.module = module

    def unwrap(self, unwrap_peft: bool = True):
        """
        Unwrap the model to return the original model for checkpoint saving/loading.
        """
        if isinstance(self.module, ModelWrapper):
            model = self.module.unwrap()
        else:
            model = self.module
        if unwrap_peft and isinstance(model, PeftModel):
            model = model.get_base_model()
        return model

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)


class AMPModelMixin:
    """This mixin class defines the interface for AMP training."""

    def update_master_params(self):
        """
        Update the master parameters for AMP training.
        """