diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py index 2806b8cb0..ff55ac54d 100644 --- a/tests/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -10,17 +10,6 @@ from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torch.half): - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.to(dtype) - model.backward(loss) - - def test_runtime_mem_tracer(): test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] @@ -28,7 +17,7 @@ def test_runtime_mem_tracer(): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, _, criterion = get_components_func() - with ColoInitContext(device=torch.device('cpu')): + with ColoInitContext(device='cpu'): model = model_builder(checkpoint=False) model_bk = deepcopy(model) @@ -40,7 +29,7 @@ def test_runtime_mem_tracer(): data = data.cuda() label = label.cuda() - run_fwd_bwd(runtime_mem_tracer, data, label, criterion, False) + run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer) for p1, p2 in zip(model_bk.parameters(), model.parameters()): torch.allclose(p1.to(torch.half), p2)