add bert for unitest and sharded model is not able to pass the bert case

This commit is contained in:
jiaruifang
2022-03-09 10:39:02 +08:00
committed by ver217
parent e35f66b159
commit d984dff88c
6 changed files with 104 additions and 14 deletions

View File

@@ -75,7 +75,7 @@ def check_grads_padding(model, zero_model, loose=False):
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose), f'{grad} vs {zero_grad}'
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
def check_params_padding(model, zero_model, loose=False):