remove vocab_size args

This commit is contained in:
wangbluo
2024-06-20 08:06:39 +00:00
parent b12e9a3275
commit dba59354d7
2 changed files with 11 additions and 34 deletions

View File

@@ -320,8 +320,6 @@ class LlamaPipelineForwards:
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)