mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-27 20:39:41 +00:00
fix(model): fix remote reranker api & fix long_term memory (#2648)
Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
parent
e78d5370ac
commit
0b3ac35ede
@ -2,13 +2,26 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from typing_extensions import Annotated, Doc
|
from typing_extensions import Annotated, Doc
|
||||||
|
|
||||||
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent
|
from dbgpt.agent import (
|
||||||
|
AgentContext,
|
||||||
|
AgentMemory,
|
||||||
|
HybridMemory,
|
||||||
|
LLMConfig,
|
||||||
|
LongTermMemory,
|
||||||
|
SensoryMemory,
|
||||||
|
ShortTermMemory,
|
||||||
|
UserProxyAgent,
|
||||||
|
)
|
||||||
from dbgpt.agent.expand.actions.react_action import ReActAction, Terminate
|
from dbgpt.agent.expand.actions.react_action import ReActAction, Terminate
|
||||||
from dbgpt.agent.expand.react_agent import ReActAgent
|
from dbgpt.agent.expand.react_agent import ReActAgent
|
||||||
from dbgpt.agent.resource import ToolPack, tool
|
from dbgpt.agent.resource import ToolPack, tool
|
||||||
|
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||||
|
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
stream=sys.stdout,
|
stream=sys.stdout,
|
||||||
@ -57,7 +70,23 @@ async def main():
|
|||||||
provider=os.getenv("LLM_PROVIDER", "proxy/siliconflow"),
|
provider=os.getenv("LLM_PROVIDER", "proxy/siliconflow"),
|
||||||
name=os.getenv("LLM_MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct"),
|
name=os.getenv("LLM_MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct"),
|
||||||
)
|
)
|
||||||
agent_memory = AgentMemory()
|
short_memory = ShortTermMemory(buffer_size=1)
|
||||||
|
sensor_memory = SensoryMemory()
|
||||||
|
embedding_fn = OpenAPIEmbeddings(
|
||||||
|
api_url="https://api.siliconflow.cn/v1/embeddings",
|
||||||
|
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||||
|
model_name="BAAI/bge-large-zh-v1.5",
|
||||||
|
)
|
||||||
|
vector_store = ChromaStore(
|
||||||
|
ChromaVectorConfig(persist_path="pilot/data"),
|
||||||
|
name="react_mem",
|
||||||
|
embedding_fn=embedding_fn,
|
||||||
|
)
|
||||||
|
long_memory = LongTermMemory(ThreadPoolExecutor(), vector_store)
|
||||||
|
|
||||||
|
agent_memory = AgentMemory(
|
||||||
|
memory=HybridMemory(datetime.now(), sensor_memory, short_memory, long_memory)
|
||||||
|
)
|
||||||
agent_memory.gpts_memory.init(conv_id="test456")
|
agent_memory.gpts_memory.init(conv_id="test456")
|
||||||
|
|
||||||
# It is important to set the temperature to a low value to get a better result
|
# It is important to set the temperature to a low value to get a better result
|
||||||
@ -81,7 +110,7 @@ async def main():
|
|||||||
await user_proxy.initiate_chat(
|
await user_proxy.initiate_chat(
|
||||||
recipient=tool_engineer,
|
recipient=tool_engineer,
|
||||||
reviewer=user_proxy,
|
reviewer=user_proxy,
|
||||||
message="Calculate the product of 10 and 99",
|
message="Calculate the product of 10 and 99, and then add 1 to the result, and finally divide the result by 2.",
|
||||||
)
|
)
|
||||||
await user_proxy.initiate_chat(
|
await user_proxy.initiate_chat(
|
||||||
recipient=tool_engineer,
|
recipient=tool_engineer,
|
||||||
|
@ -144,14 +144,16 @@ class LongTermRetriever(TimeWeightedEmbeddingRetriever):
|
|||||||
rescored_docs = []
|
rescored_docs = []
|
||||||
for doc in filtered_docs:
|
for doc in filtered_docs:
|
||||||
if _METADATA_LAST_ACCESSED_AT in doc.metadata:
|
if _METADATA_LAST_ACCESSED_AT in doc.metadata:
|
||||||
last_accessed_time = doc.metadata[_METADATA_LAST_ACCESSED_AT]
|
last_accessed_time = datetime.fromtimestamp(
|
||||||
|
float(doc.metadata[_METADATA_LAST_ACCESSED_AT])
|
||||||
|
)
|
||||||
hours_passed = self._get_hours_passed(current_time, last_accessed_time)
|
hours_passed = self._get_hours_passed(current_time, last_accessed_time)
|
||||||
time_score = (1.0 - self.decay_rate) ** hours_passed
|
time_score = (1.0 - self.decay_rate) ** hours_passed
|
||||||
|
|
||||||
# Add importance score if available
|
# Add importance score if available
|
||||||
importance_score = 0
|
importance_score = 0
|
||||||
if _METADAT_IMPORTANCE in doc.metadata:
|
if _METADAT_IMPORTANCE in doc.metadata:
|
||||||
importance_score = doc.metadata[_METADAT_IMPORTANCE]
|
importance_score = float(doc.metadata[_METADAT_IMPORTANCE])
|
||||||
|
|
||||||
# Combine scores
|
# Combine scores
|
||||||
combined_score = doc.score + time_score + importance_score
|
combined_score = doc.score + time_score + importance_score
|
||||||
@ -242,7 +244,7 @@ class LongTermMemory(Memory, Generic[T]):
|
|||||||
memory_idx = len(self.memory_retriever.memory_stream)
|
memory_idx = len(self.memory_retriever.memory_stream)
|
||||||
metadata = self._metadata
|
metadata = self._metadata
|
||||||
metadata[_METADAT_IMPORTANCE] = importance
|
metadata[_METADAT_IMPORTANCE] = importance
|
||||||
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time
|
metadata[_METADATA_LAST_ACCESSED_AT] = last_accessed_time.timestamp()
|
||||||
if self.session_id:
|
if self.session_id:
|
||||||
metadata[_METADATA_SESSION_ID] = self.session_id
|
metadata[_METADATA_SESSION_ID] = self.session_id
|
||||||
|
|
||||||
|
@ -607,6 +607,7 @@ class APIServer(BaseComponent):
|
|||||||
"input": texts,
|
"input": texts,
|
||||||
"model": model,
|
"model": model,
|
||||||
"query": query,
|
"query": query,
|
||||||
|
"worker_type": WorkerType.RERANKER.value,
|
||||||
}
|
}
|
||||||
scores = await worker_manager.embeddings(params)
|
scores = await worker_manager.embeddings(params)
|
||||||
return scores[0]
|
return scores[0]
|
||||||
@ -780,13 +781,16 @@ async def create_relevance(
|
|||||||
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
|
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
|
||||||
):
|
):
|
||||||
"""Generate relevance scores for a query and a list of documents."""
|
"""Generate relevance scores for a query and a list of documents."""
|
||||||
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
await api_server.get_model_instances_or_raise(
|
||||||
|
request.model, worker_type=WorkerType.RERANKER.value
|
||||||
|
)
|
||||||
|
|
||||||
with root_tracer.start_span(
|
with root_tracer.start_span(
|
||||||
"dbgpt.model.apiserver.generate_relevance",
|
"dbgpt.model.apiserver.generate_relevance",
|
||||||
metadata={
|
metadata={
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
"query": request.query,
|
"query": request.query,
|
||||||
|
"worker_type": WorkerType.RERANKER.value,
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
scores = await api_server.relevance_generate(
|
scores = await api_server.relevance_generate(
|
||||||
|
@ -39,6 +39,7 @@ class PromptRequest(BaseModel):
|
|||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
input: List[str]
|
input: List[str]
|
||||||
|
worker_type: str = WorkerType.TEXT2VEC.value
|
||||||
span_id: Optional[str] = None
|
span_id: Optional[str] = None
|
||||||
query: Optional[str] = None
|
query: Optional[str] = None
|
||||||
"""For rerank model, query is required"""
|
"""For rerank model, query is required"""
|
||||||
|
Loading…
Reference in New Issue
Block a user