[fix] fix test zerobubble

This commit is contained in:
duanjunwen 2024-10-28 06:06:07 +00:00
parent 6377aa0fff
commit 5aee4261a6

View File

@ -82,7 +82,7 @@ class LlamaPipelineForwards:
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2] batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2] batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None: