fix logprob, add filtering, temperature annealing, lr descent

This commit is contained in:
YeAnbang
2025-03-21 10:24:24 +08:00
parent f983071b10
commit 16e68a071d
7 changed files with 74 additions and 27 deletions

View File

@@ -387,7 +387,7 @@ def dist_log_prob(
dtype=dtype,
)
else:
log_prob = log_softmax(logits)
log_prob = log_softmax(logits, dim=-1)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
return log_prob