mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[shardformer] rewrite tests for opt/bloom/llama/vit/chatglm (#4395)
* rewrite opt tests * rewrite llama tests * rewrite bloom & vit tests * rewrite chatglm tests * fix LinearCol for classfiers * add judge for other tp layers, fix lazy init in util
This commit is contained in:
committed by
Hongxin Liu
parent
21e0a42fd1
commit
7711bd524a
@@ -53,7 +53,8 @@ def data_gen_for_question_answering():
|
||||
# inputs = tokenizer(question, text, return_tensors="pt")
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
|
||||
dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
start_positions = torch.tensor([1], dtype=torch.int64)
|
||||
end_positions = torch.tensor([10], dtype=torch.int64)
|
||||
@@ -73,12 +74,13 @@ loss_fn_for_causal_lm = lambda x: x.loss
|
||||
loss_fn_for_classification = lambda x: x.loss
|
||||
loss_fn_for_question_answering = lambda x: x.loss
|
||||
|
||||
config = transformers.BloomConfig(n_layer=1,
|
||||
config = transformers.BloomConfig(n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=250880,
|
||||
hidden_dropout=0,
|
||||
attention_dropout=0,
|
||||
hidden_size=64)
|
||||
hidden_size=64,
|
||||
pad_token_id=50256)
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_bloom',
|
||||
|
@@ -17,14 +17,24 @@ def data_gen():
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_conditional_generation():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum()
|
||||
loss_fn = lambda x: x.logits.sum()
|
||||
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
|
||||
torch.ones_like(x.last_hidden_state))
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = ChatGLMConfig(num_layers=1,
|
||||
config = ChatGLMConfig(num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
num_attention_heads=8,
|
||||
@@ -33,7 +43,6 @@ config = ChatGLMConfig(num_layers=1,
|
||||
use_cache=True,
|
||||
torch_dtype=torch.float32)
|
||||
|
||||
|
||||
model_zoo.register(name='transformers_chatglm',
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
@@ -43,7 +52,7 @@ model_zoo.register(name='transformers_chatglm',
|
||||
|
||||
model_zoo.register(name="transformers_chatglm_for_conditional_generation",
|
||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
@@ -7,11 +7,7 @@ from ..registry import ModelAttribute, model_zoo
|
||||
# Register single-sentence VIT
|
||||
# ===============================
|
||||
|
||||
config = transformers.ViTConfig(
|
||||
num_hidden_layers=4,
|
||||
# hidden_size=128,
|
||||
# intermediate_size=256,
|
||||
num_attention_heads=4)
|
||||
config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
|
||||
|
||||
|
||||
# define data gen function
|
||||
|
@@ -104,27 +104,22 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
|
||||
if 'use_lazy_init' in test_config:
|
||||
use_lazy_init = test_config.pop('use_lazy_init')
|
||||
|
||||
if use_lazy_init:
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
|
||||
plugin = HybridParallelPlugin(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
org_model = model_fn().cuda()
|
||||
org_model = model_fn()
|
||||
sharded_model = copy.deepcopy(org_model)
|
||||
|
||||
if use_lazy_init:
|
||||
org_model = ctx.materialize(org_model)
|
||||
ctx.materialize(org_model)
|
||||
|
||||
org_model = org_model.cuda()
|
||||
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
||||
criterion = loss_fn
|
||||
|
||||
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
||||
plugin = HybridParallelPlugin(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
||||
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
|
||||
|
||||
|
||||
@@ -142,11 +137,12 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
||||
data = data_gen_fn()
|
||||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
data = {
|
||||
k: v.to('cuda').repeat(*([4] + [1] *
|
||||
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
data[k] = v.to('cuda').repeat(*new_shape)
|
||||
|
||||
data_iter = iter([data])
|
||||
sharded_output = booster.execute_pipeline(data_iter,
|
||||
sharded_model,
|
||||
@@ -176,7 +172,8 @@ def check_output_hidden_state(org_output: Tensor,
|
||||
sharded_output: Tensor,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3):
|
||||
rtol: float = 1e-3,
|
||||
dim: int = 0):
|
||||
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
@@ -184,7 +181,7 @@ def check_output_hidden_state(org_output: Tensor,
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
if stage_manager and stage_manager.is_last_stage():
|
||||
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
|
||||
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
|
||||
|
||||
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
|
||||
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
|
||||
|
@@ -3,57 +3,101 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
org_loss, org_output, sharded_loss, sharded_output = \
|
||||
run_forward_backward_with_hybrid_plugin(
|
||||
org_model,
|
||||
sharded_model,
|
||||
sharded_optimizer,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
criterion,
|
||||
booster)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
|
||||
if org_model.__class__.__name__ == 'BloomModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'BloomModel':
|
||||
bloom = org_model
|
||||
sharded_bloom = sharded_model
|
||||
sharded_bloom = sharded_model.unwrap()
|
||||
else:
|
||||
bloom = org_model.transformer
|
||||
sharded_bloom = sharded_model.transformer
|
||||
sharded_bloom = sharded_model.unwrap().transformer
|
||||
|
||||
# check grad
|
||||
col_layer_for_check = ['h[0].self_attention.query_key_value']
|
||||
row_layer_for_check = ['h[0].self_attention.dense']
|
||||
check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
|
||||
check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
|
||||
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
|
||||
col_layer_for_check = ['h[0].self_attention.dense']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
|
||||
check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
|
||||
use_lazy_init):
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': True
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
def run_bloom_test(test_config):
|
||||
|
||||
# TODO: add test_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
# TODO: add test_config for flash attention & jit operator after supporting
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -67,7 +111,7 @@ def check_bloom(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom():
|
||||
spawn(check_bloom, 2)
|
||||
spawn(check_bloom, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,90 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
|
||||
policy = get_autopolicy(model)
|
||||
policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
policy.set_shard_config(model_config)
|
||||
layers = policy.get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 0 + 2
|
||||
else:
|
||||
if name == 'transformers_bloom':
|
||||
assert len(layers) == 1 + 1
|
||||
elif name == 'transformers_bloom_for_token_classification':
|
||||
assert len(layers) == 1 + 3
|
||||
else:
|
||||
assert len(layers) == 1 + 2
|
||||
|
||||
|
||||
def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
|
||||
if stage_manager.stage == 0:
|
||||
x = torch.randint(0, 1000, (1, 3)).cuda()
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||
assert output['hidden_states'].shape == (1, 3, 64)
|
||||
else:
|
||||
attention_mask = torch.ones((1, 3)).cuda()
|
||||
hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
|
||||
output = sharded_model(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
assert output[0].shape[0] == 1
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_bloom
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
PP_DIM = 0
|
||||
PP_SIZE = 2
|
||||
pg_mesh = ProcessGroupMesh(PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
check_bloom_model_policy(name, org_model, stage_manager)
|
||||
check_bloom_model_pipeline_forward(name, sharded_model, stage_manager)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_bloom(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_bloom_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom():
|
||||
spawn(check_bloom, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bloom()
|
@@ -1,99 +1,126 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = \
|
||||
run_forward_backward_with_hybrid_plugin(
|
||||
org_model,
|
||||
sharded_model,
|
||||
sharded_optimizer,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
criterion,
|
||||
booster)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
|
||||
if org_model.__class__.__name__ == 'ChatGLMModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3, dim=1)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'ChatGLMModel':
|
||||
chatglm_model = org_model
|
||||
shard_chatglm_model = sharded_model
|
||||
shard_chatglm_model = sharded_model.unwrap()
|
||||
else:
|
||||
chatglm_model = org_model.transformer
|
||||
shard_chatglm_model = sharded_model.transformer
|
||||
shard_chatglm_model = sharded_model.unwrap().transformer
|
||||
|
||||
# check attention grad
|
||||
org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
|
||||
shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
|
||||
shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight
|
||||
# check grad
|
||||
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
|
||||
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(chatglm_model,
|
||||
shard_chatglm_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-6,
|
||||
rtol=1e-3,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
check_grad(chatglm_model,
|
||||
shard_chatglm_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-6,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check embedding weights
|
||||
org_grad = chatglm_model.embedding.word_embeddings.weight.grad
|
||||
shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad
|
||||
shard_weight = shard_chatglm_model.embedding.word_embeddings.weight
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(chatglm_model,
|
||||
shard_chatglm_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-4,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': True
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
def run_chatglm_test(test_config):
|
||||
|
||||
# TODO: add test_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
# TODO: add test_config for flash attention & jit operator after supporting
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create new model
|
||||
org_model = model_fn().cuda()
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
if name == "transformers_chatglm":
|
||||
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy())
|
||||
else:
|
||||
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy())
|
||||
sharded_model = sharded_model.cuda()
|
||||
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -107,7 +134,7 @@ def check_chatglm(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_chatglm():
|
||||
spawn(check_chatglm, 2)
|
||||
spawn(check_chatglm, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,86 +0,0 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create new model for test
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids = inputs['input_ids']
|
||||
hidden_size = 64
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_state_shape = (seq_len, batch_size, hidden_size)
|
||||
if name == "transformers_chatglm":
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy())
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = torch.randn(*hidden_state_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
outputs = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
assert outputs[0].shape == hidden_state_shape
|
||||
|
||||
else:
|
||||
assert outputs['hidden_states'].shape == hidden_state_shape
|
||||
|
||||
if name == "transformers_chatglm_for_conditional_generation":
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init,
|
||||
ChatGLMForConditionalGenerationPolicy())
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = torch.randn(*hidden_state_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
outputs = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
assert outputs[0].shape == (batch_size, seq_len, 65024)
|
||||
else:
|
||||
assert outputs['hidden_states'].shape == hidden_state_shape
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_chatglm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_chatglm_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_chatglm():
|
||||
spawn(check_chatglm, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatglm()
|
@@ -2,69 +2,139 @@ import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
# forward check
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5)
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||
|
||||
# run backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
org_loss, org_output, sharded_loss, sharded_output = \
|
||||
run_forward_backward_with_hybrid_plugin(
|
||||
org_model,
|
||||
sharded_model,
|
||||
sharded_optimizer,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
criterion,
|
||||
booster)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
|
||||
if org_model.__class__.__name__ == 'LlamaModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
|
||||
|
||||
# unwrap model
|
||||
if hasattr(org_model, 'model'):
|
||||
llama_model = org_model.model
|
||||
shard_llama_model = sharded_model.model
|
||||
else:
|
||||
if org_model.__class__.__name__ == 'LlamaModel':
|
||||
llama_model = org_model
|
||||
shard_llama_model = sharded_model
|
||||
shard_llama_model = sharded_model.unwrap()
|
||||
else:
|
||||
llama_model = org_model.model
|
||||
shard_llama_model = sharded_model.unwrap().model
|
||||
|
||||
# check grad
|
||||
col_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
|
||||
row_layer_for_check = ['layers[0].self_attn.o_proj']
|
||||
check_grad(llama_model, shard_llama_model, col_layer_for_check, atol=1e-6, rtol=1e-4, dim=0, verbose=False)
|
||||
check_grad(llama_model, shard_llama_model, row_layer_for_check, atol=1e-6, rtol=1e-4, dim=1, verbose=False)
|
||||
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
|
||||
col_layer_for_check = ['layers[0].self_attn.o_proj']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(llama_model,
|
||||
shard_llama_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
check_grad(llama_model,
|
||||
shard_llama_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(llama_model,
|
||||
shard_llama_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-4,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init):
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 2,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': True
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 4,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
def run_llama_test(test_config):
|
||||
|
||||
# TODO: add test_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
# TODO: add test_config for flash attention & jit operator after supporting
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_gpt2_llama()
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -1,89 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
|
||||
policy = get_autopolicy(model)
|
||||
policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
policy.set_shard_config(model_config)
|
||||
layers = policy.get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 2 + 1
|
||||
else:
|
||||
if name == "transformers_llama":
|
||||
assert len(layers) == 2 + 1
|
||||
else:
|
||||
assert len(layers) == 2 + 2
|
||||
|
||||
|
||||
def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
assert output[0] is not None
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_llama
|
||||
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
PP_DIM = 0
|
||||
PP_SIZE = 2
|
||||
pg_mesh = ProcessGroupMesh(PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
check_llama_model_policy(name, org_model, stage_manager)
|
||||
check_llama_model_pipeline_forward(name, sharded_model, stage_manager)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
@@ -1,64 +1,129 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5)
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
# run backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
org_loss, org_output, sharded_loss, sharded_output = \
|
||||
run_forward_backward_with_hybrid_plugin(
|
||||
org_model,
|
||||
sharded_model,
|
||||
sharded_optimizer,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
criterion,
|
||||
booster)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
|
||||
if org_model.__class__.__name__ == 'OPTModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
|
||||
|
||||
# unwrap model
|
||||
if hasattr(org_model, 'model'):
|
||||
opt_model = org_model.model
|
||||
shard_opt_model = sharded_model.model
|
||||
else:
|
||||
if org_model.__class__.__name__ == 'OPTModel':
|
||||
opt_model = org_model
|
||||
shard_opt_model = sharded_model
|
||||
shard_opt_model = sharded_model.unwrap()
|
||||
else:
|
||||
opt_model = org_model.model
|
||||
shard_opt_model = sharded_model.unwrap().model
|
||||
|
||||
# check grad
|
||||
col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
|
||||
row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
|
||||
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
|
||||
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
|
||||
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(opt_model,
|
||||
shard_opt_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-6,
|
||||
rtol=1e-3,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
check_grad(opt_model,
|
||||
shard_opt_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-6,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(opt_model,
|
||||
shard_opt_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
|
||||
enable_jit_fused):
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': True
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
def run_opt_test(test_config):
|
||||
|
||||
# TODO: add test_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
# TODO: add test_config for flash attention & jit operator after supporting
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@@ -1,70 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# TODO: add tests for forward/backward later
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_opt
|
||||
def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = 128
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
# change inputs if not the first stage
|
||||
|
||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
sharded_model.train()
|
||||
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
assert output[0] is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_opt(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_opt_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_opt():
|
||||
spawn(check_opt, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_opt()
|
@@ -1,60 +1,127 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
org_loss, org_output, sharded_loss, sharded_output = \
|
||||
run_forward_backward_with_hybrid_plugin(
|
||||
org_model,
|
||||
sharded_model,
|
||||
sharded_optimizer,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
criterion,
|
||||
booster)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
|
||||
if org_model.__class__.__name__ == 'ViTModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'ViTModel':
|
||||
vit_model = org_model
|
||||
shard_vit_model = sharded_model
|
||||
shard_vit_model = sharded_model.unwrap()
|
||||
else:
|
||||
vit_model = org_model.vit
|
||||
shard_vit_model = sharded_model.vit
|
||||
shard_vit_model = sharded_model.unwrap().vit
|
||||
|
||||
# check grad
|
||||
col_layer_for_check = ['encoder.layer[0].attention.attention.query']
|
||||
row_layer_for_check = ['encoder.layer[0].attention.output.dense']
|
||||
check_grad(vit_model, shard_vit_model, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(vit_model, shard_vit_model, row_layer_for_check, atol=1e-5, rtol=1e-3, dim=1, verbose=False)
|
||||
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
|
||||
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(vit_model,
|
||||
shard_vit_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-5,
|
||||
rtol=1e-3,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
check_grad(vit_model,
|
||||
shard_vit_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=1e-5,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(vit_model,
|
||||
shard_vit_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=5e-3,
|
||||
rtol=1e-3,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
def run_vit_test(test_config):
|
||||
|
||||
# TODO: add test_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
# TODO: add test_config for flash attention & jit operator after supporting
|
||||
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -68,7 +135,7 @@ def check_vit(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_vit():
|
||||
spawn(check_vit, 2)
|
||||
spawn(check_vit, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,74 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# TODO: add tests for forward/backward later
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_vit
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
pixel_values = inputs['pixel_values']
|
||||
batch_size = len(pixel_values)
|
||||
hidden_size = 768
|
||||
hidden_state_shape = (batch_size, 197, hidden_size)
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
# change inputs if not the first stage
|
||||
hidden_states = torch.randn(*hidden_state_shape).cuda()
|
||||
# inputs['pixel_values'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
sharded_model.train()
|
||||
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
if name != 'transformers_vit':
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape, \
|
||||
f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}'
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_vit(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_vit_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_vit():
|
||||
spawn(check_vit, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_vit()
|
Reference in New Issue
Block a user