print results

This commit is contained in:
Tong Li 2025-03-13 16:51:22 +08:00
parent 4702d57841
commit 45ac6c6cb2

View File

@ -154,7 +154,11 @@ class SimpleProducer(BaseProducer):
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
return self.model.generate(input_ids, attention_mask, **kwargs)
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1:
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
return rollouts
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)