mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[workflow] fixed oom tests (#5275)
* [workflow] fixed oom tests * polish * polish * polish
This commit is contained in:
@@ -38,11 +38,11 @@ else:
|
||||
]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("test_config", TEST_CONFIGS)
|
||||
@clear_cache_before_run()
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
|
||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||
iter(model_zoo.get_sub_registry(model_name).values())
|
||||
@@ -145,3 +145,7 @@ def run_dist(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_hybrid_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hybrid_ckpIO(4)
|
||||
|
Reference in New Issue
Block a user