mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
model server fix message model
This commit is contained in:
parent
96c516ab55
commit
661a7b5697
@ -3,7 +3,6 @@
|
||||
|
||||
import gradio as gr
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from llama_index import (
|
||||
Document,
|
||||
GPTSimpleVectorIndex,
|
||||
@ -12,7 +11,7 @@ from llama_index import (
|
||||
ServiceContext,
|
||||
)
|
||||
|
||||
from pilot.model.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
|
||||
from pilot.model.llm_out.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
|
||||
|
||||
|
||||
def agent_demo():
|
||||
@ -49,7 +48,7 @@ def get_answer(q):
|
||||
|
||||
|
||||
def get_similar(q):
|
||||
from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
|
||||
from pilot.vector_store.extract_tovec import knownledge_tovec_st
|
||||
|
||||
docsearch = knownledge_tovec_st("./datasets/plan.md")
|
||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||
|
0
pilot/model/llm_out/__init__.py
Normal file
0
pilot/model/llm_out/__init__.py
Normal file
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
@ -8,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
|
||||
@torch.inference_mode()
|
||||
def chatglm_generate_stream(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
):
|
||||
"""Generate text using chatglm model's chat api"""
|
||||
prompt = params["prompt"]
|
||||
@ -28,25 +29,34 @@ def chatglm_generate_stream(
|
||||
generate_kwargs["temperature"] = temperature
|
||||
|
||||
# TODO, Fix this
|
||||
hist = []
|
||||
|
||||
messages = prompt.split(stop)
|
||||
#
|
||||
# # Add history conversation
|
||||
hist = []
|
||||
once_conversation = []
|
||||
for message in messages:
|
||||
if len(message) <= 0:
|
||||
continue
|
||||
|
||||
# Add history chat to hist for model.
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
hist.append(
|
||||
(
|
||||
messages[i].split(ROLE_USER + ":")[1],
|
||||
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
|
||||
)
|
||||
)
|
||||
if "human:" in message:
|
||||
once_conversation.append(message.split("human:")[1])
|
||||
# elif "system:" in message:
|
||||
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
|
||||
elif "ai:" in message:
|
||||
once_conversation.append(message.split("ai:")[1])
|
||||
last_conversation = copy.deepcopy(once_conversation)
|
||||
hist.append(last_conversation)
|
||||
once_conversation = []
|
||||
# else:
|
||||
# once_conversation.append(f"""###system:{message} """)
|
||||
|
||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
||||
query = messages[-1].split("human:")[1]
|
||||
print("Query Message: ", query)
|
||||
output = ""
|
||||
i = 0
|
||||
# output = ""
|
||||
# i = 0
|
||||
|
||||
for i, (response, new_hist) in enumerate(
|
||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||
):
|
||||
if echo:
|
||||
output = query + " " + response
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import copy
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
@ -8,19 +9,35 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
messages = params["prompt"]
|
||||
prompt = params["prompt"]
|
||||
messages = prompt.split(stop)
|
||||
#
|
||||
# # Add history conversation
|
||||
# hist = []
|
||||
# once_conversation = []
|
||||
# for message in messages[:-1]:
|
||||
# if len(message) <= 0:
|
||||
# continue
|
||||
#
|
||||
# if "human:" in message:
|
||||
# once_conversation.append(f"""###system:{message.split("human:")[1]} """ )
|
||||
# elif "system:" in message:
|
||||
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
|
||||
# elif "ai:" in message:
|
||||
# once_conversation.append(f"""###system:{message.split("ai:")[1]} """)
|
||||
# last_conversation = copy.deepcopy(once_conversation)
|
||||
# hist.append("".join(last_conversation))
|
||||
# once_conversation = []
|
||||
# else:
|
||||
# once_conversation.append(f"""###system:{message} """)
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# query = "".join(hist)
|
||||
|
||||
|
||||
hist = []
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
hist.append(
|
||||
(
|
||||
messages[i].split(ROLE_USER + ":")[1],
|
||||
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
|
||||
)
|
||||
)
|
||||
|
||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
@ -4,7 +4,7 @@
|
||||
from functools import cache
|
||||
from typing import List
|
||||
|
||||
from pilot.model.inference import generate_stream
|
||||
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
||||
|
||||
|
||||
class BaseChatAdpter:
|
||||
@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
return "chatglm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.chatglm_llm import chatglm_generate_stream
|
||||
from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream
|
||||
|
||||
return chatglm_generate_stream
|
||||
|
||||
@ -91,7 +91,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.guanaco_llm import guanaco_generate_output
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
|
||||
|
||||
return guanaco_generate_output
|
||||
|
||||
@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
return "proxyllm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.proxy_llm import proxyllm_generate_stream
|
||||
from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream
|
||||
|
||||
return proxyllm_generate_stream
|
||||
|
||||
|
@ -19,7 +19,7 @@ sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.model.inference import generate_output, generate_stream, get_embeddings
|
||||
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
|
||||
|
@ -6,7 +6,7 @@ from langchain.prompts import PromptTemplate
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
|
||||
from pilot.logs import logger
|
||||
from pilot.model.vicuna_llm import VicunaLLM
|
||||
from pilot.model.llm_out.vicuna_llm import VicunaLLM
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
|
||||
|
||||
|
@ -8,7 +8,7 @@ from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
from pilot.configs.model_config import DATASETS_DIR, VECTORE_PATH
|
||||
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
|
||||
from pilot.model.llm_out.vicuna_llm import VicunaEmbeddingLLM
|
||||
|
||||
embeddings = VicunaEmbeddingLLM()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user