[workflow] fixed oom tests (#5275)

* [workflow] fixed oom tests

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-16 18:55:13 +08:00
committed by GitHub
parent 04244aaaf1
commit d69cd2eb89
19 changed files with 50 additions and 582 deletions

View File

@@ -38,11 +38,11 @@ else:
]
@clear_cache_before_run()
@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@clear_cache_before_run()
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
iter(model_zoo.get_sub_registry(model_name).values())
@@ -145,3 +145,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_hybrid_ckpIO(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_hybrid_ckpIO(4)