[pipeline] add chatglm (#4363)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining

* add bert_for_pretraining forward and policy

* fix typos

* cancel warning

* change the imediate output to default dict

* change the default output of get_shared_params

* add chatglm

* add

* chatglm

* chatglm

* finish chatglm

* deletes

* fix rmsnorm

* chatglm

* fix chatglm shard

* init
This commit is contained in:
Jianghai 2023-08-04 14:55:31 +08:00 committed by Hongxin Liu
parent b1feeced8e
commit a88e92251d
9 changed files with 1828 additions and 57 deletions

View File

@ -0,0 +1,189 @@
""" PyTorch ChatGLM model. """
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
)
class ChatGLMPipelineForwards:
'''
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
'''
@staticmethod
def chatglm_model_forward(
self: ChatGLMModel,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
logger = logging.get_logger(__name__)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
if stage_manager.is_first_stage():
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
hidden_states = inputs_embeds
else:
seq_length, batch_size = hidden_states.shape[:2]
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask],
dim=-1)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
if not past_key_values:
past_key_values = [None for _ in range(self.num_layers)]
presents = () if use_cache else None
if self.encoder.gradient_checkpointing and self.encoder.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1]
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.encoder.gradient_checkpointing and self.encoder.training:
layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb,
past_key_values[idx], use_cache)
else:
layer_ret = layer(hidden_states,
full_attention_mask,
rotary_pos_emb,
kv_cache=past_key_values[idx],
use_cache=use_cache)
hidden_states, kv_cache = layer_ret
if use_cache:
presents = presents + (kv_cache,)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
# final layer_norm
if self.encoder.post_layer_norm:
hidden_states = self.encoder.final_layernorm(hidden_states)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
else:
return {'hidden_states': hidden_states}
@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
logger = logging.get_logger(__name__)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward(
self.transformer,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[-1:]
lm_logits = self.transformer.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous()
loss = None
if labels is not None:
lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
return transformer_outputs

View File

@ -0,0 +1,58 @@
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -1,32 +1,46 @@
from typing import Dict, Union from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from transformers.modeling_outputs import BaseModelOutputWithPast
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] __all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
class ChatGLMModelPolicy(Policy): class ChatGLMPolicy(Policy):
def config_sanity_check(self): def config_sanity_check(self):
pass pass
def preprocess(self): def preprocess(self):
# Resize embedding # Resize embedding
vocab_size = self.model.config.padded_vocab_size if self.shard_config.enable_tensor_parallelism:
world_size = self.shard_config.tensor_parallel_size vocab_size = self.model.config.padded_vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
policy = {} policy = {}
@ -112,9 +126,91 @@ class ChatGLMModelPolicy(Policy):
def postprocess(self): def postprocess(self):
return self.model return self.model
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == 'ChatGLMModel':
module = self.model
else:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
held_layers.append(module.encoder.final_layernorm)
# rotary_pos_emb is needed for all stages
held_layers.append(module.rotary_pos_emb)
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'ChatGLMModel':
module = self.model
else:
module = self.model.transformer
layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
class ChatGLMModelPolicy(ChatGLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=ChatGLMModel,
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in ChatGLMModel."""
return []
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=ChatGLMForConditionalGeneration,
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
policy=policy)
return policy return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.transformer.output_layer)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in ChatGLMForConditionalGenerationModel."""
return []

View File

@ -1,9 +1,11 @@
import torch import torch
import transformers import transformers
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..registry import ModelAttribute, model_zoo from ..registry import ModelAttribute, model_zoo
from .chatglm2_6b.configuration_chatglm import ChatGLMConfig
from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
# ================================ # ================================
# Register single-sentence ChatGLM # Register single-sentence ChatGLM
@ -20,15 +22,18 @@ def data_gen():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean() loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum()
loss_fn = lambda x: x.logits.mean() loss_fn = lambda x: x.logits.sum()
config = ChatGLMConfig(num_layers=1, config = ChatGLMConfig(num_layers=1,
padded_vocab_size=65024, padded_vocab_size=65024,
hidden_size=64, hidden_size=64,
num_attention_heads=8, num_attention_heads=8,
rmsnorm=False, rmsnorm=True,
original_rope=True, original_rope=True,
use_cache=True) use_cache=True,
torch_dtype=torch.float32)
model_zoo.register(name='transformers_chatglm', model_zoo.register(name='transformers_chatglm',
model_fn=lambda: ChatGLMModel(config, empty_init=False), model_fn=lambda: ChatGLMModel(config, empty_init=False),

View File

@ -1,39 +0,0 @@
from colossalai.shardformer.policies.t5 import T5BasePolicy
def test_t5_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_t5_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end

View File

@ -1,5 +1,6 @@
import copy import copy
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
@ -15,6 +16,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.shardformer._utils import getattr_ from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
@ -39,7 +41,8 @@ def build_pipeline_model(model_fn,
stage_manager=None, stage_manager=None,
enable_fused_normalization=False, enable_fused_normalization=False,
enable_tensor_parallelism=False, enable_tensor_parallelism=False,
use_lazy_init: bool = False): use_lazy_init: bool = False,
policy: Optional[Policy] = None):
ctx = LazyInitContext() if use_lazy_init else nullcontext() ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx: with ctx:
# create new model # create new model
@ -54,7 +57,7 @@ def build_pipeline_model(model_fn,
pipeline_stage_manager=stage_manager) pipeline_stage_manager=stage_manager)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy) sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
return org_model.cuda(), sharded_model.cuda() return org_model.cuda(), sharded_model.cuda()

View File

@ -60,7 +60,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
shard_weight = shard_chatglm_model.embedding.word_embeddings.weight shard_weight = shard_chatglm_model.embedding.word_embeddings.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad) torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0) all_shard_grad = torch.cat(shard_grad_list, dim=0)
else: else:

View File

@ -0,0 +1,86 @@
import copy
import os
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model for test
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids = inputs['input_ids']
hidden_size = 64
batch_size, seq_len = input_ids.shape
hidden_state_shape = (seq_len, batch_size, hidden_size)
if name == "transformers_chatglm":
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy())
if stage_manager.is_last_stage():
hidden_states = torch.randn(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
outputs = sharded_model(**inputs)
if stage_manager.is_last_stage():
assert outputs[0].shape == hidden_state_shape
else:
assert outputs['hidden_states'].shape == hidden_state_shape
if name == "transformers_chatglm_for_conditional_generation":
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init,
ChatGLMForConditionalGenerationPolicy())
if stage_manager.is_last_stage():
hidden_states = torch.randn(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
outputs = sharded_model(**inputs)
if stage_manager.is_last_stage():
assert outputs[0].shape == (batch_size, seq_len, 65024)
else:
assert outputs['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache()
def check_chatglm(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chatglm_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm():
spawn(check_chatglm, 4)
if __name__ == "__main__":
test_chatglm()