mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[unit test] Refactored test cases with component func (#339)
* refactored test with component func * fixed bug
This commit is contained in:
@@ -43,7 +43,7 @@ class DummyDataLoader(DummyDataGenerator):
|
||||
@non_distributed_component_funcs.register(name='nested_model')
|
||||
def get_training_components():
|
||||
|
||||
def model_builder(checkpoint):
|
||||
def model_builder(checkpoint=True):
|
||||
return NestedNet(checkpoint)
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
|
@@ -3,12 +3,23 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class DummyDataGenerator(ABC):
|
||||
|
||||
def __init__(self, length=10):
|
||||
self.length = length
|
||||
|
||||
@abstractmethod
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
self.step = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return self.generate()
|
||||
if self.step < self.length:
|
||||
self.step += 1
|
||||
return self.generate()
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
Reference in New Issue
Block a user