diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 2034d5cc8..6842f81d9 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -21,16 +21,13 @@ class Actor(LoRAModule): self.model = model self.convert_to_lora() - def forward(self, - input_ids: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None, - **model_kwargs, # HACK: `generate` method may pass more kwargs - ) -> torch.Tensor: + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, # HACK: `generate` method may pass more kwargs + ) -> torch.Tensor: """Returns model output. """ - output = self.model( - input_ids, - attention_mask=attention_mask, - **model_kwargs - ) + output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) return output