From 4a541aa27c26edde9bc9ef3421e72a5ff6693f04 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 12 Aug 2024 10:13:03 +0000 Subject: [PATCH 01/14] support pp training --- .../ColossalChat/coati/trainer/base.py | 3 + .../ColossalChat/coati/trainer/sft.py | 127 +++++++++++------- .../examples/training_scripts/train_sft.py | 4 +- 3 files changed, 82 insertions(+), 52 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/base.py b/applications/ColossalChat/coati/trainer/base.py index 63c903a51..2e63fc5c8 100755 --- a/applications/ColossalChat/coati/trainer/base.py +++ b/applications/ColossalChat/coati/trainer/base.py @@ -17,6 +17,7 @@ from coati.experience_maker import Experience from torch.optim import Optimizer from colossalai.booster import Booster +from colossalai.booster import Plugin from .utils import is_rank_0 @@ -38,6 +39,7 @@ class SLTrainer(ABC): max_epochs: int, model: nn.Module, optimizer: Optimizer, + plugin: Plugin, start_epoch: int = 0, ) -> None: super().__init__() @@ -45,6 +47,7 @@ class SLTrainer(ABC): self.max_epochs = max_epochs self.model = model self.optimizer = optimizer + self.plugin = plugin self.start_epoch = start_epoch @abstractmethod diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index d37676ada..ebdfd5024 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -6,14 +6,16 @@ import os from typing import Optional import torch +import torch.distributed as dist from coati.trainer.utils import all_reduce_mean from coati.utils import AccumulativeMeanMeter, save_checkpoint from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from tqdm import trange +from tqdm import tqdm, trange from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin, Plugin from colossalai.cluster import DistCoordinator from .base import SLTrainer @@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer): optim: Optimizer, lr_scheduler: _LRScheduler, max_epochs: int = 2, + plugin: Plugin = None, accumulation_steps: int = 8, apply_loss_mask: bool = True, start_epoch=0, @@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer): save_dir: str = None, coordinator: Optional[DistCoordinator] = None, ) -> None: - super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch) self.accumulation_steps = accumulation_steps self.scheduler = lr_scheduler @@ -94,60 +97,82 @@ class SFTTrainer(SLTrainer): def _train(self, epoch: int): self.model.train() - step_bar = trange( - len(self.train_dataloader) // self.accumulation_steps, - desc=f"Epoch {epoch + 1}/{self.max_epochs}", - disable=not is_rank_0(), - ) - for i, batch in enumerate(self.train_dataloader): - batch = to_device(batch, torch.cuda.current_device()) - batch_size = batch["input_ids"].size(0) - outputs = self.model( - batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1: + data_iter = iter(self.train_dataloader) + step_bar = tqdm( + range(len(self.train_dataloader)), + desc="Step", + disable=not (dist.get_rank() == dist.get_world_size() - 1), ) - loss = outputs.loss - - self.booster.backward(loss=loss, optimizer=self.optimizer) - - loss_mean = all_reduce_mean(tensor=loss) - self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) - - # Gradient accumulation - if (i + 1) % self.accumulation_steps == 0: + for step in step_bar: + outputs = self.booster.execute_pipeline( + data_iter, + self.model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + step_bar.set_postfix({"train/loss": loss.item()}) + step_bar.update() self.optimizer.step() self.optimizer.zero_grad() - self.scheduler.step() + else: + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + batch_size = batch["input_ids"].size(0) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) + loss = outputs.loss - step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) - if self.writer: - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) - self.num_train_step += 1 - self.accumulative_meter.reset() - step_bar.update() + self.booster.backward(loss=loss, optimizer=self.optimizer) - # Save checkpoint - if ( - self.save_dir is not None - and self.save_interval is not None - and (self.num_train_step + 1) % self.save_interval == 0 - ): - save_checkpoint( - save_dir=self.save_dir, - booster=self.booster, - model=self.model, - optimizer=self.optimizer, - lr_scheduler=self.scheduler, - epoch=epoch, - step=self.num_train_step + 1, - batch_size=batch_size, - coordinator=self.coordinator, - ) - self.coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" - ) + loss_mean = all_reduce_mean(tensor=loss) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + + # Gradient accumulation + if (i + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) + if self.writer: + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) + self.num_train_step += 1 + self.accumulative_meter.reset() + step_bar.update() + + # Save checkpoint + if ( + self.save_dir is not None + and self.save_interval is not None + and (self.num_train_step + 1) % self.save_interval == 0 + ): + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.scheduler, + epoch=epoch, + step=self.num_train_step + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" + ) step_bar.close() def _eval(self, epoch: int): diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index c4ef3b783..62acad32f 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -114,7 +114,7 @@ def train(args): parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, - microbatch_size=args.batch_size, + microbatch_size=args.microbatch_size, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -269,6 +269,7 @@ def train(args): model=model, booster=booster, optim=optim, + plugin=plugin, lr_scheduler=lr_scheduler, max_epochs=args.max_epochs, accumulation_steps=args.accumulation_steps, @@ -344,6 +345,7 @@ if __name__ == "__main__": parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true") + parser.add_argument("--microbatch_size", type=int, default=1) args = parser.parse_args() if args.config_file is not None: os.makedirs(os.path.dirname(args.config_file), exist_ok=True) From 515f8e4a438c2520bbdb89561bd502651fa75158 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 10:15:34 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/trainer/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/base.py b/applications/ColossalChat/coati/trainer/base.py index 2e63fc5c8..bef4ccc3e 100755 --- a/applications/ColossalChat/coati/trainer/base.py +++ b/applications/ColossalChat/coati/trainer/base.py @@ -16,8 +16,7 @@ from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience from torch.optim import Optimizer -from colossalai.booster import Booster -from colossalai.booster import Plugin +from colossalai.booster import Booster, Plugin from .utils import is_rank_0 From 123107ff288a5a9d95efd26e1f8968a7a6183009 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 12 Aug 2024 11:27:42 +0000 Subject: [PATCH 03/14] update rm --- applications/ColossalChat/coati/trainer/rm.py | 5 +++-- .../ColossalChat/examples/training_scripts/train_rm.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py index b9e84ef55..849a90a27 100755 --- a/applications/ColossalChat/coati/trainer/rm.py +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer): model: Any, booster: Booster, optimizer: Optimizer, + plugin: Plugin, lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, loss_fn: Optional[Callable] = None, @@ -59,7 +60,7 @@ class RewardModelTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch) + super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch) self.actor_scheduler = lr_scheduler self.tokenizer = tokenizer self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta) diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py index 4c0a782b4..5ea1a06ac 100755 --- a/applications/ColossalChat/examples/training_scripts/train_rm.py +++ b/applications/ColossalChat/examples/training_scripts/train_rm.py @@ -262,6 +262,7 @@ def train(args): model, booster, optim, + plugin, lr_scheduler, tokenizer, loss_fn=loss_fn, From 2c926141f335ccaef5d630287be50588122587e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 11:29:21 +0000 Subject: [PATCH 04/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/trainer/rm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py index 849a90a27..82e4625b9 100755 --- a/applications/ColossalChat/coati/trainer/rm.py +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -48,7 +48,7 @@ class RewardModelTrainer(SLTrainer): model: Any, booster: Booster, optimizer: Optimizer, - plugin: Plugin, + plugin: Plugin, lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, loss_fn: Optional[Callable] = None, @@ -60,7 +60,9 @@ class RewardModelTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch + ) self.actor_scheduler = lr_scheduler self.tokenizer = tokenizer self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta) From 7d9907f0aef9208a4e933acc041b1346e986574d Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 12 Aug 2024 11:35:14 +0000 Subject: [PATCH 05/14] refactor --- applications/ColossalChat/coati/trainer/dpo.py | 5 +++-- applications/ColossalChat/coati/trainer/kto.py | 5 +++-- applications/ColossalChat/coati/trainer/orpo.py | 5 +++-- .../ColossalChat/examples/training_scripts/train_dpo.py | 1 + .../ColossalChat/examples/training_scripts/train_kto.py | 1 + .../ColossalChat/examples/training_scripts/train_orpo.py | 1 + 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 24ddca654..063ea233e 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer): ref_model: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -63,7 +64,7 @@ class DPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index 6462ba816..dd7dabfe6 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer): ref_model: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -66,7 +67,7 @@ class KTOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index c2f75771c..9a3adcd73 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer): actor: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -59,7 +60,7 @@ class ORPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.odds_ratio_loss_fn = OddsRatioLoss() diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index d88750aeb..3b324ee78 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -267,6 +267,7 @@ def train(args): ref_model=ref_model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py index 598fd8062..931c16577 100755 --- a/applications/ColossalChat/examples/training_scripts/train_kto.py +++ b/applications/ColossalChat/examples/training_scripts/train_kto.py @@ -286,6 +286,7 @@ def train(args): ref_model=ref_model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py index 87860f7ea..0f2fbfa2b 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -250,6 +250,7 @@ def train(args): actor=model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, From 49f7428cbf5232bc7c3e8cf7bf493adaf0084a25 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 11:36:42 +0000 Subject: [PATCH 06/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/trainer/dpo.py | 4 +++- applications/ColossalChat/coati/trainer/kto.py | 4 +++- applications/ColossalChat/coati/trainer/orpo.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 063ea233e..faa7a90d9 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -64,7 +64,9 @@ class DPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index dd7dabfe6..f0b23afb6 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -67,7 +67,9 @@ class KTOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index 9a3adcd73..761fd305a 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -60,7 +60,9 @@ class ORPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.odds_ratio_loss_fn = OddsRatioLoss() From a8356da3c7125fdda2d4f7c0a944063589a590a5 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 Aug 2024 02:45:53 +0000 Subject: [PATCH 07/14] update test case --- applications/ColossalChat/tests/test_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index c26b25c83..621f66449 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models MODELS_DIR=$TEMP_DIR/models_config # Skip those tests due to CI tests timeout MODELS=('llama') -ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy +ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu', 'tp_pp', 'pp') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" From 8ce504d05cc32e625e5112f83790fa558b5a4997 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 Aug 2024 02:47:52 +0000 Subject: [PATCH 08/14] fix --- applications/ColossalChat/tests/test_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index 621f66449..f81b31550 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models MODELS_DIR=$TEMP_DIR/models_config # Skip those tests due to CI tests timeout MODELS=('llama') -ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu', 'tp_pp', 'pp') +ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'tp_pp' 'pp') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" From 4a5bfc55a65e7a54341a4f7ceb32542190a4eeaf Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 Aug 2024 04:02:21 +0000 Subject: [PATCH 09/14] change to 4 --- applications/ColossalChat/tests/test_train.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index fd8a5960b..8bc895c7f 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 2 +set_n_least_used_CUDA_VISIBLE_DEVICES 4 set -xu @@ -175,7 +175,7 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") done - colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ --pretrain $pretrain \ --tokenizer_dir $tokenizer_dir \ --dataset ${dataset[@]} \ From 0b2b454b97d55d1f974c28951fc5465b4ff24a8b Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 Aug 2024 06:48:54 +0000 Subject: [PATCH 10/14] fix eval --- .../ColossalChat/coati/trainer/sft.py | 82 +++++++++++++------ 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index ebdfd5024..6322cb8df 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -182,27 +182,63 @@ class SFTTrainer(SLTrainer): self.accumulative_meter.reset() self.model.eval() with torch.no_grad(): - step_bar = trange( - len(self.eval_dataloader), - desc=f"Epoch {epoch + 1}/{self.max_epochs}", - disable=not is_rank_0(), - ) - for batch in self.eval_dataloader: - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model( - batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1: + data_iter = iter(self.eval_dataloader) + step_bar = tqdm( + range(len(self.eval_dataloader)), + desc="Step", + disable=not (dist.get_rank() == dist.get_world_size() - 1), ) - loss_mean = all_reduce_mean(tensor=outputs.loss) - self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) - step_bar.update() - loss_mean = self.accumulative_meter.get("loss") - msg = "Evaluation Result:\n" - for tag in ["loss"]: - msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" - self.coordinator.print_on_master(msg) - os.makedirs(self.save_dir, exist_ok=True) - with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: - f.write(msg) - step_bar.close() + for step in step_bar: + outputs = self.booster.execute_pipeline( + data_iter, + self.model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + step_bar.set_postfix({"eval/loss": loss.item()}) + self.accumulative_meter.add("loss", loss.item()) + step_bar.update() + + if dist.get_rank() == dist.get_world_size() - 1: + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + print(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() + + else: + step_bar = trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for batch in self.eval_dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) + loss_mean = all_reduce_mean(tensor=outputs.loss) + self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) + step_bar.update() + + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() From 74ee10e77dfa9cf242d2df5a321831927db679c8 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 Aug 2024 07:32:25 +0000 Subject: [PATCH 11/14] test --- applications/ColossalChat/tests/test_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index 8bc895c7f..8666b52a5 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models MODELS_DIR=$TEMP_DIR/models_config # Skip those tests due to CI tests timeout MODELS=('llama') -ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'tp_pp' 'pp') +ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" From 22218d31e1f7093f4f117418dfa54d5c35db1790 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 Aug 2024 07:53:10 +0000 Subject: [PATCH 12/14] add pp --- applications/ColossalChat/tests/test_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index 8666b52a5..7b3b4ab4f 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models MODELS_DIR=$TEMP_DIR/models_config # Skip those tests due to CI tests timeout MODELS=('llama') -ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') +ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" From 2422341d0360900062de317bd31ced22e5bb6b07 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 13 Aug 2024 09:35:03 +0000 Subject: [PATCH 13/14] hotfix --- applications/ColossalChat/tests/test_train.sh | 2 +- colossalai/booster/plugin/hybrid_parallel_plugin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index 7b3b4ab4f..3b06495cb 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models MODELS_DIR=$TEMP_DIR/models_config # Skip those tests due to CI tests timeout MODELS=('llama') -ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp') +ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp' 'tp_pp') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d2933a4af..faf1f0218 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1328,7 +1328,7 @@ class HybridParallelPlugin(PipelinePluginBase): # run with gradients accumulation if model.require_grad_sync == False or ( isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False - ): + ) or not torch.is_grad_enabled(): return outputs # Synchronize the grads of shared parameters of the model. From 2789c9ee6d4e0c3067f42988ce2b595e797876a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Aug 2024 09:36:22 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index faf1f0218..e5acdb051 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1326,9 +1326,11 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False - ) or not torch.is_grad_enabled(): + if ( + model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) + or not torch.is_grad_enabled() + ): return outputs # Synchronize the grads of shared parameters of the model.