mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[test] merge old components to test to model zoo (#4945)
* [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test
This commit is contained in:
26
tests/kit/model_zoo/custom/base.py
Normal file
26
tests/kit/model_zoo/custom/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
class CheckpointModule(nn.Module):
|
||||
def __init__(self, checkpoint: bool = False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self._use_checkpoint = checkpoint
|
||||
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._use_checkpoint:
|
||||
return checkpoint(self._forward, *args, **kwargs)
|
||||
else:
|
||||
return self._forward(*args, **kwargs)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
self._use_checkpoint = self.checkpoint
|
||||
return super().train(mode=mode)
|
||||
|
||||
def eval(self):
|
||||
self._use_checkpoint = False
|
||||
return super().eval()
|
Reference in New Issue
Block a user