mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
print results
This commit is contained in:
parent
4702d57841
commit
45ac6c6cb2
@ -154,7 +154,11 @@ class SimpleProducer(BaseProducer):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
return self.model.generate(input_ids, attention_mask, **kwargs)
|
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||||
|
if self.producer_idx == 1:
|
||||||
|
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
||||||
|
|
||||||
|
return rollouts
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
Loading…
Reference in New Issue
Block a user