mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-04 14:38:10 +00:00
[Inference]Adapt temperature processing logic (#5689)
* Adapt temperature processing logic * add ValueError for top_p and top_k * add GQA Test * fix except_msg
This commit is contained in:
parent
12e7c28d5e
commit
9c2fe7935f
@ -328,10 +328,12 @@ class RequestHandler:
|
|||||||
"""
|
"""
|
||||||
Sample tokens for finished requests.
|
Sample tokens for finished requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# do logit processor
|
# do logit processor
|
||||||
|
if generation_config.do_sample:
|
||||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||||
config_dict = generation_config.to_dict()
|
config_dict = generation_config.to_dict()
|
||||||
for type in ["top_k", "top_p", "min_p"]:
|
for type in ["temperature", "top_k", "top_p"]:
|
||||||
if type in config_dict and config_dict[type] is not None:
|
if type in config_dict and config_dict[type] is not None:
|
||||||
logits = logit_processor(type, logits, config_dict[type])
|
logits = logit_processor(type, logits, config_dict[type])
|
||||||
|
|
||||||
|
@ -17,11 +17,30 @@ def register_logit_processor(process_type):
|
|||||||
return register
|
return register
|
||||||
|
|
||||||
|
|
||||||
|
@register_logit_processor("temperature")
|
||||||
|
def temperature_logit_process(logits, temperature: float):
|
||||||
|
"""
|
||||||
|
apply temperature scaling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0):
|
||||||
|
except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0."
|
||||||
|
if temperature == 0.0:
|
||||||
|
except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`."
|
||||||
|
raise ValueError(except_msg)
|
||||||
|
|
||||||
|
return logits if temperature == 1.0 else logits / temperature
|
||||||
|
|
||||||
|
|
||||||
@register_logit_processor("top_k")
|
@register_logit_processor("top_k")
|
||||||
def top_k_logit_processor(logits, top_k: int):
|
def top_k_logit_processor(logits, top_k: int):
|
||||||
"""
|
"""
|
||||||
top_k logit processor
|
top_k logit processor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
|
raise ValueError(f"`top_k` should be a strictly positive integer, but got {top_k}.")
|
||||||
|
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
logits[indices_to_remove] = -float("inf")
|
logits[indices_to_remove] = -float("inf")
|
||||||
return logits
|
return logits
|
||||||
@ -32,6 +51,10 @@ def top_p_logit_processor(logits, top_p: float):
|
|||||||
"""
|
"""
|
||||||
top_p logit processor
|
top_p logit processor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if top_p < 0 or top_p > 1.0:
|
||||||
|
raise ValueError(f"`top_p` should be a float > 0 and < 1, but got {top_p}.")
|
||||||
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
@ -28,7 +28,12 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
|||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
model = LlamaForCausalLM(
|
model = LlamaForCausalLM(
|
||||||
LlamaConfig(
|
LlamaConfig(
|
||||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
vocab_size=50000,
|
||||||
|
hidden_size=512,
|
||||||
|
intermediate_size=1536,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_hidden_layers=16,
|
||||||
)
|
)
|
||||||
).cuda()
|
).cuda()
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user