mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer] adapted T5 and LLaMa test to use kit (#4049)
* [shardformer] adapted T5 and LLaMa test to use kit * polish code
This commit is contained in:
@@ -28,27 +28,35 @@ class ModelZooRegistry(dict):
|
||||
model_fn: Callable,
|
||||
data_gen_fn: Callable,
|
||||
output_transform_fn: Callable,
|
||||
loss_fn: Callable = None,
|
||||
model_attribute: ModelAttribute = None):
|
||||
"""
|
||||
Register a model and data generation function.
|
||||
|
||||
Examples:
|
||||
>>> # Register
|
||||
>>> model_zoo = ModelZooRegistry()
|
||||
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
|
||||
>>> # Run the model
|
||||
>>> data = resnet18_data_gen() # do not input any argument
|
||||
>>> model = resnet18() # do not input any argument
|
||||
>>> out = model(**data)
|
||||
|
||||
```python
|
||||
# normal forward workflow
|
||||
model = resnet18()
|
||||
data = resnet18_data_gen()
|
||||
output = model(**data)
|
||||
transformed_output = output_transform_fn(output)
|
||||
loss = loss_fn(transformed_output)
|
||||
|
||||
# Register
|
||||
model_zoo = ModelZooRegistry()
|
||||
model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
|
||||
```
|
||||
|
||||
Args:
|
||||
name (str): Name of the model.
|
||||
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
|
||||
output_transform_fn (callable): A function that transforms the output of the model into Dict.
|
||||
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
|
||||
model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
|
||||
data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
|
||||
output_transform_fn (Callable): A function that transforms the output of the model into Dict.
|
||||
loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
|
||||
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
|
||||
"""
|
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
|
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
|
||||
|
||||
def get_sub_registry(self, keyword: str):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user