mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 03:03:37 +00:00
[Chat]Add Peft support & fix the ptx bug (#3433)
* Update ppo.py Fix the bug of fetching wrong batch data * Add peft model support in SFT and Prompts training In stage-1 and stage-3, the peft model supports are added. So the trained artifacts will be only a small lora additions instead of the whole bunch of files. * Delete test_prompts.txt * Delete test_pretrained.txt * Move the peft stuffs to a community folder. * Move the demo sft to community * delete dirty files * Add instructions to install peft using source * Remove Chinese comments * remove the Chinese comments
This commit is contained in:
@@ -92,9 +92,10 @@ class PPOTrainer(Trainer):
|
||||
|
||||
# ptx loss
|
||||
if self.ptx_coef != 0:
|
||||
ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device())
|
||||
label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:]
|
||||
attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device())
|
||||
batch = next(iter(self.pretrain_dataloader))
|
||||
ptx = batch['input_ids'].to(torch.cuda.current_device())
|
||||
label = batch['labels'].to(torch.cuda.current_device())[:, 1:]
|
||||
attention_mask = batch['attention_mask'].to(torch.cuda.current_device())
|
||||
ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :]
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1))
|
||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||
|
Reference in New Issue
Block a user