mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[misc] Accelerate CI for zero and dist optim (#5758)
* remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -3,7 +3,6 @@ import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer._operation import _gather
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import parameterize, spawn
|
||||
@@ -119,11 +118,15 @@ def run_bert_test(test_config, optim_class, sharded_optim_class):
|
||||
test_config["use_lazy_init"] = False
|
||||
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
|
||||
test_config["initial_scale"] = 2**15 # avoid overflow
|
||||
target_models = [
|
||||
"transformers_bert",
|
||||
]
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_bert_fwd_bwd(
|
||||
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
|
||||
)
|
||||
if name in target_models:
|
||||
check_bert_fwd_bwd(
|
||||
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
|
||||
)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
@@ -152,7 +155,8 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
|
||||
shard_spec = sharded_optimizer.shard_spec_dict[id(tp)]
|
||||
use_zero = sharded_optimizer.use_zero
|
||||
tp_optim_state = tp_state[key]
|
||||
p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape
|
||||
state = p_state[key]
|
||||
|
||||
dp_size, tp_size = (
|
||||
sharded_optimizer.dp_size,
|
||||
sharded_optimizer.tp_size,
|
||||
@@ -165,88 +169,54 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
|
||||
if shard_spec.sharding_sequence[0] == "R":
|
||||
if use_zero:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
if key == "exp_avg_sq_row":
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
|
||||
# gather from tp group
|
||||
# sq_row don need gather alone tp group
|
||||
if key == "exp_avg_sq_row":
|
||||
pass
|
||||
# sq_col need gather alone dp group
|
||||
# sq_col need gather alone tp group
|
||||
if key == "exp_avg_sq_col":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
|
||||
)
|
||||
tp_optim_state.shape
|
||||
|
||||
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
|
||||
# row parallel
|
||||
if shard_spec.sharding_sequence[-1] == "R":
|
||||
if use_zero:
|
||||
elif shard_spec.sharding_sequence[-1] == "R":
|
||||
# TODO: this case may cause shape mismatch @duanjunwen
|
||||
if use_zero and key == "exp_avg_sq_row" and state.shape[0] // tp_size % dp_size == 0:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
if p_state[key].shape[0] // tp_size % dp_size != 0:
|
||||
pass
|
||||
else:
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
|
||||
# gather from tp group
|
||||
# sq_row need gather alone tp group
|
||||
if key == "exp_avg_sq_row":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
|
||||
)
|
||||
tp_optim_state.shape
|
||||
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
return
|
||||
else:
|
||||
if use_zero:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
# row residule; no gather
|
||||
if p_state[key].shape[0] % dp_size != 0:
|
||||
if state.shape[0] % dp_size != 0:
|
||||
pass
|
||||
else:
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
tp_optim_state = tp_optim_state.div_(dp_size)
|
||||
# need a div;
|
||||
else:
|
||||
pass
|
||||
# Sovled a New issus: different dtype;
|
||||
# So far, only happen in H100 env;
|
||||
# Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision;
|
||||
# Or assert_close just update to check dtype;
|
||||
if p_state[key].dtype != tp_optim_state.dtype:
|
||||
tp_optim_state = tp_optim_state.type(p_state[key].dtype)
|
||||
try:
|
||||
assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2)
|
||||
except:
|
||||
pass
|
||||
|
||||
if state.dtype != tp_optim_state.dtype:
|
||||
tp_optim_state = tp_optim_state.type(state.dtype)
|
||||
# TODO: some sharding checks are currently buggy, but the state values should match
|
||||
# @duanjunwen
|
||||
if state.shape != tp_optim_state.shape:
|
||||
return
|
||||
assert_close(state, tp_optim_state, atol=5e-4, rtol=1.6e-2)
|
||||
|
||||
|
||||
def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol):
|
||||
|
@@ -7,14 +7,11 @@ from torch import nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.nn.optimizer.adafactor import Adafactor
|
||||
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer._operation import _gather
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
@@ -59,7 +56,6 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc
|
||||
rtol = 4e-3
|
||||
atol = 4e-3
|
||||
|
||||
# return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol))
|
||||
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@@ -194,7 +190,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
||||
# Col Parallel
|
||||
# ==============================
|
||||
weight_col_shard = shard_colwise(weight.clone(), tp_group)
|
||||
weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape
|
||||
weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec
|
||||
weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True))
|
||||
bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True))
|
||||
@@ -203,17 +198,12 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
||||
# Row Parallel
|
||||
# ==============================
|
||||
weight_row_shard = shard_rowwise(weight.clone(), tp_group)
|
||||
weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape
|
||||
weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec
|
||||
weight_row_shard_flatten = nn.Parameter(
|
||||
weight_row_shard.clone().flatten().requires_grad_(True)
|
||||
) # flatten input(not dtensor) to optimizer
|
||||
bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True))
|
||||
|
||||
# base_param_group = setup_param_groups([weight, bias])
|
||||
# cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten])
|
||||
# rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten])
|
||||
|
||||
# ==============================
|
||||
# Init Optimizer
|
||||
# ==============================
|
||||
@@ -267,19 +257,11 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
||||
bias_row_flatten.grad = bias.grad.clone().flatten()
|
||||
rp_dist_optim.step()
|
||||
|
||||
# gather result
|
||||
weight_col_gather = _gather(
|
||||
input_=weight_col_shard_flatten.data.view(-1, H // tp_size),
|
||||
dim=-1,
|
||||
process_group=tp_group,
|
||||
) # gather
|
||||
weight_row_gather = _gather(input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group).view(
|
||||
-1, W
|
||||
) # gather
|
||||
|
||||
weight_row_chunk = weight.t().reshape(-1, W).chunk(tp_size, dim=-1)[dist.get_rank(tp_group)].flatten()
|
||||
weight_col_chunk = weight.reshape(-1, H).chunk(tp_size, dim=-1)[dist.get_rank(tp_group)].flatten()
|
||||
# verify
|
||||
correctness_verify(weight.data, weight_col_gather.data, dtype)
|
||||
correctness_verify(weight.data, weight_row_gather.data, dtype)
|
||||
correctness_verify(weight_col_chunk, weight_col_shard_flatten, dtype)
|
||||
correctness_verify(weight_row_chunk, weight_row_shard_flatten, dtype)
|
||||
|
||||
print(f"Base Test Passed")
|
||||
|
||||
@@ -307,7 +289,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
||||
|
||||
base_param_group = setup_param_groups(base_model)
|
||||
tp_param_group = setup_param_groups(tp_model)
|
||||
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
|
||||
# tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
|
||||
|
||||
# ==============================
|
||||
# Optimizer Init
|
||||
@@ -378,143 +360,21 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
||||
if len(shard_spec.sharding_sequence) >= 2:
|
||||
# Col Parallel
|
||||
if shard_spec.sharding_sequence[0] == "R":
|
||||
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
|
||||
p = p.chunk(tp_size, dim=-1)[dist.get_rank(tp_group)]
|
||||
# ROW Parallel
|
||||
if shard_spec.sharding_sequence[-1] == "R":
|
||||
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
|
||||
p = p.chunk(tp_size, dim=0)[dist.get_rank(tp_group)]
|
||||
else:
|
||||
# TP bias
|
||||
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
|
||||
else:
|
||||
# No TP bias
|
||||
pass
|
||||
correctness_verify(p.data, tp_p.data, dtype)
|
||||
p = p.chunk(tp_size, dim=-1)[dist.get_rank(tp_group)]
|
||||
|
||||
correctness_verify(p, tp_p, dtype)
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Zero Test Passed")
|
||||
|
||||
|
||||
@parameterize("dtype", [torch.float16])
|
||||
@parameterize("tp_zero_size", [(1, 4)])
|
||||
def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
||||
tp_size, zero_size = tp_zero_size
|
||||
use_zero = True if zero_size > 1 else False
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
clear_layout_converter()
|
||||
|
||||
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
|
||||
tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1)
|
||||
|
||||
torch.set_default_dtype(dtype)
|
||||
set_seed(42)
|
||||
|
||||
# ==============================
|
||||
# Model Init
|
||||
# ==============================
|
||||
base_model = MlpModel().to(local_rank)
|
||||
# tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
|
||||
tp_model = copy.deepcopy(base_model).to(local_rank)
|
||||
|
||||
base_param_group = setup_param_groups(base_model)
|
||||
tp_param_group = setup_param_groups(tp_model)
|
||||
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
|
||||
|
||||
# ==============================
|
||||
# Optimizer Init
|
||||
# ==============================
|
||||
base_optim = Adafactor(base_param_group)
|
||||
dist_optim = DistributedAdaFactor(tp_param_group)
|
||||
|
||||
# Setup distributed optimizer
|
||||
if zero_size > 1:
|
||||
base_optim = LowLevelZeroOptimizer(
|
||||
base_optim,
|
||||
overlap_communication=True,
|
||||
initial_scale=128,
|
||||
partition_grad=True,
|
||||
dp_process_group=dp_group,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
dist_optim = LowLevelZeroOptimizer(
|
||||
dist_optim,
|
||||
overlap_communication=True,
|
||||
initial_scale=128,
|
||||
partition_grad=True,
|
||||
dp_process_group=dp_group,
|
||||
verbose=True,
|
||||
)
|
||||
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
|
||||
dist_optim.optim.setup_distributed(
|
||||
tp_group=tp_group,
|
||||
dp_group=dp_group,
|
||||
shard_to_working_param=shard_to_param,
|
||||
use_zero=use_zero,
|
||||
)
|
||||
else:
|
||||
shard_to_param = set_master_param_to_shard_param(tp_param_group)
|
||||
dist_optim.setup_distributed(
|
||||
tp_group=tp_group,
|
||||
dp_group=dp_group,
|
||||
shard_to_working_param=shard_to_param,
|
||||
use_zero=use_zero,
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Booster Init
|
||||
# ==============================
|
||||
plugin = LowLevelZeroPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
criterion = lambda x: x.mean()
|
||||
|
||||
tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion)
|
||||
|
||||
# ==============================
|
||||
# Correctness Verify
|
||||
# ==============================
|
||||
x = torch.randn(HEIGHT, WIDTH, device=local_rank)
|
||||
|
||||
out = base_model(x)
|
||||
out_tp = tp_model(x)
|
||||
|
||||
if zero_size > 1:
|
||||
dist_optim.backward(out_tp.sum())
|
||||
base_optim.backward(out.sum())
|
||||
else:
|
||||
out_tp.sum().backward()
|
||||
out.sum().backward()
|
||||
|
||||
base_optim.step()
|
||||
dist_optim.step()
|
||||
|
||||
base_optim.zero_grad()
|
||||
dist_optim.zero_grad()
|
||||
|
||||
for p, tp_p in zip(base_param_group, tp_param_group):
|
||||
param_is_distributed = is_distributed_tensor(tp_p)
|
||||
if param_is_distributed:
|
||||
shard_spec = get_sharding_spec(tp_p)
|
||||
if len(shard_spec.sharding_sequence) >= 2:
|
||||
# Col Parallel
|
||||
if shard_spec.sharding_sequence[0] == "R":
|
||||
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
|
||||
# ROW Parallel
|
||||
if shard_spec.sharding_sequence[-1] == "R":
|
||||
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
|
||||
else:
|
||||
# TP bias
|
||||
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
|
||||
else:
|
||||
# No TP bias
|
||||
pass
|
||||
correctness_verify(p.data, tp_p.data, dtype)
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Booster Test Passed")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
@@ -532,14 +392,6 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
"transformers_bert_for_pretraining",
|
||||
"transformers_bert_lm_head_model",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bert_for_sequence_classification",
|
||||
"transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_next_sentence",
|
||||
"transformers_bert_for_mcq",
|
||||
"transformers_bert_for_question_answering",
|
||||
]
|
||||
clear_layout_converter()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
@@ -627,14 +479,6 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
||||
test_config["initial_scale"] = 2**16 # avoid overflow
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
"transformers_bert_for_pretraining",
|
||||
"transformers_bert_lm_head_model",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bert_for_sequence_classification",
|
||||
"transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_next_sentence",
|
||||
"transformers_bert_for_mcq",
|
||||
"transformers_bert_for_question_answering",
|
||||
]
|
||||
clear_layout_converter()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
@@ -673,6 +517,7 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
||||
# check optim states
|
||||
check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Bert Model Zoo Test Passed")
|
||||
@@ -681,11 +526,10 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
exam_bert_test_on_lowlevelzero_plugin()
|
||||
exam_bert_test_on_hybrid_plugin()
|
||||
exam_dist_adafactor_base()
|
||||
exam_dist_adafactor_zero()
|
||||
exam_dist_adafactor_booster()
|
||||
exam_bert_test_on_lowlevelzero_plugin()
|
||||
exam_bert_test_on_hybrid_plugin()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -287,15 +287,6 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
|
||||
# test_config["initial_scale"] = 1
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
"transformers_bert_for_pretraining",
|
||||
"transformers_bert_lm_head_model",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bert_for_sequence_classification",
|
||||
"transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_next_sentence",
|
||||
"transformers_bert_for_mcq",
|
||||
"transformers_bert_for_question_answering",
|
||||
"simple_mlp",
|
||||
]
|
||||
clear_layout_converter()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
@@ -389,14 +380,6 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
||||
test_config["initial_scale"] = 2**16 # avoid overflow
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
"transformers_bert_for_pretraining",
|
||||
"transformers_bert_lm_head_model",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bert_for_sequence_classification",
|
||||
"transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_next_sentence",
|
||||
"transformers_bert_for_mcq",
|
||||
"transformers_bert_for_question_answering",
|
||||
]
|
||||
|
||||
# pass "transformers_bert",
|
||||
|
@@ -18,7 +18,6 @@ from tests.test_optimizer._utils import check_optim_states, run_bert_test
|
||||
|
||||
_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.float), # pure fp32
|
||||
(torch.float, torch.half), # fp16 amp
|
||||
(torch.float, torch.bfloat16), # bfloat16 amp
|
||||
]
|
||||
|
||||
@@ -264,7 +263,6 @@ def run_dist_lamb_fwd_bwd(
|
||||
|
||||
torch_optim.step()
|
||||
optim.step()
|
||||
dist.barrier()
|
||||
torch_optim.zero_grad()
|
||||
optim.zero_grad()
|
||||
try:
|
||||
|
Reference in New Issue
Block a user