mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo']
|
||||
__all__ = ["ModelZooRegistry", "ModelAttribute", "model_zoo"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -14,6 +14,7 @@ class ModelAttribute:
|
||||
has_control_flow (bool): Whether the model contains branching in its forward method.
|
||||
has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models.
|
||||
"""
|
||||
|
||||
has_control_flow: bool = False
|
||||
has_stochastic_depth_prob: bool = False
|
||||
|
||||
@@ -23,13 +24,15 @@ class ModelZooRegistry(dict):
|
||||
A registry to map model names to model and data generation functions.
|
||||
"""
|
||||
|
||||
def register(self,
|
||||
name: str,
|
||||
model_fn: Callable,
|
||||
data_gen_fn: Callable,
|
||||
output_transform_fn: Callable,
|
||||
loss_fn: Callable = None,
|
||||
model_attribute: ModelAttribute = None):
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
model_fn: Callable,
|
||||
data_gen_fn: Callable,
|
||||
output_transform_fn: Callable,
|
||||
loss_fn: Callable = None,
|
||||
model_attribute: ModelAttribute = None,
|
||||
):
|
||||
"""
|
||||
Register a model and data generation function.
|
||||
|
||||
@@ -71,7 +74,7 @@ class ModelZooRegistry(dict):
|
||||
if keyword in k:
|
||||
new_dict[k] = v
|
||||
|
||||
assert len(new_dict) > 0, f'No model found with keyword {keyword}'
|
||||
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
|
||||
return new_dict
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user