mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-04 08:35:40 +00:00
feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -17,24 +17,21 @@ from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.retriever.rerank import CrossEncoderRanker
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
print(f"persist_path:{os.path.join(PILOT_PATH, 'data')}")
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="example_cross_encoder_rerank",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="embedding_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
@@ -45,7 +42,7 @@ async def main():
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
@@ -57,7 +54,7 @@ async def main():
|
||||
print("before rerank results:\n")
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}")
|
||||
# cross-encoder rerank
|
||||
# cross-encoder rerankpython
|
||||
cross_encoder_model = os.path.join(MODEL_PATH, "bge-reranker-base")
|
||||
rerank = CrossEncoderRanker(topk=3, model=cross_encoder_model)
|
||||
new_chunks = rerank.rank(chunks, query=query)
|
||||
|
Reference in New Issue
Block a user