mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
print results
This commit is contained in:
parent
4702d57841
commit
45ac6c6cb2
@ -154,7 +154,11 @@ class SimpleProducer(BaseProducer):
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
return self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
if self.producer_idx == 1:
|
||||
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
||||
|
||||
return rollouts
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
Loading…
Reference in New Issue
Block a user