mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference]Adapt temperature processing logic (#5689)
* Adapt temperature processing logic * add ValueError for top_p and top_k * add GQA Test * fix except_msg
This commit is contained in:
@@ -328,12 +328,14 @@ class RequestHandler:
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
"""
|
||||
|
||||
# do logit processor
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
for type in ["top_k", "top_p", "min_p"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
if generation_config.do_sample:
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
for type in ["temperature", "top_k", "top_p"]:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
# calculate probs
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
Reference in New Issue
Block a user