mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-23 18:20:58 +00:00
support pp training
This commit is contained in:
parent
ceb1e262e7
commit
38c84a1aa0
@ -17,6 +17,7 @@ from coati.experience_maker import Experience
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster import Plugin
|
||||||
|
|
||||||
from .utils import is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ class SLTrainer(ABC):
|
|||||||
max_epochs: int,
|
max_epochs: int,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
start_epoch: int = 0,
|
start_epoch: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -45,6 +47,7 @@ class SLTrainer(ABC):
|
|||||||
self.max_epochs = max_epochs
|
self.max_epochs = max_epochs
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
self.plugin = plugin
|
||||||
self.start_epoch = start_epoch
|
self.start_epoch = start_epoch
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -6,14 +6,16 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean
|
||||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
from .base import SLTrainer
|
from .base import SLTrainer
|
||||||
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
|
|||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
lr_scheduler: _LRScheduler,
|
lr_scheduler: _LRScheduler,
|
||||||
max_epochs: int = 2,
|
max_epochs: int = 2,
|
||||||
|
plugin: Plugin = None,
|
||||||
accumulation_steps: int = 8,
|
accumulation_steps: int = 8,
|
||||||
apply_loss_mask: bool = True,
|
apply_loss_mask: bool = True,
|
||||||
start_epoch=0,
|
start_epoch=0,
|
||||||
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
|
|||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: Optional[DistCoordinator] = None,
|
coordinator: Optional[DistCoordinator] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
|
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
|
||||||
|
|
||||||
self.accumulation_steps = accumulation_steps
|
self.accumulation_steps = accumulation_steps
|
||||||
self.scheduler = lr_scheduler
|
self.scheduler = lr_scheduler
|
||||||
@ -94,6 +97,28 @@ class SFTTrainer(SLTrainer):
|
|||||||
|
|
||||||
def _train(self, epoch: int):
|
def _train(self, epoch: int):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||||
|
data_iter = iter(self.train_dataloader)
|
||||||
|
step_bar = tqdm(
|
||||||
|
range(len(self.train_dataloader)),
|
||||||
|
desc="Step",
|
||||||
|
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||||
|
)
|
||||||
|
for step in step_bar:
|
||||||
|
outputs = self.booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
self.model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
step_bar.set_postfix({"train/loss": loss.item()})
|
||||||
|
step_bar.update()
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
else:
|
||||||
step_bar = trange(
|
step_bar = trange(
|
||||||
len(self.train_dataloader) // self.accumulation_steps,
|
len(self.train_dataloader) // self.accumulation_steps,
|
||||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
|
@ -114,7 +114,7 @@ def train(args):
|
|||||||
parallel_output=False,
|
parallel_output=False,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
microbatch_size=args.batch_size,
|
microbatch_size=args.microbatch_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
@ -269,6 +269,7 @@ def train(args):
|
|||||||
model=model,
|
model=model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
accumulation_steps=args.accumulation_steps,
|
accumulation_steps=args.accumulation_steps,
|
||||||
@ -344,6 +345,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||||
|
parser.add_argument("--microbatch_size", type=int, default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.config_file is not None:
|
if args.config_file is not None:
|
||||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user