[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2024-06-14 08:05:06 +00:00
parent 1016bb3257
commit 2a7fa2e7d0
7 changed files with 44 additions and 36 deletions

View File

@@ -16,8 +16,6 @@ if HAS_COMMAND:
# ===============================
def data_gen():
input_ids = torch.Tensor(
[
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],

View File

@@ -79,10 +79,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
command_model, shard_command_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
command_model,
shard_command_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
command_model,
shard_command_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
norm_layer_grads = get_grad_tensors_for_check(
command_model,
@@ -121,7 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
check_weight(
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
command_model,
shard_command_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
# check grads