[Chat]Add Peft support & fix the ptx bug (#3433)

* Update ppo.py

Fix the bug of fetching wrong batch data

* Add peft model support in SFT and Prompts training

In stage-1 and stage-3, the peft model supports are added. So the trained artifacts will be only a small lora additions instead of the whole bunch of files.

* Delete test_prompts.txt

* Delete test_pretrained.txt

* Move the peft stuffs to a community folder.

* Move the demo sft to community

* delete dirty files

* Add instructions to install peft using source

* Remove Chinese comments

* remove the Chinese comments
This commit is contained in:
YY Lin
2023-04-06 11:54:52 +08:00
committed by GitHub
parent 73afb63594
commit 62f4e2eb07
6 changed files with 781 additions and 3 deletions

View File

@@ -92,9 +92,10 @@ class PPOTrainer(Trainer):
# ptx loss
if self.ptx_coef != 0:
ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device())
label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:]
attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device())
batch = next(iter(self.pretrain_dataloader))
ptx = batch['input_ids'].to(torch.cuda.current_device())
label = batch['labels'].to(torch.cuda.current_device())[:, 1:]
attention_mask = batch['attention_mask'].to(torch.cuda.current_device())
ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :]
ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1))
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)