[testing] add beit model for unit testings (#2196)

* [testing] add beit model

* [beit] fix bugs

* [beit] fix bugs

* [testing] fix bugs
This commit is contained in:
HELSON
2022-12-26 17:35:36 +08:00
committed by GitHub
parent 5682e6d346
commit a3100bd50d
5 changed files with 58 additions and 7 deletions

View File

@@ -25,7 +25,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def run_model_test(init_device_type, shard_strategy_class):
logger = get_dist_logger("test_zero_init")
for get_components_func in non_distributed_component_funcs:
for name, get_components_func in non_distributed_component_funcs._registry.items():
# because the ZeroInitContext automatically turns parameters to fp16
# and the beit model use tensor.erfinv_() function to initialize weights
# tensor.erfinv_() doesn't support Half in CPU, we omit the beit model
if name == 'beit':
continue
model_builder, _, _, _, _ = get_components_func()
if init_device_type == 'cuda':
init_device = get_current_device()
@@ -70,4 +75,4 @@ def test_zero_init_context(world_size):
if __name__ == '__main__':
test_zero_init_context(4)
test_zero_init_context(1)