mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 04:23:35 +00:00
fix(model): fix remote reranker api & fix long_term memory (#2648)
Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
@@ -2,13 +2,26 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent
|
||||
from dbgpt.agent import (
|
||||
AgentContext,
|
||||
AgentMemory,
|
||||
HybridMemory,
|
||||
LLMConfig,
|
||||
LongTermMemory,
|
||||
SensoryMemory,
|
||||
ShortTermMemory,
|
||||
UserProxyAgent,
|
||||
)
|
||||
from dbgpt.agent.expand.actions.react_action import ReActAction, Terminate
|
||||
from dbgpt.agent.expand.react_agent import ReActAgent
|
||||
from dbgpt.agent.resource import ToolPack, tool
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
|
||||
logging.basicConfig(
|
||||
stream=sys.stdout,
|
||||
@@ -57,7 +70,23 @@ async def main():
|
||||
provider=os.getenv("LLM_PROVIDER", "proxy/siliconflow"),
|
||||
name=os.getenv("LLM_MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct"),
|
||||
)
|
||||
agent_memory = AgentMemory()
|
||||
short_memory = ShortTermMemory(buffer_size=1)
|
||||
sensor_memory = SensoryMemory()
|
||||
embedding_fn = OpenAPIEmbeddings(
|
||||
api_url="https://api.siliconflow.cn/v1/embeddings",
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
model_name="BAAI/bge-large-zh-v1.5",
|
||||
)
|
||||
vector_store = ChromaStore(
|
||||
ChromaVectorConfig(persist_path="pilot/data"),
|
||||
name="react_mem",
|
||||
embedding_fn=embedding_fn,
|
||||
)
|
||||
long_memory = LongTermMemory(ThreadPoolExecutor(), vector_store)
|
||||
|
||||
agent_memory = AgentMemory(
|
||||
memory=HybridMemory(datetime.now(), sensor_memory, short_memory, long_memory)
|
||||
)
|
||||
agent_memory.gpts_memory.init(conv_id="test456")
|
||||
|
||||
# It is important to set the temperature to a low value to get a better result
|
||||
@@ -81,7 +110,7 @@ async def main():
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=tool_engineer,
|
||||
reviewer=user_proxy,
|
||||
message="Calculate the product of 10 and 99",
|
||||
message="Calculate the product of 10 and 99, and then add 1 to the result, and finally divide the result by 2.",
|
||||
)
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=tool_engineer,
|
||||
|
Reference in New Issue
Block a user