diff --git a/examples/app.py b/examples/app.py index 31b6e0f26..52086c20b 100644 --- a/examples/app.py +++ b/examples/app.py @@ -3,7 +3,6 @@ import gradio as gr from langchain.agents import AgentType, initialize_agent, load_tools -from langchain.embeddings.huggingface import HuggingFaceEmbeddings from llama_index import ( Document, GPTSimpleVectorIndex, @@ -12,7 +11,7 @@ from llama_index import ( ServiceContext, ) -from pilot.model.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM +from pilot.model.llm_out.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM def agent_demo(): @@ -49,7 +48,7 @@ def get_answer(q): def get_similar(q): - from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st + from pilot.vector_store.extract_tovec import knownledge_tovec_st docsearch = knownledge_tovec_st("./datasets/plan.md") docs = docsearch.similarity_search_with_score(q, k=1) diff --git a/pilot/model/llm_out/__init__.py b/pilot/model/llm_out/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/chatglm_llm.py b/pilot/model/llm_out/chatglm_llm.py similarity index 52% rename from pilot/model/chatglm_llm.py rename to pilot/model/llm_out/chatglm_llm.py index 4d72af072..b451e910c 100644 --- a/pilot/model/chatglm_llm.py +++ b/pilot/model/llm_out/chatglm_llm.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import copy import torch @@ -8,7 +9,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER @torch.inference_mode() def chatglm_generate_stream( - model, tokenizer, params, device, context_len=2048, stream_interval=2 + model, tokenizer, params, device, context_len=2048, stream_interval=2 ): """Generate text using chatglm model's chat api""" prompt = params["prompt"] @@ -28,25 +29,34 @@ def chatglm_generate_stream( generate_kwargs["temperature"] = temperature # TODO, Fix this - hist = [] - messages = prompt.split(stop) + # + # # Add history conversation + hist = [] + once_conversation = [] + for message in messages: + if len(message) <= 0: + continue - # Add history chat to hist for model. - for i in range(1, len(messages) - 2, 2): - hist.append( - ( - messages[i].split(ROLE_USER + ":")[1], - messages[i + 1].split(ROLE_ASSISTANT + ":")[1], - ) - ) + if "human:" in message: + once_conversation.append(message.split("human:")[1]) + # elif "system:" in message: + # once_conversation.append(f"""###system:{message.split("system:")[1]} """) + elif "ai:" in message: + once_conversation.append(message.split("ai:")[1]) + last_conversation = copy.deepcopy(once_conversation) + hist.append(last_conversation) + once_conversation = [] + # else: + # once_conversation.append(f"""###system:{message} """) - query = messages[-2].split(ROLE_USER + ":")[1] + query = messages[-1].split("human:")[1] print("Query Message: ", query) - output = "" - i = 0 + # output = "" + # i = 0 + for i, (response, new_hist) in enumerate( - model.stream_chat(tokenizer, query, hist, **generate_kwargs) + model.stream_chat(tokenizer, query, hist, **generate_kwargs) ): if echo: output = query + " " + response diff --git a/pilot/model/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py similarity index 64% rename from pilot/model/guanaco_llm.py rename to pilot/model/llm_out/guanaco_llm.py index 03f2d1687..6c209b565 100644 --- a/pilot/model/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -1,4 +1,5 @@ import torch +import copy from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from pilot.conversation import ROLE_ASSISTANT, ROLE_USER @@ -8,19 +9,35 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): print(params) stop = params.get("stop", "###") - messages = params["prompt"] + prompt = params["prompt"] + messages = prompt.split(stop) + # + # # Add history conversation + # hist = [] + # once_conversation = [] + # for message in messages[:-1]: + # if len(message) <= 0: + # continue + # + # if "human:" in message: + # once_conversation.append(f"""###system:{message.split("human:")[1]} """ ) + # elif "system:" in message: + # once_conversation.append(f"""###system:{message.split("system:")[1]} """) + # elif "ai:" in message: + # once_conversation.append(f"""###system:{message.split("ai:")[1]} """) + # last_conversation = copy.deepcopy(once_conversation) + # hist.append("".join(last_conversation)) + # once_conversation = [] + # else: + # once_conversation.append(f"""###system:{message} """) + # + # + # + # + # + # query = "".join(hist) - - hist = [] - for i in range(1, len(messages) - 2, 2): - hist.append( - ( - messages[i].split(ROLE_USER + ":")[1], - messages[i + 1].split(ROLE_ASSISTANT + ":")[1], - ) - ) - - query = messages[-2].split(ROLE_USER + ":")[1] + query = prompt print("Query Message: ", query) input_ids = tokenizer(query, return_tensors="pt").input_ids diff --git a/pilot/model/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py similarity index 100% rename from pilot/model/proxy_llm.py rename to pilot/model/llm_out/proxy_llm.py diff --git a/pilot/model/inference.py b/pilot/model/llm_out/vicuna_base_llm.py similarity index 100% rename from pilot/model/inference.py rename to pilot/model/llm_out/vicuna_base_llm.py diff --git a/pilot/model/vicuna_llm.py b/pilot/model/llm_out/vicuna_llm.py similarity index 100% rename from pilot/model/vicuna_llm.py rename to pilot/model/llm_out/vicuna_llm.py diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 39737112b..86901dea3 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -4,7 +4,7 @@ from functools import cache from typing import List -from pilot.model.inference import generate_stream +from pilot.model.llm_out.vicuna_base_llm import generate_stream class BaseChatAdpter: @@ -55,7 +55,7 @@ class ChatGLMChatAdapter(BaseChatAdpter): return "chatglm" in model_path def get_generate_stream_func(self): - from pilot.model.chatglm_llm import chatglm_generate_stream + from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream return chatglm_generate_stream @@ -91,7 +91,7 @@ class GuanacoChatAdapter(BaseChatAdpter): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.guanaco_llm import guanaco_generate_output + from pilot.model.llm_out.guanaco_llm import guanaco_generate_output return guanaco_generate_output @@ -101,7 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter): return "proxyllm" in model_path def get_generate_stream_func(self): - from pilot.model.proxy_llm import proxyllm_generate_stream + from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream return proxyllm_generate_stream diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 376e27852..d2730e0d5 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -19,7 +19,7 @@ sys.path.append(ROOT_PATH) from pilot.configs.config import Config from pilot.configs.model_config import * -from pilot.model.inference import generate_output, generate_stream, get_embeddings +from pilot.model.llm_out.vicuna_base_llm import get_embeddings from pilot.model.loader import ModelLoader from pilot.server.chat_adapter import get_llm_chat_adapter diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py index ff794322f..9faae5eb8 100644 --- a/pilot/server/vectordb_qa.py +++ b/pilot/server/vectordb_qa.py @@ -6,7 +6,7 @@ from langchain.prompts import PromptTemplate from pilot.configs.model_config import VECTOR_SEARCH_TOP_K from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates from pilot.logs import logger -from pilot.model.vicuna_llm import VicunaLLM +from pilot.model.llm_out.vicuna_llm import VicunaLLM from pilot.vector_store.file_loader import KnownLedge2Vector diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index 1032876cf..a79960477 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -8,7 +8,7 @@ from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from pilot.configs.model_config import DATASETS_DIR, VECTORE_PATH -from pilot.model.vicuna_llm import VicunaEmbeddingLLM +from pilot.model.llm_out.vicuna_llm import VicunaEmbeddingLLM embeddings = VicunaEmbeddingLLM()