mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 19:04:24 +00:00
model server fix message model
This commit is contained in:
parent
96c516ab55
commit
661a7b5697
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
||||||
from llama_index import (
|
from llama_index import (
|
||||||
Document,
|
Document,
|
||||||
GPTSimpleVectorIndex,
|
GPTSimpleVectorIndex,
|
||||||
@ -12,7 +11,7 @@ from llama_index import (
|
|||||||
ServiceContext,
|
ServiceContext,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.model.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
|
from pilot.model.llm_out.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
|
||||||
|
|
||||||
|
|
||||||
def agent_demo():
|
def agent_demo():
|
||||||
@ -49,7 +48,7 @@ def get_answer(q):
|
|||||||
|
|
||||||
|
|
||||||
def get_similar(q):
|
def get_similar(q):
|
||||||
from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
|
from pilot.vector_store.extract_tovec import knownledge_tovec_st
|
||||||
|
|
||||||
docsearch = knownledge_tovec_st("./datasets/plan.md")
|
docsearch = knownledge_tovec_st("./datasets/plan.md")
|
||||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||||
|
0
pilot/model/llm_out/__init__.py
Normal file
0
pilot/model/llm_out/__init__.py
Normal file
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -8,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def chatglm_generate_stream(
|
def chatglm_generate_stream(
|
||||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||||
):
|
):
|
||||||
"""Generate text using chatglm model's chat api"""
|
"""Generate text using chatglm model's chat api"""
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
@ -28,25 +29,34 @@ def chatglm_generate_stream(
|
|||||||
generate_kwargs["temperature"] = temperature
|
generate_kwargs["temperature"] = temperature
|
||||||
|
|
||||||
# TODO, Fix this
|
# TODO, Fix this
|
||||||
hist = []
|
|
||||||
|
|
||||||
messages = prompt.split(stop)
|
messages = prompt.split(stop)
|
||||||
|
#
|
||||||
|
# # Add history conversation
|
||||||
|
hist = []
|
||||||
|
once_conversation = []
|
||||||
|
for message in messages:
|
||||||
|
if len(message) <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
# Add history chat to hist for model.
|
if "human:" in message:
|
||||||
for i in range(1, len(messages) - 2, 2):
|
once_conversation.append(message.split("human:")[1])
|
||||||
hist.append(
|
# elif "system:" in message:
|
||||||
(
|
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
|
||||||
messages[i].split(ROLE_USER + ":")[1],
|
elif "ai:" in message:
|
||||||
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
|
once_conversation.append(message.split("ai:")[1])
|
||||||
)
|
last_conversation = copy.deepcopy(once_conversation)
|
||||||
)
|
hist.append(last_conversation)
|
||||||
|
once_conversation = []
|
||||||
|
# else:
|
||||||
|
# once_conversation.append(f"""###system:{message} """)
|
||||||
|
|
||||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
query = messages[-1].split("human:")[1]
|
||||||
print("Query Message: ", query)
|
print("Query Message: ", query)
|
||||||
output = ""
|
# output = ""
|
||||||
i = 0
|
# i = 0
|
||||||
|
|
||||||
for i, (response, new_hist) in enumerate(
|
for i, (response, new_hist) in enumerate(
|
||||||
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
model.stream_chat(tokenizer, query, hist, **generate_kwargs)
|
||||||
):
|
):
|
||||||
if echo:
|
if echo:
|
||||||
output = query + " " + response
|
output = query + " " + response
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import copy
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||||
@ -8,19 +9,35 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
|
|
||||||
print(params)
|
print(params)
|
||||||
stop = params.get("stop", "###")
|
stop = params.get("stop", "###")
|
||||||
messages = params["prompt"]
|
prompt = params["prompt"]
|
||||||
|
messages = prompt.split(stop)
|
||||||
|
#
|
||||||
|
# # Add history conversation
|
||||||
|
# hist = []
|
||||||
|
# once_conversation = []
|
||||||
|
# for message in messages[:-1]:
|
||||||
|
# if len(message) <= 0:
|
||||||
|
# continue
|
||||||
|
#
|
||||||
|
# if "human:" in message:
|
||||||
|
# once_conversation.append(f"""###system:{message.split("human:")[1]} """ )
|
||||||
|
# elif "system:" in message:
|
||||||
|
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
|
||||||
|
# elif "ai:" in message:
|
||||||
|
# once_conversation.append(f"""###system:{message.split("ai:")[1]} """)
|
||||||
|
# last_conversation = copy.deepcopy(once_conversation)
|
||||||
|
# hist.append("".join(last_conversation))
|
||||||
|
# once_conversation = []
|
||||||
|
# else:
|
||||||
|
# once_conversation.append(f"""###system:{message} """)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# query = "".join(hist)
|
||||||
|
|
||||||
|
query = prompt
|
||||||
hist = []
|
|
||||||
for i in range(1, len(messages) - 2, 2):
|
|
||||||
hist.append(
|
|
||||||
(
|
|
||||||
messages[i].split(ROLE_USER + ":")[1],
|
|
||||||
messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
|
||||||
print("Query Message: ", query)
|
print("Query Message: ", query)
|
||||||
|
|
||||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
@ -4,7 +4,7 @@
|
|||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.model.inference import generate_stream
|
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
||||||
|
|
||||||
|
|
||||||
class BaseChatAdpter:
|
class BaseChatAdpter:
|
||||||
@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
|||||||
return "chatglm" in model_path
|
return "chatglm" in model_path
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
from pilot.model.chatglm_llm import chatglm_generate_stream
|
from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream
|
||||||
|
|
||||||
return chatglm_generate_stream
|
return chatglm_generate_stream
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
|||||||
return "guanaco" in model_path
|
return "guanaco" in model_path
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
from pilot.model.guanaco_llm import guanaco_generate_output
|
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
|
||||||
|
|
||||||
return guanaco_generate_output
|
return guanaco_generate_output
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
|||||||
return "proxyllm" in model_path
|
return "proxyllm" in model_path
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
from pilot.model.proxy_llm import proxyllm_generate_stream
|
from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream
|
||||||
|
|
||||||
return proxyllm_generate_stream
|
return proxyllm_generate_stream
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ sys.path.append(ROOT_PATH)
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
from pilot.model.inference import generate_output, generate_stream, get_embeddings
|
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||||
from pilot.model.loader import ModelLoader
|
from pilot.model.loader import ModelLoader
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from langchain.prompts import PromptTemplate
|
|||||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||||
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
|
from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.model.vicuna_llm import VicunaLLM
|
from pilot.model.llm_out.vicuna_llm import VicunaLLM
|
||||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from langchain.text_splitter import CharacterTextSplitter
|
|||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
|
|
||||||
from pilot.configs.model_config import DATASETS_DIR, VECTORE_PATH
|
from pilot.configs.model_config import DATASETS_DIR, VECTORE_PATH
|
||||||
from pilot.model.vicuna_llm import VicunaEmbeddingLLM
|
from pilot.model.llm_out.vicuna_llm import VicunaEmbeddingLLM
|
||||||
|
|
||||||
embeddings = VicunaEmbeddingLLM()
|
embeddings = VicunaEmbeddingLLM()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user