mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
fix precommit
This commit is contained in:
@@ -16,8 +16,6 @@ if HAS_COMMAND:
|
||||
# ===============================
|
||||
|
||||
def data_gen():
|
||||
|
||||
|
||||
input_ids = torch.Tensor(
|
||||
[
|
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||
|
@@ -79,10 +79,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(
|
||||
command_model, shard_command_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||
command_model,
|
||||
shard_command_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False,
|
||||
)
|
||||
col_layer_grads = get_grad_tensors_for_check(
|
||||
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
command_model,
|
||||
shard_command_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
norm_layer_grads = get_grad_tensors_for_check(
|
||||
command_model,
|
||||
@@ -121,7 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
command_model,
|
||||
shard_command_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# check grads
|
||||
|
Reference in New Issue
Block a user