diff --git a/.gitignore b/.gitignore index 7139a8d9d..5e74f2378 100644 --- a/.gitignore +++ b/.gitignore @@ -131,4 +131,8 @@ dmypy.json .pyre/ .DS_Store logs +<<<<<<< HEAD +======= + +>>>>>>> 2eeb22ccae5b2a5d695ac267044129f3b672c3dc .vectordb \ No newline at end of file diff --git a/README.md b/README.md index e40174b68..e43d956c0 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,10 @@ # DB-GPT -A Open Database-GPT Experiment +A Open Database-GPT Experiment, A fully localized project. ![GitHub Repo stars](https://img.shields.io/github/stars/csunny/db-gpt?style=social) +一个数据库相关的GPT实验项目, 模型与数据全部本地化部署, 绝对保障数据的隐私安全。 同时此GPT项目可以直接本地部署连接到私有数据库, 进行私有数据处理。 + [DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用[vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)作为基础模型。此外,此程序结合了[langchain](https://github.com/hwchase17/langchain)和[llama-index](https://github.com/jerryjliu/llama_index)基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 @@ -30,6 +32,26 @@ Run on an RTX 4090 GPU (The origin mov not sped up!, [YouTube地址](https://www +# Dependencies +1. First you need to install python requirements. +``` +python>=3.9 +pip install -r requirements +``` +or if you use conda envirenment, you can use this command +``` +cd DB-GPT +conda env create -f environment.yml +``` + +2. MySQL Install + +In this project examples, we connect mysql and run SQL-Generate. so you need install mysql local for test. recommand docker +``` +docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest +``` +The password just for test, you can change this if necessary + # Install 1. 基础模型下载 关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。 @@ -52,3 +74,7 @@ python webserver.py - SQL-diagnosis 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。 + + +# Licence +[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE) \ No newline at end of file diff --git a/asserts/exeable.png b/asserts/exeable.png new file mode 100644 index 000000000..47ee94f7d Binary files /dev/null and b/asserts/exeable.png differ diff --git a/environment.yml b/environment.yml index 3ba070e56..ea7415df0 100644 --- a/environment.yml +++ b/environment.yml @@ -61,3 +61,7 @@ dependencies: - gradio-client==0.0.8 - wandb - fschat=0.1.10 + - llama-index=0.5.27 + - pymysql + - unstructured==0.6.3 + - pytesseract==0.3.10 diff --git a/pilot/chain/audio.py b/pilot/chain/audio.py new file mode 100644 index 000000000..8b197119c --- /dev/null +++ b/pilot/chain/audio.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- \ No newline at end of file diff --git a/pilot/chain/visual.py b/pilot/chain/visual.py new file mode 100644 index 000000000..1f776fc63 --- /dev/null +++ b/pilot/chain/visual.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- \ No newline at end of file diff --git a/pilot/server/chatbot.py b/pilot/client/auto.py similarity index 50% rename from pilot/server/chatbot.py rename to pilot/client/auto.py index 97206f2d5..78477809e 100644 --- a/pilot/server/chatbot.py +++ b/pilot/client/auto.py @@ -1,3 +1,3 @@ #!/usr/bin/env python3 -# -*- coding:utf-8 -*- +# -*- coding: utf-8 -*- diff --git a/pilot/client/chart.py b/pilot/client/chart.py new file mode 100644 index 000000000..6988cfe11 --- /dev/null +++ b/pilot/client/chart.py @@ -0,0 +1,2 @@ +#/usr/bin/env python3 +# -*- coding: utf-8 -*- \ No newline at end of file diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index ca25b3224..a5c27d9d2 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -11,6 +11,7 @@ PILOT_PATH = os.path.join(ROOT_PATH, "pilot") VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store") LOGDIR = os.path.join(ROOT_PATH, "logs") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") +DATA_DIR = os.path.join(PILOT_PATH, "data") nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path @@ -22,6 +23,7 @@ LLM_MODEL_CONFIG = { } +VECTOR_SEARCH_TOP_K = 3 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 2048 diff --git a/pilot/conversation.py b/pilot/conversation.py index e88ceaccb..2dc8df2b9 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -146,8 +146,24 @@ conv_vicuna_v1 = Conversation( sep2="", ) + +conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。 + 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议: + + 已知内容: + {context} + 问题: + {question} +""" + default_conversation = conv_one_shot +conversation_types = { + "native": "LLM原生对话", + "default_knownledge": "默认知识库对话", + "custome": "新增知识库对话", +} + conv_templates = { "conv_one_shot": conv_one_shot, "vicuna_v1": conv_vicuna_v1, diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index f17a17a00..2337a3bbf 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List from langchain.llms.base import LLM from pilot.configs.model_config import * -class VicunaRequestLLM(LLM): +class VicunaLLM(LLM): + + vicuna_generate_path = "generate_stream" + def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str: - vicuna_generate_path = "generate" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - if isinstance(stop, list): - stop = stop + ["Observation:"] - - skip_echo_len = len(prompt.replace("", " ")) + 1 params = { "prompt": prompt, - "temperature": 0.7, - "max_new_tokens": 1024, + "temperature": temperature, + "max_new_tokens": max_new_tokens, "stop": stop } response = requests.post( url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), data=json.dumps(params), ) - response.raise_for_status() - # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - # if chunk: - # data = json.loads(chunk.decode()) - # if data["error_code"] == 0: - # output = data["text"][skip_echo_len:].strip() - # output = self.post_process_code(output) - # yield output - return response.json()["response"] + + skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][skip_echo_len:].strip() + yield output @property def _llm_type(self) -> str: diff --git a/pilot/pturning/lora/finetune.py b/pilot/pturning/lora/finetune.py new file mode 100644 index 000000000..6cd9935ed --- /dev/null +++ b/pilot/pturning/lora/finetune.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import json +import transformers +from transformers import LlamaTokenizer, LlamaForCausalLM + +from typing import List +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + prepare_model_for_int8_training, +) + +import torch +from datasets import load_dataset +import pandas as pd + + +from pilot.configs.model_config import DATA_DIR, LLM_MODEL, LLM_MODEL_CONFIG +device = "cuda" if torch.cuda.is_available() else "cpu" +CUTOFF_LEN = 50 + +df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv")) + +def sentiment_score_to_name(score: float): + if score > 0: + return "Positive" + elif score < 0: + return "Negative" + return "Neutral" + + +dataset_data = [ + { + "instruction": "Detect the sentiment of the tweet.", + "input": row_dict["Tweet"], + "output": sentiment_score_to_name(row_dict["New_Sentiment_State"]) + } + for row_dict in df.to_dict(orient="records") +] + +with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w") as f: + json.dump(dataset_data, f) + + +data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json")) +print(data["train"]) + +BASE_MODEL = LLM_MODEL_CONFIG[LLM_MODEL] +model = LlamaForCausalLM.from_pretrained( + BASE_MODEL, + torch_dtype=torch.float16, + device_map="auto", + offload_folder=os.path.join(DATA_DIR, "vicuna-lora") +) + +tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) +tokenizer.pad_token_id = (0) +tokenizer.padding_side = "left" + +def generate_prompt(data_point): + return f"""Blow is an instruction that describes a task, paired with an input that provide future context. + Write a response that appropriately completes the request. #noqa: + + ### Instruct: + {data_point["instruction"]} + ### Input + {data_point["input"]} + ### Response + {data_point["output"]} + """ + +def tokenize(prompt, add_eos_token=True): + result = tokenizer( + prompt, + truncation=True, + max_length=CUTOFF_LEN, + padding=False, + return_tensors=None, + ) + + if (result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < CUTOFF_LEN and add_eos_token): + result["input_ids"].append(tokenizer.eos_token_id) + result["attention_mask"].append(1) + + result["labels"] = result["input_ids"].copy() + return result + +def generate_and_tokenize_prompt(data_point): + full_prompt = generate_prompt(data_point) + tokenized_full_prompt = tokenize(full_prompt) + return tokenized_full_prompt + + +train_val = data["train"].train_test_split( + test_size=200, shuffle=True, seed=42 +) + +train_data = ( + train_val["train"].map(generate_and_tokenize_prompt) +) + +val_data = ( + train_val["test"].map(generate_and_tokenize_prompt) +) + +# Training +LORA_R = 8 +LORA_ALPHA = 16 +LORA_DROPOUT = 0.05 +LORA_TARGET_MODULES = [ + "q_proj", + "v_proj", +] + +BATCH_SIZE = 128 +MICRO_BATCH_SIZE = 4 +GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE +LEARNING_RATE = 3e-4 +TRAIN_STEPS = 300 +OUTPUT_DIR = "experiments" + +# We can now prepare model for training +model = prepare_model_for_int8_training(model) +config = LoraConfig( + r = LORA_R, + lora_alpha=LORA_ALPHA, + target_modules=LORA_TARGET_MODULES, + lora_dropout=LORA_DROPOUT, + bias="none", + task_type="CAUSAL_LM", +) + +model = get_peft_model(model, config) +model.print_trainable_parameters() + +training_arguments = transformers.TrainingArguments( + per_device_train_batch_size=MICRO_BATCH_SIZE, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + warmup_steps=100, + max_steps=TRAIN_STEPS, + no_cuda=True, + learning_rate=LEARNING_RATE, + logging_steps=10, + optim="adamw_torch", + evaluation_strategy="steps", + save_strategy="steps", + eval_steps=50, + save_steps=50, + output_dir=OUTPUT_DIR, + save_total_limit=3, + load_best_model_at_end=True, + report_to="tensorboard" +) + +data_collector = transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True +) + +trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + eval_dataset=val_data, + args=training_arguments, + data_collector=data_collector +) + +model.config.use_cache = False +old_state_dict = model.state_dict +model.state_dict = ( + lambda self, *_, **__: get_peft_model_state_dict( + self, old_state_dict() + ) +).__get__(model, type(model)) + +trainer.train() +model.save_pretrained(OUTPUT_DIR) diff --git a/pilot/server/embdserver.py b/pilot/server/embdserver.py index 5e0ad9294..6599a18ad 100644 --- a/pilot/server/embdserver.py +++ b/pilot/server/embdserver.py @@ -4,29 +4,44 @@ import requests import json import time +import uuid from urllib.parse import urljoin import gradio as gr from pilot.configs.model_config import * -vicuna_base_uri = "http://192.168.31.114:21002/" -vicuna_stream_path = "worker_generate_stream" -vicuna_status_path = "worker_get_status" +from pilot.conversation import conv_qa_prompt_template, conv_templates +from langchain.prompts import PromptTemplate -def generate(prompt): +vicuna_stream_path = "generate_stream" + +def generate(query): + + template_name = "conv_one_shot" + state = conv_templates[template_name].copy() + + pt = PromptTemplate( + template=conv_qa_prompt_template, + input_variables=["context", "question"] + ) + + result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.", + question=query) + + print(result) + + state.append_message(state.roles[0], result) + state.append_message(state.roles[1], None) + + prompt = state.get_prompt() params = { "model": "vicuna-13b", "prompt": prompt, "temperature": 0.7, - "max_new_tokens": 512, + "max_new_tokens": 1024, "stop": "###" } - sts_response = requests.post( - url=urljoin(vicuna_base_uri, vicuna_status_path) - ) - print(sts_response.text) - response = requests.post( - url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params) + url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params) ) skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 @@ -34,11 +49,10 @@ def generate(prompt): if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: - output = data["text"] + output = data["text"][skip_echo_len:].strip() + state.messages[-1][-1] = output + "▌" yield(output) - - time.sleep(0.02) - + if __name__ == "__main__": print(LLM_MODEL) with gr.Blocks() as demo: diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py new file mode 100644 index 000000000..71a9b881d --- /dev/null +++ b/pilot/server/vectordb_qa.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from pilot.vector_store.file_loader import KnownLedge2Vector +from langchain.prompts import PromptTemplate +from pilot.conversation import conv_qa_prompt_template +from pilot.configs.model_config import VECTOR_SEARCH_TOP_K +from pilot.model.vicuna_llm import VicunaLLM + +class KnownLedgeBaseQA: + + def __init__(self) -> None: + k2v = KnownLedge2Vector() + self.vector_store = k2v.init_vector_store() + self.llm = VicunaLLM() + + def get_similar_answer(self, query): + + prompt = PromptTemplate( + template=conv_qa_prompt_template, + input_variables=["context", "question"] + ) + + retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}) + docs = retriever.get_relevant_documents(query=query) + + context = [d.page_content for d in docs] + result = prompt.format(context="\n".join(context), question=query) + return result diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index e89d205fa..b31bfde7f 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -11,6 +11,7 @@ import datetime import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS +from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql_conn import MySQLOperator from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st @@ -19,6 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D from pilot.conversation import ( default_conversation, conv_templates, + conversation_types, SeparatorStyle ) @@ -149,7 +151,7 @@ def post_process_code(code): code = sep.join(blocks) return code -def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request): +def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request): start_tstamp = time.time() model_name = LLM_MODEL @@ -170,7 +172,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques query = state.messages[-2][1] - # prompt 中添加上下文提示 + # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? + # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 if db_selector: new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) new_state.append_message(new_state.roles[1], None) @@ -180,13 +183,11 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques new_state.append_message(new_state.roles[1], None) state = new_state - # try: - # if not db_selector: - # sim_q = get_simlar(query) - # print("********vector similar info*************: ", sim_q) - # state.append_message(new_state.roles[0], sim_q + query) - # except Exception as e: - # print(e) + if mode == conversation_types["default_knownledge"] and not db_selector: + query = state.messages[-2][1] + knqa = KnownLedgeBaseQA() + state.messages[-2][1] = knqa.get_similar_answer(query) + prompt = state.get_prompt() @@ -222,7 +223,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return - time.sleep(0.02) + except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) @@ -231,6 +232,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + # 记录运行日志 finish_tstamp = time.time() logger.info(f"{output}") @@ -266,7 +268,7 @@ def change_tab(tab): pass def change_mode(mode): - if mode == "默认知识库对话": + if mode in ["默认知识库对话", "LLM原生对话"]: return gr.update(visible=False) else: return gr.update(visible=True) @@ -281,7 +283,7 @@ def build_single_model_ui(): """ learn_more_markdown = """ ### Licence - The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA + The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B """ state = gr.State() @@ -318,7 +320,8 @@ def build_single_model_ui(): show_label=True).style(container=False) with gr.TabItem("知识问答", elem_id="QA"): - mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话") + + mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话") vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) with vs_setting: @@ -363,7 +366,7 @@ def build_single_model_ui(): btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, db_selector, temperature, max_output_tokens], + [state, mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -372,7 +375,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, db_selector, temperature, max_output_tokens], + [state, mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -380,7 +383,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, db_selector, temperature, max_output_tokens], + [state, mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list ) diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py index 269ec05f1..881c8106f 100644 --- a/pilot/vector_store/file_loader.py +++ b/pilot/vector_store/file_loader.py @@ -10,9 +10,8 @@ from langchain.text_splitter import CharacterTextSplitter from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader from langchain.chains import VectorDBQA from langchain.embeddings import HuggingFaceEmbeddings -from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG +from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K -VECTOR_SEARCH_TOP_K = 5 class KnownLedge2Vector: @@ -26,21 +25,21 @@ class KnownLedge2Vector: self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) def init_vector_store(self): - documents = self.load_knownlege() persist_dir = os.path.join(VECTORE_PATH, ".vectordb") print("向量数据库持久化地址: ", persist_dir) if os.path.exists(persist_dir): # 从本地持久化文件中Load + print("从本地向量加载数据...") vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) - vector_store.add_documents(documents=documents) + # vector_store.add_documents(documents=documents) else: + documents = self.load_knownlege() # 重新初始化 vector_store = Chroma.from_documents(documents=documents, embedding=self.embeddings, persist_directory=persist_dir) vector_store.persist() - vector_store = None - return persist_dir + return vector_store def load_knownlege(self): docments = [] @@ -70,10 +69,23 @@ class KnownLedge2Vector: return docs def _load_from_url(self, url): + """Load data from url address""" pass + + def query(self, q): + """Query similar doc from Vector """ + vector_store = self.init_vector_store() + docs = vector_store.similarity_search_with_score(q, k=self.top_k) + for doc in docs: + dc, s = doc + yield s, dc if __name__ == "__main__": k2v = KnownLedge2Vector() - k2v.init_vector_store() + + persist_dir = os.path.join(VECTORE_PATH, ".vectordb") + print(persist_dir) + for s, dc in k2v.query("什么是OceanBase"): + print(s, dc.page_content, dc.metadata) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f0ddf8fb5..0582f5a41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,4 +52,6 @@ gradio-client==0.0.8 wandb fschat=0.1.10 llama-index=0.5.27 -pymysql \ No newline at end of file +pymysql +unstructured==0.6.3 +pytesseract==0.3.10 \ No newline at end of file