mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
reconstruct chat trainer and fix training script (#3588)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -156,8 +156,10 @@ def main(args):
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device())
|
||||
random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool)
|
||||
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
|
||||
trainer.fit(random_prompts, random_pretrain,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
@@ -149,8 +149,10 @@ def main(args):
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device())
|
||||
random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool)
|
||||
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
|
||||
trainer.fit(random_prompts, random_pretrain,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
Reference in New Issue
Block a user