update rm

This commit is contained in:
Tong Li 2024-08-12 11:27:42 +00:00
parent 38c84a1aa0
commit 5a24b0dc31
2 changed files with 4 additions and 2 deletions

View File

@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer):
model: Any, model: Any,
booster: Booster, booster: Booster,
optimizer: Optimizer, optimizer: Optimizer,
plugin: Plugin,
lr_scheduler: _LRScheduler, lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None, loss_fn: Optional[Callable] = None,
@ -59,7 +60,7 @@ class RewardModelTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: DistCoordinator = None, coordinator: DistCoordinator = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch) super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch)
self.actor_scheduler = lr_scheduler self.actor_scheduler = lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta) self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)

View File

@ -262,6 +262,7 @@ def train(args):
model, model,
booster, booster,
optim, optim,
plugin,
lr_scheduler, lr_scheduler,
tokenizer, tokenizer,
loss_fn=loss_fn, loss_fn=loss_fn,