update conversation

This commit is contained in:
csunny 2023-05-09 21:48:47 +08:00
parent fa965999e1
commit bfbbf0ba88
3 changed files with 40 additions and 4 deletions

View File

@ -251,8 +251,6 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回
{question}
"""
default_conversation = conv_one_shot
conversation_types = {

View File

@ -2,11 +2,47 @@
# -*- coding: utf-8 -*-
import abc
import time
import functools
from typing import List, Optional
from pilot.model.llm.base import Message
from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot
from pilot.configs.config import Config
# TODO Rewrite this
def retry_stream_api(
num_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True
):
"""Retry an Vicuna Server call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"Error: Reached rate limit, passing..."
backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...")
def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
if (e.http_status != 502) or (attempt == num_attempts):
raise
backoff = backoff_base ** (attempt + 2)
time.sleep(backoff)
return _wrapped
return _wrapper
# Overly simple abstraction util we create something better
# simple retry mechanism when getting a rate error or a bad gateway
def create_chat_competion(
@ -31,8 +67,10 @@ def create_chat_competion(
temperature = cfg.temperature
# TODO request vicuna model get response
# convert vicuna message to chat completion.
for plugin in cfg.plugins:
pass
if plugin.can_handle_chat_completion():
pass
class ChatIO(abc.ABC):

View File

@ -10,7 +10,7 @@ from transformers import (
from fastchat.serve.compression import compress_module
class ModelLoader:
class ModelLoader():
"""Model loader is a class for model load
Args: model_path