mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 14:12:02 +00:00
[pipeline] add chatglm (#4363)
* add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params * add chatglm * add * chatglm * chatglm * finish chatglm * deletes * fix rmsnorm * chatglm * fix chatglm shard * init
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
from .chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
|
||||
# ================================
|
||||
# Register single-sentence ChatGLM
|
||||
@@ -20,15 +22,18 @@ def data_gen():
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn = lambda x: x.logits.mean()
|
||||
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum()
|
||||
loss_fn = lambda x: x.logits.sum()
|
||||
|
||||
config = ChatGLMConfig(num_layers=1,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
num_attention_heads=8,
|
||||
rmsnorm=False,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True)
|
||||
use_cache=True,
|
||||
torch_dtype=torch.float32)
|
||||
|
||||
|
||||
model_zoo.register(name='transformers_chatglm',
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
|
@@ -1,39 +0,0 @@
|
||||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||
|
||||
|
||||
def test_t5_pipeline_distribution():
|
||||
num_test_cases = 8
|
||||
test_dict = {
|
||||
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
|
||||
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
|
||||
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
|
||||
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
|
||||
test_dict['num_decoder_layers'][i],
|
||||
test_dict['num_stages'][i])
|
||||
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
|
||||
|
||||
|
||||
def test_t5_pipeline_layers():
|
||||
num_test_cases = 4
|
||||
test_dict = {
|
||||
'num_encoder_layers': [2, 3, 2, 4],
|
||||
'num_decoder_layers': [2, 0, 2, 8],
|
||||
'num_stages': [2, 2, 4, 4],
|
||||
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
|
||||
[[0, 4], [0, 3], [3, 6], [6, 8]]]
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
|
||||
|
||||
for stage in range(test_dict['num_stages'][i]):
|
||||
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
|
||||
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
|
||||
decoder_starting_stage)
|
||||
assert start_idx == predicted_start
|
||||
assert end_idx == predicted_end
|
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -15,6 +16,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.auto_policy import Policy
|
||||
from colossalai.shardformer._utils import getattr_
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
|
||||
@@ -39,7 +41,8 @@ def build_pipeline_model(model_fn,
|
||||
stage_manager=None,
|
||||
enable_fused_normalization=False,
|
||||
enable_tensor_parallelism=False,
|
||||
use_lazy_init: bool = False):
|
||||
use_lazy_init: bool = False,
|
||||
policy: Optional[Policy] = None):
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
# create new model
|
||||
@@ -54,7 +57,7 @@ def build_pipeline_model(model_fn,
|
||||
pipeline_stage_manager=stage_manager)
|
||||
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
|
@@ -60,7 +60,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
shard_weight = shard_chatglm_model.embedding.word_embeddings.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
|
@@ -0,0 +1,86 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create new model for test
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids = inputs['input_ids']
|
||||
hidden_size = 64
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_state_shape = (seq_len, batch_size, hidden_size)
|
||||
if name == "transformers_chatglm":
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy())
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = torch.randn(*hidden_state_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
outputs = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
assert outputs[0].shape == hidden_state_shape
|
||||
|
||||
else:
|
||||
assert outputs['hidden_states'].shape == hidden_state_shape
|
||||
|
||||
if name == "transformers_chatglm_for_conditional_generation":
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init,
|
||||
ChatGLMForConditionalGenerationPolicy())
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = torch.randn(*hidden_state_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
outputs = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
assert outputs[0].shape == (batch_size, seq_len, 65024)
|
||||
else:
|
||||
assert outputs['hidden_states'].shape == hidden_state_shape
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_chatglm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_chatglm_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_chatglm():
|
||||
spawn(check_chatglm, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatglm()
|
Reference in New Issue
Block a user