diff --git a/pilot/model/llm/llm_utils.py b/pilot/model/llm/llm_utils.py index 513eee492..7a1fb47bd 100644 --- a/pilot/model/llm/llm_utils.py +++ b/pilot/model/llm/llm_utils.py @@ -1,17 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import abc from typing import List, Optional from pilot.model.llm.base import Message +from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot from pilot.configs.config import Config # Overly simple abstraction util we create something better # simple retry mechanism when getting a rate error or a bad gateway def create_chat_competion( - messages: List[Message], + conv: Conversation, model: Optional[str] = None, temperature: float = None, - max_tokens: Optional[int] = None, + max_new_tokens: Optional[int] = None, ) -> str: """Create a chat completion using the Vicuna-13b @@ -27,6 +29,42 @@ def create_chat_competion( cfg = Config() if temperature is None: temperature = cfg.temperature - + # TODO request vicuna model get response - \ No newline at end of file + for plugin in cfg.plugins: + pass + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str) -> str: + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream, skip_echo_len: int): + """Stream output.""" + + +class SimpleChatIO(ChatIO): + def prompt_for_input(self, role: str) -> str: + return input(f"{role}: ") + + def prompt_for_output(self, role: str) -> str: + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream, skip_echo_len: int): + pre = 0 + for outputs in output_stream: + outputs = outputs[skip_echo_len:].strip() + now = len(outputs) - 1 + if now > pre: + print(" ".join(outputs[pre:now]), end=" ", flush=True) + pre = now + + print(" ".join(outputs[pre:]), flush=True) + return " ".join(outputs) +