fix(model): fix remote reranker api & fix long_term memory (#2648)

Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
Cooper 2025-04-24 11:17:20 +08:00 committed by GitHub
parent e78d5370ac
commit 0b3ac35ede
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 7 deletions

View File

@ -2,13 +2,26 @@ import asyncio
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing_extensions import Annotated, Doc
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent
from dbgpt.agent import (
AgentContext,
AgentMemory,
HybridMemory,
LLMConfig,
LongTermMemory,
SensoryMemory,
ShortTermMemory,
UserProxyAgent,
)
from dbgpt.agent.expand.actions.react_action import ReActAction, Terminate
from dbgpt.agent.expand.react_agent import ReActAgent
from dbgpt.agent.resource import ToolPack, tool
from dbgpt.rag.embedding import OpenAPIEmbeddings
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
logging.basicConfig(
stream=sys.stdout,
@ -57,7 +70,23 @@ async def main():
provider=os.getenv("LLM_PROVIDER", "proxy/siliconflow"),
name=os.getenv("LLM_MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct"),
)
agent_memory = AgentMemory()
short_memory = ShortTermMemory(buffer_size=1)
sensor_memory = SensoryMemory()
embedding_fn = OpenAPIEmbeddings(
api_url="https://api.siliconflow.cn/v1/embeddings",
api_key=os.getenv("SILICONFLOW_API_KEY"),
model_name="BAAI/bge-large-zh-v1.5",
)
vector_store = ChromaStore(
ChromaVectorConfig(persist_path="pilot/data"),
name="react_mem",
embedding_fn=embedding_fn,
)
long_memory = LongTermMemory(ThreadPoolExecutor(), vector_store)
agent_memory = AgentMemory(
memory=HybridMemory(datetime.now(), sensor_memory, short_memory, long_memory)
)
agent_memory.gpts_memory.init(conv_id="test456")
# It is important to set the temperature to a low value to get a better result
@ -81,7 +110,7 @@ async def main():
await user_proxy.initiate_chat(
recipient=tool_engineer,
reviewer=user_proxy,
message="Calculate the product of 10 and 99",
message="Calculate the product of 10 and 99, and then add 1 to the result, and finally divide the result by 2.",
)
await user_proxy.initiate_chat(
recipient=tool_engineer,

View File

@ -144,14 +144,16 @@ class LongTermRetriever(TimeWeightedEmbeddingRetriever):
rescored_docs = []
for doc in filtered_docs:
if _METADATA_LAST_ACCESSED_AT in doc.metadata:
last_accessed_time = doc.metadata[_METADATA_LAST_ACCESSED_AT]
last_accessed_time = datetime.fromtimestamp(
float(doc.metadata[_METADATA_LAST_ACCESSED_AT])
)
hours_passed = self._get_hours_passed(current_time, last_accessed_time)
time_score = (1.0 - self.decay_rate) ** hours_passed
# Add importance score if available
importance_score = 0
if _METADAT_IMPORTANCE in doc.metadata:
importance_score = doc.metadata[_METADAT_IMPORTANCE]
importance_score = float(doc.metadata[_METADAT_IMPORTANCE])
# Combine scores
combined_score = doc.score + time_score + importance_score
@ -242,7 +244,7 @@ class LongTermMemory(Memory, Generic[T]):
memory_idx = len(self.memory_retriever.memory_stream)
metadata = self._metadata
metadata[_METADAT_IMPORTANCE] = importance
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time.timestamp()
if self.session_id:
metadata[_METADATA_SESSION_ID] = self.session_id

View File

@ -607,6 +607,7 @@ class APIServer(BaseComponent):
"input": texts,
"model": model,
"query": query,
"worker_type": WorkerType.RERANKER.value,
}
scores = await worker_manager.embeddings(params)
return scores[0]
@ -780,13 +781,16 @@ async def create_relevance(
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
):
"""Generate relevance scores for a query and a list of documents."""
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
await api_server.get_model_instances_or_raise(
request.model, worker_type=WorkerType.RERANKER.value
)
with root_tracer.start_span(
"dbgpt.model.apiserver.generate_relevance",
metadata={
"model": request.model,
"query": request.query,
"worker_type": WorkerType.RERANKER.value,
},
):
scores = await api_server.relevance_generate(

View File

@ -39,6 +39,7 @@ class PromptRequest(BaseModel):
class EmbeddingsRequest(BaseModel):
model: str
input: List[str]
worker_type: str = WorkerType.TEXT2VEC.value
span_id: Optional[str] = None
query: Optional[str] = None
"""For rerank model, query is required"""