[shardformer] add util functions for shardformer tests/fix sync_shared_param (#4366)

* add util functions for shardformer tests & rewrite gpt2 test

* fix shared_params & embedding/merging

* fix precision
This commit is contained in:
Baizhou Zhang 2023-08-03 17:50:15 +08:00 committed by Hongxin Liu
parent 5c6f183192
commit b1feeced8e
4 changed files with 189 additions and 113 deletions

View File

@ -37,7 +37,8 @@ class HybridParallelModule(ModelWrapper):
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
if precision == 'fp16':
module = module.half().cuda()
elif precision == 'bf16':

View File

@ -72,7 +72,9 @@ config = transformers.GPT2Config(n_layer=2,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0)
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256)
config_for_token_classification = copy.deepcopy(config)
config_for_token_classification.num_labels = 2

View File

@ -1,11 +1,19 @@
import copy
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
@ -79,20 +87,151 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False):
def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):
use_lazy_init = False
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')
if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
with ctx:
org_model = model_fn().cuda()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
org_model = ctx.materialize(org_model)
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
booster: Booster):
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
data = data_gen_fn()
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True)
sharded_loss = sharded_output['loss']
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_loss.backward()
org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output
def check_output_hidden_state(org_output: Tensor,
sharded_output: Tensor,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3):
org_hidden_state = org_output.last_hidden_state
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
def check_weight(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: Optional[ProcessGroup] = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):
for suffix in layer_suffix:
org_grad = getattr_(original_model, suffix).weight.grad
org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def check_grad(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=dim)
else:
all_shard_grad = shard_grad
shard_grad_list = [
torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}")
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
assert torch.allclose(
org_grad, all_shard_grad, rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}"
org_grad, shard_grad, rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"

View File

@ -1,107 +1,48 @@
import copy
from contextlib import nullcontext
import pytest
import torch
from torch import distributed as dist
from torch.optim import Adam
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import (
clear_layout_converter,
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
use_lazy_init = False
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()
# prepare booster
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
stage_manager = plugin.stage_manager
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
# prepare models and optimizers
with ctx:
org_model = model_fn().cuda()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
org_model = ctx.materialize(org_model)
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
# do forward and backward
data = data_gen_fn()
sharded_model.train()
if stage_manager:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True)
sharded_loss = sharded_output['loss']
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_loss.backward()
org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
# check last hidden state
if org_model.__class__.__name__ == 'GPT2Model':
org_hidden_state = org_output.last_hidden_state
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']],
dim=0)
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
# check loss
assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'GPT2Model':
@ -111,27 +52,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
gpt2 = org_model.transformer
sharded_gpt2 = sharded_model.unwrap().transformer
# check grad
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['h[0].mlp.c_proj']
check_grad(gpt2, sharded_gpt2, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad
if stage_manager is None or stage_manager.is_first_stage():
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
org_weight = org_model.h[0].mlp.c_fc.weight
shard_weight = sharded_model.h[0].mlp.c_fc.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)]
dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group)
shard_weight = torch.cat(shard_weight_list, dim=1)
assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}"
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
torch.cuda.empty_cache()
@ -156,9 +89,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@clear_cache_before_run()
def run_gpt2_test(test_config):
# TODO: add plugin_config for TP+DP after supporting & debugging it
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
@ -175,7 +110,6 @@ def check_gpt2(rank, world_size, port):
run_gpt2_test()
@pytest.mark.skip('Have some bug caused by merge')
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()