mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-23 10:11:37 +00:00
support pp training
This commit is contained in:
parent
ceb1e262e7
commit
38c84a1aa0
@ -17,6 +17,7 @@ from coati.experience_maker import Experience
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster import Plugin
|
||||||
|
|
||||||
from .utils import is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ class SLTrainer(ABC):
|
|||||||
max_epochs: int,
|
max_epochs: int,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
start_epoch: int = 0,
|
start_epoch: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -45,6 +47,7 @@ class SLTrainer(ABC):
|
|||||||
self.max_epochs = max_epochs
|
self.max_epochs = max_epochs
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
self.plugin = plugin
|
||||||
self.start_epoch = start_epoch
|
self.start_epoch = start_epoch
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -6,14 +6,16 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean
|
||||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
from .base import SLTrainer
|
from .base import SLTrainer
|
||||||
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
|
|||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
lr_scheduler: _LRScheduler,
|
lr_scheduler: _LRScheduler,
|
||||||
max_epochs: int = 2,
|
max_epochs: int = 2,
|
||||||
|
plugin: Plugin = None,
|
||||||
accumulation_steps: int = 8,
|
accumulation_steps: int = 8,
|
||||||
apply_loss_mask: bool = True,
|
apply_loss_mask: bool = True,
|
||||||
start_epoch=0,
|
start_epoch=0,
|
||||||
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
|
|||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: Optional[DistCoordinator] = None,
|
coordinator: Optional[DistCoordinator] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
|
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
|
||||||
|
|
||||||
self.accumulation_steps = accumulation_steps
|
self.accumulation_steps = accumulation_steps
|
||||||
self.scheduler = lr_scheduler
|
self.scheduler = lr_scheduler
|
||||||
@ -94,60 +97,82 @@ class SFTTrainer(SLTrainer):
|
|||||||
|
|
||||||
def _train(self, epoch: int):
|
def _train(self, epoch: int):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
step_bar = trange(
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||||
len(self.train_dataloader) // self.accumulation_steps,
|
data_iter = iter(self.train_dataloader)
|
||||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
step_bar = tqdm(
|
||||||
disable=not is_rank_0(),
|
range(len(self.train_dataloader)),
|
||||||
)
|
desc="Step",
|
||||||
for i, batch in enumerate(self.train_dataloader):
|
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||||
batch = to_device(batch, torch.cuda.current_device())
|
|
||||||
batch_size = batch["input_ids"].size(0)
|
|
||||||
outputs = self.model(
|
|
||||||
batch["input_ids"],
|
|
||||||
attention_mask=batch["attention_mask"],
|
|
||||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
|
||||||
)
|
)
|
||||||
loss = outputs.loss
|
for step in step_bar:
|
||||||
|
outputs = self.booster.execute_pipeline(
|
||||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
data_iter,
|
||||||
|
self.model,
|
||||||
loss_mean = all_reduce_mean(tensor=loss)
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
optimizer=self.optimizer,
|
||||||
|
return_loss=True,
|
||||||
# Gradient accumulation
|
)
|
||||||
if (i + 1) % self.accumulation_steps == 0:
|
loss = outputs["loss"]
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
step_bar.set_postfix({"train/loss": loss.item()})
|
||||||
|
step_bar.update()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.scheduler.step()
|
else:
|
||||||
|
step_bar = trange(
|
||||||
|
len(self.train_dataloader) // self.accumulation_steps,
|
||||||
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
|
disable=not is_rank_0(),
|
||||||
|
)
|
||||||
|
for i, batch in enumerate(self.train_dataloader):
|
||||||
|
batch = to_device(batch, torch.cuda.current_device())
|
||||||
|
batch_size = batch["input_ids"].size(0)
|
||||||
|
outputs = self.model(
|
||||||
|
batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||||
|
)
|
||||||
|
loss = outputs.loss
|
||||||
|
|
||||||
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||||
if self.writer:
|
|
||||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
|
||||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
|
||||||
self.num_train_step += 1
|
|
||||||
self.accumulative_meter.reset()
|
|
||||||
step_bar.update()
|
|
||||||
|
|
||||||
# Save checkpoint
|
loss_mean = all_reduce_mean(tensor=loss)
|
||||||
if (
|
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||||
self.save_dir is not None
|
|
||||||
and self.save_interval is not None
|
# Gradient accumulation
|
||||||
and (self.num_train_step + 1) % self.save_interval == 0
|
if (i + 1) % self.accumulation_steps == 0:
|
||||||
):
|
self.optimizer.step()
|
||||||
save_checkpoint(
|
self.optimizer.zero_grad()
|
||||||
save_dir=self.save_dir,
|
self.scheduler.step()
|
||||||
booster=self.booster,
|
|
||||||
model=self.model,
|
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
||||||
optimizer=self.optimizer,
|
if self.writer:
|
||||||
lr_scheduler=self.scheduler,
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||||
epoch=epoch,
|
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
||||||
step=self.num_train_step + 1,
|
self.num_train_step += 1
|
||||||
batch_size=batch_size,
|
self.accumulative_meter.reset()
|
||||||
coordinator=self.coordinator,
|
step_bar.update()
|
||||||
)
|
|
||||||
self.coordinator.print_on_master(
|
# Save checkpoint
|
||||||
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
|
if (
|
||||||
)
|
self.save_dir is not None
|
||||||
|
and self.save_interval is not None
|
||||||
|
and (self.num_train_step + 1) % self.save_interval == 0
|
||||||
|
):
|
||||||
|
save_checkpoint(
|
||||||
|
save_dir=self.save_dir,
|
||||||
|
booster=self.booster,
|
||||||
|
model=self.model,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
lr_scheduler=self.scheduler,
|
||||||
|
epoch=epoch,
|
||||||
|
step=self.num_train_step + 1,
|
||||||
|
batch_size=batch_size,
|
||||||
|
coordinator=self.coordinator,
|
||||||
|
)
|
||||||
|
self.coordinator.print_on_master(
|
||||||
|
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
|
||||||
|
)
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
|
||||||
def _eval(self, epoch: int):
|
def _eval(self, epoch: int):
|
||||||
|
@ -114,7 +114,7 @@ def train(args):
|
|||||||
parallel_output=False,
|
parallel_output=False,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
microbatch_size=args.batch_size,
|
microbatch_size=args.microbatch_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
@ -269,6 +269,7 @@ def train(args):
|
|||||||
model=model,
|
model=model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
accumulation_steps=args.accumulation_steps,
|
accumulation_steps=args.accumulation_steps,
|
||||||
@ -344,6 +345,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||||
|
parser.add_argument("--microbatch_size", type=int, default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.config_file is not None:
|
if args.config_file is not None:
|
||||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user