From 709e121cd5a98e01177e47dd4fe8a0833ff0af8a Mon Sep 17 00:00:00 2001 From: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Date: Tue, 18 Jul 2023 18:04:02 +0800 Subject: [PATCH] [NFC] polish applications/Chat/coati/models/generation.py code style (#4275) --- applications/Chat/coati/models/generation.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index 0156e2284..d96ad78a8 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -5,7 +5,6 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F - try: from transformers.generation_logits_process import ( LogitsProcessorList, @@ -148,12 +147,12 @@ def generate(model: nn.Module, @torch.no_grad() -def generate_with_actor(actor_model: nn.Module, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs - ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], - Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: +def generate_with_actor( + actor_model: nn.Module, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs +) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: """Generate token sequence with actor model. Refer to `generate` for more details. """ # generate sequences