mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-23 10:11:37 +00:00
update rm
This commit is contained in:
parent
38c84a1aa0
commit
5a24b0dc31
@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer):
|
|||||||
model: Any,
|
model: Any,
|
||||||
booster: Booster,
|
booster: Booster,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
lr_scheduler: _LRScheduler,
|
lr_scheduler: _LRScheduler,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
loss_fn: Optional[Callable] = None,
|
loss_fn: Optional[Callable] = None,
|
||||||
@ -59,7 +60,7 @@ class RewardModelTrainer(SLTrainer):
|
|||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: DistCoordinator = None,
|
coordinator: DistCoordinator = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
|
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch)
|
||||||
self.actor_scheduler = lr_scheduler
|
self.actor_scheduler = lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
|
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
|
||||||
|
@ -262,6 +262,7 @@ def train(args):
|
|||||||
model,
|
model,
|
||||||
booster,
|
booster,
|
||||||
optim,
|
optim,
|
||||||
|
plugin,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
|
Loading…
Reference in New Issue
Block a user