diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 0a2b151d4..66d77b48a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -160,7 +160,7 @@ def run_forward_backward_with_hybrid_plugin( input_shape = data["input_ids"].shape for k, v in data.items(): if v.shape == input_shape: - data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) + data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,)) sharded_model.train() if booster.plugin.stage_manager is not None: