[shardformer] add util functions for shardformer tests/fix sync_shared_param (#4366)

* add util functions for shardformer tests & rewrite gpt2 test

* fix shared_params & embedding/merging

* fix precision
This commit is contained in:
Baizhou Zhang 2023-08-03 17:50:15 +08:00 committed by Hongxin Liu
parent 5c6f183192
commit b1feeced8e
4 changed files with 189 additions and 113 deletions

View File

@ -37,7 +37,8 @@ class HybridParallelModule(ModelWrapper):
self.shared_param_process_groups = [] self.shared_param_process_groups = []
for shared_param in self.shared_params: for shared_param in self.shared_params:
if len(shared_param) > 0: if len(shared_param) > 0:
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
if precision == 'fp16': if precision == 'fp16':
module = module.half().cuda() module = module.half().cuda()
elif precision == 'bf16': elif precision == 'bf16':

View File

@ -72,7 +72,9 @@ config = transformers.GPT2Config(n_layer=2,
embd_pdrop=0, embd_pdrop=0,
resid_pdrop=0, resid_pdrop=0,
summary_first_dropout=0, summary_first_dropout=0,
hidden_dropout=0) hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256)
config_for_token_classification = copy.deepcopy(config) config_for_token_classification = copy.deepcopy(config)
config_for_token_classification.num_labels = 2 config_for_token_classification.num_labels = 2

View File

@ -1,11 +1,19 @@
import copy import copy
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module from torch.nn import Module
from torch.optim import Adam, Optimizer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_ from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
@ -79,20 +87,151 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
assert torch.equal(v, shard_v), f'{name} {k} value mismatch' assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False): def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):
use_lazy_init = False
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')
if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
with ctx:
org_model = model_fn().cuda()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
org_model = ctx.materialize(org_model)
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
booster: Booster):
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
data = data_gen_fn()
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True)
sharded_loss = sharded_output['loss']
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_loss.backward()
org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output
def check_output_hidden_state(org_output: Tensor,
sharded_output: Tensor,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3):
org_hidden_state = org_output.last_hidden_state
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
def check_weight(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: Optional[ProcessGroup] = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):
for suffix in layer_suffix: for suffix in layer_suffix:
org_grad = getattr_(original_model, suffix).weight.grad org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def check_grad(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())] shard_grad_list = [
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
all_shard_grad = torch.cat(shard_grad_list, dim=dim) ]
else: dist.all_gather(shard_grad_list, shard_grad, tp_group)
all_shard_grad = shard_grad shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}") print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
assert torch.allclose( assert torch.allclose(
org_grad, all_shard_grad, rtol=rtol, atol=atol org_grad, shard_grad, rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}" ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"

View File

@ -1,107 +1,48 @@
import copy
from contextlib import nullcontext
import pytest import pytest
import torch import torch
from torch import distributed as dist from torch import distributed as dist
from torch.optim import Adam
import colossalai import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import ( from colossalai.tensor.d_tensor.api import clear_layout_converter
clear_layout_converter,
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
use_lazy_init = False org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
if 'use_lazy_init' in test_config: build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
use_lazy_init = test_config.pop('use_lazy_init')
if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()
# prepare booster org_loss, org_output, sharded_loss, sharded_output = \
plugin = HybridParallelPlugin(**test_config) run_forward_backward_with_hybrid_plugin(
booster = Booster(plugin=plugin) org_model,
stage_manager = plugin.stage_manager
# prepare models and optimizers
with ctx:
org_model = model_fn().cuda()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
org_model = ctx.materialize(org_model)
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
# do forward and backward
data = data_gen_fn()
sharded_model.train()
if stage_manager:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model, sharded_model,
_criterion,
sharded_optimizer, sharded_optimizer,
return_loss=True, data_gen_fn,
return_outputs=True) output_transform_fn,
sharded_loss = sharded_output['loss'] criterion,
else: booster)
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_loss.backward()
org_model.train() stage_manager = booster.plugin.stage_manager
org_output = org_model(**data) tp_group = booster.plugin.tp_group
org_loss = criterion(org_output)
org_loss.backward()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage():
# check last hidden state
if org_model.__class__.__name__ == 'GPT2Model': if org_model.__class__.__name__ == 'GPT2Model':
org_hidden_state = org_output.last_hidden_state check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
if stage_manager is None: check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']],
dim=0)
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
# check loss
assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
# unwrap model # unwrap model
if org_model.__class__.__name__ == 'GPT2Model': if org_model.__class__.__name__ == 'GPT2Model':
@ -111,27 +52,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
gpt2 = org_model.transformer gpt2 = org_model.transformer
sharded_gpt2 = sharded_model.unwrap().transformer sharded_gpt2 = sharded_model.unwrap().transformer
# check grad
col_layer_for_check = ['h[0].mlp.c_fc'] col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['h[0].mlp.c_proj'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
check_grad(gpt2, sharded_gpt2, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False) # check grad
if stage_manager is None or stage_manager.is_first_stage():
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
# check weights after optimizer.step() # check weights after optimizer.step()
org_optimizer.step() org_optimizer.step()
sharded_optimizer.step() sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
org_weight = org_model.h[0].mlp.c_fc.weight
shard_weight = sharded_model.h[0].mlp.c_fc.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)]
dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group)
shard_weight = torch.cat(shard_weight_list, dim=1)
assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}"
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -156,9 +89,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@clear_cache_before_run() @clear_cache_before_run()
def run_gpt2_test(test_config): def run_gpt2_test(test_config):
# TODO: add plugin_config for TP+DP after supporting & debugging it # TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
@ -175,7 +110,6 @@ def check_gpt2(rank, world_size, port):
run_gpt2_test() run_gpt2_test()
@pytest.mark.skip('Have some bug caused by merge')
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()