mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
fix vllm
This commit is contained in:
@@ -15,18 +15,14 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
|
||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model)
|
||||
generate_config = dict(
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
temperature=1.0,
|
||||
)
|
||||
generate_config = dict(top_k=50, top_p=0.9, temperature=0.7)
|
||||
|
||||
if args.backend == "transformers":
|
||||
inference_model_config.update(
|
||||
@@ -52,19 +48,13 @@ if __name__ == "__main__":
|
||||
)
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.7,
|
||||
)
|
||||
)
|
||||
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=2048,
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -97,6 +87,6 @@ if __name__ == "__main__":
|
||||
plugin_config={},
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=29504,
|
||||
master_port=29503,
|
||||
core_algo=args.algo,
|
||||
)
|
||||
|
Reference in New Issue
Block a user