feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Aries-ckt
2024-06-13 13:49:17 +08:00
committed by GitHub
parent 162e2c9b1c
commit 58d08780d6
86 changed files with 948 additions and 440 deletions

View File

@@ -17,24 +17,21 @@ from dbgpt.rag.assembler import EmbeddingAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.rag.retriever.rerank import CrossEncoderRanker
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
def _create_vector_connector():
"""Create vector connector."""
print(f"persist_path:{os.path.join(PILOT_PATH, 'data')}")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="example_cross_encoder_rerank",
persist_path=os.path.join(PILOT_PATH, "data"),
),
config = ChromaVectorConfig(
persist_path=PILOT_PATH,
name="embedding_rag_test",
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
return ChromaStore(config)
async def main():
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
@@ -45,7 +42,7 @@ async def main():
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_connector,
index_store=vector_connector,
)
assembler.persist()
# get embeddings retriever
@@ -57,7 +54,7 @@ async def main():
print("before rerank results:\n")
for i, chunk in enumerate(chunks):
print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}")
# cross-encoder rerank
# cross-encoder rerankpython
cross_encoder_model = os.path.join(MODEL_PATH, "bge-reranker-base")
rerank = CrossEncoderRanker(topk=3, model=cross_encoder_model)
new_chunks = rerank.rank(chunks, query=query)