mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 12:00:46 +00:00
fix(model): fix remote reranker api & fix long_term memory (#2648)
Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
@@ -144,14 +144,16 @@ class LongTermRetriever(TimeWeightedEmbeddingRetriever):
|
||||
rescored_docs = []
|
||||
for doc in filtered_docs:
|
||||
if _METADATA_LAST_ACCESSED_AT in doc.metadata:
|
||||
last_accessed_time = doc.metadata[_METADATA_LAST_ACCESSED_AT]
|
||||
last_accessed_time = datetime.fromtimestamp(
|
||||
float(doc.metadata[_METADATA_LAST_ACCESSED_AT])
|
||||
)
|
||||
hours_passed = self._get_hours_passed(current_time, last_accessed_time)
|
||||
time_score = (1.0 - self.decay_rate) ** hours_passed
|
||||
|
||||
# Add importance score if available
|
||||
importance_score = 0
|
||||
if _METADAT_IMPORTANCE in doc.metadata:
|
||||
importance_score = doc.metadata[_METADAT_IMPORTANCE]
|
||||
importance_score = float(doc.metadata[_METADAT_IMPORTANCE])
|
||||
|
||||
# Combine scores
|
||||
combined_score = doc.score + time_score + importance_score
|
||||
@@ -242,7 +244,7 @@ class LongTermMemory(Memory, Generic[T]):
|
||||
memory_idx = len(self.memory_retriever.memory_stream)
|
||||
metadata = self._metadata
|
||||
metadata[_METADAT_IMPORTANCE] = importance
|
||||
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time
|
||||
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time.timestamp()
|
||||
if self.session_id:
|
||||
metadata[_METADATA_SESSION_ID] = self.session_id
|
||||
|
||||
|
@@ -607,6 +607,7 @@ class APIServer(BaseComponent):
|
||||
"input": texts,
|
||||
"model": model,
|
||||
"query": query,
|
||||
"worker_type": WorkerType.RERANKER.value,
|
||||
}
|
||||
scores = await worker_manager.embeddings(params)
|
||||
return scores[0]
|
||||
@@ -780,13 +781,16 @@ async def create_relevance(
|
||||
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
"""Generate relevance scores for a query and a list of documents."""
|
||||
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
||||
await api_server.get_model_instances_or_raise(
|
||||
request.model, worker_type=WorkerType.RERANKER.value
|
||||
)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.model.apiserver.generate_relevance",
|
||||
metadata={
|
||||
"model": request.model,
|
||||
"query": request.query,
|
||||
"worker_type": WorkerType.RERANKER.value,
|
||||
},
|
||||
):
|
||||
scores = await api_server.relevance_generate(
|
||||
|
@@ -39,6 +39,7 @@ class PromptRequest(BaseModel):
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
model: str
|
||||
input: List[str]
|
||||
worker_type: str = WorkerType.TEXT2VEC.value
|
||||
span_id: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
"""For rerank model, query is required"""
|
||||
|
Reference in New Issue
Block a user