feat(core): Support opentelemetry exporter (#1690)

This commit is contained in:
Fangyin Cheng
2024-07-05 15:20:21 +08:00
committed by GitHub
parent 84fc1fc7fe
commit bf978d2bf9
39 changed files with 1176 additions and 218 deletions

View File

@@ -9,6 +9,7 @@ from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
@@ -655,6 +656,9 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
pass_trace_id: bool = Field(
default=True, description="Whether to pass the trace ID to the API."
)
session: Optional[requests.Session] = None
@@ -688,10 +692,15 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
corresponds to a single input text.
"""
# Call OpenAI Embedding API
headers = {}
if self.pass_trace_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id()
res = self.session.post( # type: ignore
self.api_url,
json={"input": texts, "model": self.model_name},
timeout=self.timeout,
headers=headers,
)
return _handle_request_result(res)
@@ -717,6 +726,9 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
List[float] corresponds to a single input text.
"""
headers = {"Authorization": f"Bearer {self.api_key}"}
if self.pass_trace_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id()
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:

View File

@@ -8,6 +8,7 @@ import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.core import RerankEmbeddings
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
@@ -78,6 +79,9 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
pass_trace_id: bool = Field(
default=True, description="Whether to pass the trace ID to the API."
)
session: Optional[requests.Session] = None
@@ -112,9 +116,13 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
"""
if not candidates:
return []
headers = {}
if self.pass_trace_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id()
data = {"model": self.model_name, "query": query, "documents": candidates}
response = self.session.post( # type: ignore
self.api_url, json=data, timeout=self.timeout
self.api_url, json=data, timeout=self.timeout, headers=headers
)
response.raise_for_status()
return response.json()["data"]
@@ -122,6 +130,9 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates asynchronously."""
headers = {"Authorization": f"Bearer {self.api_key}"}
if self.pass_trace_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id()
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:

View File

@@ -1,4 +1,5 @@
"""DBSchema retriever."""
from functools import reduce
from typing import List, Optional, cast
@@ -10,6 +11,8 @@ from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.tracer import root_tracer
class DBSchemaRetriever(BaseRetriever):
@@ -155,7 +158,12 @@ class DBSchemaRetriever(BaseRetriever):
"""
if self._need_embeddings:
queries = [query]
candidates = [self._similarity_search(query, filters) for query in queries]
candidates = [
self._similarity_search(
query, filters, root_tracer.get_current_span_id()
)
for query in queries
]
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
@@ -166,7 +174,8 @@ class DBSchemaRetriever(BaseRetriever):
)
table_summaries = await run_async_tasks(
tasks=[self._aparse_db_summary()], concurrency_limit=1
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
concurrency_limit=1,
)
return [Chunk(content=table_summary) for table_summary in table_summaries]
@@ -186,15 +195,33 @@ class DBSchemaRetriever(BaseRetriever):
return await self._aretrieve(query, filters)
async def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
self,
query,
filters: Optional[MetadataFilters] = None,
parent_span_id: Optional[str] = None,
) -> List[Chunk]:
"""Similar search."""
return self._index_store.similar_search(query, self._top_k, filters)
with root_tracer.start_span(
"dbgpt.rag.retriever.db_schema._similarity_search",
parent_span_id,
metadata={"query": query},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search, query, self._top_k, filters
)
async def _aparse_db_summary(self) -> List[str]:
async def _aparse_db_summary(
self, parent_span_id: Optional[str] = None
) -> List[str]:
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
return _parse_db_summary(self._connector)
with root_tracer.start_span(
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
parent_span_id,
):
return await blocking_func_to_async_no_executor(
_parse_db_summary, self._connector
)

View File

@@ -10,6 +10,7 @@ from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.tracer import root_tracer
@@ -140,7 +141,10 @@ class EmbeddingRetriever(BaseRetriever):
queries = [query]
if self._query_rewrite:
candidates_tasks = [
self._similarity_search(query, filters) for query in queries
self._similarity_search(
query, filters, root_tracer.get_current_span_id()
)
for query in queries
]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
@@ -148,7 +152,10 @@ class EmbeddingRetriever(BaseRetriever):
origin_query=query, context=context, nums=1
)
queries.extend(new_queries)
candidates = [self._similarity_search(query, filters) for query in queries]
candidates = [
self._similarity_search(query, filters, root_tracer.get_current_span_id())
for query in queries
]
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
return new_candidates
@@ -170,16 +177,19 @@ class EmbeddingRetriever(BaseRetriever):
queries = [query]
if self._query_rewrite:
with root_tracer.start_span(
"EmbeddingRetriever.query_rewrite.similarity_search",
"dbgpt.rag.retriever.embeddings.query_rewrite.similarity_search",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_tasks = [
self._similarity_search(query, filters) for query in queries
self._similarity_search(
query, filters, root_tracer.get_current_span_id()
)
for query in queries
]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
with root_tracer.start_span(
"EmbeddingRetriever.query_rewrite.rewrite",
"dbgpt.rag.retriever.embeddings.query_rewrite.rewrite",
metadata={"query": query, "context": context, "nums": 1},
):
new_queries = await self._query_rewrite.rewrite(
@@ -188,11 +198,13 @@ class EmbeddingRetriever(BaseRetriever):
queries.extend(new_queries)
with root_tracer.start_span(
"EmbeddingRetriever.similarity_search_with_score",
"dbgpt.rag.retriever.embeddings.similarity_search_with_score",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_with_score = [
self._similarity_search_with_score(query, score_threshold, filters)
self._similarity_search_with_score(
query, score_threshold, filters, root_tracer.get_current_span_id()
)
for query in queries
]
res_candidates_with_score = await run_async_tasks(
@@ -203,7 +215,7 @@ class EmbeddingRetriever(BaseRetriever):
)
with root_tracer.start_span(
"EmbeddingRetriever.rerank",
"dbgpt.rag.retriever.embeddings.rerank",
metadata={
"query": query,
"score_threshold": score_threshold,
@@ -216,10 +228,22 @@ class EmbeddingRetriever(BaseRetriever):
return new_candidates_with_score
async def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
self,
query,
filters: Optional[MetadataFilters] = None,
parent_span_id: Optional[str] = None,
) -> List[Chunk]:
"""Similar search."""
return self._index_store.similar_search(query, self._top_k, filters)
with root_tracer.start_span(
"dbgpt.rag.retriever.embeddings.similarity_search",
parent_span_id,
metadata={
"query": query,
},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search, query, self._top_k, filters
)
async def _run_async_tasks(self, tasks) -> List[Chunk]:
"""Run async tasks."""
@@ -228,9 +252,25 @@ class EmbeddingRetriever(BaseRetriever):
return cast(List[Chunk], candidates)
async def _similarity_search_with_score(
self, query, score_threshold, filters: Optional[MetadataFilters] = None
self,
query,
score_threshold,
filters: Optional[MetadataFilters] = None,
parent_span_id: Optional[str] = None,
) -> List[Chunk]:
"""Similar search with score."""
return await self._index_store.asimilar_search_with_scores(
query, self._top_k, score_threshold, filters
)
with root_tracer.start_span(
"dbgpt.rag.retriever.embeddings._do_similarity_search_with_score",
parent_span_id,
metadata={
"query": query,
"score_threshold": score_threshold,
},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search_with_scores,
query,
self._top_k,
score_threshold,
filters,
)

View File

@@ -1,11 +1,11 @@
"""Rerank module for RAG retriever."""
import asyncio
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from dbgpt.core import Chunk, RerankEmbeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.i18n_utils import _
RANK_FUNC = Callable[[List[Chunk]], List[Chunk]]
@@ -54,8 +54,8 @@ class Ranker(ABC):
Return:
List[Chunk]
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.rank, candidates_with_scores, query
return await blocking_func_to_async_no_executor(
self.rank, candidates_with_scores, query
)
def _filter(self, candidates_with_scores: List) -> List[Chunk]: