1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-01 21:26:42 +00:00
ColossalAI/colossalai/interface/model.py
Hongxin Liu 807e01a4ba
[zero] hotfix master param sync ()
* [zero] add method to update master params

* [zero] update zero plugin

* [plugin] update low level zero plugin
2023-09-05 15:04:02 +08:00

37 lines
869 B
Python

import torch.nn as nn
class ModelWrapper(nn.Module):
"""
A wrapper class to define the common interface used by booster.
Args:
module (nn.Module): The model to be wrapped.
"""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def unwrap(self):
"""
Unwrap the model to return the original model for checkpoint saving/loading.
"""
if isinstance(self.module, ModelWrapper):
return self.module.unwrap()
return self.module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class AMPModelMixin:
"""This mixin class defines the interface for AMP training.
"""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
pass