mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-09 03:13:54 +00:00
add profiling
This commit is contained in:
@@ -246,14 +246,25 @@ if __name__ == "__main__":
|
||||
tensor_parallel_size=args.producer_tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
if args.enable_profiling:
|
||||
# If profiling is enabled, we force model to generate to max_new_tokens
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
)
|
||||
)
|
||||
)
|
||||
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
Reference in New Issue
Block a user