mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -17,24 +17,21 @@ from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.retriever.rerank import CrossEncoderRanker
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
print(f"persist_path:{os.path.join(PILOT_PATH, 'data')}")
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="example_cross_encoder_rerank",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="embedding_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
@@ -45,7 +42,7 @@ async def main():
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
@@ -57,7 +54,7 @@ async def main():
|
||||
print("before rerank results:\n")
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}")
|
||||
# cross-encoder rerank
|
||||
# cross-encoder rerankpython
|
||||
cross_encoder_model = os.path.join(MODEL_PATH, "bge-reranker-base")
|
||||
rerank = CrossEncoderRanker(topk=3, model=cross_encoder_model)
|
||||
new_chunks = rerank.rank(chunks, query=query)
|
||||
|
@@ -4,8 +4,7 @@ from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler import DBSchemaAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
"""DB struct rag example.
|
||||
pre-requirements:
|
||||
@@ -46,27 +45,27 @@ def _create_temporary_connection():
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="db_schema_vector_store_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="dbschema_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connection = _create_temporary_connection()
|
||||
vector_connector = _create_vector_connector()
|
||||
index_store = _create_vector_connector()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=connection,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=index_store,
|
||||
)
|
||||
assembler.persist()
|
||||
# get db schema retriever
|
||||
retriever = assembler.as_retriever(top_k=1)
|
||||
chunks = retriever.retrieve("show columns from user")
|
||||
print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")
|
||||
index_store.delete_vector_name("dbschema_rag_test")
|
||||
|
@@ -6,8 +6,7 @@ from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
"""Embedding rag example.
|
||||
pre-requirements:
|
||||
@@ -24,28 +23,27 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="db_schema_vector_store_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="embedding_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
vector_connector = _create_vector_connector()
|
||||
vector_store = _create_vector_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_store,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
|
@@ -6,9 +6,11 @@ from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.storage.knowledge_graph.knowledge_graph import BuiltinKnowledgeGraphConfig
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.rag.retriever import RetrieverStrategy
|
||||
from dbgpt.storage.knowledge_graph.knowledge_graph import (
|
||||
BuiltinKnowledgeGraph,
|
||||
BuiltinKnowledgeGraphConfig,
|
||||
)
|
||||
|
||||
"""GraphRAG example.
|
||||
pre-requirements:
|
||||
@@ -31,9 +33,8 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
def _create_kg_connector():
|
||||
"""Create knowledge graph connector."""
|
||||
return VectorStoreConnector(
|
||||
vector_store_type="KnowledgeGraph",
|
||||
vector_store_config=VectorStoreConfig(
|
||||
return BuiltinKnowledgeGraph(
|
||||
config=BuiltinKnowledgeGraphConfig(
|
||||
name="graph_rag_test",
|
||||
embedding_fn=None,
|
||||
llm_client=OpenAILLMClient(),
|
||||
@@ -45,22 +46,23 @@ def _create_kg_connector():
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "examples/test_files/tranformers_story.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
vector_connector = _create_kg_connector()
|
||||
graph_store = _create_kg_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
assembler = await EmbeddingAssembler.aload_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=graph_store,
|
||||
retrieve_strategy=RetrieverStrategy.GRAPH,
|
||||
)
|
||||
assembler.persist()
|
||||
await assembler.apersist()
|
||||
# get embeddings retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
chunks = await retriever.aretrieve_with_scores(
|
||||
"What actions has Megatron taken ?", score_threshold=0.3
|
||||
)
|
||||
print(f"embedding rag example results:{chunks}")
|
||||
vector_connector.delete_vector_name("graph_rag_test")
|
||||
graph_store.delete_vector_name("graph_rag_test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
55
examples/rag/keyword_rag_example.py
Normal file
55
examples/rag/keyword_rag_example.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.storage.full_text.elasticsearch import (
|
||||
ElasticDocumentConfig,
|
||||
ElasticDocumentStore,
|
||||
)
|
||||
|
||||
"""Keyword rag example.
|
||||
pre-requirements:
|
||||
set your Elasticsearch environment.
|
||||
|
||||
Examples:
|
||||
..code-block:: shell
|
||||
python examples/rag/keyword_rag_example.py
|
||||
"""
|
||||
|
||||
|
||||
def _create_es_connector():
|
||||
"""Create es connector."""
|
||||
config = ElasticDocumentConfig(
|
||||
name="keyword_rag_test",
|
||||
uri="localhost",
|
||||
port="9200",
|
||||
user="elastic",
|
||||
password="dbgpt",
|
||||
)
|
||||
|
||||
return ElasticDocumentStore(config)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
keyword_store = _create_es_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
index_store=keyword_store,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
|
||||
print(f"keyword rag example results:{chunks}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@@ -14,35 +14,33 @@ from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="example_metadata_filter_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="metadata_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
vector_connector = _create_vector_connector()
|
||||
vector_store = _create_vector_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_store,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
@@ -54,6 +52,7 @@ async def main():
|
||||
"what is awel talk about", 0.0, filters
|
||||
)
|
||||
print(f"embedding rag example results:{chunks}")
|
||||
vector_store.delete_vector_name("metadata_rag_test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -31,8 +31,7 @@ from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
|
||||
def _create_embeddings(
|
||||
@@ -54,33 +53,32 @@ def _create_embeddings(
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="example_embedding_api_vector_store_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="embedding_api_rag_test",
|
||||
embedding_fn=_create_embeddings(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
vector_connector = _create_vector_connector()
|
||||
vector_store = _create_vector_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_store,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
|
||||
print(f"embedding rag example results:{chunks}")
|
||||
vector_store.delete_vector_name("embedding_api_rag_test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -15,8 +15,7 @@ from dbgpt.rag.evaluation.retriever import (
|
||||
)
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.operators import EmbeddingRetrieverOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
|
||||
def _create_embeddings(
|
||||
@@ -28,19 +27,16 @@ def _create_embeddings(
|
||||
).create()
|
||||
|
||||
|
||||
def _create_vector_connector(
|
||||
embeddings: Embeddings, space_name: str = "retriever_evaluation_example"
|
||||
) -> VectorStoreConnector:
|
||||
def _create_vector_connector(embeddings: Embeddings):
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name=space_name,
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="embedding_rag_test",
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
@@ -52,7 +48,7 @@ async def main():
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
|
||||
|
@@ -11,7 +11,7 @@
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connector=_create_temporary_connection()
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector
|
||||
index_store=vector_store_connector
|
||||
)
|
||||
```
|
||||
|
||||
@@ -27,31 +27,29 @@ from typing import Dict, List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, InputOperator, JoinOperator, MapOperator
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators import DBSchemaAssemblerOperator, DBSchemaRetrieverOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="vector_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
name="vector_name",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
def _create_temporary_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
@@ -104,17 +102,17 @@ with DAG("simple_rag_db_schema_example") as dag:
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
query_operator = MapOperator(lambda request: request["query"])
|
||||
vector_store_connector = _create_vector_connector()
|
||||
index_store = _create_vector_connector()
|
||||
connector = _create_temporary_connection()
|
||||
assembler_task = DBSchemaAssemblerOperator(
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
index_store=index_store,
|
||||
)
|
||||
join_operator = JoinOperator(combine_function=_join_fn)
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connector=_create_temporary_connection(),
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector,
|
||||
index_store=index_store,
|
||||
)
|
||||
result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
|
||||
trigger >> assembler_task >> join_operator
|
||||
|
@@ -16,30 +16,28 @@ from typing import Dict, List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeType
|
||||
from dbgpt.rag.operators import EmbeddingAssemblerOperator, KnowledgeOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _create_vector_connector() -> VectorStoreConnector:
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="vector_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="embedding_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
@@ -75,10 +73,10 @@ with DAG("simple_sdk_rag_embedding_example") as dag:
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL.name)
|
||||
vector_connector = _create_vector_connector()
|
||||
vector_store = _create_vector_connector()
|
||||
url_parser_operator = MapOperator(map_function=lambda x: x["url"])
|
||||
embedding_operator = EmbeddingAssemblerOperator(
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_store,
|
||||
)
|
||||
output_task = ResultOperator()
|
||||
(
|
||||
|
@@ -31,7 +31,7 @@ from typing import Dict, List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
@@ -41,8 +41,7 @@ from dbgpt.rag.operators import (
|
||||
QueryRewriteOperator,
|
||||
RerankOperator,
|
||||
)
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@@ -78,21 +77,19 @@ def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="vector_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="embedding_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
with DAG("simple_sdk_rag_retriever_example") as dag:
|
||||
vector_connector = _create_vector_connector()
|
||||
vector_store = _create_vector_connector()
|
||||
trigger = HttpTrigger(
|
||||
"/examples/rag/retrieve", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
@@ -102,11 +99,11 @@ with DAG("simple_sdk_rag_retriever_example") as dag:
|
||||
rewrite_operator = QueryRewriteOperator(llm_client=OpenAILLMClient())
|
||||
retriever_context_operator = EmbeddingRetrieverOperator(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_store,
|
||||
)
|
||||
retriever_operator = EmbeddingRetrieverOperator(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_store,
|
||||
)
|
||||
rerank_operator = RerankOperator()
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
|
Reference in New Issue
Block a user