mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
refactor evaluation
This commit is contained in:
@@ -174,14 +174,16 @@ def train(args):
|
||||
|
||||
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
|
||||
actual_ratio = (args.desirable_weight * num_desirable) / (args.undesirable_weight * num_undesirable)
|
||||
if actual_ratio <= 1:
|
||||
raise AssertionError(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase desirable weight or decrease undesirable weight."
|
||||
)
|
||||
elif actual_ratio > 4 / 3:
|
||||
raise AssertionError(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please decrease desirable weight or increase undesirable weight."
|
||||
)
|
||||
if actual_ratio < 1 or actual_ratio > 4 / 3:
|
||||
if not args.auto_weight:
|
||||
raise AssertionError(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase/decrease desirable weight or decrease/increase undesirable weight."
|
||||
)
|
||||
else:
|
||||
args.desirable_weight = args.desirable_weight / actual_ratio
|
||||
coordinator.print_on_master(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, auto weight is enabled, set desirable weight to {args.desirable_weight} and undesirable weight to {args.undesirable_weight}"
|
||||
)
|
||||
|
||||
data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
@@ -304,9 +306,12 @@ def train(args):
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
if args.save_dir is not None:
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
@@ -343,8 +348,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default=None)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
@@ -359,14 +364,16 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--auto_weight", default=False, action="store_true")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
||||
|
Reference in New Issue
Block a user