diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 24ddca654..063ea233e 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer): ref_model: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -63,7 +64,7 @@ class DPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index 6462ba816..dd7dabfe6 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer): ref_model: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -66,7 +67,7 @@ class KTOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index c2f75771c..9a3adcd73 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer): actor: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -59,7 +60,7 @@ class ORPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.odds_ratio_loss_fn = OddsRatioLoss() diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index d88750aeb..3b324ee78 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -267,6 +267,7 @@ def train(args): ref_model=ref_model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py index 598fd8062..931c16577 100755 --- a/applications/ColossalChat/examples/training_scripts/train_kto.py +++ b/applications/ColossalChat/examples/training_scripts/train_kto.py @@ -286,6 +286,7 @@ def train(args): ref_model=ref_model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py index 87860f7ea..0f2fbfa2b 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -250,6 +250,7 @@ def train(args): actor=model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs,