mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[fix\ fix fail case test_shard_llama
This commit is contained in:
@@ -287,6 +287,11 @@ def main():
|
||||
# ==============================
|
||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
|
||||
torch.cuda.manual_seed(42)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
|
Reference in New Issue
Block a user