fix(model): fix remote reranker api & fix long_term memory (#2648)

Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
Cooper
2025-04-24 11:17:20 +08:00
committed by GitHub
parent e78d5370ac
commit 0b3ac35ede
4 changed files with 43 additions and 7 deletions

View File

@@ -144,14 +144,16 @@ class LongTermRetriever(TimeWeightedEmbeddingRetriever):
rescored_docs = []
for doc in filtered_docs:
if _METADATA_LAST_ACCESSED_AT in doc.metadata:
last_accessed_time = doc.metadata[_METADATA_LAST_ACCESSED_AT]
last_accessed_time = datetime.fromtimestamp(
float(doc.metadata[_METADATA_LAST_ACCESSED_AT])
)
hours_passed = self._get_hours_passed(current_time, last_accessed_time)
time_score = (1.0 - self.decay_rate) ** hours_passed
# Add importance score if available
importance_score = 0
if _METADAT_IMPORTANCE in doc.metadata:
importance_score = doc.metadata[_METADAT_IMPORTANCE]
importance_score = float(doc.metadata[_METADAT_IMPORTANCE])
# Combine scores
combined_score = doc.score + time_score + importance_score
@@ -242,7 +244,7 @@ class LongTermMemory(Memory, Generic[T]):
memory_idx = len(self.memory_retriever.memory_stream)
metadata = self._metadata
metadata[_METADAT_IMPORTANCE] = importance
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time.timestamp()
if self.session_id:
metadata[_METADATA_SESSION_ID] = self.session_id

View File

@@ -607,6 +607,7 @@ class APIServer(BaseComponent):
"input": texts,
"model": model,
"query": query,
"worker_type": WorkerType.RERANKER.value,
}
scores = await worker_manager.embeddings(params)
return scores[0]
@@ -780,13 +781,16 @@ async def create_relevance(
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
):
"""Generate relevance scores for a query and a list of documents."""
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
await api_server.get_model_instances_or_raise(
request.model, worker_type=WorkerType.RERANKER.value
)
with root_tracer.start_span(
"dbgpt.model.apiserver.generate_relevance",
metadata={
"model": request.model,
"query": request.query,
"worker_type": WorkerType.RERANKER.value,
},
):
scores = await api_server.relevance_generate(

View File

@@ -39,6 +39,7 @@ class PromptRequest(BaseModel):
class EmbeddingsRequest(BaseModel):
model: str
input: List[str]
worker_type: str = WorkerType.TEXT2VEC.value
span_id: Optional[str] = None
query: Optional[str] = None
"""For rerank model, query is required"""