mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Inference] Support the logic related to ignoring EOS token (#5693)
* Adapt temperature processing logic * add ValueError for top_p and top_k * add GQA Test * fix except_msg * support ignore EOS token * change variable's name * fix annotation
This commit is contained in:
@@ -60,6 +60,7 @@ class Sequence:
|
||||
eos_token_id (int): The eos token id for this inference process.
|
||||
pad_token_id (int): The pad token id for this inference process.
|
||||
max_output_len (int): Maximum output length.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
"""
|
||||
|
||||
request_id: int
|
||||
@@ -70,6 +71,8 @@ class Sequence:
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
max_output_len: int = 256
|
||||
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
|
||||
ignore_eos: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.output_token_id = []
|
||||
@@ -107,7 +110,9 @@ class Sequence:
|
||||
return True
|
||||
|
||||
if self.output_token_id:
|
||||
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
|
||||
if (
|
||||
self.output_token_id[-1] == self.eos_token_id and not self.ignore_eos
|
||||
) or self.output_len >= self.max_output_len:
|
||||
self.status = RequestStatus.COMPLETED
|
||||
return True
|
||||
|
||||
|
Reference in New Issue
Block a user