[Inference] Support the logic related to ignoring EOS token (#5693)

* Adapt temperature processing logic

* add ValueError for top_p and top_k

* add GQA Test

* fix except_msg

* support ignore EOS token

* change variable's name

* fix annotation
This commit is contained in:
yuehuayingxueluo
2024-05-08 19:59:10 +08:00
committed by GitHub
parent 9c2fe7935f
commit d482922035
3 changed files with 9 additions and 1 deletions

View File

@@ -60,6 +60,7 @@ class Sequence:
eos_token_id (int): The eos token id for this inference process.
pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
"""
request_id: int
@@ -70,6 +71,8 @@ class Sequence:
eos_token_id: int
pad_token_id: int
max_output_len: int = 256
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
ignore_eos: bool = False
def __post_init__(self):
self.output_token_id = []
@@ -107,7 +110,9 @@ class Sequence:
return True
if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
if (
self.output_token_id[-1] == self.eos_token_id and not self.ignore_eos
) or self.output_len >= self.max_output_len:
self.status = RequestStatus.COMPLETED
return True