[ColossalChat] Add PP support (#6001)

* support pp training

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update rm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test case

* fix

* change to 4

* fix eval

* test

* add pp

* hotfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support pp training

* update rm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test case

* fix

* change to 4

* fix eval

* test

* add pp

* hotfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* skip pp eval

* update all reduce

* update sft

* update ignore

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update no cache

* add eval

* remove fi

* remove debug

* remove parentheses to avoid warning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "add eval"

This reverts commit 3ab2f6fa32.

* add all reduce

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tong Li
2024-08-21 10:47:39 +08:00
committed by GitHub
parent 0d3b0bd864
commit 39e2597426
16 changed files with 241 additions and 115 deletions

View File

@@ -6,14 +6,16 @@ import os
from typing import Optional
import torch
import torch.distributed as dist
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from tqdm import tqdm, trange
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
from colossalai.cluster import DistCoordinator
from .base import SLTrainer
@@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
optim: Optimizer,
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
plugin: Plugin = None,
accumulation_steps: int = 8,
apply_loss_mask: bool = True,
start_epoch=0,
@@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
save_dir: str = None,
coordinator: Optional[DistCoordinator] = None,
) -> None:
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
@@ -94,60 +97,85 @@ class SFTTrainer(SLTrainer):
def _train(self, epoch: int):
self.model.train()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
data_iter = iter(self.train_dataloader)
step_bar = tqdm(
range(len(self.train_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
loss = outputs.loss
for step in step_bar:
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"train/loss": global_loss.item()})
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
else:
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
self.booster.backward(loss=loss, optimizer=self.optimizer)
# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
@@ -157,27 +185,64 @@ class SFTTrainer(SLTrainer):
self.accumulative_meter.reset()
self.model.eval()
with torch.no_grad():
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
data_iter = iter(self.eval_dataloader)
step_bar = tqdm(
range(len(self.eval_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
for step in step_bar:
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"eval/loss": global_loss.item()})
self.accumulative_meter.add("loss", global_loss.item())
if dist.get_rank() == dist.get_world_size() - 1:
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
print(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
else:
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()