[unit test] Refactored test cases with component func (#339)

* refactored test with component func

* fixed bug
This commit is contained in:
Frank Lee
2022-03-11 14:09:09 +08:00
parent de46450461
commit 526a318032
11 changed files with 148 additions and 420 deletions

View File

@@ -43,7 +43,7 @@ class DummyDataLoader(DummyDataGenerator):
@non_distributed_component_funcs.register(name='nested_model')
def get_training_components():
def model_builder(checkpoint):
def model_builder(checkpoint=True):
return NestedNet(checkpoint)
trainloader = DummyDataLoader()

View File

@@ -3,12 +3,23 @@ from abc import ABC, abstractmethod
class DummyDataGenerator(ABC):
def __init__(self, length=10):
self.length = length
@abstractmethod
def generate(self):
pass
def __iter__(self):
self.step = 0
return self
def __next__(self):
return self.generate()
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration
def __len__(self):
return self.length