mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[Shardformer] Support the Qwen2 model (#5699)
* feat: support qwen2 model * fix: modify model config and add Qwen2RMSNorm * fix qwen2 model conflicts * test: add qwen2 shard test * to: add qwen2 auto policy * support qwen model * fix the conflicts * add try catch * add transformers version for qwen2 * add the ColoAttention for the qwen2 model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add the unit test version check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the test input bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the version check * fix the version check --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -17,3 +17,8 @@ try:
|
||||
from .mistral import *
|
||||
except ImportError:
|
||||
print("This version of transformers doesn't support mistral.")
|
||||
|
||||
try:
|
||||
from .qwen2 import *
|
||||
except ImportError:
|
||||
print("This version of transformers doesn't support qwen2.")
|
||||
|
89
tests/kit/model_zoo/transformers/qwen2.py
Normal file
89
tests/kit/model_zoo/transformers/qwen2.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
try:
|
||||
from transformers import Qwen2Config
|
||||
|
||||
HAS_QWEN2 = True
|
||||
except ImportError:
|
||||
HAS_QWEN2 = False
|
||||
|
||||
if HAS_QWEN2:
|
||||
# ===============================
|
||||
# Register Qwen2
|
||||
# ===============================
|
||||
|
||||
def data_gen():
|
||||
# the input ids are corresponding to the sentence
|
||||
# 'Hello, my dog is cute'
|
||||
#
|
||||
# the code is give below:
|
||||
# -----------------------------------
|
||||
# from transformers import Qwen2TokenizerFast
|
||||
# tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen1.5-7B-Chat")
|
||||
# input = 'Hello, my dog is cute'
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
# -----------------------------------
|
||||
|
||||
input_ids = torch.Tensor(
|
||||
[[9707, 11, 847, 5562, 374, 13, 123, 18838], [9707, 11, 847, 5562, 374, 17, 89, 18838]]
|
||||
).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
data = data_gen()
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
# transform the output to a dict
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# function to get the loss
|
||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = Qwen2Config(
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
max_window_layers=4,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=16,
|
||||
)
|
||||
|
||||
config.pad_token_id = 0
|
||||
|
||||
# register the following models
|
||||
# transformers.Qwen2Model,
|
||||
# transformers.Qwen2ForCausalLM,
|
||||
# transformers.Qwen2ForSequenceClassification,
|
||||
model_zoo.register(
|
||||
name="transformers_qwen2",
|
||||
model_fn=lambda: transformers.Qwen2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_qwen2_for_casual_lm",
|
||||
model_fn=lambda: transformers.Qwen2ForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_qwen2_for_sequence_classification",
|
||||
model_fn=lambda: transformers.Qwen2ForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_seq_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
235
tests/test_shardformer/test_model/test_shard_qwen2.py
Normal file
235
tests/test_shardformer/test_model/test_shard_qwen2.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config
|
||||
)
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
qwen2_model = unwrap_model(org_model, "Qwen2Model", "model")
|
||||
shard_qwen2_model = unwrap_model(sharded_model, "Qwen2Model", "model")
|
||||
|
||||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-6, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(
|
||||
qwen2_model, shard_qwen2_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||
)
|
||||
col_layer_grads = get_grad_tensors_for_check(
|
||||
qwen2_model, shard_qwen2_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "Qwen2Model":
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
qwen2_model, shard_qwen2_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
},
|
||||
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_qwen2_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"pp_style": "interleaved",
|
||||
"num_model_chunks": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_qwen2_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_qwen2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_qwen2_test()
|
||||
|
||||
|
||||
def check_qwen2_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_qwen2_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_qwen2():
|
||||
spawn(check_qwen2, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_qwen2_3d():
|
||||
spawn(check_qwen2_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_qwen2()
|
||||
test_qwen2_3d()
|
Reference in New Issue
Block a user