diff --git a/pilot/conversation.py b/pilot/conversation.py index d52a51b41..877e61a80 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -251,8 +251,6 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回 {question} """ - - default_conversation = conv_one_shot conversation_types = { diff --git a/pilot/model/llm/llm_utils.py b/pilot/model/llm/llm_utils.py index 7a1fb47bd..a68860ee6 100644 --- a/pilot/model/llm/llm_utils.py +++ b/pilot/model/llm/llm_utils.py @@ -2,11 +2,47 @@ # -*- coding: utf-8 -*- import abc +import time +import functools from typing import List, Optional from pilot.model.llm.base import Message from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot from pilot.configs.config import Config + +# TODO Rewrite this +def retry_stream_api( + num_retries: int = 10, + backoff_base: float = 2.0, + warn_user: bool = True +): + """Retry an Vicuna Server call. + + Args: + num_retries int: Number of retries. Defaults to 10. + backoff_base float: Base for exponential backoff. Defaults to 2. + warn_user bool: Whether to warn the user. Defaults to True. + """ + retry_limit_msg = f"Error: Reached rate limit, passing..." + backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...") + + def _wrapper(func): + @functools.wraps(func) + def _wrapped(*args, **kwargs): + user_warned = not warn_user + num_attempts = num_retries + 1 # +1 for the first attempt + for attempt in range(1, num_attempts + 1): + try: + return func(*args, **kwargs) + except Exception as e: + if (e.http_status != 502) or (attempt == num_attempts): + raise + + backoff = backoff_base ** (attempt + 2) + time.sleep(backoff) + return _wrapped + return _wrapper + # Overly simple abstraction util we create something better # simple retry mechanism when getting a rate error or a bad gateway def create_chat_competion( @@ -31,8 +67,10 @@ def create_chat_competion( temperature = cfg.temperature # TODO request vicuna model get response + # convert vicuna message to chat completion. for plugin in cfg.plugins: - pass + if plugin.can_handle_chat_completion(): + pass class ChatIO(abc.ABC): diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 747585fa4..55458ff4a 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -10,7 +10,7 @@ from transformers import ( from fastchat.serve.compression import compress_module -class ModelLoader: +class ModelLoader(): """Model loader is a class for model load Args: model_path