diff --git a/colossalai/inference/engine/engine.py b/colossalai/inference/engine/engine.py index aeb637b96..e122b821b 100644 --- a/colossalai/inference/engine/engine.py +++ b/colossalai/inference/engine/engine.py @@ -156,7 +156,8 @@ class InferenceEngine: input_list, self.max_input_len, self.max_output_len, self.cache_manager_list[0] ) # bind the infer state to the model (not lm model) - self.model.model.infer_state = batch_infer_state + model_to_bind = self.model.model if hasattr(self.model, "model") else self.model.transformer + model_to_bind.infer_state = batch_infer_state if generation_config is not None: generation_config.max_new_tokens = self.max_output_len else: diff --git a/colossalai/inference/engine/modeling/chatglm2.py b/colossalai/inference/engine/modeling/chatglm2.py index 2eac235d5..d6ac78df0 100644 --- a/colossalai/inference/engine/modeling/chatglm2.py +++ b/colossalai/inference/engine/modeling/chatglm2.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple import torch +from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils import logging from colossalai.inference.kv_cache import BatchInferState @@ -83,6 +84,7 @@ class ChatGLM2InferenceForwards: stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): + # This function is only used when pipeline is enabled. logger = logging.get_logger(__name__) if output_attentions: @@ -136,11 +138,12 @@ class ChatGLM2InferenceForwards: stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): + infer_state = infer_state or getattr(self, "infer_state", None) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - if stage_manager.is_first_stage(): + if stage_manager is None or stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -229,7 +232,7 @@ class ChatGLM2InferenceForwards: ) # Run encoder. - hidden_states = self.encoder( + hidden_states, next_cache = self.encoder( hidden_states, full_attention_mask, kv_caches=past_key_values, @@ -246,6 +249,11 @@ class ChatGLM2InferenceForwards: infer_state.seq_len += 1 infer_state.max_len_in_batch += 1 + if stage_manager is None: + if not return_dict: + return (hidden_states, next_cache) + return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache) + return {"hidden_states": hidden_states} @staticmethod @@ -264,10 +272,15 @@ class ChatGLM2InferenceForwards: hidden_states = hidden_states.transpose(0, 1).contiguous() infer_state.decode_layer_id = 0 + + if stage_index is None: + stage_index = (0, len(self.layers)) start_idx, end_idx = stage_index[0], stage_index[1] if kv_caches is None: kv_caches = tuple([None] * (end_idx - start_idx + 1)) + # for HF api compatibility, kv-cache must be returned + next_decoder_cache = () if use_cache else None for idx, kv_cache in zip(range(start_idx, end_idx), kv_caches): layer = self.layers[idx] layer_ret = layer( @@ -279,15 +292,19 @@ class ChatGLM2InferenceForwards: ) infer_state.decode_layer_id += 1 - hidden_states, _ = layer_ret + hidden_states, next_kv_cache = layer_ret + if use_cache: + next_decoder_cache += (next_kv_cache,) hidden_states = hidden_states.transpose(0, 1).contiguous() - if self.post_layer_norm and (stage_manager.is_last_stage() or stage_manager.num_stages == 1): + if self.post_layer_norm and (stage_manager is None or stage_manager.is_last_stage()): # Final layer norm. hidden_states = self.final_layernorm(hidden_states) - return hidden_states + next_cache = next_decoder_cache if use_cache else None + + return hidden_states, next_cache @staticmethod def chatglm_glmblock_forward(