mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[test] merge old components to test to model zoo (#4945)
* [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test
This commit is contained in:
@@ -9,6 +9,7 @@ from .comparison import (
|
||||
)
|
||||
from .pytest_wrapper import run_on_environment_flag
|
||||
from .utils import (
|
||||
DummyDataloader,
|
||||
clear_cache_before_run,
|
||||
free_port,
|
||||
parameterize,
|
||||
@@ -34,4 +35,5 @@ __all__ = [
|
||||
"run_on_environment_flag",
|
||||
"check_state_dict_equal",
|
||||
"assert_hf_output_close",
|
||||
"DummyDataloader",
|
||||
]
|
||||
|
@@ -273,3 +273,24 @@ def clear_cache_before_run():
|
||||
return _clear_cache
|
||||
|
||||
return _wrap_func
|
||||
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, data_gen_fn: Callable, length: int = 10):
|
||||
self.data_gen_fn = data_gen_fn
|
||||
self.length = length
|
||||
self.step = 0
|
||||
|
||||
def __iter__(self):
|
||||
self.step = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.step < self.length:
|
||||
self.step += 1
|
||||
return self.data_gen_fn()
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
Reference in New Issue
Block a user