refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -2,7 +2,7 @@ import json
import uuid
from typing import Any, Dict
from pydantic import Field, BaseModel
from pydantic import BaseModel, Field
class Document(BaseModel):

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional, List, Any
from typing import Any, List, Optional
from pydantic import BaseModel, Field

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Type
from dbgpt.component import BaseComponent
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import requests
from pydantic import Field, Extra, BaseModel
from pydantic import BaseModel, Extra, Field
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
@@ -54,12 +54,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
from .embeddings import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": False}
hf = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
encode_kwargs=encode_kwargs,
)
"""
@@ -142,12 +142,12 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
from langchain.embeddings import HuggingFaceInstructEmbeddings
model_name = "hkunlp/instructor-large"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
hf = HuggingFaceInstructEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
encode_kwargs=encode_kwargs,
)
"""
@@ -221,12 +221,12 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
from langchain.embeddings import HuggingFaceBgeEmbeddings
model_name = "BAAI/bge-large-en"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
hf = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
encode_kwargs=encode_kwargs,
)
"""
@@ -336,7 +336,7 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key="your_api_key",
model_name="sentence-transformers/all-MiniLM-l6-v2"
model_name="sentence-transformers/all-MiniLM-l6-v2",
)
texts = ["Hello, world!", "How are you?"]
hf_embeddings.embed_documents(texts)

View File

@@ -1,4 +1,4 @@
from abc import abstractmethod, ABC
from abc import ABC, abstractmethod
from typing import List
from dbgpt.core import LLMClient

View File

@@ -1,7 +1,7 @@
from typing import List, Optional
from dbgpt._private.llm_metadata import LLMMetadata
from dbgpt.core import LLMClient, ModelRequest, ModelMessageRoleType
from dbgpt.core import LLMClient, ModelMessageRoleType, ModelRequest
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.extractor.base import Extractor
from dbgpt.util import utils

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Optional, Callable, Tuple, List
from typing import Any, Callable, List, Optional, Tuple
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -87,9 +87,10 @@ class RAGGraphEngine:
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
"""Extract triplets from text by llm"""
import uuid
from dbgpt.app.scene import ChatScene
from dbgpt.util.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type

View File

@@ -2,11 +2,11 @@ import logging
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Dict, Any, Set, Callable
from typing import Any, Callable, Dict, List, Optional, Set
from langchain.schema import Document
from dbgpt.rag.graph.node import BaseNode, TextNode, NodeWithScore
from dbgpt.rag.graph.node import BaseNode, NodeWithScore, TextNode
from dbgpt.rag.graph.search import BaseSearch, SearchMode
logger = logging.getLogger(__name__)
@@ -77,9 +77,10 @@ class RAGGraphSearch(BaseSearch):
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
"""extract subject entities from text by llm"""
import uuid
from dbgpt.app.scene import ChatScene
from dbgpt.util.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),

View File

@@ -11,9 +11,8 @@ from typing import Dict, List, Optional, Sequence, Set
from dataclasses_json import DataClassJsonMixin
from dbgpt.rag.graph.index_type import IndexStructType
from dbgpt.rag.graph.node import TextNode, BaseNode
from dbgpt.rag.graph.node import BaseNode, TextNode
# TODO: legacy backport of old Node class
Node = TextNode

View File

@@ -1,4 +1,5 @@
from typing import List, Optional
from llama_index.data_structs.data_structs import IndexStruct
from llama_index.storage.index_store.utils import (
index_struct_to_json,

View File

@@ -8,9 +8,9 @@ from hashlib import sha256
from typing import Any, Dict, List, Optional, Union
from langchain.schema import Document
from dbgpt._private.pydantic import BaseModel, Field, root_validator
from typing_extensions import Self
from dbgpt._private.pydantic import BaseModel, Field, root_validator
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
DEFAULT_METADATA_TMPL = "{key}: {value}"

View File

@@ -1,14 +1,14 @@
from abc import abstractmethod, ABC
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional, Any, List
from typing import Any, List, Optional
from dbgpt.rag.chunk import Document
from dbgpt.rag.text_splitter.text_splitter import (
RecursiveCharacterTextSplitter,
MarkdownHeaderTextSplitter,
ParagraphTextSplitter,
CharacterTextSplitter,
MarkdownHeaderTextSplitter,
PageTextSplitter,
ParagraphTextSplitter,
RecursiveCharacterTextSplitter,
SeparatorTextSplitter,
)

View File

@@ -1,11 +1,12 @@
from typing import Optional, Any, List
import csv
from typing import Any, List, Optional
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import (
KnowledgeType,
Knowledge,
ChunkStrategy,
DocumentType,
Knowledge,
KnowledgeType,
)

View File

@@ -1,13 +1,14 @@
from typing import Optional, Any, List
from typing import Any, List, Optional
import docx
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import (
KnowledgeType,
Knowledge,
ChunkStrategy,
DocumentType,
Knowledge,
KnowledgeType,
)
import docx
class DocxKnowledge(Knowledge):

View File

@@ -1,7 +1,6 @@
from typing import Optional
from typing import List
from typing import List, Optional
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
from dbgpt.rag.knowledge.string import StringKnowledge
from dbgpt.rag.knowledge.url import URLKnowledge
@@ -32,11 +31,21 @@ class KnowledgeFactory:
Args:
datasource: path of the file to convert
knowledge_type: type of knowledge
Example:
Examples:
.. code-block:: python
>>> from dbgpt.rag.knowledge.factory import KnowledgeFactory
>>> url_knowlege = KnowledgeFactory.create(datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL)
>>> doc_knowlege = KnowledgeFactory.create(datasource="path/to/document.pdf", knowledge_type=KnowledgeType.DOCUMENT)
from dbgpt.rag.knowledge.factory import KnowledgeFactory
url_knowlege = KnowledgeFactory.create(
datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL
)
doc_knowlege = KnowledgeFactory.create(
datasource="path/to/document.pdf",
knowledge_type=KnowledgeType.DOCUMENT,
)
"""
match knowledge_type:
case KnowledgeType.DOCUMENT:
@@ -57,13 +66,22 @@ class KnowledgeFactory:
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
) -> Knowledge:
"""Create knowledge from path
Args:
param file_path: path of the file to convert
param knowledge_type: type of knowledge
Example:
Examples:
.. code-block:: python
>>> from dbgpt.rag.knowledge.factory import KnowledgeFactory
>>> doc_knowlege = KnowledgeFactory.create(datasource="path/to/document.pdf", knowledge_type=KnowledgeType.DOCUMENT)
from dbgpt.rag.knowledge.factory import KnowledgeFactory
doc_knowlege = KnowledgeFactory.create(
datasource="path/to/document.pdf",
knowledge_type=KnowledgeType.DOCUMENT,
)
"""
factory = cls(file_path=file_path, knowledge_type=knowledge_type)
return factory._select_document_knowledge(
@@ -76,13 +94,21 @@ class KnowledgeFactory:
knowledge_type: Optional[KnowledgeType] = KnowledgeType.URL,
) -> Knowledge:
"""Create knowledge from url
Args:
param url: url of the file to convert
param knowledge_type: type of knowledge
Example:
Examples:
.. code-block:: python
>>> from dbgpt.rag.knowledge.factory import KnowledgeFactory
>>> url_knowlege = KnowledgeFactory.create(datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL)
from dbgpt.rag.knowledge.factory import KnowledgeFactory
url_knowlege = KnowledgeFactory.create(
datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL
)
"""
return URLKnowledge(
url=url,
@@ -130,14 +156,14 @@ class KnowledgeFactory:
def _get_knowledge_subclasses() -> List[Knowledge]:
"""get all knowledge subclasses"""
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.knowledge.pdf import PDFKnowledge
from dbgpt.rag.knowledge.docx import DocxKnowledge
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge
from dbgpt.rag.knowledge.csv import CSVKnowledge
from dbgpt.rag.knowledge.txt import TXTKnowledge
from dbgpt.rag.knowledge.pptx import PPTXKnowledge
from dbgpt.rag.knowledge.docx import DocxKnowledge
from dbgpt.rag.knowledge.html import HTMLKnowledge
from dbgpt.rag.knowledge.url import URLKnowledge
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge
from dbgpt.rag.knowledge.pdf import PDFKnowledge
from dbgpt.rag.knowledge.pptx import PPTXKnowledge
from dbgpt.rag.knowledge.string import StringKnowledge
from dbgpt.rag.knowledge.txt import TXTKnowledge
from dbgpt.rag.knowledge.url import URLKnowledge
return Knowledge.__subclasses__()

View File

@@ -1,13 +1,13 @@
from typing import Optional, Any, List
from typing import Any, List, Optional
import chardet
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import (
Knowledge,
KnowledgeType,
ChunkStrategy,
DocumentType,
Knowledge,
KnowledgeType,
)

View File

@@ -1,11 +1,11 @@
from typing import Optional, Any, List
from typing import Any, List, Optional
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import (
KnowledgeType,
Knowledge,
ChunkStrategy,
DocumentType,
Knowledge,
KnowledgeType,
)

View File

@@ -1,11 +1,11 @@
from typing import Optional, Any, List
from typing import Any, List, Optional
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import (
Knowledge,
KnowledgeType,
ChunkStrategy,
DocumentType,
Knowledge,
KnowledgeType,
)

View File

@@ -1,11 +1,11 @@
from typing import Optional, Any, List
from typing import Any, List, Optional
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import (
Knowledge,
KnowledgeType,
ChunkStrategy,
DocumentType,
Knowledge,
KnowledgeType,
)

View File

@@ -1,7 +1,7 @@
from typing import Optional, Any, List
from typing import Any, List, Optional
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge, ChunkStrategy
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
class StringKnowledge(Knowledge):

View File

@@ -1,6 +1,7 @@
import pytest
from unittest.mock import MagicMock, mock_open, patch
import pytest
from dbgpt.rag.knowledge.csv import CSVKnowledge
MOCK_CSV_DATA = "id,name,age\n1,John Doe,30\n2,Jane Smith,25\n3,Bob Johnson,40"

View File

@@ -1,6 +1,7 @@
import pytest
from unittest.mock import MagicMock, patch
import pytest
from dbgpt.rag.knowledge.docx import DocxKnowledge

View File

@@ -1,6 +1,7 @@
import pytest
from unittest.mock import mock_open, patch
import pytest
from dbgpt.rag.knowledge.html import HTMLKnowledge
MOCK_HTML_CONTENT = b"""

View File

@@ -1,6 +1,7 @@
import pytest
from unittest.mock import mock_open, patch
import pytest
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge
MOCK_MARKDOWN_DATA = """# Header 1

View File

@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, mock_open, patch
import pytest
from unittest.mock import MagicMock, patch, mock_open
from dbgpt.rag.knowledge.pdf import PDFKnowledge

View File

@@ -1,6 +1,7 @@
import pytest
from unittest.mock import mock_open, patch
import pytest
from dbgpt.rag.knowledge.txt import TXTKnowledge
MOCK_TXT_CONTENT = b"Sample text content for testing.\nAnother line of text."

View File

@@ -1,13 +1,13 @@
from typing import Optional, Any, List
from typing import Any, List, Optional
import chardet
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import (
Knowledge,
KnowledgeType,
ChunkStrategy,
DocumentType,
Knowledge,
KnowledgeType,
)

View File

@@ -1,7 +1,7 @@
from typing import Optional, Any, List
from typing import Any, List, Optional
from dbgpt.rag.chunk import Document
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge, ChunkStrategy
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
class URLKnowledge(Knowledge):

View File

@@ -1,4 +1,5 @@
from typing import Any
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary

View File

@@ -2,7 +2,7 @@ from typing import Any, List, Optional
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
from dbgpt.rag.knowledge.factory import KnowledgeFactory

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, List
from typing import Any, List, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, List
from typing import Any, List, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Tuple
from dbgpt.rag.chunk import Chunk

View File

@@ -1,13 +1,13 @@
from functools import reduce
from typing import List, Optional
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
class DBSchemaRetriever(BaseRetriever):
@@ -29,50 +29,59 @@ class DBSchemaRetriever(BaseRetriever):
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
code example:
.. code-block:: python
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
>>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
def _create_temporary_connection():
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
Examples:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
def _create_temporary_connection():
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
}
)
return connect
connection = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn,
)
return connect
connection = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(
model_name=embedding_model_path
)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn
)
# get db struct retriever
retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector)
chunks = retriever.retrieve("show columns from table")
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3, vector_store_connector=vector_connector
)
chunks = retriever.retrieve("show columns from table")
print(
f"db struct rag example results:{[chunk.content for chunk in chunks]}"
)
"""
self._top_k = top_k

View File

@@ -1,12 +1,12 @@
from functools import reduce
from typing import List, Optional
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
class EmbeddingRetriever(BaseRetriever):
@@ -25,31 +25,38 @@ class EmbeddingRetriever(BaseRetriever):
query_rewrite (Optional[QueryRewrite]): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
code example:
Examples:
.. code-block:: python
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
>>> from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = DefaultEmbeddingFactory()
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.embedding.embedding_factory import (
DefaultEmbeddingFactory,
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_name = "test"
config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn)
vector_store_connector = VectorStoreConnector(
vector_store_type=""Chroma"",
vector_store_config=config,
)
embedding_retriever = EmbeddingRetriever(
top_k=3, vector_store_connector=vector_store_connector
)
chunks = embedding_retriever.retrieve("your query text")
print(f"embedding retriever results:{[chunk.content for chunk in chunks]}")
embedding_factory = DefaultEmbeddingFactory()
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_name = "test"
config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn)
vector_store_connector = VectorStoreConnector(
vector_store_type="Chroma",
vector_store_config=config,
)
embedding_retriever = EmbeddingRetriever(
top_k=3, vector_store_connector=vector_store_connector
)
chunks = embedding_retriever.retrieve("your query text")
print(
f"embedding retriever results:{[chunk.content for chunk in chunks]}"
)
"""
self._top_k = top_k
self._query_rewrite = query_rewrite

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
REWRITE_PROMPT_TEMPLATE_EN = """
Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'":

View File

@@ -1,6 +1,7 @@
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from typing import List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import dbgpt
from dbgpt.rag.chunk import Chunk

View File

@@ -1,4 +1,5 @@
from unittest.mock import MagicMock
import pytest
from dbgpt.rag.chunk import Chunk

View File

@@ -1,12 +1,9 @@
import logging
import traceback
from dbgpt.component import SystemApp
from dbgpt._private.config import Config
from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
)
from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
logger = logging.getLogger(__name__)
@@ -44,8 +41,8 @@ class DBSummaryClient:
def get_db_summary(self, dbname, query, topk):
"""get user query related tables info"""
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
vector_connector = VectorStoreConnector.from_default(
@@ -82,8 +79,8 @@ class DBSummaryClient:
dbname(str): dbname
"""
vector_store_name = dbname + "_profile"
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(

View File

@@ -1,7 +1,8 @@
from typing import List
from dbgpt._private.config import Config
from dbgpt.rag.summary.db_summary import DBSummary
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.summary.db_summary import DBSummary
CFG = Config()

View File

@@ -1,6 +1,6 @@
from typing import Iterable, List
from dbgpt.rag.chunk import Document, Chunk
from dbgpt.rag.chunk import Chunk, Document
from dbgpt.rag.text_splitter.text_splitter import TextSplitter

View File

@@ -1,7 +1,7 @@
import copy
import logging
import re
from abc import abstractmethod, ABC
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
@@ -14,7 +14,7 @@ from typing import (
Union,
)
from dbgpt.rag.chunk import Document, Chunk
from dbgpt.rag.chunk import Chunk, Document
logger = logging.getLogger(__name__)

View File

@@ -4,7 +4,7 @@ from typing import Callable, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
from dbgpt.util.global_helper import globals_helper
from dbgpt.util.splitter_utils import split_by_sep, split_by_char
from dbgpt.util.splitter_utils import split_by_char, split_by_sep
DEFAULT_METADATA_FORMAT_LEN = 2
DEFAULT_CHUNK_OVERLAP = 20