From 6377aa0fffb8fbd6862fc2b4ed536724cbe09d64 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 28 Oct 2024 02:42:33 +0000 Subject: [PATCH] [fix] fix test_shard_llama ci; --- colossalai/shardformer/modeling/llama.py | 2 +- tests/test_shardformer/test_model/test_shard_llama.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 7a04c5451..47c17e749 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -82,7 +82,7 @@ class LlamaPipelineForwards: elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] + batch_size, seq_length, _ = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index b43e45bcf..33707a4f6 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -325,7 +325,6 @@ def run_llama_test(test_config): ).get_v_schedule() test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - print(f"name {name}") if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue try: