[fix\ fix fail case test_shard_llama

This commit is contained in:
duanjunwen
2024-10-25 02:28:55 +00:00
parent 2eca112c90
commit d0ec221b38
5 changed files with 10 additions and 12 deletions

View File

@@ -287,6 +287,11 @@ def main():
# ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size