From 77c469e1ba948cdc6c4d6dd32ec151d653255ad9 Mon Sep 17 00:00:00 2001 From: Junming Wu Date: Tue, 18 Jul 2023 10:43:52 +0800 Subject: [PATCH] [NFC] polish applications/Chat/coati/models/base/actor.py code style (#4248) --- applications/Chat/coati/models/base/actor.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 2034d5cc8..6842f81d9 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -21,16 +21,13 @@ class Actor(LoRAModule): self.model = model self.convert_to_lora() - def forward(self, - input_ids: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None, - **model_kwargs, # HACK: `generate` method may pass more kwargs - ) -> torch.Tensor: + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, # HACK: `generate` method may pass more kwargs + ) -> torch.Tensor: """Returns model output. """ - output = self.model( - input_ids, - attention_mask=attention_mask, - **model_kwargs - ) + output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) return output