mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[misc] Accelerate CI for zero and dist optim (#5758)
* remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -3,7 +3,6 @@ import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer._operation import _gather
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import parameterize, spawn
|
||||
@@ -119,11 +118,15 @@ def run_bert_test(test_config, optim_class, sharded_optim_class):
|
||||
test_config["use_lazy_init"] = False
|
||||
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
|
||||
test_config["initial_scale"] = 2**15 # avoid overflow
|
||||
target_models = [
|
||||
"transformers_bert",
|
||||
]
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_bert_fwd_bwd(
|
||||
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
|
||||
)
|
||||
if name in target_models:
|
||||
check_bert_fwd_bwd(
|
||||
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
|
||||
)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
@@ -152,7 +155,8 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
|
||||
shard_spec = sharded_optimizer.shard_spec_dict[id(tp)]
|
||||
use_zero = sharded_optimizer.use_zero
|
||||
tp_optim_state = tp_state[key]
|
||||
p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape
|
||||
state = p_state[key]
|
||||
|
||||
dp_size, tp_size = (
|
||||
sharded_optimizer.dp_size,
|
||||
sharded_optimizer.tp_size,
|
||||
@@ -165,88 +169,54 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
|
||||
if shard_spec.sharding_sequence[0] == "R":
|
||||
if use_zero:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
if key == "exp_avg_sq_row":
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
|
||||
# gather from tp group
|
||||
# sq_row don need gather alone tp group
|
||||
if key == "exp_avg_sq_row":
|
||||
pass
|
||||
# sq_col need gather alone dp group
|
||||
# sq_col need gather alone tp group
|
||||
if key == "exp_avg_sq_col":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
|
||||
)
|
||||
tp_optim_state.shape
|
||||
|
||||
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
|
||||
# row parallel
|
||||
if shard_spec.sharding_sequence[-1] == "R":
|
||||
if use_zero:
|
||||
elif shard_spec.sharding_sequence[-1] == "R":
|
||||
# TODO: this case may cause shape mismatch @duanjunwen
|
||||
if use_zero and key == "exp_avg_sq_row" and state.shape[0] // tp_size % dp_size == 0:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
if p_state[key].shape[0] // tp_size % dp_size != 0:
|
||||
pass
|
||||
else:
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
|
||||
# gather from tp group
|
||||
# sq_row need gather alone tp group
|
||||
if key == "exp_avg_sq_row":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
|
||||
)
|
||||
tp_optim_state.shape
|
||||
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
return
|
||||
else:
|
||||
if use_zero:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
# row residule; no gather
|
||||
if p_state[key].shape[0] % dp_size != 0:
|
||||
if state.shape[0] % dp_size != 0:
|
||||
pass
|
||||
else:
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
tp_optim_state = tp_optim_state.div_(dp_size)
|
||||
# need a div;
|
||||
else:
|
||||
pass
|
||||
# Sovled a New issus: different dtype;
|
||||
# So far, only happen in H100 env;
|
||||
# Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision;
|
||||
# Or assert_close just update to check dtype;
|
||||
if p_state[key].dtype != tp_optim_state.dtype:
|
||||
tp_optim_state = tp_optim_state.type(p_state[key].dtype)
|
||||
try:
|
||||
assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2)
|
||||
except:
|
||||
pass
|
||||
|
||||
if state.dtype != tp_optim_state.dtype:
|
||||
tp_optim_state = tp_optim_state.type(state.dtype)
|
||||
# TODO: some sharding checks are currently buggy, but the state values should match
|
||||
# @duanjunwen
|
||||
if state.shape != tp_optim_state.shape:
|
||||
return
|
||||
assert_close(state, tp_optim_state, atol=5e-4, rtol=1.6e-2)
|
||||
|
||||
|
||||
def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol):
|
||||
|
Reference in New Issue
Block a user