Support overall loss, update KTO logging

This commit is contained in:
YeAnbang
2024-08-02 06:51:38 +00:00
parent 75c963686f
commit 0b2d55c4ab
15 changed files with 119 additions and 119 deletions

View File

@@ -278,6 +278,7 @@ def train(args):
beta=args.beta,
gamma=args.gamma,
length_normalization=args.length_normalization,
apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -346,6 +347,7 @@ if __name__ == "__main__":
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")

View File

@@ -297,6 +297,7 @@ def train(args):
beta=args.beta,
desirable_weight=args.desirable_weight,
undesirable_weight=args.undesirable_weight,
apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -341,6 +342,7 @@ if __name__ == "__main__":
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")

View File

@@ -259,6 +259,7 @@ def train(args):
save_dir=args.save_dir,
coordinator=coordinator,
lam=args.lam,
apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -301,6 +302,7 @@ if __name__ == "__main__":
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")

View File

@@ -411,6 +411,7 @@ def train(args):
use_cache=True,
do_sample=True,
temperature=0.7,
apply_loss_mask=not args.disable_loss_mask,
accumulation_steps=args.accumulation_steps,
save_dir=args.save_path,
save_interval=args.save_interval,
@@ -498,6 +499,7 @@ if __name__ == "__main__":
parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.0)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--log_dir", default="logs", type=str)

View File

@@ -272,6 +272,7 @@ def train(args):
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
apply_loss_mask=not args.disable_loss_mask,
start_epoch=start_epoch,
save_interval=args.save_interval,
save_dir=args.save_path,
@@ -317,6 +318,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")