precision alignment

This commit is contained in:
yuehuayingxueluo
2024-01-02 18:30:11 +08:00
committed by FrankLeeeee
parent 62968588d1
commit 9489dc64d8
5 changed files with 45 additions and 47 deletions

View File

@@ -21,8 +21,8 @@ def multinomial_sample(
"""
Sample tokens in a random phase.
"""
max_best_of = generation_config.best_of
random_results = torch.multinomial(probs, num_samples=max_best_of, replacement=True).cpu()
# max_best_of = generation_config.best_of
random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu()
return random_results
@@ -44,7 +44,8 @@ def beam_search_sample(
# NOTE: this beam search sample function is wrong now.
"""
beam_width = generation_config.best_of
# beam_width = generation_config.best_of
beam_width = 1
results = []
if is_prompt:
# Prompt phase.