optimized context test time consumption (#446)

This commit is contained in:
Frank Lee
2022-03-17 14:40:52 +08:00
committed by GitHub
parent 496cbb0760
commit b72b8445c6
8 changed files with 169 additions and 357 deletions

View File

@@ -5,6 +5,7 @@ import pytest
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import assert_close_loose
from colossalai.utils import free_port
from functools import partial
@@ -48,7 +49,7 @@ def run_naive_amp():
# forward pass
amp_output = amp_model(data)
torch_output = torch_model(data)
assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
assert_close_loose(amp_output, torch_output)
# backward
amp_optimizer.backward(amp_output.mean())
@@ -56,7 +57,7 @@ def run_naive_amp():
# check grad
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
assert_close_loose(amp_param.grad, torch_param.grad.half())
# step
amp_optimizer.step()
@@ -64,7 +65,7 @@ def run_naive_amp():
# check updated param
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
assert_close_loose(amp_param, torch_param.half())
def run_dist(rank, world_size, port):