[misc] Accelerate CI for zero and dist optim (#5758)

* remove fp16 from lamb

* remove d2h copy in checking states

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-06-05 11:25:19 +08:00
committed by GitHub
parent 50b4c8e8cf
commit 79f7a7b211
12 changed files with 65 additions and 400 deletions

View File

@@ -3,7 +3,6 @@ import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, spawn
@@ -119,11 +118,15 @@ def run_bert_test(test_config, optim_class, sharded_optim_class):
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**15 # avoid overflow
target_models = [
"transformers_bert",
]
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_bert_fwd_bwd(
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
)
if name in target_models:
check_bert_fwd_bwd(
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
)
clear_layout_converter()
Randomizer.reset_index()
@@ -152,7 +155,8 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
shard_spec = sharded_optimizer.shard_spec_dict[id(tp)]
use_zero = sharded_optimizer.use_zero
tp_optim_state = tp_state[key]
p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape
state = p_state[key]
dp_size, tp_size = (
sharded_optimizer.dp_size,
sharded_optimizer.tp_size,
@@ -165,88 +169,54 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
if shard_spec.sharding_sequence[0] == "R":
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
pass
if key == "exp_avg_sq_row":
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
# gather from tp group
# sq_row don need gather alone tp group
if key == "exp_avg_sq_row":
pass
# sq_col need gather alone dp group
# sq_col need gather alone tp group
if key == "exp_avg_sq_col":
tp_optim_state = _gather(
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
)
tp_optim_state.shape
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
# row parallel
if shard_spec.sharding_sequence[-1] == "R":
if use_zero:
elif shard_spec.sharding_sequence[-1] == "R":
# TODO: this case may cause shape mismatch @duanjunwen
if use_zero and key == "exp_avg_sq_row" and state.shape[0] // tp_size % dp_size == 0:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
if p_state[key].shape[0] // tp_size % dp_size != 0:
pass
else:
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
pass
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
# gather from tp group
# sq_row need gather alone tp group
if key == "exp_avg_sq_row":
tp_optim_state = _gather(
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
)
tp_optim_state.shape
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
return
else:
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
# row residule; no gather
if p_state[key].shape[0] % dp_size != 0:
if state.shape[0] % dp_size != 0:
pass
else:
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
tp_optim_state = tp_optim_state.div_(dp_size)
# need a div;
else:
pass
# Sovled a New issus: different dtype;
# So far, only happen in H100 env;
# Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision;
# Or assert_close just update to check dtype;
if p_state[key].dtype != tp_optim_state.dtype:
tp_optim_state = tp_optim_state.type(p_state[key].dtype)
try:
assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2)
except:
pass
if state.dtype != tp_optim_state.dtype:
tp_optim_state = tp_optim_state.type(state.dtype)
# TODO: some sharding checks are currently buggy, but the state values should match
# @duanjunwen
if state.shape != tp_optim_state.shape:
return
assert_close(state, tp_optim_state, atol=5e-4, rtol=1.6e-2)
def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol):