add profiling

This commit is contained in:
Tong Li 2025-06-26 17:49:53 +08:00
parent 71ef6b32c6
commit 58cb4fb4f7
2 changed files with 24 additions and 9 deletions

View File

@ -134,8 +134,12 @@ class BaseConsumer:
def calculate_effective_group_to_raw_group_mapping(self, step):
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None and (self.buffer[buffer_idx][-1] <= step - self.n_behind):
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
if self.buffer[buffer_idx][0] is not None:
if self.n_behind == 0:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
else:
if self.buffer[buffer_idx][-1] <= step - self.n_behind:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
return effective_group_to_raw_group_mapping
def loop(self) -> None:

View File

@ -246,14 +246,25 @@ if __name__ == "__main__":
tensor_parallel_size=args.producer_tensor_parallel_size,
)
)
generate_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
if args.enable_profiling:
# If profiling is enabled, we force model to generate to max_new_tokens
inference_model_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True,
include_stop_str_in_output=True,
stop=None,
)
)
else:
generate_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
)
)
)
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
else:
raise ValueError(f"Unsupported backend: {args.backend}")