mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-29 22:37:14 +00:00
support resume training
This commit is contained in:
@@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost"
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument(
|
||||
"-cp",
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument(
|
||||
"-ed",
|
||||
@@ -226,8 +233,10 @@ if __name__ == "__main__":
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path)
|
||||
train_model_config = dict(
|
||||
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
|
||||
)
|
||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||
|
||||
if args.backend == "transformers":
|
||||
|
||||
Reference in New Issue
Block a user