update logging

This commit is contained in:
YeAnbang
2025-03-21 16:12:07 +08:00
parent d8eaf0d473
commit 2aa7385c88
2 changed files with 9 additions and 5 deletions

View File

@@ -101,6 +101,9 @@ class BaseProducer:
break
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"