mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-22 17:58:56 +00:00
refactor
This commit is contained in:
parent
56fd2dc5d2
commit
7d9907f0ae
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer):
|
|||||||
ref_model: Any,
|
ref_model: Any,
|
||||||
booster: Booster,
|
booster: Booster,
|
||||||
actor_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
actor_lr_scheduler: _LRScheduler,
|
actor_lr_scheduler: _LRScheduler,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
max_epochs: int = 1,
|
max_epochs: int = 1,
|
||||||
@ -63,7 +64,7 @@ class DPOTrainer(SLTrainer):
|
|||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: DistCoordinator = None,
|
coordinator: DistCoordinator = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch)
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
self.actor_scheduler = actor_lr_scheduler
|
self.actor_scheduler = actor_lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -17,7 +17,7 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer):
|
|||||||
ref_model: Any,
|
ref_model: Any,
|
||||||
booster: Booster,
|
booster: Booster,
|
||||||
actor_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
actor_lr_scheduler: _LRScheduler,
|
actor_lr_scheduler: _LRScheduler,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
max_epochs: int = 1,
|
max_epochs: int = 1,
|
||||||
@ -66,7 +67,7 @@ class KTOTrainer(SLTrainer):
|
|||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: DistCoordinator = None,
|
coordinator: DistCoordinator = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch)
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
self.actor_scheduler = actor_lr_scheduler
|
self.actor_scheduler = actor_lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer):
|
|||||||
actor: Any,
|
actor: Any,
|
||||||
booster: Booster,
|
booster: Booster,
|
||||||
actor_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
actor_lr_scheduler: _LRScheduler,
|
actor_lr_scheduler: _LRScheduler,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
max_epochs: int = 1,
|
max_epochs: int = 1,
|
||||||
@ -59,7 +60,7 @@ class ORPOTrainer(SLTrainer):
|
|||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: DistCoordinator = None,
|
coordinator: DistCoordinator = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch)
|
||||||
self.actor_scheduler = actor_lr_scheduler
|
self.actor_scheduler = actor_lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.odds_ratio_loss_fn = OddsRatioLoss()
|
self.odds_ratio_loss_fn = OddsRatioLoss()
|
||||||
|
@ -267,6 +267,7 @@ def train(args):
|
|||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
actor_optim=optim,
|
actor_optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
actor_lr_scheduler=lr_scheduler,
|
actor_lr_scheduler=lr_scheduler,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
|
@ -286,6 +286,7 @@ def train(args):
|
|||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
actor_optim=optim,
|
actor_optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
actor_lr_scheduler=lr_scheduler,
|
actor_lr_scheduler=lr_scheduler,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
|
@ -250,6 +250,7 @@ def train(args):
|
|||||||
actor=model,
|
actor=model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
actor_optim=optim,
|
actor_optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
actor_lr_scheduler=lr_scheduler,
|
actor_lr_scheduler=lr_scheduler,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
|
Loading…
Reference in New Issue
Block a user