[Inference]Adapt temperature processing logic (#5689)

* Adapt temperature processing logic

* add ValueError for top_p and top_k

* add GQA Test

* fix except_msg
This commit is contained in:
yuehuayingxueluo
2024-05-08 17:58:29 +08:00
committed by GitHub
parent 12e7c28d5e
commit 9c2fe7935f
3 changed files with 36 additions and 6 deletions

View File

@@ -17,11 +17,30 @@ def register_logit_processor(process_type):
return register
@register_logit_processor("temperature")
def temperature_logit_process(logits, temperature: float):
"""
apply temperature scaling.
"""
if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0):
except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0."
if temperature == 0.0:
except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`."
raise ValueError(except_msg)
return logits if temperature == 1.0 else logits / temperature
@register_logit_processor("top_k")
def top_k_logit_processor(logits, top_k: int):
"""
top_k logit processor
"""
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` should be a strictly positive integer, but got {top_k}.")
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float("inf")
return logits
@@ -32,6 +51,10 @@ def top_p_logit_processor(logits, top_p: float):
"""
top_p logit processor
"""
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` should be a float > 0 and < 1, but got {top_p}.")
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)