mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-23 18:20:58 +00:00
support pp training
This commit is contained in:
parent
ceb1e262e7
commit
38c84a1aa0
@ -17,6 +17,7 @@ from coati.experience_maker import Experience
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster import Plugin
|
||||
|
||||
from .utils import is_rank_0
|
||||
|
||||
@ -38,6 +39,7 @@ class SLTrainer(ABC):
|
||||
max_epochs: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
plugin: Plugin,
|
||||
start_epoch: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -45,6 +47,7 @@ class SLTrainer(ABC):
|
||||
self.max_epochs = max_epochs
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.plugin = plugin
|
||||
self.start_epoch = start_epoch
|
||||
|
||||
@abstractmethod
|
||||
|
@ -6,14 +6,16 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from .base import SLTrainer
|
||||
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
|
||||
optim: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
max_epochs: int = 2,
|
||||
plugin: Plugin = None,
|
||||
accumulation_steps: int = 8,
|
||||
apply_loss_mask: bool = True,
|
||||
start_epoch=0,
|
||||
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
|
||||
save_dir: str = None,
|
||||
coordinator: Optional[DistCoordinator] = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
|
||||
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
|
||||
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.scheduler = lr_scheduler
|
||||
@ -94,6 +97,28 @@ class SFTTrainer(SLTrainer):
|
||||
|
||||
def _train(self, epoch: int):
|
||||
self.model.train()
|
||||
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||
data_iter = iter(self.train_dataloader)
|
||||
step_bar = tqdm(
|
||||
range(len(self.train_dataloader)),
|
||||
desc="Step",
|
||||
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||
)
|
||||
for step in step_bar:
|
||||
outputs = self.booster.execute_pipeline(
|
||||
data_iter,
|
||||
self.model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
step_bar.set_postfix({"train/loss": loss.item()})
|
||||
step_bar.update()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
else:
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
|
@ -114,7 +114,7 @@ def train(args):
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.batch_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
@ -269,6 +269,7 @@ def train(args):
|
||||
model=model,
|
||||
booster=booster,
|
||||
optim=optim,
|
||||
plugin=plugin,
|
||||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
@ -344,6 +345,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
|
Loading…
Reference in New Issue
Block a user