mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 05:59:59 +00:00
docs: Config documents & i18n supports (#2365)
This commit is contained in:
@@ -1,31 +1,32 @@
|
||||
"""Translate the po file content to Chinese using LLM."""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, List
|
||||
import argparse
|
||||
from typing import Any, Dict, List, NamedTuple
|
||||
|
||||
from dbgpt.core import (
|
||||
SystemPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
ModelOutput,
|
||||
LLMClient,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.core.operators import PromptBuilderOperator, RequestBuilderOperator
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
MapOperator,
|
||||
InputOperator,
|
||||
InputSource,
|
||||
JoinOperator,
|
||||
IteratorTrigger,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
)
|
||||
from dbgpt.model import AdaptiveLLMClient
|
||||
from dbgpt.core.operators import PromptBuilderOperator, RequestBuilderOperator
|
||||
from dbgpt.model import AutoLLMClient
|
||||
from dbgpt.model.operators import LLMOperator
|
||||
from dbgpt.model.proxy.base import TiktokenProxyTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Adapted from https://baoyu.io/blog/prompt-engineering/my-translator-bot
|
||||
PROMPT_ZH = """
|
||||
@@ -66,8 +67,10 @@ msgstr ""
|
||||
- 常见的 AI 相关术语请根据下表进行翻译,保持一致性
|
||||
- 以下是常见的 AI 相关术语词汇对应表:
|
||||
{vocabulary}
|
||||
- 如果已经存在对应的翻译( msgstr 不为空),请你分析原文和翻译,看看是否有更好的翻译方式,如果有请进行修改。
|
||||
- 如果已经存在对应的翻译( msgstr 不为空),请你分析原文和翻译,看看是否有更好的翻译方式,如果有请进行\
|
||||
修改,直接给我最终优化的内容,不要单独再给一份优化前的版本!
|
||||
- 直接给我内容,不要包含在markdown代码块中,具体参考样例。
|
||||
- 不要给额外的解释!
|
||||
|
||||
|
||||
策略:保持原有格式,不要遗漏任何信息,遵守原意的前提下让内容更通俗易懂、符合{language}表达习惯,但要保留原有格式不变。
|
||||
@@ -173,74 +176,86 @@ vocabulary_map = {
|
||||
|
||||
class ModuleInfo(NamedTuple):
|
||||
"""Module information container"""
|
||||
base_module: str # Base module name (e.g., dbgpt)
|
||||
sub_module: str # Sub module name (e.g., core) or file name without .py
|
||||
full_path: str # Full path to the module or file
|
||||
|
||||
base_module: str # Base module name (e.g., dbgpt)
|
||||
sub_module: str # Sub module name (e.g., core) or file name without .py
|
||||
full_path: str # Full path to the module or file
|
||||
|
||||
|
||||
def find_modules(root_path: str = None) -> List[ModuleInfo]:
|
||||
"""
|
||||
Find all DBGpt modules, including:
|
||||
1. First-level submodules (directories with __init__.py)
|
||||
2. Python files directly under base module directory
|
||||
|
||||
|
||||
Args:
|
||||
root_path: Root path containing the packages directory. If None, uses current ROOT_PATH
|
||||
|
||||
|
||||
Returns:
|
||||
List of ModuleInfo containing module details
|
||||
"""
|
||||
if root_path is None:
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
|
||||
root_path = ROOT_PATH
|
||||
|
||||
|
||||
base_path = Path(root_path) / "packages"
|
||||
all_modules = []
|
||||
|
||||
|
||||
# Iterate through all packages
|
||||
for pkg_dir in base_path.iterdir():
|
||||
if not pkg_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
src_dir = pkg_dir / "src"
|
||||
if not src_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
# Find the base module directory
|
||||
try:
|
||||
base_module_dir = next(src_dir.iterdir())
|
||||
if not base_module_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
# Check if it's a Python module
|
||||
if not (base_module_dir / "__init__.py").exists():
|
||||
continue
|
||||
|
||||
|
||||
# Scan first-level submodules (directories)
|
||||
for item in base_module_dir.iterdir():
|
||||
# Handle directories with __init__.py
|
||||
if (item.is_dir() and
|
||||
not item.name.startswith('__') and
|
||||
(item / "__init__.py").exists()):
|
||||
all_modules.append(ModuleInfo(
|
||||
base_module=base_module_dir.name,
|
||||
sub_module=item.name,
|
||||
full_path=str(item.absolute())
|
||||
))
|
||||
if (
|
||||
item.is_dir()
|
||||
and not item.name.startswith("__")
|
||||
and (item / "__init__.py").exists()
|
||||
):
|
||||
all_modules.append(
|
||||
ModuleInfo(
|
||||
base_module=base_module_dir.name,
|
||||
sub_module=item.name,
|
||||
full_path=str(item.absolute()),
|
||||
)
|
||||
)
|
||||
# Handle Python files (excluding __init__.py and private files)
|
||||
elif (item.is_file() and
|
||||
item.suffix == '.py' and
|
||||
not item.name.startswith('__')):
|
||||
all_modules.append(ModuleInfo(
|
||||
base_module=base_module_dir.name,
|
||||
sub_module=item.stem, # filename without .py
|
||||
full_path=str(item.absolute())
|
||||
))
|
||||
|
||||
elif (
|
||||
item.is_file()
|
||||
and item.suffix == ".py"
|
||||
and not item.name.startswith("__")
|
||||
):
|
||||
all_modules.append(
|
||||
ModuleInfo(
|
||||
base_module=base_module_dir.name,
|
||||
sub_module=item.stem, # filename without .py
|
||||
full_path=str(item.absolute()),
|
||||
)
|
||||
)
|
||||
|
||||
except StopIteration:
|
||||
continue
|
||||
|
||||
|
||||
return sorted(all_modules, key=lambda x: (x.base_module, x.sub_module))
|
||||
|
||||
|
||||
class ReadPoFileOperator(MapOperator[str, List[str]]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -254,11 +269,16 @@ class ReadPoFileOperator(MapOperator[str, List[str]]):
|
||||
|
||||
|
||||
class ParsePoFileOperator(MapOperator[List[str], List[str]]):
|
||||
_HEADER_SHARE_DATA_KEY = "header_lines"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, content_lines: List[str]) -> List[str]:
|
||||
block_lines = extract_messages_with_comments(content_lines)
|
||||
block_lines, header_lines = extract_messages_with_comments(content_lines)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self._HEADER_SHARE_DATA_KEY, header_lines
|
||||
)
|
||||
return block_lines
|
||||
|
||||
|
||||
@@ -268,6 +288,7 @@ def extract_messages_with_comments(lines: List[str]):
|
||||
has_start = False
|
||||
has_msgid = False
|
||||
sep = "#: .."
|
||||
header_lines = []
|
||||
for line in lines:
|
||||
if line.startswith(sep):
|
||||
has_start = True
|
||||
@@ -285,33 +306,37 @@ def extract_messages_with_comments(lines: List[str]):
|
||||
elif has_start:
|
||||
current_msg.append(line)
|
||||
else:
|
||||
print("Skip line:", line)
|
||||
logger.debug(f"Skip line: {line}")
|
||||
if not has_start:
|
||||
header_lines.append(line)
|
||||
if current_msg:
|
||||
messages.append("".join(current_msg))
|
||||
|
||||
return messages
|
||||
return messages, header_lines
|
||||
|
||||
|
||||
class BatchOperator(JoinOperator[str]):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "deepseek-chat", # or "gpt-4"
|
||||
max_new_token: int = 4096,
|
||||
**kwargs,
|
||||
):
|
||||
self._tokenizer = TiktokenProxyTokenizer()
|
||||
self._model_name = model_name
|
||||
self._max_new_token = max_new_token
|
||||
super().__init__(combine_function=self.batch_run, **kwargs)
|
||||
|
||||
async def batch_run(self, blocks: List[str], ext_dict: Dict[str, Any]) -> str:
|
||||
max_new_token = ext_dict.get("max_new_token", self._max_new_token)
|
||||
input_token = ext_dict.get("input_token", 512)
|
||||
max_new_token = ext_dict.get("max_new_token", 4096)
|
||||
parallel_num = ext_dict.get("parallel_num", 5)
|
||||
provider = ext_dict.get("provider", "proxy/deepseek")
|
||||
model_name = ext_dict.get("model_name", self._model_name)
|
||||
count_token_model = ext_dict.get("count_token_model", "cl100k_base")
|
||||
support_system_role = ext_dict.get("support_system_role", True)
|
||||
llm_client = AdaptiveLLMClient(provider=provider, name=model_name)
|
||||
batch_blocks = await self.split_blocks(llm_client, blocks, model_name, max_new_token)
|
||||
llm_client = AutoLLMClient(provider=provider, name=model_name)
|
||||
batch_blocks = await self.split_blocks(
|
||||
llm_client, blocks, count_token_model, input_token
|
||||
)
|
||||
new_blocks = []
|
||||
for block in batch_blocks:
|
||||
new_blocks.append({"user_input": "".join(block), **ext_dict})
|
||||
@@ -324,14 +349,16 @@ class BatchOperator(JoinOperator[str]):
|
||||
new_temp = PROMPT_ZH + "\n\n" + "{user_input}"
|
||||
messages = [HumanPromptTemplate.from_template(new_temp)]
|
||||
with DAG("split_blocks_dag"):
|
||||
trigger = IteratorTrigger(data=InputSource.from_iterable(new_blocks))
|
||||
trigger = IteratorTrigger(
|
||||
data=InputSource.from_iterable(new_blocks), max_retries=3
|
||||
)
|
||||
prompt_task = PromptBuilderOperator(
|
||||
ChatPromptTemplate(
|
||||
messages=messages,
|
||||
)
|
||||
)
|
||||
model_pre_handle_task = RequestBuilderOperator(
|
||||
model=model_name, temperature=0.1, max_new_tokens=4096
|
||||
model=model_name, temperature=0.1, max_new_tokens=max_new_token
|
||||
)
|
||||
llm_task = LLMOperator(llm_client)
|
||||
out_parse_task = OutputParser()
|
||||
@@ -345,19 +372,36 @@ class BatchOperator(JoinOperator[str]):
|
||||
)
|
||||
results = await trigger.trigger(parallel_num=parallel_num)
|
||||
outs = []
|
||||
for _, out_data in results:
|
||||
for input_data, out_data in results:
|
||||
user_input: str = input_data["user_input"]
|
||||
if not out_data:
|
||||
raise ValueError("Output data is empty.")
|
||||
|
||||
# Count 'msgstr' in user_input
|
||||
count_msgstr = user_input.count("msgstr")
|
||||
count_out_msgstr = out_data.count("msgstr")
|
||||
if count_msgstr != count_out_msgstr:
|
||||
logger.error(f"Input: {user_input}\n\n" + "==" * 100)
|
||||
logger.error(f"Output: {out_data}")
|
||||
raise ValueError(
|
||||
f"Output msgstr count {count_out_msgstr} is not equal to input {count_msgstr}."
|
||||
)
|
||||
outs.append(out_data)
|
||||
return "\n\n".join(outs)
|
||||
|
||||
async def split_blocks(
|
||||
self, llm_client: AdaptiveLLMClient, blocks: List[str], model_nam: str, max_new_token: int
|
||||
self,
|
||||
llm_client: AutoLLMClient,
|
||||
blocks: List[str],
|
||||
model_name: str,
|
||||
input_token: int,
|
||||
) -> List[List[str]]:
|
||||
batch_blocks = []
|
||||
last_block_end = 0
|
||||
while last_block_end < len(blocks):
|
||||
start = last_block_end
|
||||
split_point = await self.bin_search(
|
||||
llm_client, blocks[start:], model_nam, max_new_token
|
||||
llm_client, blocks[start:], model_name, input_token
|
||||
)
|
||||
new_end = start + split_point + 1
|
||||
batch_blocks.append(blocks[start:new_end])
|
||||
@@ -368,25 +412,31 @@ class BatchOperator(JoinOperator[str]):
|
||||
|
||||
# Check all blocks are within the token limit
|
||||
for block in batch_blocks:
|
||||
block_tokens = await llm_client.count_token(model_nam, "".join(block))
|
||||
if block_tokens > max_new_token:
|
||||
block_tokens = await llm_client.count_token(model_name, "".join(block))
|
||||
if block_tokens > input_token:
|
||||
raise ValueError(
|
||||
f"Block size {block_tokens} exceeds the max token limit "
|
||||
f"{max_new_token}, your bin_search function is wrong."
|
||||
f"{input_token}, your bin_search function is wrong."
|
||||
)
|
||||
return batch_blocks
|
||||
|
||||
async def bin_search(
|
||||
self, llm_client: AdaptiveLLMClient, blocks: List[str], model_nam: str, max_new_token: int
|
||||
self,
|
||||
llm_client: AutoLLMClient,
|
||||
blocks: List[str],
|
||||
model_name: str,
|
||||
input_token: int,
|
||||
) -> int:
|
||||
"""Binary search to find the split point."""
|
||||
l, r = 0, len(blocks) - 1
|
||||
while l < r:
|
||||
mid = l + r + 1 >> 1
|
||||
current_tokens = await llm_client.count_token(
|
||||
model_nam, "".join(blocks[: mid + 1])
|
||||
model_name, "".join(blocks[: mid + 1])
|
||||
)
|
||||
if current_tokens <= max_new_token:
|
||||
if current_tokens < 0:
|
||||
raise ValueError("Count token error.")
|
||||
if current_tokens <= input_token:
|
||||
l = mid
|
||||
else:
|
||||
r = mid - 1
|
||||
@@ -398,6 +448,13 @@ class OutputParser(MapOperator[ModelOutput, str]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, model_output: ModelOutput) -> str:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Model output: {model_output}")
|
||||
if not model_output.success:
|
||||
raise ValueError(
|
||||
f"Model output failed: {model_output.error_code}, {model_output.text}, "
|
||||
f"finish_reason: {model_output.finish_reason}"
|
||||
)
|
||||
content = model_output.text
|
||||
return content.strip()
|
||||
|
||||
@@ -406,21 +463,26 @@ class SaveTranslatedPoFileOperator(JoinOperator[str]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(combine_function=self.save_file, **kwargs)
|
||||
|
||||
async def save_file(self, translated_content: str, file_path: str) -> str:
|
||||
async def save_file(self, translated_content: str, params: str) -> str:
|
||||
header_lines = await self.current_dag_context.get_from_share_data(
|
||||
ParsePoFileOperator._HEADER_SHARE_DATA_KEY
|
||||
)
|
||||
return await self.blocking_func_to_async(
|
||||
self._save_file, translated_content, file_path
|
||||
self._save_file, translated_content, params, header_lines
|
||||
)
|
||||
|
||||
def _save_file(self, translated_content: str, params) -> str:
|
||||
def _save_file(self, translated_content: str, params, header_lines) -> str:
|
||||
file_path = params["file_path"]
|
||||
override = params["override"]
|
||||
output_file = file_path.replace(".po", "_ai_translated.po")
|
||||
with open(output_file, "w") as f:
|
||||
f.write(translated_content)
|
||||
if override:
|
||||
lines = "".join(header_lines)
|
||||
save_content = lines + translated_content
|
||||
# Override the original file
|
||||
with open(file_path, "w") as f:
|
||||
f.write(translated_content)
|
||||
f.write(save_content)
|
||||
return translated_content
|
||||
|
||||
|
||||
@@ -430,7 +492,7 @@ with DAG("translate_po_dag") as dag:
|
||||
read_po_file_task = ReadPoFileOperator()
|
||||
parse_po_file_task = ParsePoFileOperator()
|
||||
# ChatGPT can't work if the max_new_token is too large
|
||||
batch_task = BatchOperator(max_new_token=1024)
|
||||
batch_task = BatchOperator()
|
||||
save_translated_po_file_task = SaveTranslatedPoFileOperator()
|
||||
(
|
||||
input_task
|
||||
@@ -442,7 +504,13 @@ with DAG("translate_po_dag") as dag:
|
||||
input_task >> MapOperator(lambda x: x["ext_dict"]) >> batch_task
|
||||
|
||||
batch_task >> save_translated_po_file_task
|
||||
input_task >> MapOperator(lambda x: {"file_path": x["file_path"], "override": x["override"]}) >> save_translated_po_file_task
|
||||
(
|
||||
input_task
|
||||
>> MapOperator(
|
||||
lambda x: {"file_path": x["file_path"], "override": x["override"]}
|
||||
)
|
||||
>> save_translated_po_file_task
|
||||
)
|
||||
|
||||
|
||||
async def run_translate_po_dag(
|
||||
@@ -450,14 +518,16 @@ async def run_translate_po_dag(
|
||||
language: str,
|
||||
language_desc: str,
|
||||
module_name: str,
|
||||
input_token: int = 512,
|
||||
max_new_token: int = 1024,
|
||||
parallel_num=10,
|
||||
provider: int = "proxy/deepseek",
|
||||
model_name: str = "deepseek-chat",
|
||||
override: bool=False,
|
||||
override: bool = False,
|
||||
support_system_role: bool = True,
|
||||
):
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
|
||||
if "zhipu" in provider:
|
||||
support_system_role = False
|
||||
|
||||
@@ -479,6 +549,7 @@ async def run_translate_po_dag(
|
||||
"example_1_output": example_1_output_1,
|
||||
"example_2_input": example_2_input,
|
||||
"example_2_output": example_2_output,
|
||||
"input_token": input_token,
|
||||
"max_new_token": max_new_token,
|
||||
"parallel_num": parallel_num,
|
||||
"provider": provider,
|
||||
@@ -486,15 +557,18 @@ async def run_translate_po_dag(
|
||||
"support_system_role": support_system_role,
|
||||
}
|
||||
try:
|
||||
result = await task.call({"file_path": full_path, "ext_dict": ext_dict, "override": override})
|
||||
result = await task.call(
|
||||
{"file_path": full_path, "ext_dict": ext_dict, "override": override}
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error in {module_name}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
from dbgpt.util.utils import setup_logging
|
||||
|
||||
all_modules = find_modules(ROOT_PATH)
|
||||
str_all_modules = [f"{m.base_module}.{m.sub_module}" for m in all_modules]
|
||||
@@ -519,21 +593,28 @@ if __name__ == "__main__":
|
||||
default="zh_CN",
|
||||
help="Language to translate, 'all' for all languages, split by ','.",
|
||||
)
|
||||
parser.add_argument("--max_new_token", type=int, default=1024)
|
||||
parser.add_argument("--input_token", type=int, default=512)
|
||||
parser.add_argument("--max_new_token", type=int, default=4096)
|
||||
parser.add_argument("--parallel_num", type=int, default=10)
|
||||
parser.add_argument("--provider", type=str, default="proxy/deepseek")
|
||||
parser.add_argument("--model_name", type=str, default="deepseek-chat")
|
||||
parser.add_argument("--override", action="store_true")
|
||||
parser.add_argument("--log_level", type=str, default="INFO")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(f"args: {args}")
|
||||
log_level = args.log_level
|
||||
setup_logging("dbgpt", default_logger_level=log_level)
|
||||
|
||||
provider = args.provider
|
||||
model_name = args.model_name
|
||||
override = args.override
|
||||
# modules = ["app", "core", "model", "rag", "serve", "storage", "util"]
|
||||
modules = str_all_modules if args.modules == "all" else args.modules.strip().split(",")
|
||||
max_new_token = args.max_new_token
|
||||
modules = (
|
||||
str_all_modules if args.modules == "all" else args.modules.strip().split(",")
|
||||
)
|
||||
_input_token = args.input_token
|
||||
_max_new_token = args.max_new_token
|
||||
parallel_num = args.parallel_num
|
||||
langs = lang_map.keys() if args.lang == "all" else args.lang.strip().split(",")
|
||||
|
||||
@@ -552,10 +633,11 @@ if __name__ == "__main__":
|
||||
lang,
|
||||
lang_desc,
|
||||
module,
|
||||
max_new_token,
|
||||
_input_token,
|
||||
_max_new_token,
|
||||
parallel_num,
|
||||
provider,
|
||||
model_name,
|
||||
override=override
|
||||
override=override,
|
||||
)
|
||||
)
|
||||
|
Reference in New Issue
Block a user