mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 04:23:35 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from dbgpt.component import ComponentType
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
@@ -58,6 +59,9 @@ class BaseChat(ABC):
|
||||
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
|
||||
)
|
||||
self.llm_echo = False
|
||||
self.worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
self.model_cache_enable = chat_param.get("model_cache_enable", False)
|
||||
|
||||
### load prompt template
|
||||
@@ -162,6 +166,10 @@ class BaseChat(ABC):
|
||||
"BaseChat.__call_base.prompt_template.format", metadata=metadata
|
||||
):
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
### prompt context token adapt according to llm max context length
|
||||
current_prompt = await self.prompt_context_token_adapt(
|
||||
prompt=current_prompt
|
||||
)
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
|
||||
llm_messages = self.generate_llm_messages()
|
||||
@@ -169,6 +177,7 @@ class BaseChat(ABC):
|
||||
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
|
||||
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
|
||||
llm_messages = list(map(lambda m: m.dict(), llm_messages))
|
||||
|
||||
payload = {
|
||||
"model": self.llm_model,
|
||||
"prompt": self.generate_llm_text(),
|
||||
@@ -431,6 +440,39 @@ class BaseChat(ABC):
|
||||
return message.content
|
||||
return None
|
||||
|
||||
async def prompt_context_token_adapt(self, prompt) -> str:
|
||||
"""prompt token adapt according to llm max context length"""
|
||||
model_metadata = await self.worker_manager.get_model_metadata(
|
||||
{"model": self.llm_model}
|
||||
)
|
||||
current_token_count = await self.worker_manager.count_token(
|
||||
{"model": self.llm_model, "prompt": prompt}
|
||||
)
|
||||
if current_token_count == -1:
|
||||
logger.warning(
|
||||
"tiktoken not installed, please `pip install tiktoken` first"
|
||||
)
|
||||
template_define_token_count = 0
|
||||
if len(self.prompt_template.template_define) > 0:
|
||||
template_define_token_count = await self.worker_manager.count_token(
|
||||
{
|
||||
"model": self.llm_model,
|
||||
"prompt": self.prompt_template.template_define,
|
||||
}
|
||||
)
|
||||
current_token_count += template_define_token_count
|
||||
if (
|
||||
current_token_count + self.prompt_template.max_new_tokens
|
||||
) > model_metadata.context_length:
|
||||
prompt = prompt[
|
||||
: (
|
||||
model_metadata.context_length
|
||||
- self.prompt_template.max_new_tokens
|
||||
- template_define_token_count
|
||||
)
|
||||
]
|
||||
return prompt
|
||||
|
||||
def generate(self, p) -> str:
|
||||
"""
|
||||
generate context for LLM input
|
||||
|
Reference in New Issue
Block a user