[test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test
This commit is contained in:
Hongxin Liu
2023-10-20 10:35:08 +08:00
committed by GitHub
parent 3a41e8304e
commit b8e770c832
49 changed files with 461 additions and 914 deletions

View File

@@ -0,0 +1,26 @@
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = False):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward")
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args, **kwargs)
else:
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()