mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-17 07:47:25 +00:00
update:format
This commit is contained in:
parent
be1a792d3c
commit
f2f28fee42
@ -0,0 +1 @@
|
||||
LlamaIndex是一个数据框架,旨在帮助您构建LLM应用程序。它包括一个向量存储索引和一个简单的目录阅读器,可以帮助您处理和操作数据。此外,LlamaIndex还提供了一个GPT Index,可以用于数据增强和生成更好的LM模型。
|
@ -82,7 +82,7 @@ class ChatGLMAdapater(BaseLLMAdaper):
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
class GuanacoAdapter(BaseLLMAdaper):
|
||||
"""TODO Support guanaco"""
|
||||
|
||||
@ -97,7 +97,6 @@ class GuanacoAdapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
class GuanacoAdapter(BaseLLMAdaper):
|
||||
"""TODO Support guanaco"""
|
||||
|
||||
|
@ -3,7 +3,6 @@ from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
|
||||
|
||||
|
||||
def guanaco_stream_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
tokenizer.bos_token_id = 1
|
||||
@ -19,7 +18,7 @@ def guanaco_stream_generate_output(model, tokenizer, params, device, context_len
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
@ -52,4 +51,4 @@ def guanaco_stream_generate_output(model, tokenizer, params, device, context_len
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield new_text
|
||||
return out
|
||||
return out
|
||||
|
@ -28,7 +28,9 @@ _DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users w
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
@ -29,7 +29,9 @@ _DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users w
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
@ -28,7 +28,9 @@ _DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users w
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
@ -59,6 +59,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
|
||||
return chatglm_generate_stream
|
||||
|
||||
|
||||
class GuanacoChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for Guanaco"""
|
||||
|
||||
@ -66,10 +67,13 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.llm_out.guanaco_stream_llm import guanaco_stream_generate_output
|
||||
from pilot.model.llm_out.guanaco_stream_llm import (
|
||||
guanaco_stream_generate_output,
|
||||
)
|
||||
|
||||
return guanaco_generate_output
|
||||
|
||||
|
||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat adapter for CodeT5"""
|
||||
|
@ -15,12 +15,12 @@ class EncodeTextLoader(BaseLoader):
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
with open(self.file_path, 'rb') as f:
|
||||
with open(self.file_path, "rb") as f:
|
||||
raw_text = f.read()
|
||||
result = chardet.detect(raw_text)
|
||||
if result['encoding'] is None:
|
||||
text = raw_text.decode('utf-8')
|
||||
if result["encoding"] is None:
|
||||
text = raw_text.decode("utf-8")
|
||||
else:
|
||||
text = raw_text.decode(result['encoding'])
|
||||
text = raw_text.decode(result["encoding"])
|
||||
metadata = {"source": self.file_path}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
|
@ -20,13 +20,14 @@ CFG = Config()
|
||||
|
||||
KnowledgeEmbeddingType = {
|
||||
".txt": (MarkdownEmbedding, {}),
|
||||
".md": (MarkdownEmbedding,{}),
|
||||
".md": (MarkdownEmbedding, {}),
|
||||
".pdf": (PDFEmbedding, {}),
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
@ -34,7 +35,6 @@ class KnowledgeEmbedding:
|
||||
vector_store_config,
|
||||
file_type: Optional[str] = "default",
|
||||
file_path: Optional[str] = None,
|
||||
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
@ -62,13 +62,20 @@ class KnowledgeEmbedding:
|
||||
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||
if extension in KnowledgeEmbeddingType:
|
||||
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||
embedding = knowledge_class(self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config, **knowledge_args)
|
||||
embedding = knowledge_class(
|
||||
self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
**knowledge_args,
|
||||
)
|
||||
return embedding
|
||||
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||
return embedding
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config)
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
return vector_client.similar_search(text, topk)
|
||||
|
||||
def vector_exist(self):
|
||||
|
@ -20,6 +20,7 @@ class PDFEmbedding(SourceEmbedding):
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.encoding = encoding
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from pdf path."""
|
||||
|
@ -40,7 +40,6 @@ class LocalKnowledgeInit:
|
||||
client.source_embedding()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vector_name", type=str, default="default")
|
||||
|
Loading…
Reference in New Issue
Block a user