fix(model): fix remote reranker api & fix long_term memory (#2648)

Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
Cooper 2025-04-24 11:17:20 +08:00 committed by GitHub
parent e78d5370ac
commit 0b3ac35ede
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 7 deletions

View File

@ -2,13 +2,26 @@ import asyncio
import logging import logging
import os import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing_extensions import Annotated, Doc from typing_extensions import Annotated, Doc
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent from dbgpt.agent import (
AgentContext,
AgentMemory,
HybridMemory,
LLMConfig,
LongTermMemory,
SensoryMemory,
ShortTermMemory,
UserProxyAgent,
)
from dbgpt.agent.expand.actions.react_action import ReActAction, Terminate from dbgpt.agent.expand.actions.react_action import ReActAction, Terminate
from dbgpt.agent.expand.react_agent import ReActAgent from dbgpt.agent.expand.react_agent import ReActAgent
from dbgpt.agent.resource import ToolPack, tool from dbgpt.agent.resource import ToolPack, tool
from dbgpt.rag.embedding import OpenAPIEmbeddings
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
logging.basicConfig( logging.basicConfig(
stream=sys.stdout, stream=sys.stdout,
@ -57,7 +70,23 @@ async def main():
provider=os.getenv("LLM_PROVIDER", "proxy/siliconflow"), provider=os.getenv("LLM_PROVIDER", "proxy/siliconflow"),
name=os.getenv("LLM_MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct"), name=os.getenv("LLM_MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct"),
) )
agent_memory = AgentMemory() short_memory = ShortTermMemory(buffer_size=1)
sensor_memory = SensoryMemory()
embedding_fn = OpenAPIEmbeddings(
api_url="https://api.siliconflow.cn/v1/embeddings",
api_key=os.getenv("SILICONFLOW_API_KEY"),
model_name="BAAI/bge-large-zh-v1.5",
)
vector_store = ChromaStore(
ChromaVectorConfig(persist_path="pilot/data"),
name="react_mem",
embedding_fn=embedding_fn,
)
long_memory = LongTermMemory(ThreadPoolExecutor(), vector_store)
agent_memory = AgentMemory(
memory=HybridMemory(datetime.now(), sensor_memory, short_memory, long_memory)
)
agent_memory.gpts_memory.init(conv_id="test456") agent_memory.gpts_memory.init(conv_id="test456")
# It is important to set the temperature to a low value to get a better result # It is important to set the temperature to a low value to get a better result
@ -81,7 +110,7 @@ async def main():
await user_proxy.initiate_chat( await user_proxy.initiate_chat(
recipient=tool_engineer, recipient=tool_engineer,
reviewer=user_proxy, reviewer=user_proxy,
message="Calculate the product of 10 and 99", message="Calculate the product of 10 and 99, and then add 1 to the result, and finally divide the result by 2.",
) )
await user_proxy.initiate_chat( await user_proxy.initiate_chat(
recipient=tool_engineer, recipient=tool_engineer,

View File

@ -144,14 +144,16 @@ class LongTermRetriever(TimeWeightedEmbeddingRetriever):
rescored_docs = [] rescored_docs = []
for doc in filtered_docs: for doc in filtered_docs:
if _METADATA_LAST_ACCESSED_AT in doc.metadata: if _METADATA_LAST_ACCESSED_AT in doc.metadata:
last_accessed_time = doc.metadata[_METADATA_LAST_ACCESSED_AT] last_accessed_time = datetime.fromtimestamp(
float(doc.metadata[_METADATA_LAST_ACCESSED_AT])
)
hours_passed = self._get_hours_passed(current_time, last_accessed_time) hours_passed = self._get_hours_passed(current_time, last_accessed_time)
time_score = (1.0 - self.decay_rate) ** hours_passed time_score = (1.0 - self.decay_rate) ** hours_passed
# Add importance score if available # Add importance score if available
importance_score = 0 importance_score = 0
if _METADAT_IMPORTANCE in doc.metadata: if _METADAT_IMPORTANCE in doc.metadata:
importance_score = doc.metadata[_METADAT_IMPORTANCE] importance_score = float(doc.metadata[_METADAT_IMPORTANCE])
# Combine scores # Combine scores
combined_score = doc.score + time_score + importance_score combined_score = doc.score + time_score + importance_score
@ -242,7 +244,7 @@ class LongTermMemory(Memory, Generic[T]):
memory_idx = len(self.memory_retriever.memory_stream) memory_idx = len(self.memory_retriever.memory_stream)
metadata = self._metadata metadata = self._metadata
metadata[_METADAT_IMPORTANCE] = importance metadata[_METADAT_IMPORTANCE] = importance
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time.timestamp()
if self.session_id: if self.session_id:
metadata[_METADATA_SESSION_ID] = self.session_id metadata[_METADATA_SESSION_ID] = self.session_id

View File

@ -607,6 +607,7 @@ class APIServer(BaseComponent):
"input": texts, "input": texts,
"model": model, "model": model,
"query": query, "query": query,
"worker_type": WorkerType.RERANKER.value,
} }
scores = await worker_manager.embeddings(params) scores = await worker_manager.embeddings(params)
return scores[0] return scores[0]
@ -780,13 +781,16 @@ async def create_relevance(
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server) request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
): ):
"""Generate relevance scores for a query and a list of documents.""" """Generate relevance scores for a query and a list of documents."""
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec") await api_server.get_model_instances_or_raise(
request.model, worker_type=WorkerType.RERANKER.value
)
with root_tracer.start_span( with root_tracer.start_span(
"dbgpt.model.apiserver.generate_relevance", "dbgpt.model.apiserver.generate_relevance",
metadata={ metadata={
"model": request.model, "model": request.model,
"query": request.query, "query": request.query,
"worker_type": WorkerType.RERANKER.value,
}, },
): ):
scores = await api_server.relevance_generate( scores = await api_server.relevance_generate(

View File

@ -39,6 +39,7 @@ class PromptRequest(BaseModel):
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
model: str model: str
input: List[str] input: List[str]
worker_type: str = WorkerType.TEXT2VEC.value
span_id: Optional[str] = None span_id: Optional[str] = None
query: Optional[str] = None query: Optional[str] = None
"""For rerank model, query is required""" """For rerank model, query is required"""