refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-03 09:45:26 +08:00
committed by GitHub
parent 90775aad50
commit 9ad70a2961
206 changed files with 5766 additions and 2419 deletions

View File

@@ -11,6 +11,7 @@ from dbgpt.component import ComponentType
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core.interface.message import OnceConversation
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.tracer import root_tracer, trace
@@ -58,6 +59,9 @@ class BaseChat(ABC):
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
)
self.llm_echo = False
self.worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self.model_cache_enable = chat_param.get("model_cache_enable", False)
### load prompt template
@@ -162,6 +166,10 @@ class BaseChat(ABC):
"BaseChat.__call_base.prompt_template.format", metadata=metadata
):
current_prompt = self.prompt_template.format(**input_values)
### prompt context token adapt according to llm max context length
current_prompt = await self.prompt_context_token_adapt(
prompt=current_prompt
)
self.current_message.add_system_message(current_prompt)
llm_messages = self.generate_llm_messages()
@@ -169,6 +177,7 @@ class BaseChat(ABC):
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
llm_messages = list(map(lambda m: m.dict(), llm_messages))
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
@@ -431,6 +440,39 @@ class BaseChat(ABC):
return message.content
return None
async def prompt_context_token_adapt(self, prompt) -> str:
"""prompt token adapt according to llm max context length"""
model_metadata = await self.worker_manager.get_model_metadata(
{"model": self.llm_model}
)
current_token_count = await self.worker_manager.count_token(
{"model": self.llm_model, "prompt": prompt}
)
if current_token_count == -1:
logger.warning(
"tiktoken not installed, please `pip install tiktoken` first"
)
template_define_token_count = 0
if len(self.prompt_template.template_define) > 0:
template_define_token_count = await self.worker_manager.count_token(
{
"model": self.llm_model,
"prompt": self.prompt_template.template_define,
}
)
current_token_count += template_define_token_count
if (
current_token_count + self.prompt_template.max_new_tokens
) > model_metadata.context_length:
prompt = prompt[
: (
model_metadata.context_length
- self.prompt_template.max_new_tokens
- template_define_token_count
)
]
return prompt
def generate(self, p) -> str:
"""
generate context for LLM input