mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
Support overall loss, update KTO logging
This commit is contained in:
@@ -41,6 +41,7 @@ class SFTTrainer(SLTrainer):
|
||||
lr_scheduler: _LRScheduler,
|
||||
max_epochs: int = 2,
|
||||
accumulation_steps: int = 8,
|
||||
apply_loss_mask: bool = True,
|
||||
start_epoch=0,
|
||||
save_interval: int = None,
|
||||
save_dir: str = None,
|
||||
@@ -55,6 +56,7 @@ class SFTTrainer(SLTrainer):
|
||||
self.coordinator = coordinator
|
||||
self.num_train_step = 0
|
||||
self.num_eval_step = 0
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
|
||||
def _before_fit(
|
||||
@@ -100,7 +102,11 @@ class SFTTrainer(SLTrainer):
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
@@ -158,7 +164,11 @@ class SFTTrainer(SLTrainer):
|
||||
)
|
||||
for batch in self.eval_dataloader:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss_mean = all_reduce_mean(tensor=outputs.loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
|
||||
step_bar.update()
|
||||
|
Reference in New Issue
Block a user