refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -1,13 +1,13 @@
from functools import reduce
from typing import List, Optional
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
class DBSchemaRetriever(BaseRetriever):
@@ -29,50 +29,59 @@ class DBSchemaRetriever(BaseRetriever):
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
code example:
.. code-block:: python
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
>>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
def _create_temporary_connection():
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
Examples:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
def _create_temporary_connection():
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
}
)
return connect
connection = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn,
)
return connect
connection = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(
model_name=embedding_model_path
)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn
)
# get db struct retriever
retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector)
chunks = retriever.retrieve("show columns from table")
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3, vector_store_connector=vector_connector
)
chunks = retriever.retrieve("show columns from table")
print(
f"db struct rag example results:{[chunk.content for chunk in chunks]}"
)
"""
self._top_k = top_k