ColossalAI/tests/test_shardformer/test_model/_utils.py
Edenzzzz 43995ee436
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694)
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476)

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [optim] add distributed came (#5526)

* test CAME under LowLevelZeroOptimizer wrapper

* test CAME TP row and col pass

* test CAME zero pass

* came zero add master and worker param id convert

* came zero test pass

* came zero test pass

* test distributed came passed

* reform code, Modify some expressions and add comments

* minor fix of test came

* minor fix of dist_came and test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix of dist_came and test

* rebase dist-optim

* rebase dist-optim

* fix remaining comments

* add test dist came using booster api

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [optim] Distributed Adafactor (#5484)

* [feature] solve conflict; update optimizer readme;

* [feature] update optimize readme;

* [fix] fix testcase;

* [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel);

* [feature] Add transformers_bert model zoo in testcase;

* [feature] add user documentation to docs/source/feature.

* [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam;

* [feature] modify user documentation;

* [fix] fix readme format issue;

* [fix] add zero=0 in testcase; cached augment in dict;

* [fix] fix percision issue;

* [feature] add distributed rms;

* [feature] remove useless comment in testcase;

* [fix] Remove useless test; open zero test; remove fp16 test in bert exam;

* [feature] Extract distributed rms function;

* [feature] add booster + lowlevelzeroPlugin in test;

* [feature] add Start_with_booster_API case in md; add Supporting Information in md;

* [fix] Also remove state movement in base adafactor;

* [feature] extract factor function;

* [feature] add LowLevelZeroPlugin test;

* [fix] add tp=False and zero=True in logic;

* [fix] fix use zero logic;

* [feature] add row residue logic in column parallel factor;

* [feature] add check optim state func;

* [feature] Remove duplicate logic;

* [feature] update optim state check func and percision test bug;

* [fix] update/fix optim state; Still exist percision issue;

* [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info;

* [feature] removed print & comments in utils;

* [feature] uodate Readme;

* [feature] add LowLevelZeroPlugin test with Bert model zoo;

* [fix] fix logic in _rms;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] remove comments in testcase;

* [feature] add zh-Han Readme;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676)

* [feature] daily update;

* [fix] fix dist came;

* [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo;

* [fix] open rms; fix low level zero test; fix dist came test function name;

* [fix] remove redundant test;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570)

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

* update comments

* add initial distributed galore

* add initial distributed galore

* add galore set param utils; change setup_distributed interface

* projected grad precision passed

* basic precision tests passed

* tests passed; located svd precision issue in fwd-bwd; banned these tests

* Plugin DP + TP tests passed

* move get_shard_dim to d_tensor

* add comments

* remove useless files

* remove useless files

* fix zero typo

* improve interface

* remove moe changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix import

* fix deepcopy

* update came & adafactor to main

* fix param map

* fix typo

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692)


Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
2024-05-14 13:52:45 +08:00

443 lines
15 KiB
Python

import copy
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
def build_model(
model_fn,
enable_fused_normalization=True,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
use_lazy_init: bool = False,
dtype=torch.float32,
):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda().to(dtype), sharded_model.cuda().to(dtype)
def build_pipeline_model(
model_fn,
stage_manager=None,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
use_lazy_init: bool = False,
policy: Optional[Policy] = None,
):
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager,
)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
return org_model.cuda(), sharded_model.cuda()
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
original_model.train()
sharded_model.train()
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
org_loss = loss_fn(org_output)
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
org_sd = org_model.state_dict()
shard_sd = sharded_model.state_dict()
for k, v in org_sd.items():
assert k in shard_sd, f"{name} {k} not in sharded model"
shard_v = shard_sd[k]
assert v.shape == shard_v.shape, f"{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}"
assert v.dtype == shard_v.dtype, f"{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}"
assert torch.equal(v, shard_v), f"{name} {k} value mismatch"
def build_model_from_hybrid_plugin(
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
):
use_lazy_init = False
if "use_lazy_init" in test_config:
use_lazy_init = test_config.pop("use_lazy_init")
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model_fn()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
if sharded_optim_class == DistGaloreAwamW:
# Disable clipping and block-wise quantization
org_optimizer = optim_class(
get_galore_param_groups(org_model, weight_decay=0, rank=4),
lr=1e-3,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
sharded_optimizer = sharded_optim_class(
get_galore_param_groups(sharded_model, weight_decay=0, rank=4),
lr=1e-3,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
else:
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return (
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
)
def build_model_from_low_level_zero_plugin(
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
):
use_lazy_init = False
if "use_lazy_init" in test_config:
use_lazy_init = test_config.pop("use_lazy_init")
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model_fn()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
plugin = LowLevelZeroPlugin(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
def run_forward_backward_with_hybrid_plugin(
org_model: Module,
sharded_model: Module,
sharded_optimizer: Optimizer,
data_gen_fn: Callable,
output_transform_fn: Callable,
criterion: Callable,
booster: Booster,
):
org_model.cuda()
sharded_model.cuda()
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
data = data_gen_fn()
shard_test_data = {}
for k, v in data.items():
shard_test_data[k] = data[k].clone()
unshard_test_data = {}
for k, v in data.items():
unshard_test_data[k] = data[k].clone()
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in shard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
shard_test_data[k] = v.to("cuda").repeat(*new_shape)
data_iter = iter([shard_test_data])
sharded_output = booster.execute_pipeline(
data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True,
)
sharded_loss = sharded_output["loss"]
else:
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
sharded_output = sharded_model(**shard_test_data)
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)
org_model.train()
if booster.plugin.stage_manager is not None:
for k, v in unshard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
org_output = org_model(**unshard_test_data)
org_loss = criterion(org_output)
org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output
def run_forward_backward_with_low_level_zero_plugin(
org_model: Module,
sharded_model: Module,
sharded_optimizer: Optimizer,
data_gen_fn: Callable,
output_transform_fn: Callable,
criterion: Callable,
booster: Booster,
):
get_accelerator().get_current_device()
org_model.cuda()
sharded_model.cuda()
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
data = data_gen_fn()
# data = {
# k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
# }
data = {k: v.cuda() for k, v in data.items()}
sharded_model.train()
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)
org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output
def check_output_hidden_state(
org_output: Tensor,
sharded_output: Tensor,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3,
):
org_hidden_state = org_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else:
sharded_hidden_state = sharded_output.last_hidden_state
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
def check_weight(
org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: Optional[ProcessGroup] = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
):
for suffix in layer_suffix:
org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight
# skip if layer is not held by this process
if sharded_weight is None:
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False)
if is_padded_tensor(sharded_weight):
sharded_weight = to_unpadded_tensor(sharded_weight)
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert_close(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol)
def get_grad_tensors_for_check(
org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
name: str = None,
):
grad_to_check = {}
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[: org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
grad_to_check[suffix] = {
"org_grad": org_grad.float(),
"shard_grad": shard_grad.float(),
"rtol": rtol,
"atol": atol,
}
return grad_to_check
# used by sam/blip2
def check_grad(
org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
):
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
# if verbose and dist.get_rank() == 0:
# print("shard_weight", shard_weight)
# print("org_grad", org_grad)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[: org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
assert_close(org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol)
def unwrap_model(
module: Module,
base_model_class_name: Optional[str] = None,
base_model_attribute_name: Optional[str] = None,
):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if base_model_class_name is None:
return module
if module.__class__.__name__ == base_model_class_name:
return module
return getattr(module, base_model_attribute_name, None)
def check_all_grad_tensors(check_tensors):
"""
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
for suffix, check_info in check_tensors.items():
org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]
atol = check_info["atol"]
assert_close(org_grad, shard_grad, atol=atol, rtol=rtol)