support resume training

This commit is contained in:
YeAnbang
2025-08-12 08:10:56 +00:00
parent 08a1244ef1
commit e589ec505e
5 changed files with 98 additions and 10 deletions

View File

@@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument(
"-cp",
"--checkpoint-path",
type=str,
default=None,
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
)
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument(
"-ed",
@@ -226,8 +233,10 @@ if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path)
train_model_config = dict(
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers":