mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
@@ -2,7 +2,7 @@ import json
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Any
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
|
||||
from dbgpt.component import BaseComponent
|
||||
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
|
||||
|
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import Field, Extra, BaseModel
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
@@ -54,12 +54,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
from .embeddings import HuggingFaceEmbeddings
|
||||
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': False}
|
||||
model_kwargs = {"device": "cpu"}
|
||||
encode_kwargs = {"normalize_embeddings": False}
|
||||
hf = HuggingFaceEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
encode_kwargs=encode_kwargs,
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -142,12 +142,12 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
||||
|
||||
model_name = "hkunlp/instructor-large"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': True}
|
||||
model_kwargs = {"device": "cpu"}
|
||||
encode_kwargs = {"normalize_embeddings": True}
|
||||
hf = HuggingFaceInstructEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
encode_kwargs=encode_kwargs,
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -221,12 +221,12 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
|
||||
model_name = "BAAI/bge-large-en"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': True}
|
||||
model_kwargs = {"device": "cpu"}
|
||||
encode_kwargs = {"normalize_embeddings": True}
|
||||
hf = HuggingFaceBgeEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
encode_kwargs=encode_kwargs,
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -336,7 +336,7 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
|
||||
api_key="your_api_key",
|
||||
model_name="sentence-transformers/all-MiniLM-l6-v2"
|
||||
model_name="sentence-transformers/all-MiniLM-l6-v2",
|
||||
)
|
||||
texts = ["Hello, world!", "How are you?"]
|
||||
hf_embeddings.embed_documents(texts)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.llm_metadata import LLMMetadata
|
||||
from dbgpt.core import LLMClient, ModelRequest, ModelMessageRoleType
|
||||
from dbgpt.core import LLMClient, ModelMessageRoleType, ModelRequest
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.util import utils
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Optional, Callable, Tuple, List
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
@@ -87,9 +87,10 @@ class RAGGraphEngine:
|
||||
|
||||
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
||||
"""Extract triplets from text by llm"""
|
||||
import uuid
|
||||
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.util.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Type
|
||||
|
||||
|
@@ -2,11 +2,11 @@ import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Dict, Any, Set, Callable
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from dbgpt.rag.graph.node import BaseNode, TextNode, NodeWithScore
|
||||
from dbgpt.rag.graph.node import BaseNode, NodeWithScore, TextNode
|
||||
from dbgpt.rag.graph.search import BaseSearch, SearchMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -77,9 +77,10 @@ class RAGGraphSearch(BaseSearch):
|
||||
|
||||
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||
"""extract subject entities from text by llm"""
|
||||
import uuid
|
||||
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.util.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
|
@@ -11,9 +11,8 @@ from typing import Dict, List, Optional, Sequence, Set
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
|
||||
|
||||
from dbgpt.rag.graph.index_type import IndexStructType
|
||||
from dbgpt.rag.graph.node import TextNode, BaseNode
|
||||
from dbgpt.rag.graph.node import BaseNode, TextNode
|
||||
|
||||
# TODO: legacy backport of old Node class
|
||||
Node = TextNode
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index.data_structs.data_structs import IndexStruct
|
||||
from llama_index.storage.index_store.utils import (
|
||||
index_struct_to_json,
|
||||
|
@@ -8,9 +8,9 @@ from hashlib import sha256
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.schema import Document
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
|
||||
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
|
||||
DEFAULT_METADATA_TMPL = "{key}: {value}"
|
||||
|
@@ -1,14 +1,14 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.text_splitter.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
MarkdownHeaderTextSplitter,
|
||||
ParagraphTextSplitter,
|
||||
CharacterTextSplitter,
|
||||
MarkdownHeaderTextSplitter,
|
||||
PageTextSplitter,
|
||||
ParagraphTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
SeparatorTextSplitter,
|
||||
)
|
||||
|
||||
|
@@ -1,11 +1,12 @@
|
||||
from typing import Optional, Any, List
|
||||
import csv
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
KnowledgeType,
|
||||
Knowledge,
|
||||
ChunkStrategy,
|
||||
DocumentType,
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,13 +1,14 @@
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import docx
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
KnowledgeType,
|
||||
Knowledge,
|
||||
ChunkStrategy,
|
||||
DocumentType,
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
)
|
||||
import docx
|
||||
|
||||
|
||||
class DocxKnowledge(Knowledge):
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from typing import Optional
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge
|
||||
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
||||
from dbgpt.rag.knowledge.string import StringKnowledge
|
||||
from dbgpt.rag.knowledge.url import URLKnowledge
|
||||
|
||||
@@ -32,11 +31,21 @@ class KnowledgeFactory:
|
||||
Args:
|
||||
datasource: path of the file to convert
|
||||
knowledge_type: type of knowledge
|
||||
Example:
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
>>> from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
>>> url_knowlege = KnowledgeFactory.create(datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL)
|
||||
>>> doc_knowlege = KnowledgeFactory.create(datasource="path/to/document.pdf", knowledge_type=KnowledgeType.DOCUMENT)
|
||||
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
|
||||
url_knowlege = KnowledgeFactory.create(
|
||||
datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL
|
||||
)
|
||||
doc_knowlege = KnowledgeFactory.create(
|
||||
datasource="path/to/document.pdf",
|
||||
knowledge_type=KnowledgeType.DOCUMENT,
|
||||
)
|
||||
|
||||
"""
|
||||
match knowledge_type:
|
||||
case KnowledgeType.DOCUMENT:
|
||||
@@ -57,13 +66,22 @@ class KnowledgeFactory:
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
) -> Knowledge:
|
||||
"""Create knowledge from path
|
||||
|
||||
Args:
|
||||
param file_path: path of the file to convert
|
||||
param knowledge_type: type of knowledge
|
||||
Example:
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
>>> from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
>>> doc_knowlege = KnowledgeFactory.create(datasource="path/to/document.pdf", knowledge_type=KnowledgeType.DOCUMENT)
|
||||
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
|
||||
doc_knowlege = KnowledgeFactory.create(
|
||||
datasource="path/to/document.pdf",
|
||||
knowledge_type=KnowledgeType.DOCUMENT,
|
||||
)
|
||||
|
||||
"""
|
||||
factory = cls(file_path=file_path, knowledge_type=knowledge_type)
|
||||
return factory._select_document_knowledge(
|
||||
@@ -76,13 +94,21 @@ class KnowledgeFactory:
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.URL,
|
||||
) -> Knowledge:
|
||||
"""Create knowledge from url
|
||||
|
||||
Args:
|
||||
param url: url of the file to convert
|
||||
param knowledge_type: type of knowledge
|
||||
Example:
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
>>> from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
>>> url_knowlege = KnowledgeFactory.create(datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL)
|
||||
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
|
||||
url_knowlege = KnowledgeFactory.create(
|
||||
datasource="https://www.baidu.com", knowledge_type=KnowledgeType.URL
|
||||
)
|
||||
|
||||
"""
|
||||
return URLKnowledge(
|
||||
url=url,
|
||||
@@ -130,14 +156,14 @@ class KnowledgeFactory:
|
||||
def _get_knowledge_subclasses() -> List[Knowledge]:
|
||||
"""get all knowledge subclasses"""
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.knowledge.pdf import PDFKnowledge
|
||||
from dbgpt.rag.knowledge.docx import DocxKnowledge
|
||||
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge
|
||||
from dbgpt.rag.knowledge.csv import CSVKnowledge
|
||||
from dbgpt.rag.knowledge.txt import TXTKnowledge
|
||||
from dbgpt.rag.knowledge.pptx import PPTXKnowledge
|
||||
from dbgpt.rag.knowledge.docx import DocxKnowledge
|
||||
from dbgpt.rag.knowledge.html import HTMLKnowledge
|
||||
from dbgpt.rag.knowledge.url import URLKnowledge
|
||||
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge
|
||||
from dbgpt.rag.knowledge.pdf import PDFKnowledge
|
||||
from dbgpt.rag.knowledge.pptx import PPTXKnowledge
|
||||
from dbgpt.rag.knowledge.string import StringKnowledge
|
||||
from dbgpt.rag.knowledge.txt import TXTKnowledge
|
||||
from dbgpt.rag.knowledge.url import URLKnowledge
|
||||
|
||||
return Knowledge.__subclasses__()
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import chardet
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
ChunkStrategy,
|
||||
DocumentType,
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
KnowledgeType,
|
||||
Knowledge,
|
||||
ChunkStrategy,
|
||||
DocumentType,
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
ChunkStrategy,
|
||||
DocumentType,
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
ChunkStrategy,
|
||||
DocumentType,
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge, ChunkStrategy
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
|
||||
|
||||
|
||||
class StringKnowledge(Knowledge):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.rag.knowledge.csv import CSVKnowledge
|
||||
|
||||
MOCK_CSV_DATA = "id,name,age\n1,John Doe,30\n2,Jane Smith,25\n3,Bob Johnson,40"
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.rag.knowledge.docx import DocxKnowledge
|
||||
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.rag.knowledge.html import HTMLKnowledge
|
||||
|
||||
MOCK_HTML_CONTENT = b"""
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge
|
||||
|
||||
MOCK_MARKDOWN_DATA = """# Header 1
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
|
||||
from dbgpt.rag.knowledge.pdf import PDFKnowledge
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.rag.knowledge.txt import TXTKnowledge
|
||||
|
||||
MOCK_TXT_CONTENT = b"Sample text content for testing.\nAnother line of text."
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import chardet
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
ChunkStrategy,
|
||||
DocumentType,
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Document
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge, ChunkStrategy
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
|
||||
|
||||
|
||||
class URLKnowledge(Knowledge):
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from dbgpt.core.interface.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
@@ -2,7 +2,7 @@ from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge
|
||||
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Tuple
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
||||
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from functools import reduce
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
|
||||
class DBSchemaRetriever(BaseRetriever):
|
||||
@@ -29,50 +29,59 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
query_rewrite (bool): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
code example:
|
||||
.. code-block:: python
|
||||
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
>>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
|
||||
def _create_temporary_connection():
|
||||
connect = SQLiteTempConnect.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
|
||||
|
||||
def _create_temporary_connection():
|
||||
connect = SQLiteTempConnect.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
|
||||
|
||||
connection = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_fn=embedding_fn,
|
||||
)
|
||||
return connect
|
||||
connection = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=embedding_model_path
|
||||
)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_fn=embedding_fn
|
||||
)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3, vector_store_connector=vector_connector
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
print(
|
||||
f"db struct rag example results:{[chunk.content for chunk in chunks]}"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
self._top_k = top_k
|
||||
|
@@ -1,12 +1,12 @@
|
||||
from functools import reduce
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
@@ -25,31 +25,38 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
query_rewrite (Optional[QueryRewrite]): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
code example:
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
>>> from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
embedding_factory = DefaultEmbeddingFactory()
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.embedding.embedding_factory import (
|
||||
DefaultEmbeddingFactory,
|
||||
)
|
||||
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_name = "test"
|
||||
config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=""Chroma"",
|
||||
vector_store_config=config,
|
||||
)
|
||||
embedding_retriever = EmbeddingRetriever(
|
||||
top_k=3, vector_store_connector=vector_store_connector
|
||||
)
|
||||
chunks = embedding_retriever.retrieve("your query text")
|
||||
print(f"embedding retriever results:{[chunk.content for chunk in chunks]}")
|
||||
embedding_factory = DefaultEmbeddingFactory()
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_name = "test"
|
||||
config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type="Chroma",
|
||||
vector_store_config=config,
|
||||
)
|
||||
embedding_retriever = EmbeddingRetriever(
|
||||
top_k=3, vector_store_connector=vector_store_connector
|
||||
)
|
||||
chunks = embedding_retriever.retrieve("your query text")
|
||||
print(
|
||||
f"embedding retriever results:{[chunk.content for chunk in chunks]}"
|
||||
)
|
||||
"""
|
||||
self._top_k = top_k
|
||||
self._query_rewrite = query_rewrite
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
|
||||
|
||||
from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
|
||||
|
||||
REWRITE_PROMPT_TEMPLATE_EN = """
|
||||
Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'":
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import dbgpt
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
@@ -1,12 +1,9 @@
|
||||
import logging
|
||||
|
||||
import traceback
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
)
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,8 +41,8 @@ class DBSummaryClient:
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
@@ -82,8 +79,8 @@ class DBSummaryClient:
|
||||
dbname(str): dbname
|
||||
"""
|
||||
vector_store_name = dbname + "_profile"
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.rag.summary.db_summary import DBSummary
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.summary.db_summary import DBSummary
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import Iterable, List
|
||||
|
||||
from dbgpt.rag.chunk import Document, Chunk
|
||||
from dbgpt.rag.chunk import Chunk, Document
|
||||
from dbgpt.rag.text_splitter.text_splitter import TextSplitter
|
||||
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import abstractmethod, ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -14,7 +14,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from dbgpt.rag.chunk import Document, Chunk
|
||||
from dbgpt.rag.chunk import Chunk, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from dbgpt.util.global_helper import globals_helper
|
||||
from dbgpt.util.splitter_utils import split_by_sep, split_by_char
|
||||
from dbgpt.util.splitter_utils import split_by_char, split_by_sep
|
||||
|
||||
DEFAULT_METADATA_FORMAT_LEN = 2
|
||||
DEFAULT_CHUNK_OVERLAP = 20
|
||||
|
Reference in New Issue
Block a user