mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[Inference]Adapt temperature processing logic (#5689)
* Adapt temperature processing logic * add ValueError for top_p and top_k * add GQA Test * fix except_msg
This commit is contained in:
@@ -17,11 +17,30 @@ def register_logit_processor(process_type):
|
||||
return register
|
||||
|
||||
|
||||
@register_logit_processor("temperature")
|
||||
def temperature_logit_process(logits, temperature: float):
|
||||
"""
|
||||
apply temperature scaling.
|
||||
"""
|
||||
|
||||
if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0):
|
||||
except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0."
|
||||
if temperature == 0.0:
|
||||
except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`."
|
||||
raise ValueError(except_msg)
|
||||
|
||||
return logits if temperature == 1.0 else logits / temperature
|
||||
|
||||
|
||||
@register_logit_processor("top_k")
|
||||
def top_k_logit_processor(logits, top_k: int):
|
||||
"""
|
||||
top_k logit processor
|
||||
"""
|
||||
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError(f"`top_k` should be a strictly positive integer, but got {top_k}.")
|
||||
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = -float("inf")
|
||||
return logits
|
||||
@@ -32,6 +51,10 @@ def top_p_logit_processor(logits, top_p: float):
|
||||
"""
|
||||
top_p logit processor
|
||||
"""
|
||||
|
||||
if top_p < 0 or top_p > 1.0:
|
||||
raise ValueError(f"`top_p` should be a float > 0 and < 1, but got {top_p}.")
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
|
Reference in New Issue
Block a user