mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[testing] add beit model for unit testings (#2196)
* [testing] add beit model * [beit] fix bugs * [beit] fix bugs * [testing] fix bugs
This commit is contained in:
@@ -26,7 +26,7 @@ from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
# this model is large enough to slice to chunks
|
||||
TEST_MODELS = ['gpt2']
|
||||
# these models are too small, all parameters in these models are compacted into one chunk
|
||||
EXAMPLE_MODELS = ['albert', 'hanging_param_model', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
|
||||
EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers']
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
@@ -142,7 +142,7 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||
|
||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss)
|
||||
assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
Reference in New Issue
Block a user