support pp training

This commit is contained in:
Tong Li 2024-08-12 10:13:03 +00:00
parent ceb1e262e7
commit 38c84a1aa0
3 changed files with 82 additions and 52 deletions

View File

@ -17,6 +17,7 @@ from coati.experience_maker import Experience
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster import Plugin
from .utils import is_rank_0 from .utils import is_rank_0
@ -38,6 +39,7 @@ class SLTrainer(ABC):
max_epochs: int, max_epochs: int,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optimizer,
plugin: Plugin,
start_epoch: int = 0, start_epoch: int = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -45,6 +47,7 @@ class SLTrainer(ABC):
self.max_epochs = max_epochs self.max_epochs = max_epochs
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.plugin = plugin
self.start_epoch = start_epoch self.start_epoch = start_epoch
@abstractmethod @abstractmethod

View File

@ -6,14 +6,16 @@ import os
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import trange from tqdm import tqdm, trange
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from .base import SLTrainer from .base import SLTrainer
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
optim: Optimizer, optim: Optimizer,
lr_scheduler: _LRScheduler, lr_scheduler: _LRScheduler,
max_epochs: int = 2, max_epochs: int = 2,
plugin: Plugin = None,
accumulation_steps: int = 8, accumulation_steps: int = 8,
apply_loss_mask: bool = True, apply_loss_mask: bool = True,
start_epoch=0, start_epoch=0,
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: Optional[DistCoordinator] = None, coordinator: Optional[DistCoordinator] = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch) super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
self.accumulation_steps = accumulation_steps self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler self.scheduler = lr_scheduler
@ -94,60 +97,82 @@ class SFTTrainer(SLTrainer):
def _train(self, epoch: int): def _train(self, epoch: int):
self.model.train() self.model.train()
step_bar = trange( if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
len(self.train_dataloader) // self.accumulation_steps, data_iter = iter(self.train_dataloader)
desc=f"Epoch {epoch + 1}/{self.max_epochs}", step_bar = tqdm(
disable=not is_rank_0(), range(len(self.train_dataloader)),
) desc="Step",
for i, batch in enumerate(self.train_dataloader): disable=not (dist.get_rank() == dist.get_world_size() - 1),
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
) )
loss = outputs.loss for step in step_bar:
outputs = self.booster.execute_pipeline(
self.booster.backward(loss=loss, optimizer=self.optimizer) data_iter,
self.model,
loss_mean = all_reduce_mean(tensor=loss) criterion=lambda outputs, inputs: outputs[0],
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) optimizer=self.optimizer,
return_loss=True,
# Gradient accumulation )
if (i + 1) % self.accumulation_steps == 0: loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"train/loss": loss.item()})
step_bar.update()
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step() else:
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
# Save checkpoint loss_mean = all_reduce_mean(tensor=loss)
if ( self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.save_dir is not None
and self.save_interval is not None # Gradient accumulation
and (self.num_train_step + 1) % self.save_interval == 0 if (i + 1) % self.accumulation_steps == 0:
): self.optimizer.step()
save_checkpoint( self.optimizer.zero_grad()
save_dir=self.save_dir, self.scheduler.step()
booster=self.booster,
model=self.model, step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
optimizer=self.optimizer, if self.writer:
lr_scheduler=self.scheduler, self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
epoch=epoch, self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
step=self.num_train_step + 1, self.num_train_step += 1
batch_size=batch_size, self.accumulative_meter.reset()
coordinator=self.coordinator, step_bar.update()
)
self.coordinator.print_on_master( # Save checkpoint
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" if (
) self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
step_bar.close() step_bar.close()
def _eval(self, epoch: int): def _eval(self, epoch: int):

View File

@ -114,7 +114,7 @@ def train(args):
parallel_output=False, parallel_output=False,
max_norm=args.grad_clip, max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
microbatch_size=args.batch_size, microbatch_size=args.microbatch_size,
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")
@ -269,6 +269,7 @@ def train(args):
model=model, model=model,
booster=booster, booster=booster,
optim=optim, optim=optim,
plugin=plugin,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps, accumulation_steps=args.accumulation_steps,
@ -344,6 +345,7 @@ if __name__ == "__main__":
parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument("--microbatch_size", type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
if args.config_file is not None: if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True) os.makedirs(os.path.dirname(args.config_file), exist_ok=True)