mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -11,6 +11,9 @@ follow_imports = skip
|
||||
[mypy-dbgpt.serve.*]
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-dbgpt.model.*]
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-dbgpt.util.*]
|
||||
follow_imports = skip
|
||||
|
||||
|
106
CONTRIBUTING.md
106
CONTRIBUTING.md
@@ -1,5 +1,11 @@
|
||||
# Contribution
|
||||
|
||||
To contribute to this GitHub project, you can follow these steps:
|
||||
First of all, thank you for considering contributing to this project.
|
||||
It's people like you that make it a reality for the community. There are many ways to contribute, and we appreciate all of them.
|
||||
|
||||
This guide will help you get started with contributing to this project.
|
||||
|
||||
## Fork The Repository
|
||||
|
||||
1. Fork the repository you want to contribute to by clicking the "Fork" button on the project page.
|
||||
|
||||
@@ -8,71 +14,107 @@ To contribute to this GitHub project, you can follow these steps:
|
||||
```
|
||||
git clone https://github.com/<YOUR-GITHUB-USERNAME>/DB-GPT
|
||||
```
|
||||
Please replace `<YOUR-GITHUB-USERNAME>` with your GitHub username.
|
||||
|
||||
3. Install the project requirements
|
||||
|
||||
## Create A New Development Environment
|
||||
|
||||
1. Create a new virtual environment using the following command:
|
||||
```
|
||||
# Make sure python >= 3.10
|
||||
conda create -n dbgpt_env python=3.10
|
||||
conda activate dbgpt_env
|
||||
```
|
||||
|
||||
2. Change to the project directory using the following command:
|
||||
```
|
||||
cd DB-GPT
|
||||
```
|
||||
|
||||
3. Install the project from the local source using the following command:
|
||||
```
|
||||
# it will take some minutes
|
||||
pip install -e ".[default]"
|
||||
```
|
||||
|
||||
4. Install pre-commit hooks
|
||||
4. Install development requirements
|
||||
```
|
||||
pip install -r requirements/dev-requirements.txt
|
||||
pip install -r requirements/lint-requirements.txt
|
||||
```
|
||||
|
||||
5. Install pre-commit hooks
|
||||
```
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
5. Create a new branch for your changes using the following command:
|
||||
6. Install `make` command
|
||||
The `make` command has been installed by default on most Unix-based systems. If you not
|
||||
have it, you can install it by searching on the internet.
|
||||
|
||||
## New Branch And Make Changes
|
||||
|
||||
1. Create a new branch for your changes using the following command:
|
||||
```
|
||||
git checkout -b "branch-name"
|
||||
git checkout -b <branch-name>
|
||||
```
|
||||
Please replace `<branch-name>` with a descriptive name for your branch.
|
||||
|
||||
6. Make your changes to the code or documentation.
|
||||
2. Make your changes to the code or documentation.
|
||||
|
||||
- Example: Improve User Interface or Add Documentation.
|
||||
3. Add tests for your changes if necessary.
|
||||
|
||||
7. Format the code using the following command:
|
||||
4. Format your code using the following command:
|
||||
```
|
||||
make fmt
|
||||
```
|
||||
|
||||
8. Add the changes to the staging area using the following command:
|
||||
5. Run the tests using the following command:
|
||||
```
|
||||
git add .
|
||||
make test
|
||||
```
|
||||
|
||||
9. Make sure the tests pass and your code lints using the following command:
|
||||
6. Check types using the following command:
|
||||
```
|
||||
make pre-commit
|
||||
make mypy
|
||||
```
|
||||
|
||||
10. Commit the changes with a meaningful commit message using the following command:
|
||||
7. Check lint using the following command:
|
||||
```
|
||||
make fmt-check
|
||||
```
|
||||
|
||||
8. If all checks pass, you can add and commit your changes using the following commands:
|
||||
```
|
||||
git add xxxx
|
||||
```
|
||||
make sure to replace `xxxx` with the files you want to commit.
|
||||
|
||||
then commit your changes using the following command:
|
||||
```
|
||||
git commit -m "your commit message"
|
||||
```
|
||||
11. Push the changes to your forked repository using the following command:
|
||||
Please replace `your commit message` with a meaningful commit message.
|
||||
|
||||
It will take some time to get used to the process, but it's worth it. And it will run
|
||||
all git hooks and checks before you commit. If it fails, you need to fix the issues
|
||||
then re-commit it.
|
||||
|
||||
9. Push the changes to your forked repository using the following command:
|
||||
```
|
||||
git push origin branch-name
|
||||
git push origin <branch-name>
|
||||
```
|
||||
12. Go to the GitHub website and navigate to your forked repository.
|
||||
|
||||
13. Click the "New pull request" button.
|
||||
## Create A Pull Request
|
||||
|
||||
14. Select the branch you just pushed to and the branch you want to merge into on the original repository.
|
||||
1. Go to the GitHub website and navigate to your forked repository.
|
||||
|
||||
15. Add a description of your changes and click the "Create pull request" button.
|
||||
2. Click the "New pull request" button.
|
||||
|
||||
16. Wait for the project maintainer to review your changes and provide feedback.
|
||||
3. Select the branch you just pushed to and the branch you want to merge into on the original repository.
|
||||
Write necessary information about your changes and click "Create pull request".
|
||||
|
||||
17. Make any necessary changes based on feedback and repeat steps 5-12 until your changes are accepted and merged into the main project.
|
||||
4. Wait for the project maintainer to review your changes and provide feedback.
|
||||
|
||||
18. Once your changes are merged, you can update your forked repository and local copy of the repository with the following commands:
|
||||
|
||||
```
|
||||
git fetch upstream
|
||||
git checkout master
|
||||
git merge upstream/master
|
||||
```
|
||||
Finally, delete the branch you created with the following command:
|
||||
```
|
||||
git branch -d branch-name
|
||||
```
|
||||
That's it you made it 🐣⭐⭐
|
||||
|
||||
|
5
Makefile
5
Makefile
@@ -26,8 +26,7 @@ testenv: $(VENV)/.testenv
|
||||
|
||||
$(VENV)/.testenv: $(VENV)/bin/activate
|
||||
# $(VENV_BIN)/pip install -e ".[framework]"
|
||||
# $(VENV_BIN)/pip install -e ".[knowledge]"
|
||||
# the openai optional dependency is include framework and knowledge dependencies
|
||||
# the openai optional dependency is include framework and rag dependencies
|
||||
$(VENV_BIN)/pip install -e ".[openai]"
|
||||
touch $(VENV)/.testenv
|
||||
|
||||
@@ -100,7 +99,7 @@ package: clean-dist ## Package the project for distribution
|
||||
IS_DEV_MODE=false python setup.py sdist bdist_wheel
|
||||
|
||||
.PHONY: upload
|
||||
upload: package ## Upload the package to PyPI
|
||||
upload: ## Upload the package to PyPI
|
||||
# upload to testpypi: twine upload --repository testpypi dist/*
|
||||
twine upload dist/*
|
||||
|
||||
|
@@ -1 +1 @@
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
|
@@ -82,8 +82,10 @@ class LocalEmbeddingFactory(EmbeddingFactory):
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> "Embeddings":
|
||||
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
)
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
|
@@ -27,6 +27,8 @@ from dbgpt.component import ComponentType
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType
|
||||
@@ -36,8 +38,6 @@ from dbgpt.rag.text_splitter.text_splitter import (
|
||||
SpacyTextSplitter,
|
||||
)
|
||||
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from dbgpt.serve.rag.service.service import Service, SyncStatus
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
@@ -347,6 +347,7 @@ class KnowledgeService:
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embeddings=embedding_fn,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
chunk_docs = assembler.get_chunks()
|
||||
|
@@ -710,6 +710,11 @@ class DAG:
|
||||
self.print_tree()
|
||||
return _visualize_dag(self, view=view, **kwargs)
|
||||
|
||||
def show(self, mermaid: bool = False) -> Any:
|
||||
"""Return the graph of current DAG."""
|
||||
dot, mermaid_str = _get_graph(self)
|
||||
return mermaid_str if mermaid else dot
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter a DAG context."""
|
||||
DAGVar.enter_dag(self)
|
||||
@@ -813,26 +818,12 @@ def _handle_dag_nodes(
|
||||
_handle_dag_nodes(is_down_to_up, level, node, func)
|
||||
|
||||
|
||||
def _visualize_dag(
|
||||
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
|
||||
) -> Optional[str]:
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True.
|
||||
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The filename of the DAG graph
|
||||
"""
|
||||
def _get_graph(dag: DAG):
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logger.warn("Can't import graphviz, skip visualize DAG")
|
||||
return None
|
||||
|
||||
return None, None
|
||||
dot = Digraph(name=dag.dag_id)
|
||||
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
|
||||
# Record the added edges to avoid adding duplicate edges
|
||||
@@ -851,6 +842,26 @@ def _visualize_dag(
|
||||
|
||||
for root in dag.root_nodes:
|
||||
add_edges(root)
|
||||
return dot, mermaid_str
|
||||
|
||||
|
||||
def _visualize_dag(
|
||||
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
|
||||
) -> Optional[str]:
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True.
|
||||
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The filename of the DAG graph
|
||||
"""
|
||||
dot, mermaid_str = _get_graph(dag)
|
||||
if not dot:
|
||||
return None
|
||||
filename = f"dag-vis-{dag.dag_id}.gv"
|
||||
if "filename" in kwargs:
|
||||
filename = kwargs["filename"]
|
||||
|
@@ -1,12 +1,13 @@
|
||||
"""Common operators of AWEL."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Dict, Generic, List, Optional, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Union
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import (
|
||||
IN,
|
||||
OUT,
|
||||
SKIP_DATA,
|
||||
InputContext,
|
||||
InputSource,
|
||||
JoinFunc,
|
||||
@@ -276,6 +277,11 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
curr_task_ctx.set_task_output(task_output)
|
||||
return task_output
|
||||
|
||||
@classmethod
|
||||
def dummy_input(cls, dummy_data: Any = SKIP_DATA, **kwargs) -> "InputOperator[OUT]":
|
||||
"""Create a dummy InputOperator with a given input value."""
|
||||
return cls(input_source=InputSource.from_data(dummy_data), **kwargs)
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
|
||||
"""Operator node that triggers the DAG to run."""
|
||||
|
@@ -1,6 +1,17 @@
|
||||
"""Module to define the data source connectors."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseConnector # noqa: F401
|
||||
from .rdbms.base import RDBMSConnector # noqa: F401
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "RDBMSConnector":
|
||||
from .rdbms.base import RDBMSConnector # noqa: F401
|
||||
|
||||
return RDBMSConnector
|
||||
else:
|
||||
raise AttributeError(f"Could not find: {name} in datasource")
|
||||
|
||||
|
||||
__ALL__ = ["BaseConnector", "RDBMSConnector"]
|
||||
|
@@ -12,10 +12,10 @@ from ..base import BaseConnector
|
||||
class DatasourceOperator(MapOperator[str, Any]):
|
||||
"""The Datasource Operator."""
|
||||
|
||||
def __init__(self, connection: BaseConnector, **kwargs):
|
||||
def __init__(self, connector: BaseConnector, **kwargs):
|
||||
"""Create the datasource operator."""
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
|
||||
async def map(self, input_value: str) -> Any:
|
||||
"""Execute the query."""
|
||||
@@ -23,4 +23,4 @@ class DatasourceOperator(MapOperator[str, Any]):
|
||||
|
||||
def query(self, input_value: str) -> Any:
|
||||
"""Execute the query."""
|
||||
return self._connection.run_to_df(input_value)
|
||||
return self._connector.run_to_df(input_value)
|
||||
|
@@ -3,11 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, cast
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
import regex as re
|
||||
import sqlalchemy
|
||||
import sqlparse
|
||||
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||
|
@@ -1,25 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Union, cast
|
||||
import logging
|
||||
from typing import List, Optional, Type, cast
|
||||
|
||||
from dbgpt.model.parameter import BaseEmbeddingModelParameters, ProxyEmbeddingParameters
|
||||
from dbgpt.util.parameter_utils import _get_dict_from_obj
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.model.parameter import (
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
ProxyEmbeddingParameters,
|
||||
)
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
||||
from dbgpt.util.system_utils import get_system_info
|
||||
from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
|
||||
|
||||
from dbgpt.rag.embedding import Embeddings, HuggingFaceEmbeddings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingLoader:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def load(
|
||||
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||
) -> "Union[LangChainEmbeddings, Embeddings]":
|
||||
def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddings:
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
||||
@@ -33,7 +35,9 @@ class EmbeddingLoader:
|
||||
if model_name in ["proxy_openai", "proxy_azure"]:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings(**param.build_kwargs())
|
||||
from dbgpt.rag.embedding._wrapped import WrappedEmbeddings
|
||||
|
||||
return WrappedEmbeddings(OpenAIEmbeddings(**param.build_kwargs()))
|
||||
elif model_name in ["proxy_http_openapi"]:
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
|
||||
@@ -51,3 +55,28 @@ class EmbeddingLoader:
|
||||
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return HuggingFaceEmbeddings(**kwargs)
|
||||
|
||||
|
||||
def _parse_embedding_params(
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
command_args: List[str] = None,
|
||||
param_cls: Optional[Type] = EmbeddingModelParameters,
|
||||
**kwargs,
|
||||
):
|
||||
model_args = EnvArgumentParser()
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
||||
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefixes=[env_prefix],
|
||||
command_args=command_args,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
**kwargs,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
)
|
||||
return model_params
|
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Type
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import ModelMetadata
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
)
|
||||
from dbgpt.model.adapter.loader import _get_model_real_path
|
||||
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
@@ -13,14 +15,13 @@ from dbgpt.model.parameter import (
|
||||
WorkerType,
|
||||
)
|
||||
from dbgpt.util.model_utils import _clear_model_cache
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingsModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
from dbgpt.rag.embedding import Embeddings, HuggingFaceEmbeddings
|
||||
from dbgpt.rag.embedding import Embeddings
|
||||
|
||||
self._embeddings_impl: Embeddings = None
|
||||
self._model_params = None
|
||||
@@ -97,26 +98,3 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
logger.info(f"Receive embeddings request, model: {model}")
|
||||
input: List[str] = params["input"]
|
||||
return self._embeddings_impl.embed_documents(input)
|
||||
|
||||
|
||||
def _parse_embedding_params(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
command_args: List[str] = None,
|
||||
param_cls: Optional[Type] = EmbeddingModelParameters,
|
||||
):
|
||||
model_args = EnvArgumentParser()
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
||||
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefixes=[env_prefix],
|
||||
command_args=command_args,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
)
|
||||
return model_params
|
||||
|
@@ -1 +1,11 @@
|
||||
"""Module of RAG."""
|
||||
|
||||
from dbgpt.core import Chunk, Document # noqa: F401
|
||||
|
||||
from .chunk_manager import ChunkParameters # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
"Chunk",
|
||||
"Document",
|
||||
"ChunkParameters",
|
||||
]
|
||||
|
16
dbgpt/rag/assembler/__init__.py
Normal file
16
dbgpt/rag/assembler/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Assembler Module For RAG.
|
||||
|
||||
The Assembler is a module that is responsible for assembling the knowledge.
|
||||
"""
|
||||
|
||||
from .base import BaseAssembler # noqa: F401
|
||||
from .db_schema import DBSchemaAssembler # noqa: F401
|
||||
from .embedding import EmbeddingAssembler # noqa: F401
|
||||
from .summary import SummaryAssembler # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"BaseAssembler",
|
||||
"DBSchemaAssembler",
|
||||
"EmbeddingAssembler",
|
||||
"SummaryAssembler",
|
||||
]
|
@@ -1,36 +1,41 @@
|
||||
"""Base Assembler."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from ..chunk_manager import ChunkManager, ChunkParameters
|
||||
from ..extractor.base import Extractor
|
||||
from ..knowledge.base import Knowledge
|
||||
from ..retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class BaseAssembler(ABC):
|
||||
"""Base Assembler"""
|
||||
"""Base Assembler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Optional[Knowledge] = None,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
extractor: Optional[Extractor] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Assembler arguments.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
extractor: (Optional[Extractor]) Extractor to use for summarization."""
|
||||
knowledge(Knowledge): Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
extractor(Optional[Extractor]): Extractor to use for summarization.
|
||||
"""
|
||||
self._knowledge = knowledge
|
||||
self._chunk_parameters = chunk_parameters or ChunkParameters()
|
||||
self._extractor = extractor
|
||||
self._chunk_manager = ChunkManager(
|
||||
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
|
||||
)
|
||||
self._chunks = None
|
||||
self._chunks: List[Chunk] = []
|
||||
metadata = {
|
||||
"knowledge_cls": self._knowledge.__class__.__name__
|
||||
if self._knowledge
|
135
dbgpt/rag/assembler/db_schema.py
Normal file
135
dbgpt/rag/assembler/db_schema.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""DBSchemaAssembler."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..knowledge.datasource import DatasourceKnowledge
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
|
||||
class DBSchemaAssembler(BaseAssembler):
|
||||
"""DBSchemaAssembler.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
|
||||
connection = SQLiteTempConnector.create_temporary_db()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=connection,
|
||||
embedding_model=embedding_model_path,
|
||||
)
|
||||
assembler.persist()
|
||||
# get db struct retriever
|
||||
retriever = assembler.as_retriever(top_k=3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
knowledge = DatasourceKnowledge(connector)
|
||||
self._connector = connector
|
||||
self._vector_store_connector = vector_store_connector
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._vector_store_connector.vector_store_config.embedding_fn is None
|
||||
):
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
|
||||
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_connection(
|
||||
cls,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> "DBSchemaAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
Returns:
|
||||
DBSchemaAssembler
|
||||
"""
|
||||
return cls(
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
embedding_model=embedding_model,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
"""Return chunk ids."""
|
||||
return self._chunks
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
return []
|
||||
|
||||
def as_retriever(self, top_k: int = 4, **kwargs) -> DBSchemaRetriever:
|
||||
"""Create DBSchemaRetriever.
|
||||
|
||||
Args:
|
||||
top_k(int): default 4.
|
||||
|
||||
Returns:
|
||||
DBSchemaRetriever
|
||||
"""
|
||||
return DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connector=self._connector,
|
||||
is_embeddings=True,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
@@ -1,20 +1,20 @@
|
||||
import os
|
||||
"""Embedding Assembler."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.serve.rag.assembler.base import BaseAssembler
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..knowledge.base import Knowledge
|
||||
from ..retriever.embedding import EmbeddingRetriever
|
||||
|
||||
|
||||
class EmbeddingAssembler(BaseAssembler):
|
||||
"""Embedding Assembler
|
||||
"""Embedding Assembler.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
@@ -30,35 +30,37 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
if knowledge is None:
|
||||
raise ValueError("knowledge datasource must be provided.")
|
||||
self._vector_store_connector = vector_store_connector
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model:
|
||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
)
|
||||
self.embedding_fn = embedding_factory.create(self._embedding_model)
|
||||
if self._vector_store_connector.vector_store_config.embedding_fn is None:
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = (
|
||||
self.embedding_fn
|
||||
)
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._vector_store_connector.vector_store_config.embedding_fn is None
|
||||
):
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
|
||||
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
@@ -70,32 +72,30 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> "EmbeddingAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
|
||||
Returns:
|
||||
EmbeddingAssembler
|
||||
"""
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
||||
default_model_name=embedding_model or os.getenv("EMBEDDING_MODEL_PATH")
|
||||
)
|
||||
return cls(
|
||||
knowledge=knowledge,
|
||||
vector_store_connector=vector_store_connector,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embedding_model=embedding_model,
|
||||
embedding_factory=embedding_factory,
|
||||
vector_store_connector=vector_store_connector,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
@@ -108,12 +108,14 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
pass
|
||||
return []
|
||||
|
||||
def as_retriever(self, top_k: int = 4, **kwargs) -> EmbeddingRetriever:
|
||||
"""Create a retriever.
|
||||
|
||||
def as_retriever(self, top_k: Optional[int] = 4) -> EmbeddingRetriever:
|
||||
"""
|
||||
Args:
|
||||
top_k:(Optional[int]), default 4
|
||||
top_k(int): default 4.
|
||||
|
||||
Returns:
|
||||
EmbeddingRetriever
|
||||
"""
|
@@ -1,16 +1,19 @@
|
||||
"""Summary Assembler."""
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.serve.rag.assembler.base import BaseAssembler
|
||||
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..extractor.base import Extractor
|
||||
from ..knowledge.base import Knowledge
|
||||
from ..retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class SummaryAssembler(BaseAssembler):
|
||||
"""Summary Assembler
|
||||
"""Summary Assembler.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -40,6 +43,7 @@ class SummaryAssembler(BaseAssembler):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
@@ -51,13 +55,24 @@ class SummaryAssembler(BaseAssembler):
|
||||
if knowledge is None:
|
||||
raise ValueError("knowledge datasource must be provided.")
|
||||
|
||||
self._model_name = model_name or os.getenv("LLM_MODEL")
|
||||
self._llm_client = llm_client
|
||||
from dbgpt.rag.extractor.summary import SummaryExtractor
|
||||
model_name = model_name or os.getenv("LLM_MODEL")
|
||||
|
||||
self._extractor = extractor or SummaryExtractor(
|
||||
llm_client=self._llm_client, model_name=self._model_name, language=language
|
||||
if not extractor:
|
||||
from ..extractor.summary import SummaryExtractor
|
||||
|
||||
if not llm_client:
|
||||
raise ValueError("llm_client must be provided.")
|
||||
if not model_name:
|
||||
raise ValueError("model_name must be provided.")
|
||||
extractor = SummaryExtractor(
|
||||
llm_client=llm_client,
|
||||
model_name=model_name,
|
||||
language=language,
|
||||
)
|
||||
if not extractor:
|
||||
raise ValueError("extractor must be provided.")
|
||||
|
||||
self._extractor: Extractor = extractor
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
@@ -77,9 +92,11 @@ class SummaryAssembler(BaseAssembler):
|
||||
**kwargs: Any,
|
||||
) -> "SummaryAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
model_name: (Optional[str]) llm model to use.
|
||||
llm_client: (Optional[LLMClient]) LLMClient to use.
|
||||
extractor: (Optional[Extractor]) Extractor to use for summarization.
|
||||
@@ -107,6 +124,8 @@ class SummaryAssembler(BaseAssembler):
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
return []
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
"""Return a retriever."""
|
||||
raise NotImplementedError
|
@@ -3,11 +3,11 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_load_knowledge(
|
||||
assembler = EmbeddingAssembler(
|
||||
knowledge=mock_knowledge,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embedding_factory=mock_embedding_factory,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
@@ -1,13 +1,12 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@@ -51,14 +50,8 @@ def mock_vector_store_connector():
|
||||
return MagicMock(spec=VectorStoreConnector)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_knowledge():
|
||||
return MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_knowledge,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
@@ -67,10 +60,9 @@ def test_load_knowledge(
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = DBSchemaAssembler(
|
||||
connection=mock_db_connection,
|
||||
connector=mock_db_connection,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embedding_factory=mock_embedding_factory,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
||||
assert len(assembler._chunks) == 1
|
@@ -1,6 +1,10 @@
|
||||
"""Module for embedding related classes and functions."""
|
||||
|
||||
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory # noqa: F401
|
||||
from .embedding_factory import ( # noqa: F401
|
||||
DefaultEmbeddingFactory,
|
||||
EmbeddingFactory,
|
||||
WrappedEmbeddingFactory,
|
||||
)
|
||||
from .embeddings import ( # noqa: F401
|
||||
Embeddings,
|
||||
HuggingFaceBgeEmbeddings,
|
||||
@@ -21,4 +25,5 @@ __ALL__ = [
|
||||
"OpenAPIEmbeddings",
|
||||
"DefaultEmbeddingFactory",
|
||||
"EmbeddingFactory",
|
||||
"WrappedEmbeddingFactory",
|
||||
]
|
||||
|
32
dbgpt/rag/embedding/_wrapped.py
Normal file
32
dbgpt/rag/embedding/_wrapped.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Wraps the third-party language model embeddings to the common interface."""
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
|
||||
|
||||
|
||||
class WrappedEmbeddings(Embeddings):
|
||||
"""Wraps the third-party language model embeddings to the common interface."""
|
||||
|
||||
def __init__(self, embeddings: "LangChainEmbeddings") -> None:
|
||||
"""Create a new WrappedEmbeddings."""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return self._embeddings.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self._embeddings.embed_query(text)
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await self._embeddings.aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await self._embeddings.aembed_query(text)
|
@@ -1,15 +1,14 @@
|
||||
"""EmbeddingFactory class and DefaultEmbeddingFactory class."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
|
||||
from dbgpt.core import Embeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.rag.embedding.embeddings import Embeddings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingFactory(BaseComponent, ABC):
|
||||
@@ -20,7 +19,7 @@ class EmbeddingFactory(BaseComponent, ABC):
|
||||
@abstractmethod
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> "Embeddings":
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
@@ -39,12 +38,19 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
default_model_name: Optional[str] = None,
|
||||
default_model_path: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new DefaultEmbeddingFactory."""
|
||||
super().__init__(system_app=system_app)
|
||||
if not default_model_path:
|
||||
default_model_path = default_model_name
|
||||
if not default_model_name:
|
||||
default_model_name = default_model_path
|
||||
self._default_model_name = default_model_name
|
||||
self.kwargs = kwargs
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
"""Init the app."""
|
||||
@@ -52,20 +58,166 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> "Embeddings":
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
"""
|
||||
if not model_name:
|
||||
model_name = self._default_model_name
|
||||
|
||||
new_kwargs = {k: v for k, v in self.kwargs.items()}
|
||||
new_kwargs["model_name"] = model_name
|
||||
|
||||
if embedding_cls:
|
||||
return embedding_cls(**new_kwargs)
|
||||
else:
|
||||
return HuggingFaceEmbeddings(**new_kwargs)
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> Embeddings:
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
)
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
)
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
model_name = self._default_model_name or model_params.model_name
|
||||
if not model_name:
|
||||
raise ValueError("model_name must be provided.")
|
||||
return loader.load(model_name, model_params)
|
||||
|
||||
@classmethod
|
||||
def openai(
|
||||
cls,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text-embedding-3-small",
|
||||
timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> Embeddings:
|
||||
"""Create an OpenAI embeddings.
|
||||
|
||||
If api_url and api_key are not provided, we will try to get them from
|
||||
environment variables.
|
||||
|
||||
Args:
|
||||
api_url (Optional[str], optional): The api url. Defaults to None.
|
||||
api_key (Optional[str], optional): The api key. Defaults to None.
|
||||
model_name (str, optional): The model name.
|
||||
Defaults to "text-embedding-3-small".
|
||||
timeout (int, optional): The timeout. Defaults to 60.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings instance.
|
||||
"""
|
||||
api_url = (
|
||||
api_url
|
||||
or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") + "/embeddings"
|
||||
)
|
||||
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("api_key must be provided.")
|
||||
return cls.remote(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default(
|
||||
cls, model_name: str, model_path: Optional[str] = None, **kwargs: Any
|
||||
) -> Embeddings:
|
||||
"""Create a default embeddings.
|
||||
|
||||
It will try to load the model from the model name or model path.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
model_path (Optional[str], optional): The model path. Defaults to None.
|
||||
if not provided, it will use the model name as the model path to load
|
||||
the model.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings instance.
|
||||
"""
|
||||
return cls(
|
||||
default_model_name=model_name, default_model_path=model_path, **kwargs
|
||||
).create()
|
||||
|
||||
@classmethod
|
||||
def remote(
|
||||
cls,
|
||||
api_url: str = "http://localhost:8100/api/v1/embeddings",
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text2vec",
|
||||
timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> Embeddings:
|
||||
"""Create a remote embeddings.
|
||||
|
||||
Create a remote embeddings which API compatible with the OpenAI's API. So if
|
||||
your model is compatible with OpenAI's API, you can use this method to create
|
||||
a remote embeddings.
|
||||
|
||||
Args:
|
||||
api_url (str, optional): The api url. Defaults to
|
||||
"http://localhost:8100/api/v1/embeddings".
|
||||
api_key (Optional[str], optional): The api key. Defaults to None.
|
||||
model_name (str, optional): The model name. Defaults to "text2vec".
|
||||
timeout (int, optional): The timeout. Defaults to 60.
|
||||
"""
|
||||
from .embeddings import OpenAPIEmbeddings
|
||||
|
||||
return OpenAPIEmbeddings(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class WrappedEmbeddingFactory(EmbeddingFactory):
|
||||
"""The default embedding factory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new DefaultEmbeddingFactory."""
|
||||
super().__init__(system_app=system_app)
|
||||
if not embeddings:
|
||||
raise ValueError("embeddings must be provided.")
|
||||
self._model = embeddings
|
||||
|
||||
def init_app(self, system_app):
|
||||
"""Init the app."""
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
"""
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
@@ -1,23 +1,50 @@
|
||||
"""Module Of Knowledge."""
|
||||
|
||||
from .base import ChunkStrategy, Knowledge, KnowledgeType # noqa: F401
|
||||
from .csv import CSVKnowledge # noqa: F401
|
||||
from .docx import DocxKnowledge # noqa: F401
|
||||
from .factory import KnowledgeFactory # noqa: F401
|
||||
from .html import HTMLKnowledge # noqa: F401
|
||||
from .markdown import MarkdownKnowledge # noqa: F401
|
||||
from .pdf import PDFKnowledge # noqa: F401
|
||||
from .pptx import PPTXKnowledge # noqa: F401
|
||||
from .string import StringKnowledge # noqa: F401
|
||||
from .txt import TXTKnowledge # noqa: F401
|
||||
from .url import URLKnowledge # noqa: F401
|
||||
from typing import Any, Dict
|
||||
|
||||
__ALL__ = [
|
||||
_MODULE_CACHE: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# Lazy load
|
||||
import importlib
|
||||
|
||||
if name in _MODULE_CACHE:
|
||||
return _MODULE_CACHE[name]
|
||||
|
||||
_LIBS = {
|
||||
"KnowledgeFactory": "factory",
|
||||
"Knowledge": "base",
|
||||
"KnowledgeType": "base",
|
||||
"ChunkStrategy": "base",
|
||||
"CSVKnowledge": "csv",
|
||||
"DatasourceKnowledge": "datasource",
|
||||
"DocxKnowledge": "docx",
|
||||
"HTMLKnowledge": "html",
|
||||
"MarkdownKnowledge": "markdown",
|
||||
"PDFKnowledge": "pdf",
|
||||
"PPTXKnowledge": "pptx",
|
||||
"StringKnowledge": "string",
|
||||
"TXTKnowledge": "txt",
|
||||
"URLKnowledge": "url",
|
||||
}
|
||||
|
||||
if name in _LIBS:
|
||||
module_path = "." + _LIBS[name]
|
||||
module = importlib.import_module(module_path, __name__)
|
||||
attr = getattr(module, name)
|
||||
_MODULE_CACHE[name] = attr
|
||||
return attr
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeFactory",
|
||||
"Knowledge",
|
||||
"KnowledgeType",
|
||||
"ChunkStrategy",
|
||||
"CSVKnowledge",
|
||||
"DatasourceKnowledge",
|
||||
"DocxKnowledge",
|
||||
"HTMLKnowledge",
|
||||
"MarkdownKnowledge",
|
||||
|
@@ -25,6 +25,7 @@ class DocumentType(Enum):
|
||||
DOCX = "docx"
|
||||
TXT = "txt"
|
||||
HTML = "html"
|
||||
DATASOURCE = "datasource"
|
||||
|
||||
|
||||
class KnowledgeType(Enum):
|
||||
|
57
dbgpt/rag/knowledge/datasource.py
Normal file
57
dbgpt/rag/knowledge/datasource.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Datasource Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.datasource import BaseConnector
|
||||
|
||||
from ..summary.rdbms_db_summary import _parse_db_summary
|
||||
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
|
||||
|
||||
|
||||
class DatasourceKnowledge(Knowledge):
|
||||
"""Datasource Knowledge."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
summary_template: str = "{table_name}({columns})",
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create Datasource Knowledge with Knowledge arguments.
|
||||
|
||||
Args:
|
||||
path(str, optional): file path
|
||||
knowledge_type(KnowledgeType, optional): knowledge type
|
||||
data_loader(Any, optional): loader
|
||||
"""
|
||||
self._connector = connector
|
||||
self._summary_template = summary_template
|
||||
super().__init__(knowledge_type=knowledge_type, **kwargs)
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
"""Load datasource document from data_loader."""
|
||||
docs = []
|
||||
for table_summary in _parse_db_summary(self._connector, self._summary_template):
|
||||
docs.append(
|
||||
Document(content=table_summary, metadata={"source": "database"})
|
||||
)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def support_chunk_strategy(cls) -> List[ChunkStrategy]:
|
||||
"""Return support chunk strategy."""
|
||||
return [
|
||||
ChunkStrategy.CHUNK_BY_SIZE,
|
||||
ChunkStrategy.CHUNK_BY_SEPARATOR,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def type(cls) -> KnowledgeType:
|
||||
"""Knowledge type of Datasource."""
|
||||
return KnowledgeType.DOCUMENT
|
||||
|
||||
@classmethod
|
||||
def document_type(cls) -> DocumentType:
|
||||
"""Return document type."""
|
||||
return DocumentType.DATASOURCE
|
@@ -156,6 +156,7 @@ class KnowledgeFactory:
|
||||
"""Get all knowledge subclasses."""
|
||||
from dbgpt.rag.knowledge.base import Knowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.csv import CSVKnowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.datasource import DatasourceKnowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.docx import DocxKnowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.html import HTMLKnowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge # noqa: F401
|
||||
|
@@ -1,8 +1,14 @@
|
||||
"""Module for RAG operators."""
|
||||
|
||||
from .datasource import DatasourceRetrieverOperator # noqa: F401
|
||||
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
|
||||
from .embedding import EmbeddingRetrieverOperator # noqa: F401
|
||||
from .db_schema import ( # noqa: F401
|
||||
DBSchemaAssemblerOperator,
|
||||
DBSchemaRetrieverOperator,
|
||||
)
|
||||
from .embedding import ( # noqa: F401
|
||||
EmbeddingAssemblerOperator,
|
||||
EmbeddingRetrieverOperator,
|
||||
)
|
||||
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
|
||||
from .knowledge import KnowledgeOperator # noqa: F401
|
||||
from .rerank import RerankOperator # noqa: F401
|
||||
@@ -12,7 +18,9 @@ from .summary import SummaryAssemblerOperator # noqa: F401
|
||||
__all__ = [
|
||||
"DatasourceRetrieverOperator",
|
||||
"DBSchemaRetrieverOperator",
|
||||
"DBSchemaAssemblerOperator",
|
||||
"EmbeddingRetrieverOperator",
|
||||
"EmbeddingAssemblerOperator",
|
||||
"KnowledgeOperator",
|
||||
"RerankOperator",
|
||||
"QueryRewriteOperator",
|
||||
|
@@ -1,3 +1,4 @@
|
||||
"""Base Assembler Operator."""
|
||||
from abc import abstractmethod
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
@@ -20,4 +21,4 @@ class AssemblerOperator(MapOperator[IN, OUT]):
|
||||
|
||||
@abstractmethod
|
||||
def assemble(self, input_value: IN) -> OUT:
|
||||
"""assemble knowledge for input value."""
|
||||
"""Assemble knowledge for input value."""
|
@@ -1,21 +1,21 @@
|
||||
"""Datasource operator for RDBMS database."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, List[str]]):
|
||||
"""The Datasource Retriever Operator."""
|
||||
|
||||
def __init__(self, connection: RDBMSConnector, **kwargs):
|
||||
def __init__(self, connector: BaseConnector, **kwargs):
|
||||
"""Create a new DatasourceRetrieverOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
|
||||
def retrieve(self, input_value: Any) -> Any:
|
||||
def retrieve(self, input_value: Any) -> List[str]:
|
||||
"""Retrieve the database summary."""
|
||||
summary = _parse_db_summary(self._connection)
|
||||
summary = _parse_db_summary(self._connector)
|
||||
return summary
|
||||
|
@@ -1,18 +1,22 @@
|
||||
"""The DBSchema Retriever Operator."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.db_schema import DBSchemaAssembler
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
|
||||
class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
"""The DBSchema Retriever Operator.
|
||||
|
||||
Args:
|
||||
connection (RDBMSConnector): The connection.
|
||||
connector (BaseConnector): The connection.
|
||||
top_k (int, optional): The top k. Defaults to 4.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store
|
||||
connector. Defaults to None.
|
||||
@@ -22,21 +26,57 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSConnector] = None,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new DBSchemaRetrieverOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._retriever = DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connection=connection,
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def retrieve(self, query: Any) -> Any:
|
||||
def retrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve the table schemas.
|
||||
|
||||
Args:
|
||||
query (IN): query.
|
||||
query (str): The query.
|
||||
"""
|
||||
return self._retriever.retrieve(query)
|
||||
|
||||
|
||||
class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
"""The DBSchema Assembler Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new DBSchemaAssemblerOperator.
|
||||
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
"""
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._connector = connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, dummy_value) -> List[Chunk]:
|
||||
"""Persist the database schema.
|
||||
|
||||
Args:
|
||||
dummy_value: Dummy value, not used.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: The chunks.
|
||||
"""
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=self._connector,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -5,11 +5,16 @@ from typing import List, Optional, Union
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.embedding import EmbeddingAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..knowledge import Knowledge
|
||||
from ..retriever.embedding import EmbeddingRetriever
|
||||
from ..retriever.rerank import Ranker
|
||||
from ..retriever.rewrite import QueryRewrite
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
|
||||
"""The Embedding Retriever Operator."""
|
||||
@@ -43,3 +48,36 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
|
||||
for q in query
|
||||
]
|
||||
return reduce(lambda x, y: x + y, candidates)
|
||||
|
||||
|
||||
class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
|
||||
"""The Embedding Assembler Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
|
||||
chunk_strategy="CHUNK_BY_SIZE"
|
||||
),
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new EmbeddingAssemblerOperator.
|
||||
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
||||
parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
|
||||
"""
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._vector_store_connector = vector_store_connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, knowledge: Knowledge) -> List[Chunk]:
|
||||
"""Assemble knowledge for input value."""
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Knowledge Operator."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
@@ -14,7 +14,7 @@ from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
|
||||
|
||||
class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
class KnowledgeOperator(MapOperator[str, Knowledge]):
|
||||
"""Knowledge Factory Operator."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
@@ -26,7 +26,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
IOField.build_from(
|
||||
"knowledge datasource",
|
||||
"knowledge datasource",
|
||||
dict,
|
||||
str,
|
||||
"knowledge datasource",
|
||||
)
|
||||
],
|
||||
@@ -85,7 +85,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
self._datasource = datasource
|
||||
self._knowledge_type = KnowledgeType.get_by_value(knowledge_type)
|
||||
|
||||
async def map(self, datasource: Any) -> Knowledge:
|
||||
async def map(self, datasource: str) -> Knowledge:
|
||||
"""Create knowledge from datasource."""
|
||||
if self._datasource:
|
||||
datasource = self._datasource
|
||||
|
@@ -1,12 +1,12 @@
|
||||
"""The Rerank Operator."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker
|
||||
|
||||
|
||||
class RerankOperator(MapOperator[Any, Any]):
|
||||
class RerankOperator(MapOperator[List[Chunk], List[Chunk]]):
|
||||
"""The Rewrite Operator."""
|
||||
|
||||
def __init__(
|
||||
|
@@ -7,7 +7,7 @@ from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -17,7 +17,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector,
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
@@ -27,14 +27,14 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
"""Create the schema linking operator.
|
||||
|
||||
Args:
|
||||
connection (RDBMSConnector): The connection.
|
||||
connector (BaseConnector): The connection.
|
||||
llm (Optional[LLMClient]): base llm
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._schema_linking = SchemaLinking(
|
||||
top_k=top_k,
|
||||
connection=connection,
|
||||
connector=connector,
|
||||
llm=llm,
|
||||
model_name=model_name,
|
||||
vector_store_connector=vector_store_connector,
|
||||
|
@@ -4,9 +4,9 @@ from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
|
||||
from dbgpt.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||
from dbgpt.rag.operators.assembler import AssemblerOperator
|
||||
|
||||
|
||||
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||
|
@@ -3,7 +3,7 @@ from functools import reduce
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
@@ -18,7 +18,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSConnector] = None,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
rerank: Optional[Ranker] = None,
|
||||
**kwargs
|
||||
@@ -28,7 +28,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
top_k (int): top k
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
|
||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||
query_rewrite (bool): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
|
||||
@@ -65,7 +65,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
return connect
|
||||
|
||||
|
||||
connection = _create_temporary_connection()
|
||||
connector = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
@@ -76,14 +76,16 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3, vector_store_connector=vector_connector
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
connector=connector,
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
result = [chunk.content for chunk in chunks]
|
||||
print(f"db struct rag example results:{result}")
|
||||
"""
|
||||
self._top_k = top_k
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._need_embeddings = False
|
||||
@@ -108,9 +110,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
else:
|
||||
if not self._connection:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
@@ -173,6 +175,6 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
if not self._connection:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
return _parse_db_summary(self._connection)
|
||||
return _parse_db_summary(self._connector)
|
||||
|
@@ -24,7 +24,7 @@ def mock_vector_store_connector():
|
||||
@pytest.fixture
|
||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
return DBSchemaRetriever(
|
||||
connection=mock_db_connection,
|
||||
connector=mock_db_connection,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
@@ -10,7 +10,7 @@ from dbgpt.core import (
|
||||
ModelMessageRoleType,
|
||||
ModelRequest,
|
||||
)
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
@@ -42,7 +42,7 @@ class SchemaLinking(BaseSchemaLinker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector,
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
@@ -52,19 +52,19 @@ class SchemaLinking(BaseSchemaLinker):
|
||||
"""Create the schema linking instance.
|
||||
|
||||
Args:
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
|
||||
connection (Optional[BaseConnector]): BaseConnector connection.
|
||||
llm (Optional[LLMClient]): base llm
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._top_k = top_k
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
self._llm = llm
|
||||
self._model_name = model_name
|
||||
self._vector_store_connector = vector_store_connector
|
||||
|
||||
def _schema_linking(self, query: str) -> List:
|
||||
"""Get all db schema info."""
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
chunks_content = [chunk.content for chunk in chunks]
|
||||
return chunks_content
|
||||
|
@@ -97,10 +97,10 @@ class DBSummaryClient:
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if not vector_connector.vector_name_exists():
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
|
||||
db_assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=db_summary_client.db, vector_store_connector=vector_connector
|
||||
connector=db_summary_client.db, vector_store_connector=vector_connector
|
||||
)
|
||||
if len(db_assembler.get_chunks()) > 0:
|
||||
db_assembler.persist()
|
||||
|
@@ -3,7 +3,7 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource import BaseConnector
|
||||
from dbgpt.rag.summary.db_summary import DBSummary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -64,12 +64,12 @@ class RdbmsSummary(DBSummary):
|
||||
|
||||
|
||||
def _parse_db_summary(
|
||||
conn: RDBMSConnector, summary_template: str = "{table_name}({columns})"
|
||||
conn: BaseConnector, summary_template: str = "{table_name}({columns})"
|
||||
) -> List[str]:
|
||||
"""Get db summary for database.
|
||||
|
||||
Args:
|
||||
conn (RDBMSConnector): database connection
|
||||
conn (BaseConnector): database connection
|
||||
summary_template (str): summary template
|
||||
"""
|
||||
tables = conn.get_table_names()
|
||||
@@ -81,12 +81,12 @@ def _parse_db_summary(
|
||||
|
||||
|
||||
def _parse_table_summary(
|
||||
conn: RDBMSConnector, summary_template: str, table_name: str
|
||||
conn: BaseConnector, summary_template: str, table_name: str
|
||||
) -> str:
|
||||
"""Get table summary for table.
|
||||
|
||||
Args:
|
||||
conn (RDBMSConnector): database connection
|
||||
conn (BaseConnector): database connection
|
||||
summary_template (str): summary template
|
||||
table_name (str): table name
|
||||
|
||||
|
@@ -1,153 +0,0 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.serve.rag.assembler.base import BaseAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class DBSchemaAssembler(BaseAssembler):
|
||||
"""DBSchemaAssembler
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
|
||||
connection = SQLiteTempConnector.create_temporary_db()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=connection,
|
||||
embedding_model=embedding_model_path,
|
||||
)
|
||||
assembler.persist()
|
||||
# get db struct retriever
|
||||
retriever = assembler.as_retriever(top_k=3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
Args:
|
||||
connection: (RDBMSConnector) RDBMSConnector connection.
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||
"""
|
||||
if connection is None:
|
||||
raise ValueError("datasource connection must be provided.")
|
||||
self._connection = connection
|
||||
self._vector_store_connector = vector_store_connector
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model:
|
||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
)
|
||||
self.embedding_fn = embedding_factory.create(self._embedding_model)
|
||||
if self._vector_store_connector.vector_store_config.embedding_fn is None:
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = (
|
||||
self.embedding_fn
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
chunk_parameters=chunk_parameters,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_connection(
|
||||
cls,
|
||||
connection: RDBMSConnector = None,
|
||||
knowledge: Optional[Knowledge] = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
) -> "DBSchemaAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
Args:
|
||||
connection: (RDBMSConnector) RDBMSDatabase connection.
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||
Returns:
|
||||
DBSchemaAssembler
|
||||
"""
|
||||
embedding_factory = embedding_factory
|
||||
chunk_parameters = chunk_parameters or ChunkParameters(
|
||||
chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0
|
||||
)
|
||||
|
||||
return cls(
|
||||
connection=connection,
|
||||
knowledge=knowledge,
|
||||
embedding_model=embedding_model,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embedding_factory=embedding_factory,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None:
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
self._chunks = []
|
||||
self._knowledge = knowledge
|
||||
for table_summary in table_summaries:
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
|
||||
self._knowledge = KnowledgeFactory.from_text(
|
||||
text=table_summary, knowledge_type=KnowledgeType.DOCUMENT
|
||||
)
|
||||
self._chunk_parameters.chunk_size = len(table_summary)
|
||||
self._chunk_manager = ChunkManager(
|
||||
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
|
||||
)
|
||||
self._chunks.extend(self._chunk_manager.split(self._knowledge.load()))
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
"""Return chunk ids."""
|
||||
return self._chunks
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
|
||||
def as_retriever(self, top_k: Optional[int] = 4) -> DBSchemaRetriever:
|
||||
"""
|
||||
Args:
|
||||
top_k:(Optional[int]), default 4
|
||||
Returns:
|
||||
DBSchemaRetriever
|
||||
"""
|
||||
return DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connection=self._connection,
|
||||
is_embeddings=True,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
@@ -1,36 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class DBSchemaAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||
"""The DBSchema Assembler Operator.
|
||||
Args:
|
||||
connection (RDBMSConnector): The connection.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
self._connection = connection
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=self._connection,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, input_value: IN) -> Any:
|
||||
"""assemble knowledge for input value."""
|
||||
if self._vector_store_connector:
|
||||
self._assembler.persist()
|
||||
return self._assembler.get_chunks()
|
@@ -1,44 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||
"""The Embedding Assembler Operator.
|
||||
Args:
|
||||
knowledge (Knowledge): The knowledge.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
|
||||
chunk_strategy="CHUNK_BY_SIZE"
|
||||
),
|
||||
vector_store_connector: VectorStoreConnector = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||
"""
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._vector_store_connector = vector_store_connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, knowledge: IN) -> Any:
|
||||
"""assemble knowledge for input value."""
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
@@ -23,6 +23,7 @@ from dbgpt.configs.model_config import (
|
||||
)
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge import ChunkStrategy, KnowledgeFactory, KnowledgeType
|
||||
@@ -43,7 +44,6 @@ from ..api.schemas import (
|
||||
SpaceServeRequest,
|
||||
SpaceServeResponse,
|
||||
)
|
||||
from ..assembler.embedding import EmbeddingAssembler
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
|
||||
|
@@ -27,8 +27,7 @@ class ChromaVectorConfig(VectorStoreConfig):
|
||||
|
||||
persist_path: str = Field(
|
||||
default=os.getenv("CHROMA_PERSIST_PATH", None),
|
||||
description="The password of vector store, if not set, will use the default "
|
||||
"password.",
|
||||
description="the persist path of vector store.",
|
||||
)
|
||||
collection_metadata: dict = Field(
|
||||
default=None,
|
||||
|
@@ -1,19 +1,21 @@
|
||||
import asyncio
|
||||
from typing import Any, Coroutine, List
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatFactory
|
||||
|
||||
chat_factory = ChatFactory()
|
||||
|
||||
|
||||
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||
"""llm_chat_response_nostream"""
|
||||
from dbgpt.app.scene import BaseChat, ChatFactory
|
||||
|
||||
chat_factory = ChatFactory()
|
||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||
res = await chat.get_llm_response()
|
||||
return res
|
||||
|
||||
|
||||
async def llm_chat_response(chat_scene: str, **chat_param):
|
||||
from dbgpt.app.scene import BaseChat, ChatFactory
|
||||
|
||||
chat_factory = ChatFactory()
|
||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||
return chat.stream_call()
|
||||
|
||||
|
@@ -226,7 +226,7 @@ class EnvArgumentParser:
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
for field in fields(dataclass_type):
|
||||
env_var_value: Any = _genenv_ignoring_key_case_with_prefixes(
|
||||
field.name, env_prefixes
|
||||
|
@@ -34,4 +34,10 @@ API_KEYS - The list of API keys that are allowed to access the API. Each of the
|
||||
API_KEYS=dbgpt
|
||||
```
|
||||
|
||||
## Installation
|
||||
If you use Python, you should install the official DB-GPT Client package from PyPI:
|
||||
|
||||
```bash
|
||||
pip install "dbgpt[client]>=0.5.2"
|
||||
```
|
||||
|
||||
|
@@ -40,7 +40,7 @@ awel-tutorial
|
||||
## Adding DB-GPT Dependency
|
||||
|
||||
```bash
|
||||
poetry add "dbgpt>=0.5.1rc0"
|
||||
poetry add "dbgpt>=0.5.1"
|
||||
```
|
||||
|
||||
## First Hello World
|
||||
|
358
docs/docs/awel/cookbook/first_rag_with_awel.md
Normal file
358
docs/docs/awel/cookbook/first_rag_with_awel.md
Normal file
@@ -0,0 +1,358 @@
|
||||
# RAG With AWEL
|
||||
|
||||
In this example, we will show how to use the AWEL library to create a RAG program.
|
||||
|
||||
Now, let us create a python file `first_rag_with_awel.py`.
|
||||
|
||||
In this example, we will load your knowledge from a URL and store it in a vector store.
|
||||
|
||||
### Install Dependencies
|
||||
|
||||
First, you need to install the `dbgpt` library.
|
||||
|
||||
```bash
|
||||
pip install "dbgpt[rag]>=0.5.2"
|
||||
````
|
||||
|
||||
### Prepare Embedding Model
|
||||
|
||||
To store the knowledge in a vector store, we need an embedding model, DB-GPT supports
|
||||
a lot of embedding models, here are some of them:
|
||||
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
<Tabs
|
||||
defaultValue="openai"
|
||||
values={[
|
||||
{label: 'Open AI(API)', value: 'openai'},
|
||||
{label: 'text2vec(local)', value: 'text2vec'},
|
||||
{label: 'Embedding API Server(cluster)', value: 'remote_embedding'},
|
||||
]}>
|
||||
<TabItem value="openai">
|
||||
```python
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
|
||||
embeddings = DefaultEmbeddingFactory.openai()
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="text2vec">
|
||||
|
||||
```python
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
|
||||
embeddings = DefaultEmbeddingFactory.default("/data/models/text2vec-large-chinese")
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="remote_embedding">
|
||||
|
||||
If you have deployed [DB-GPT cluster](/docs/installation/model_service/cluster) and
|
||||
[API server](/docs/installation/advanced_usage/OpenAI_SDK_call)
|
||||
, you can connect to the API server to get the embeddings.
|
||||
|
||||
```python
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
|
||||
embeddings = DefaultEmbeddingFactory.remote(
|
||||
api_url="http://localhost:8100/api/v1/embeddings",
|
||||
api_key="{your_api_key}",
|
||||
model_name="text2vec"
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Load Knowledge And Store In Vector Store
|
||||
|
||||
Then we can create a DAG which loads the knowledge from a URL and stores it in a vector
|
||||
store.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import shutil
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.knowledge import KnowledgeType
|
||||
from dbgpt.rag.operators import EmbeddingAssemblerOperator, KnowledgeOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
# Delete old vector store directory(/tmp/awel_rag_test_vector_store)
|
||||
shutil.rmtree("/tmp/awel_rag_test_vector_store", ignore_errors=True)
|
||||
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="test_vstore",
|
||||
persist_path="/tmp/awel_rag_test_vector_store",
|
||||
),
|
||||
embedding_fn=embeddings
|
||||
)
|
||||
|
||||
with DAG("load_knowledge_dag") as knowledge_dag:
|
||||
# Load knowledge from URL
|
||||
knowledge_task = KnowledgeOperator(knowledge_type=KnowledgeType.URL.name)
|
||||
assembler_task = EmbeddingAssemblerOperator(
|
||||
vector_store_connector=vector_connector,
|
||||
chunk_parameters=ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
)
|
||||
knowledge_task >> assembler_task
|
||||
|
||||
chunks = asyncio.run(assembler_task.call("https://docs.dbgpt.site/docs/latest/awel/"))
|
||||
print(f"Chunk length: {len(chunks)}")
|
||||
```
|
||||
|
||||
### Retrieve Knowledge From Vector Store
|
||||
|
||||
Then you can retrieve the knowledge from the vector store.
|
||||
|
||||
```python
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.rag.operators import EmbeddingRetrieverOperator
|
||||
|
||||
with DAG("retriever_dag") as retriever_dag:
|
||||
retriever_task = EmbeddingRetrieverOperator(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
content_task = MapOperator(lambda cks: "\n".join(c.content for c in cks))
|
||||
retriever_task >> content_task
|
||||
|
||||
chunks = asyncio.run(content_task.call("What is the AWEL?"))
|
||||
print(chunks)
|
||||
```
|
||||
|
||||
### Prepare LLM
|
||||
|
||||
To build a RAG program, we need a LLM, here are some of the LLMs that DB-GPT supports:
|
||||
|
||||
<Tabs
|
||||
defaultValue="openai"
|
||||
values={[
|
||||
{label: 'Open AI(API)', value: 'openai'},
|
||||
{label: 'YI(API)', value: 'yi_proxy'},
|
||||
{label: 'API Server(cluster)', value: 'model_service'},
|
||||
]}>
|
||||
<TabItem value="openai">
|
||||
|
||||
First, you should install the `openai` library.
|
||||
|
||||
```bash
|
||||
pip install openai
|
||||
```
|
||||
Then set your API key in the environment `OPENAI_API_KEY`.
|
||||
|
||||
```python
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="yi_proxy">
|
||||
|
||||
You should have a YI account and get the API key from the YI official website.
|
||||
|
||||
First, you should install the `openai` library.
|
||||
|
||||
```bash
|
||||
pip install openai
|
||||
```
|
||||
|
||||
Then set your API key in the environment variable `YI_API_KEY`.
|
||||
|
||||
```python
|
||||
from dbgpt.model.proxy import YiLLMClient
|
||||
|
||||
llm_client = YiLLMClient()
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="model_service">
|
||||
|
||||
If you have deployed [DB-GPT cluster](/docs/installation/model_service/cluster) and
|
||||
[API server](/docs/installation/advanced_usage/OpenAI_SDK_call)
|
||||
, you can connect to the API server to get the LLM model.
|
||||
|
||||
The API is compatible with the OpenAI API, so you can use the OpenAILLMClient to
|
||||
connect to the API server.
|
||||
|
||||
First you should install the `openai` library.
|
||||
```bash
|
||||
pip install openai
|
||||
```
|
||||
|
||||
```python
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient(api_base="http://localhost:8100/api/v1/", api_key="{your_api_key}")
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Create RAG Program
|
||||
|
||||
Lastly, we can create a RAG with the retrieved knowledge.
|
||||
|
||||
```python
|
||||
|
||||
from dbgpt.core.awel import InputOperator, JoinOperator, InputSource
|
||||
from dbgpt.core.operators import PromptBuilderOperator, RequestBuilderOperator
|
||||
from dbgpt.model.operators import LLMOperator
|
||||
|
||||
prompt = """Based on the known information below, provide users with professional and concise answers to their questions.
|
||||
If the answer cannot be obtained from the provided content, please say:
|
||||
"The information provided in the knowledge base is not sufficient to answer this question.".
|
||||
It is forbidden to make up information randomly. When answering, it is best to summarize according to points 1.2.3.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
with DAG("llm_rag_dag") as rag_dag:
|
||||
input_task = InputOperator(input_source=InputSource.from_callable())
|
||||
retriever_task = EmbeddingRetrieverOperator(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
content_task = MapOperator(lambda cks: "\n".join(c.content for c in cks))
|
||||
|
||||
merge_task = JoinOperator(lambda context, question: {"context": context, "question": question})
|
||||
|
||||
prompt_task = PromptBuilderOperator(prompt)
|
||||
# The model is gpt-3.5-turbo, you can replace it with other models.
|
||||
req_build_task = RequestBuilderOperator(model="gpt-3.5-turbo")
|
||||
llm_task = LLMOperator(llm_client=llm_client)
|
||||
result_task = MapOperator(lambda r: r.text)
|
||||
|
||||
input_task >> retriever_task >> content_task >> merge_task
|
||||
input_task >> merge_task
|
||||
|
||||
merge_task >> prompt_task >> req_build_task >> llm_task >> result_task
|
||||
|
||||
print(asyncio.run(result_task.call("What is the AWEL?")))
|
||||
```
|
||||
The output will be:
|
||||
|
||||
```bash
|
||||
AWEL stands for Agentic Workflow Expression Language, which is a set of intelligent agent workflow expression language designed for large model application development. It simplifies the process by providing functionality and flexibility through its layered API design architecture, including the operator layer, AgentFrame layer, and DSL layer. Its goal is to allow developers to focus on business logic for LLMs applications without having to deal with intricate model and environment details.
|
||||
```
|
||||
|
||||
Congratulations! You have created a RAG program with AWEL.
|
||||
|
||||
### Full Code
|
||||
|
||||
And let's look the full code of `first_rag_with_awel.py`:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import shutil
|
||||
from dbgpt.core.awel import DAG, MapOperator, InputOperator, JoinOperator, InputSource
|
||||
from dbgpt.core.operators import PromptBuilderOperator, RequestBuilderOperator
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.knowledge import KnowledgeType
|
||||
from dbgpt.rag.operators import EmbeddingAssemblerOperator, KnowledgeOperator, EmbeddingRetrieverOperator
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.model.operators import LLMOperator
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
# Here we use the openai embedding model, if you want to use other models, you can
|
||||
# replace it according to the previous example.
|
||||
embeddings = DefaultEmbeddingFactory.openai()
|
||||
# Here we use the openai LLM model, if you want to use other models, you can replace
|
||||
# it according to the previous example.
|
||||
llm_client = OpenAILLMClient()
|
||||
|
||||
# Delete old vector store directory(/tmp/awel_rag_test_vector_store)
|
||||
shutil.rmtree("/tmp/awel_rag_test_vector_store", ignore_errors=True)
|
||||
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="test_vstore",
|
||||
persist_path="/tmp/awel_rag_test_vector_store",
|
||||
),
|
||||
embedding_fn=embeddings
|
||||
)
|
||||
|
||||
with DAG("load_knowledge_dag") as knowledge_dag:
|
||||
# Load knowledge from URL
|
||||
knowledge_task = KnowledgeOperator(knowledge_type=KnowledgeType.URL.name)
|
||||
assembler_task = EmbeddingAssemblerOperator(
|
||||
vector_store_connector=vector_connector,
|
||||
chunk_parameters=ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
)
|
||||
knowledge_task >> assembler_task
|
||||
|
||||
chunks = asyncio.run(assembler_task.call("https://docs.dbgpt.site/docs/latest/awel/"))
|
||||
print(f"Chunk length: {len(chunks)}\n")
|
||||
|
||||
|
||||
prompt = """Based on the known information below, provide users with professional and concise answers to their questions.
|
||||
If the answer cannot be obtained from the provided content, please say:
|
||||
"The information provided in the knowledge base is not sufficient to answer this question.".
|
||||
It is forbidden to make up information randomly. When answering, it is best to summarize according to points 1.2.3.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
|
||||
with DAG("llm_rag_dag") as rag_dag:
|
||||
input_task = InputOperator(input_source=InputSource.from_callable())
|
||||
retriever_task = EmbeddingRetrieverOperator(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
content_task = MapOperator(lambda cks: "\n".join(c.content for c in cks))
|
||||
|
||||
merge_task = JoinOperator(lambda context, question: {"context": context, "question": question})
|
||||
|
||||
prompt_task = PromptBuilderOperator(prompt)
|
||||
# The model is gpt-3.5-turbo, you can replace it with other models.
|
||||
req_build_task = RequestBuilderOperator(model="gpt-3.5-turbo")
|
||||
llm_task = LLMOperator(llm_client=llm_client)
|
||||
result_task = MapOperator(lambda r: r.text)
|
||||
|
||||
input_task >> retriever_task >> content_task >> merge_task
|
||||
input_task >> merge_task
|
||||
|
||||
merge_task >> prompt_task >> req_build_task >> llm_task >> result_task
|
||||
|
||||
print(asyncio.run(result_task.call("What is the AWEL?")))
|
||||
```
|
||||
|
||||
### Visualize DAGs
|
||||
|
||||
And we can visualize the DAGs with the following code:
|
||||
|
||||
```python
|
||||
knowledge_dag.visualize_dag()
|
||||
rag_dag.visualize_dag()
|
||||
```
|
||||
If you execute the code in Jupyter Notebook, you can see the DAGs in the notebook.
|
||||
|
||||
```python
|
||||
display(knowledge_dag.show())
|
||||
display(rag_dag.show())
|
||||
```
|
||||
|
||||
The graph of the `knowledge_dag` is:
|
||||
|
||||
<p align="left">
|
||||
<img src={'/img/awel/cookbook/first_rag_knowledge_dag.png'} width="1000px"/>
|
||||
</p>
|
||||
|
||||
And the graph of the `rag_dag` is:
|
||||
<p align="left">
|
||||
<img src={'/img/awel/cookbook/first_rag_rag_dag.png'} width="1000px"/>
|
||||
</p>
|
||||
|
@@ -68,6 +68,10 @@ const sidebars = {
|
||||
type: "doc",
|
||||
id: "awel/cookbook/multi_round_chat_withllm"
|
||||
},
|
||||
{
|
||||
type:"doc",
|
||||
id: "awel/cookbook/first_rag_with_awel"
|
||||
}
|
||||
],
|
||||
link: {
|
||||
type: 'generated-index',
|
||||
|
BIN
docs/static/img/awel/cookbook/first_rag_knowledge_dag.png
vendored
Normal file
BIN
docs/static/img/awel/cookbook/first_rag_knowledge_dag.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 86 KiB |
BIN
docs/static/img/awel/cookbook/first_rag_rag_dag.png
vendored
Normal file
BIN
docs/static/img/awel/cookbook/first_rag_rag_dag.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 185 KiB |
@@ -10,7 +10,7 @@ from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators.schema_linking import SchemaLinkingOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
@@ -181,13 +181,13 @@ class SqlGenOperator(MapOperator[Any, Any]):
|
||||
class SqlExecOperator(MapOperator[Any, Any]):
|
||||
"""The Sql Execution Operator."""
|
||||
|
||||
def __init__(self, connection: Optional[RDBMSConnector] = None, **kwargs):
|
||||
def __init__(self, connector: Optional[RDBMSConnector] = None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
|
||||
def map(self, sql: str) -> DataFrame:
|
||||
"""retrieve table schemas.
|
||||
@@ -196,7 +196,7 @@ class SqlExecOperator(MapOperator[Any, Any]):
|
||||
Return:
|
||||
str: sql execution
|
||||
"""
|
||||
dataframe = self._connection.run_to_df(command=sql, fetch="all")
|
||||
dataframe = self._connector.run_to_df(command=sql, fetch="all")
|
||||
print(f"sql data is \n{dataframe}")
|
||||
return dataframe
|
||||
|
||||
@@ -237,12 +237,12 @@ with DAG("simple_nl_schema_sql_chart_example") as dag:
|
||||
llm = OpenAILLMClient()
|
||||
model_name = "gpt-3.5-turbo"
|
||||
retriever_task = SchemaLinkingOperator(
|
||||
connection=_create_temporary_connection(), llm=llm, model_name=model_name
|
||||
connector=_create_temporary_connection(), llm=llm, model_name=model_name
|
||||
)
|
||||
prompt_join_operator = JoinOperator(combine_function=_prompt_join_fn)
|
||||
sql_gen_operator = SqlGenOperator(llm=llm, model_name=model_name)
|
||||
sql_exec_operator = SqlExecOperator(connection=_create_temporary_connection())
|
||||
draw_chart_operator = ChartDrawOperator(connection=_create_temporary_connection())
|
||||
sql_exec_operator = SqlExecOperator(connector=_create_temporary_connection())
|
||||
draw_chart_operator = ChartDrawOperator(connector=_create_temporary_connection())
|
||||
trigger >> request_handle_task >> query_operator >> prompt_join_operator
|
||||
(
|
||||
trigger
|
||||
|
@@ -33,7 +33,7 @@ from typing import Dict
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.operators.rewrite import QueryRewriteOperator
|
||||
from dbgpt.rag.operators import QueryRewriteOperator
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
|
@@ -31,9 +31,8 @@ from typing import Dict
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
from dbgpt.rag.operators.knowledge import KnowledgeOperator
|
||||
from dbgpt.rag.operators.summary import SummaryAssemblerOperator
|
||||
from dbgpt.rag.knowledge import KnowledgeType
|
||||
from dbgpt.rag.operators import KnowledgeOperator, SummaryAssemblerOperator
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
|
@@ -2,8 +2,8 @@ import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler import DBSchemaAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -62,7 +62,7 @@ if __name__ == "__main__":
|
||||
connection = _create_temporary_connection()
|
||||
vector_connector = _create_vector_connector()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=connection,
|
||||
connector=connection,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
|
@@ -2,10 +2,10 @@ import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -27,10 +27,10 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -4,12 +4,12 @@ from typing import Optional
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.evaluation import RetrieverEvaluator
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.operators import EmbeddingRetrieverOperator
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -3,13 +3,13 @@
|
||||
if you not set vector_store_connector, it will return all tables schema in database.
|
||||
```
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection()
|
||||
connector=_create_temporary_connection()
|
||||
)
|
||||
```
|
||||
if you set vector_store_connector, it will recall topk similarity tables schema in database.
|
||||
```
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection()
|
||||
connector=_create_temporary_connection()
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector
|
||||
)
|
||||
@@ -30,11 +30,10 @@ from pydantic import BaseModel, Field
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, InputOperator, JoinOperator, MapOperator
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators import DBSchemaRetrieverOperator
|
||||
from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator
|
||||
from dbgpt.rag.operators import DBSchemaAssemblerOperator, DBSchemaRetrieverOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -107,18 +106,19 @@ with DAG("simple_rag_db_schema_example") as dag:
|
||||
request_handle_task = RequestHandleOperator()
|
||||
query_operator = MapOperator(lambda request: request["query"])
|
||||
vector_store_connector = _create_vector_connector()
|
||||
connector = _create_temporary_connection()
|
||||
assembler_task = DBSchemaAssemblerOperator(
|
||||
connection=_create_temporary_connection(),
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
join_operator = JoinOperator(combine_function=_join_fn)
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection(),
|
||||
connector=_create_temporary_connection(),
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
|
||||
trigger >> request_handle_task >> assembler_task >> join_operator
|
||||
trigger >> assembler_task >> join_operator
|
||||
trigger >> request_handle_task >> query_operator >> join_operator
|
||||
join_operator >> retriever_task >> result_parse_task
|
||||
|
||||
|
@@ -17,12 +17,11 @@ from typing import Dict, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeType
|
||||
from dbgpt.rag.operators import KnowledgeOperator
|
||||
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
|
||||
from dbgpt.rag.operators import EmbeddingAssemblerOperator, KnowledgeOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -22,15 +22,17 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import SummaryAssembler
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = "./docs/docs/awel.md"
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
llm_client = OpenAILLMClient()
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
|
@@ -117,7 +117,7 @@ class SQLResultOperator(JoinOperator[Dict]):
|
||||
with DAG("simple_sdk_llm_sql_example") as dag:
|
||||
db_connection = _create_temporary_connection()
|
||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||
retriever_task = DatasourceRetrieverOperator(connection=db_connection)
|
||||
retriever_task = DatasourceRetrieverOperator(connector=db_connection)
|
||||
# Merge the input data and the table structure information.
|
||||
prompt_input_task = JoinOperator(combine_function=_join_func)
|
||||
prompt_task = PromptBuilderOperator(_sql_prompt())
|
||||
@@ -125,7 +125,7 @@ with DAG("simple_sdk_llm_sql_example") as dag:
|
||||
llm_task = BaseLLMOperator(OpenAILLMClient())
|
||||
out_parse_task = SQLOutputParser()
|
||||
sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
|
||||
db_query_task = DatasourceOperator(connection=db_connection)
|
||||
db_query_task = DatasourceOperator(connector=db_connection)
|
||||
sql_result_task = SQLResultOperator()
|
||||
input_task >> prompt_input_task
|
||||
input_task >> retriever_task >> prompt_input_task
|
||||
|
45
setup.py
45
setup.py
@@ -18,7 +18,7 @@ with open("README.md", mode="r", encoding="utf-8") as fh:
|
||||
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
|
||||
# If you modify the version, please modify the version in the following files:
|
||||
# dbgpt/_version.py
|
||||
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.1")
|
||||
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.2")
|
||||
|
||||
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
||||
LLAMA_CPP_GPU_ACCELERATION = (
|
||||
@@ -370,8 +370,13 @@ def core_requires():
|
||||
# For AWEL type checking
|
||||
"typeguard",
|
||||
]
|
||||
# For DB-GPT python client SDK
|
||||
setup_spec.extras["client"] = setup_spec.extras["core"] + [
|
||||
"httpx",
|
||||
"fastapi==0.98.0",
|
||||
]
|
||||
# Simple command line dependencies
|
||||
setup_spec.extras["cli"] = setup_spec.extras["core"] + [
|
||||
setup_spec.extras["cli"] = setup_spec.extras["client"] + [
|
||||
"prettytable",
|
||||
"click",
|
||||
"psutil==5.9.4",
|
||||
@@ -382,10 +387,7 @@ def core_requires():
|
||||
# we core unit test.
|
||||
# The dependency "framework" is too large for now.
|
||||
setup_spec.extras["simple_framework"] = setup_spec.extras["cli"] + [
|
||||
"pydantic<2,>=1",
|
||||
"httpx",
|
||||
"jinja2",
|
||||
"fastapi==0.98.0",
|
||||
"uvicorn",
|
||||
"shortuuid",
|
||||
# change from fixed version 2.0.22 to variable version, because other
|
||||
@@ -397,11 +399,12 @@ def core_requires():
|
||||
# TODO: pympler has not been updated for a long time and needs to
|
||||
# find a new toolkit.
|
||||
"pympler",
|
||||
"sqlparse==0.4.4",
|
||||
"duckdb==0.8.1",
|
||||
"duckdb-engine",
|
||||
# lightweight python library for scheduling jobs
|
||||
"schedule",
|
||||
# For datasource subpackage
|
||||
"sqlparse==0.4.4",
|
||||
]
|
||||
# TODO: remove fschat from simple_framework
|
||||
if BUILD_FROM_SOURCE:
|
||||
@@ -418,7 +421,6 @@ def core_requires():
|
||||
"pandas==2.0.3",
|
||||
"auto-gpt-plugin-template",
|
||||
"gTTS==2.3.1",
|
||||
"langchain>=0.0.286",
|
||||
"pymysql",
|
||||
"jsonschema",
|
||||
# TODO move transformers to default
|
||||
@@ -439,9 +441,10 @@ def core_requires():
|
||||
|
||||
def knowledge_requires():
|
||||
"""
|
||||
pip install "dbgpt[knowledge]"
|
||||
pip install "dbgpt[rag]"
|
||||
"""
|
||||
setup_spec.extras["knowledge"] = [
|
||||
setup_spec.extras["rag"] = setup_spec.extras["vstore"] + [
|
||||
"langchain>=0.0.286",
|
||||
"spacy==3.5.3",
|
||||
"chromadb==0.4.10",
|
||||
"markdown",
|
||||
@@ -547,8 +550,7 @@ def all_vector_store_requires():
|
||||
pip install "dbgpt[vstore]"
|
||||
"""
|
||||
setup_spec.extras["vstore"] = [
|
||||
"grpcio==1.47.5", # maybe delete it
|
||||
"pymilvus==2.2.1",
|
||||
"pymilvus",
|
||||
"weaviate-client",
|
||||
]
|
||||
|
||||
@@ -559,6 +561,7 @@ def all_datasource_requires():
|
||||
"""
|
||||
|
||||
setup_spec.extras["datasource"] = [
|
||||
# "sqlparse==0.4.4",
|
||||
"pymssql",
|
||||
"pymysql",
|
||||
"pyspark",
|
||||
@@ -586,7 +589,7 @@ def openai_requires():
|
||||
setup_spec.extras["openai"].append("openai")
|
||||
|
||||
setup_spec.extras["openai"] += setup_spec.extras["framework"]
|
||||
setup_spec.extras["openai"] += setup_spec.extras["knowledge"]
|
||||
setup_spec.extras["openai"] += setup_spec.extras["rag"]
|
||||
|
||||
|
||||
def gpt4all_requires():
|
||||
@@ -624,7 +627,8 @@ def default_requires():
|
||||
"chardet",
|
||||
]
|
||||
setup_spec.extras["default"] += setup_spec.extras["framework"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["rag"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["datasource"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["torch"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["quantization"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["cache"]
|
||||
@@ -645,12 +649,12 @@ def init_install_requires():
|
||||
|
||||
core_requires()
|
||||
torch_requires()
|
||||
knowledge_requires()
|
||||
llama_cpp_requires()
|
||||
quantization_requires()
|
||||
|
||||
all_vector_store_requires()
|
||||
all_datasource_requires()
|
||||
knowledge_requires()
|
||||
openai_requires()
|
||||
gpt4all_requires()
|
||||
vllm_requires()
|
||||
@@ -675,12 +679,14 @@ else:
|
||||
"dbgpt._private.*",
|
||||
"dbgpt.cli",
|
||||
"dbgpt.cli.*",
|
||||
"dbgpt.client",
|
||||
"dbgpt.client.*",
|
||||
"dbgpt.configs",
|
||||
"dbgpt.configs.*",
|
||||
"dbgpt.core",
|
||||
"dbgpt.core.*",
|
||||
"dbgpt.util",
|
||||
"dbgpt.util.*",
|
||||
"dbgpt.datasource",
|
||||
"dbgpt.datasource.*",
|
||||
"dbgpt.model",
|
||||
"dbgpt.model.proxy",
|
||||
"dbgpt.model.proxy.*",
|
||||
@@ -688,6 +694,13 @@ else:
|
||||
"dbgpt.model.operators.*",
|
||||
"dbgpt.model.utils",
|
||||
"dbgpt.model.utils.*",
|
||||
"dbgpt.model.adapter",
|
||||
"dbgpt.rag",
|
||||
"dbgpt.rag.*",
|
||||
"dbgpt.storage",
|
||||
"dbgpt.storage.*",
|
||||
"dbgpt.util",
|
||||
"dbgpt.util.*",
|
||||
],
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user