remove debug code

This commit is contained in:
YeAnbang 2025-06-09 09:42:58 +08:00
parent 3bed6ae9ee
commit d0e12c508a

View File

@ -273,12 +273,10 @@ class BaseProducer:
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
# breakpoint()
if self.grpo_config["reward_fn_type"] == "code":
test_cases = []
for prompt_id in range(bs):
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
# assert all([isinstance(test_cases[i], str) for i in range(len(test_cases))]), f"{[type(test_cases[i]) for i in range(len(test_cases))]}, {test_cases}"
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
test_cases=test_cases,
@ -315,7 +313,6 @@ class BaseProducer:
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
# breakpoint()
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)