This commit is contained in:
Tong Li 2024-08-12 11:35:14 +00:00
parent 56fd2dc5d2
commit 7d9907f0ae
6 changed files with 12 additions and 6 deletions

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from tqdm import trange from tqdm import trange
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer):
ref_model: Any, ref_model: Any,
booster: Booster, booster: Booster,
actor_optim: Optimizer, actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler, actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1, max_epochs: int = 1,
@ -63,7 +64,7 @@ class DPOTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: DistCoordinator = None, coordinator: DistCoordinator = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch)
self.ref_model = ref_model self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer

View File

@ -17,7 +17,7 @@ from torch.utils.data import DataLoader
from tqdm import trange from tqdm import trange
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer):
ref_model: Any, ref_model: Any,
booster: Booster, booster: Booster,
actor_optim: Optimizer, actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler, actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1, max_epochs: int = 1,
@ -66,7 +67,7 @@ class KTOTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: DistCoordinator = None, coordinator: DistCoordinator = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch)
self.ref_model = ref_model self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from tqdm import trange from tqdm import trange
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer):
actor: Any, actor: Any,
booster: Booster, booster: Booster,
actor_optim: Optimizer, actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler, actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1, max_epochs: int = 1,
@ -59,7 +60,7 @@ class ORPOTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: DistCoordinator = None, coordinator: DistCoordinator = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch)
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss() self.odds_ratio_loss_fn = OddsRatioLoss()

View File

@ -267,6 +267,7 @@ def train(args):
ref_model=ref_model, ref_model=ref_model,
booster=booster, booster=booster,
actor_optim=optim, actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler, actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer, tokenizer=tokenizer,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,

View File

@ -286,6 +286,7 @@ def train(args):
ref_model=ref_model, ref_model=ref_model,
booster=booster, booster=booster,
actor_optim=optim, actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler, actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer, tokenizer=tokenizer,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,

View File

@ -250,6 +250,7 @@ def train(args):
actor=model, actor=model,
booster=booster, booster=booster,
actor_optim=optim, actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler, actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer, tokenizer=tokenizer,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,