[shardformer] adapted T5 and LLaMa test to use kit (#4049)

* [shardformer] adapted T5 and LLaMa test to use kit

* polish code
This commit is contained in:
Frank Lee
2023-06-21 09:32:46 +08:00
parent 4021b9a8a2
commit 58df720570
24 changed files with 239 additions and 168 deletions

View File

@@ -28,27 +28,35 @@ class ModelZooRegistry(dict):
model_fn: Callable,
data_gen_fn: Callable,
output_transform_fn: Callable,
loss_fn: Callable = None,
model_attribute: ModelAttribute = None):
"""
Register a model and data generation function.
Examples:
>>> # Register
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
>>> data = resnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)
```python
# normal forward workflow
model = resnet18()
data = resnet18_data_gen()
output = model(**data)
transformed_output = output_transform_fn(output)
loss = loss_fn(transformed_output)
# Register
model_zoo = ModelZooRegistry()
model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
```
Args:
name (str): Name of the model.
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
output_transform_fn (callable): A function that transforms the output of the model into Dict.
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
output_transform_fn (Callable): A function that transforms the output of the model into Dict.
loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
def get_sub_registry(self, keyword: str):
"""