fix bugs in sampler

This commit is contained in:
yuehuayingxueluo
2024-01-04 15:03:18 +08:00
committed by FrankLeeeee
parent 02c1bf8b2a
commit bbfebfb9fc
4 changed files with 12 additions and 8 deletions

View File

@@ -21,7 +21,7 @@ def multinomial_sample(
"""
Sample tokens in a random phase.
"""
random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu()
random_results = torch.multinomial(probs, num_samples=1).squeeze(1)
return random_results