[Inference]Adapt temperature processing logic (#5689)

* Adapt temperature processing logic

* add ValueError for top_p and top_k

* add GQA Test

* fix except_msg
This commit is contained in:
yuehuayingxueluo
2024-05-08 17:58:29 +08:00
committed by GitHub
parent 12e7c28d5e
commit 9c2fe7935f
3 changed files with 36 additions and 6 deletions

View File

@@ -328,12 +328,14 @@ class RequestHandler:
"""
Sample tokens for finished requests.
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
for type in ["top_k", "top_p", "min_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
if generation_config.do_sample:
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)