mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[testing] add beit model for unit testings (#2196)
* [testing] add beit model * [beit] fix bugs * [beit] fix bugs * [testing] fix bugs
This commit is contained in:
@@ -25,7 +25,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
def run_model_test(init_device_type, shard_strategy_class):
|
||||
logger = get_dist_logger("test_zero_init")
|
||||
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
for name, get_components_func in non_distributed_component_funcs._registry.items():
|
||||
# because the ZeroInitContext automatically turns parameters to fp16
|
||||
# and the beit model use tensor.erfinv_() function to initialize weights
|
||||
# tensor.erfinv_() doesn't support Half in CPU, we omit the beit model
|
||||
if name == 'beit':
|
||||
continue
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
if init_device_type == 'cuda':
|
||||
init_device = get_current_device()
|
||||
@@ -70,4 +75,4 @@ def test_zero_init_context(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_init_context(4)
|
||||
test_zero_init_context(1)
|
||||
|
Reference in New Issue
Block a user