mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
optimized context test time consumption (#446)
This commit is contained in:
@@ -5,6 +5,7 @@ import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_naive_amp
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.testing import assert_close_loose
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
|
||||
@@ -48,7 +49,7 @@ def run_naive_amp():
|
||||
# forward pass
|
||||
amp_output = amp_model(data)
|
||||
torch_output = torch_model(data)
|
||||
assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
|
||||
assert_close_loose(amp_output, torch_output)
|
||||
|
||||
# backward
|
||||
amp_optimizer.backward(amp_output.mean())
|
||||
@@ -56,7 +57,7 @@ def run_naive_amp():
|
||||
|
||||
# check grad
|
||||
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
|
||||
torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
|
||||
assert_close_loose(amp_param.grad, torch_param.grad.half())
|
||||
|
||||
# step
|
||||
amp_optimizer.step()
|
||||
@@ -64,7 +65,7 @@ def run_naive_amp():
|
||||
|
||||
# check updated param
|
||||
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
|
||||
torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
|
||||
assert_close_loose(amp_param, torch_param.half())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
Reference in New Issue
Block a user