model server fix message model

This commit is contained in:
yhjun1026 2023-06-01 16:34:51 +08:00
parent 96c516ab55
commit 661a7b5697
11 changed files with 63 additions and 37 deletions

View File

@ -3,7 +3,6 @@
import gradio as gr
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import (
Document,
GPTSimpleVectorIndex,
@ -12,7 +11,7 @@ from llama_index import (
ServiceContext,
)
from pilot.model.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
from pilot.model.llm_out.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
def agent_demo():
@ -49,7 +48,7 @@ def get_answer(q):
def get_similar(q):
from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
from pilot.vector_store.extract_tovec import knownledge_tovec_st
docsearch = knownledge_tovec_st("./datasets/plan.md")
docs = docsearch.similarity_search_with_score(q, k=1)

View File

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import copy
import torch
@ -8,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@torch.inference_mode()
def chatglm_generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2
model, tokenizer, params, device, context_len=2048, stream_interval=2
):
"""Generate text using chatglm model's chat api"""
prompt = params["prompt"]
@ -28,25 +29,34 @@ def chatglm_generate_stream(
generate_kwargs["temperature"] = temperature
# TODO, Fix this
hist = []
messages = prompt.split(stop)
#
# # Add history conversation
hist = []
once_conversation = []
for message in messages:
if len(message) <= 0:
continue
# Add history chat to hist for model.
for i in range(1, len(messages) - 2, 2):
hist.append(
(
messages[i].split(ROLE_USER + ":")[1],
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
)
)
if "human:" in message:
once_conversation.append(message.split("human:")[1])
# elif "system:" in message:
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
elif "ai:" in message:
once_conversation.append(message.split("ai:")[1])
last_conversation = copy.deepcopy(once_conversation)
hist.append(last_conversation)
once_conversation = []
# else:
# once_conversation.append(f"""###system:{message} """)
query = messages[-2].split(ROLE_USER + ":")[1]
query = messages[-1].split("human:")[1]
print("Query Message: ", query)
output = ""
i = 0
# output = ""
# i = 0
for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
):
if echo:
output = query + " " + response

View File

@ -1,4 +1,5 @@
import torch
import copy
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@ -8,19 +9,35 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
print(params)
stop = params.get("stop", "###")
messages = params["prompt"]
prompt = params["prompt"]
messages = prompt.split(stop)
#
# # Add history conversation
# hist = []
# once_conversation = []
# for message in messages[:-1]:
# if len(message) <= 0:
# continue
#
# if "human:" in message:
# once_conversation.append(f"""###system:{message.split("human:")[1]} """ )
# elif "system:" in message:
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
# elif "ai:" in message:
# once_conversation.append(f"""###system:{message.split("ai:")[1]} """)
# last_conversation = copy.deepcopy(once_conversation)
# hist.append("".join(last_conversation))
# once_conversation = []
# else:
# once_conversation.append(f"""###system:{message} """)
#
#
#
#
#
# query = "".join(hist)
hist = []
for i in range(1, len(messages) - 2, 2):
hist.append(
(
messages[i].split(ROLE_USER + ":")[1],
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
)
)
query = messages[-2].split(ROLE_USER + ":")[1]
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids

View File

@ -4,7 +4,7 @@
from functools import cache
from typing import List
from pilot.model.inference import generate_stream
from pilot.model.llm_out.vicuna_base_llm import generate_stream
class BaseChatAdpter:
@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
return "chatglm" in model_path
def get_generate_stream_func(self):
from pilot.model.chatglm_llm import chatglm_generate_stream
from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream
@ -91,7 +91,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
return "guanaco" in model_path
def get_generate_stream_func(self):
from pilot.model.guanaco_llm import guanaco_generate_output
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
return guanaco_generate_output
@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
return "proxyllm" in model_path
def get_generate_stream_func(self):
from pilot.model.proxy_llm import proxyllm_generate_stream
from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream

View File

@ -19,7 +19,7 @@ sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import *
from pilot.model.inference import generate_output, generate_stream, get_embeddings
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader
from pilot.server.chat_adapter import get_llm_chat_adapter

View File

@ -6,7 +6,7 @@ from langchain.prompts import PromptTemplate
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
from pilot.logs import logger
from pilot.model.vicuna_llm import VicunaLLM
from pilot.model.llm_out.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector

View File

@ -8,7 +8,7 @@ from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from pilot.configs.model_config import DATASETS_DIR, VECTORE_PATH
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
from pilot.model.llm_out.vicuna_llm import VicunaEmbeddingLLM
embeddings = VicunaEmbeddingLLM()