diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 977aab07c..a68400fb0 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -111,6 +111,7 @@ class InferenceConfig: use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. + ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. """ # NOTE: arrange configs according to their importance and frequency of usage @@ -156,6 +157,7 @@ class InferenceConfig: # cuda_graph use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference max_context_len_to_capture: int = 512 + ignore_eos: bool = False def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 73fe7df9b..04eb620c5 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -662,6 +662,7 @@ class InferenceEngine: self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, max_output_len=max_new_tokens, + ignore_eos=self.inference_config.ignore_eos, ) self.request_handler.add_sequence(sequence) diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 148b2bf88..db4820f51 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -60,6 +60,7 @@ class Sequence: eos_token_id (int): The eos token id for this inference process. pad_token_id (int): The pad token id for this inference process. max_output_len (int): Maximum output length. + ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. """ request_id: int @@ -70,6 +71,8 @@ class Sequence: eos_token_id: int pad_token_id: int max_output_len: int = 256 + # NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future. + ignore_eos: bool = False def __post_init__(self): self.output_token_id = [] @@ -107,7 +110,9 @@ class Sequence: return True if self.output_token_id: - if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len: + if ( + self.output_token_id[-1] == self.eos_token_id and not self.ignore_eos + ) or self.output_len >= self.max_output_len: self.status = RequestStatus.COMPLETED return True