diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 3e4a5277a..a3ae22a79 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -154,7 +154,11 @@ class SimpleProducer(BaseProducer): @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): - return self.model.generate(input_ids, attention_mask, **kwargs) + rollouts = self.model.generate(input_ids, attention_mask, **kwargs) + if self.producer_idx == 1: + print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) + + return rollouts def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict)