mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
fix bugs in sampler
This commit is contained in:
committed by
FrankLeeeee
parent
02c1bf8b2a
commit
bbfebfb9fc
@@ -180,9 +180,9 @@ class RequestHandler:
|
||||
"""
|
||||
# do logit processor
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
for type in ["top_p", "top_k", "min_p"]:
|
||||
for type in ["top_k", "top_p", "min_p"]:
|
||||
config_dict = generation_config.to_dict()
|
||||
if type in config_dict:
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
@@ -21,7 +21,7 @@ def multinomial_sample(
|
||||
"""
|
||||
Sample tokens in a random phase.
|
||||
"""
|
||||
random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu()
|
||||
random_results = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
return random_results
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user