mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(core): Support opentelemetry exporter (#1690)
This commit is contained in:
@@ -9,6 +9,7 @@ from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
@@ -655,6 +656,9 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
timeout: int = Field(
|
||||
default=60, description="The timeout for the request in seconds."
|
||||
)
|
||||
pass_trace_id: bool = Field(
|
||||
default=True, description="Whether to pass the trace ID to the API."
|
||||
)
|
||||
|
||||
session: Optional[requests.Session] = None
|
||||
|
||||
@@ -688,10 +692,15 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
corresponds to a single input text.
|
||||
"""
|
||||
# Call OpenAI Embedding API
|
||||
headers = {}
|
||||
if self.pass_trace_id:
|
||||
# Set the trace ID if available
|
||||
headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id()
|
||||
res = self.session.post( # type: ignore
|
||||
self.api_url,
|
||||
json={"input": texts, "model": self.model_name},
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
)
|
||||
return _handle_request_result(res)
|
||||
|
||||
@@ -717,6 +726,9 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
List[float] corresponds to a single input text.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
if self.pass_trace_id:
|
||||
# Set the trace ID if available
|
||||
headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id()
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
|
@@ -8,6 +8,7 @@ import requests
|
||||
|
||||
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import RerankEmbeddings
|
||||
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
|
||||
|
||||
|
||||
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
@@ -78,6 +79,9 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
timeout: int = Field(
|
||||
default=60, description="The timeout for the request in seconds."
|
||||
)
|
||||
pass_trace_id: bool = Field(
|
||||
default=True, description="Whether to pass the trace ID to the API."
|
||||
)
|
||||
|
||||
session: Optional[requests.Session] = None
|
||||
|
||||
@@ -112,9 +116,13 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
headers = {}
|
||||
if self.pass_trace_id:
|
||||
# Set the trace ID if available
|
||||
headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id()
|
||||
data = {"model": self.model_name, "query": query, "documents": candidates}
|
||||
response = self.session.post( # type: ignore
|
||||
self.api_url, json=data, timeout=self.timeout
|
||||
self.api_url, json=data, timeout=self.timeout, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
@@ -122,6 +130,9 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates asynchronously."""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
if self.pass_trace_id:
|
||||
# Set the trace ID if available
|
||||
headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id()
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""DBSchema retriever."""
|
||||
|
||||
from functools import reduce
|
||||
from typing import List, Optional, cast
|
||||
|
||||
@@ -10,6 +11,8 @@ from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
|
||||
class DBSchemaRetriever(BaseRetriever):
|
||||
@@ -155,7 +158,12 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [self._similarity_search(query, filters) for query in queries]
|
||||
candidates = [
|
||||
self._similarity_search(
|
||||
query, filters, root_tracer.get_current_span_id()
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
result_candidates = await run_async_tasks(
|
||||
tasks=candidates, concurrency_limit=1
|
||||
)
|
||||
@@ -166,7 +174,8 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
)
|
||||
|
||||
table_summaries = await run_async_tasks(
|
||||
tasks=[self._aparse_db_summary()], concurrency_limit=1
|
||||
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
@@ -186,15 +195,33 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
return await self._aretrieve(query, filters)
|
||||
|
||||
async def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
self,
|
||||
query,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
parent_span_id: Optional[str] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._index_store.similar_search(query, self._top_k, filters)
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.db_schema._similarity_search",
|
||||
parent_span_id,
|
||||
metadata={"query": query},
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self._index_store.similar_search, query, self._top_k, filters
|
||||
)
|
||||
|
||||
async def _aparse_db_summary(self) -> List[str]:
|
||||
async def _aparse_db_summary(
|
||||
self, parent_span_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
return _parse_db_summary(self._connector)
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
|
||||
parent_span_id,
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
_parse_db_summary, self._connector
|
||||
)
|
||||
|
@@ -10,6 +10,7 @@ from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
|
||||
@@ -140,7 +141,10 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
queries = [query]
|
||||
if self._query_rewrite:
|
||||
candidates_tasks = [
|
||||
self._similarity_search(query, filters) for query in queries
|
||||
self._similarity_search(
|
||||
query, filters, root_tracer.get_current_span_id()
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
chunks = await self._run_async_tasks(candidates_tasks)
|
||||
context = "\n".join([chunk.content for chunk in chunks])
|
||||
@@ -148,7 +152,10 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
origin_query=query, context=context, nums=1
|
||||
)
|
||||
queries.extend(new_queries)
|
||||
candidates = [self._similarity_search(query, filters) for query in queries]
|
||||
candidates = [
|
||||
self._similarity_search(query, filters, root_tracer.get_current_span_id())
|
||||
for query in queries
|
||||
]
|
||||
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
||||
return new_candidates
|
||||
|
||||
@@ -170,16 +177,19 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
queries = [query]
|
||||
if self._query_rewrite:
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingRetriever.query_rewrite.similarity_search",
|
||||
"dbgpt.rag.retriever.embeddings.query_rewrite.similarity_search",
|
||||
metadata={"query": query, "score_threshold": score_threshold},
|
||||
):
|
||||
candidates_tasks = [
|
||||
self._similarity_search(query, filters) for query in queries
|
||||
self._similarity_search(
|
||||
query, filters, root_tracer.get_current_span_id()
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
chunks = await self._run_async_tasks(candidates_tasks)
|
||||
context = "\n".join([chunk.content for chunk in chunks])
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingRetriever.query_rewrite.rewrite",
|
||||
"dbgpt.rag.retriever.embeddings.query_rewrite.rewrite",
|
||||
metadata={"query": query, "context": context, "nums": 1},
|
||||
):
|
||||
new_queries = await self._query_rewrite.rewrite(
|
||||
@@ -188,11 +198,13 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
queries.extend(new_queries)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingRetriever.similarity_search_with_score",
|
||||
"dbgpt.rag.retriever.embeddings.similarity_search_with_score",
|
||||
metadata={"query": query, "score_threshold": score_threshold},
|
||||
):
|
||||
candidates_with_score = [
|
||||
self._similarity_search_with_score(query, score_threshold, filters)
|
||||
self._similarity_search_with_score(
|
||||
query, score_threshold, filters, root_tracer.get_current_span_id()
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
res_candidates_with_score = await run_async_tasks(
|
||||
@@ -203,7 +215,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingRetriever.rerank",
|
||||
"dbgpt.rag.retriever.embeddings.rerank",
|
||||
metadata={
|
||||
"query": query,
|
||||
"score_threshold": score_threshold,
|
||||
@@ -216,10 +228,22 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
return new_candidates_with_score
|
||||
|
||||
async def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
self,
|
||||
query,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
parent_span_id: Optional[str] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._index_store.similar_search(query, self._top_k, filters)
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.embeddings.similarity_search",
|
||||
parent_span_id,
|
||||
metadata={
|
||||
"query": query,
|
||||
},
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self._index_store.similar_search, query, self._top_k, filters
|
||||
)
|
||||
|
||||
async def _run_async_tasks(self, tasks) -> List[Chunk]:
|
||||
"""Run async tasks."""
|
||||
@@ -228,9 +252,25 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
return cast(List[Chunk], candidates)
|
||||
|
||||
async def _similarity_search_with_score(
|
||||
self, query, score_threshold, filters: Optional[MetadataFilters] = None
|
||||
self,
|
||||
query,
|
||||
score_threshold,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
parent_span_id: Optional[str] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Similar search with score."""
|
||||
return await self._index_store.asimilar_search_with_scores(
|
||||
query, self._top_k, score_threshold, filters
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.embeddings._do_similarity_search_with_score",
|
||||
parent_span_id,
|
||||
metadata={
|
||||
"query": query,
|
||||
"score_threshold": score_threshold,
|
||||
},
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self._index_store.similar_search_with_scores,
|
||||
query,
|
||||
self._top_k,
|
||||
score_threshold,
|
||||
filters,
|
||||
)
|
||||
|
@@ -1,11 +1,11 @@
|
||||
"""Rerank module for RAG retriever."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, RerankEmbeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
RANK_FUNC = Callable[[List[Chunk]], List[Chunk]]
|
||||
@@ -54,8 +54,8 @@ class Ranker(ABC):
|
||||
Return:
|
||||
List[Chunk]
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.rank, candidates_with_scores, query
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self.rank, candidates_with_scores, query
|
||||
)
|
||||
|
||||
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
|
Reference in New Issue
Block a user