ColossalAI/tests/test_shardformer/test_model/_utils.py
flybird11111 af6aa9ed06
[plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>
2024-09-27 14:48:55 +08:00

464 lines
16 KiB
Python

import copy
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional, Type
import torch
import torch.distributed as dist
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer
from torch.testing import assert_close
from transformers.modeling_outputs import BaseModelOutputWithPast
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import GaLoreAdamW8bit
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
def build_model(
model_fn,
enable_fused_normalization=True,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
use_lazy_init: bool = False,
dtype=torch.float32,
):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda().to(dtype), sharded_model.cuda().to(dtype)
def build_pipeline_model(
model_fn,
stage_manager=None,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
use_lazy_init: bool = False,
policy: Optional[Policy] = None,
):
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager,
)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
return org_model.cuda(), sharded_model.cuda()
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
original_model.train()
sharded_model.train()
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
org_loss = loss_fn(org_output)
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
org_sd = org_model.state_dict()
shard_sd = sharded_model.state_dict()
for k, v in org_sd.items():
assert k in shard_sd, f"{name} {k} not in sharded model"
shard_v = shard_sd[k]
assert v.shape == shard_v.shape, f"{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}"
assert v.dtype == shard_v.dtype, f"{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}"
assert torch.equal(v, shard_v), f"{name} {k} value mismatch"
def build_model_from_hybrid_plugin(
model_fn: Callable,
loss_fn: Callable,
test_config: Dict[str, Any],
optim_class=Adam,
sharded_optim_class=Adam,
pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
):
use_lazy_init = False
if "use_lazy_init" in test_config:
use_lazy_init = test_config.pop("use_lazy_init")
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model_fn()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
if optim_class == GaLoreAdamW8bit:
# Disable clipping and block-wise quantization
org_optimizer = optim_class(
get_galore_param_groups(org_model, weight_decay=0, rank=4),
lr=1e-3,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
sharded_optimizer = sharded_optim_class(
get_galore_param_groups(sharded_model, weight_decay=0, rank=4),
lr=1e-3,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
else:
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
plugin = pluggin_cls(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return (
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
)
def build_model_from_low_level_zero_plugin(
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
):
use_lazy_init = False
if "use_lazy_init" in test_config:
use_lazy_init = test_config.pop("use_lazy_init")
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model_fn()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
plugin = LowLevelZeroPlugin(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
def run_forward_backward_with_hybrid_plugin(
org_model: Module,
sharded_model: Module,
sharded_optimizer: Optimizer,
data_gen_fn: Callable,
output_transform_fn: Callable,
criterion: Callable,
booster: Booster,
):
org_model.cuda()
sharded_model.cuda()
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
data = data_gen_fn()
shard_test_data = {}
for k, v in data.items():
shard_test_data[k] = data[k].clone()
unshard_test_data = {}
for k, v in data.items():
unshard_test_data[k] = data[k].clone()
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in shard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
shard_test_data[k] = v.to("cuda").repeat(*new_shape)
data_iter = iter([shard_test_data])
sharded_output = booster.execute_pipeline(
data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True,
)
sharded_loss = sharded_output["loss"]
else:
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
sharded_output = sharded_model(**shard_test_data)
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)
org_model.train()
if booster.plugin.stage_manager is not None:
for k, v in unshard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
org_output = org_model(**unshard_test_data)
org_loss = criterion(org_output)
org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output
def run_forward_backward_with_low_level_zero_plugin(
org_model: Module,
sharded_model: Module,
sharded_optimizer: Optimizer,
data_gen_fn: Callable,
output_transform_fn: Callable,
criterion: Callable,
booster: Booster,
):
get_accelerator().get_current_device()
org_model.cuda()
sharded_model.cuda()
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
data = data_gen_fn()
# data = {
# k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
# }
data = {k: v.cuda() for k, v in data.items()}
sharded_model.train()
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)
org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output
def check_output_hidden_state(
org_output: BaseModelOutputWithPast,
sharded_output: BaseModelOutputWithPast,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3,
shard_config: Optional[ShardConfig] = None,
):
org_hidden_state = org_output.last_hidden_state
if stage_manager:
if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else:
sharded_hidden_state = sharded_output.last_hidden_state
elif stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else:
sharded_hidden_state = sharded_output.last_hidden_state
else:
sharded_hidden_state = sharded_output.last_hidden_state
# Check if the output sequence is gathered before cross entropy
if shard_config is not None:
seq_dim = 1
sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
def check_weight(
org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: Optional[ProcessGroup] = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
):
for suffix in layer_suffix:
org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight
# skip if layer is not held by this process
if sharded_weight is None:
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False)
if is_padded_tensor(sharded_weight):
sharded_weight = to_unpadded_tensor(sharded_weight)
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert_close(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol)
def get_grad_tensors_for_check(
org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
name: str = None,
):
grad_to_check = {}
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
try:
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[: org_grad.shape[0], :]
except:
pass
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
grad_to_check[suffix] = {
"org_grad": org_grad.float(),
"shard_grad": shard_grad.float(),
"rtol": rtol,
"atol": atol,
}
return grad_to_check
# used by sam/blip2
def check_grad(
org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
):
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[: org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
assert_close(org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol)
def unwrap_model(
module: Module,
base_model_class_name: Optional[str] = None,
base_model_attribute_name: Optional[str] = None,
):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if base_model_class_name is None:
return module
if module.__class__.__name__ == base_model_class_name:
return module
return getattr(module, base_model_attribute_name, None)
def check_all_grad_tensors(check_tensors):
"""
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
for idx, (suffix, check_info) in enumerate(check_tensors.items()):
org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]
atol = check_info["atol"]
assert_close(org_grad, shard_grad, atol=atol, rtol=rtol)