mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 18:31:22 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
"""The output parser is used to parse the output of an LLM call.
|
||||
|
||||
TODO: Make this more general and clear.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
@@ -22,11 +27,16 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
"""
|
||||
|
||||
def __init__(self, is_stream_out: bool = True, **kwargs):
|
||||
"""Create a new output parser."""
|
||||
super().__init__(**kwargs)
|
||||
self.is_stream_out = is_stream_out
|
||||
self.data_schema = None
|
||||
|
||||
def update(self, data_schema):
|
||||
"""Update the data schema.
|
||||
|
||||
TODO: Remove this method.
|
||||
"""
|
||||
self.data_schema = data_schema
|
||||
|
||||
def __post_process_code(self, code):
|
||||
@@ -40,9 +50,16 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
return code
|
||||
|
||||
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
|
||||
data = _parse_model_response(chunk)
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
Args:
|
||||
chunk (ResponseTye): The output of an LLM call.
|
||||
skip_echo_len (int): The length of the prompt to skip.
|
||||
"""
|
||||
data = _parse_model_response(chunk)
|
||||
# TODO: Multi mode output handler, rewrite this for multi model, use adapter
|
||||
# mode.
|
||||
|
||||
model_context = data.get("model_context")
|
||||
has_echo = False
|
||||
if model_context and "prompt_echo_len_char" in model_context:
|
||||
@@ -65,6 +82,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
return output
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
"""Parse the output of an LLM call."""
|
||||
resp_obj_ex = _parse_model_response(response)
|
||||
if isinstance(resp_obj_ex, str):
|
||||
resp_obj_ex = json.loads(resp_obj_ex)
|
||||
@@ -89,7 +107,8 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
return ai_response
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
|
||||
f"Model server error!code={resp_obj_ex['error_code']}, error msg is "
|
||||
f"{resp_obj_ex['text']}"
|
||||
)
|
||||
|
||||
def _illegal_json_ends(self, s):
|
||||
@@ -117,7 +136,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
|
||||
temp_json = self._illegal_json_ends(temp_json)
|
||||
return temp_json
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError("Failed to find a valid json in LLM response!" + temp_json)
|
||||
|
||||
def _json_interception(self, s, is_json_array: bool = False):
|
||||
@@ -150,17 +169,17 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
break
|
||||
assert count == 0
|
||||
return s[i : j + 1]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
parse model out text to prompt define response
|
||||
def parse_prompt_response(self, model_out_text) -> Any:
|
||||
"""Parse model out text to prompt define response.
|
||||
|
||||
Args:
|
||||
model_out_text:
|
||||
model_out_text: The output of an LLM call.
|
||||
|
||||
Returns:
|
||||
|
||||
Any: The parsed output of an LLM call.
|
||||
"""
|
||||
cleaned_output = model_out_text.rstrip()
|
||||
if "```json" in cleaned_output:
|
||||
@@ -194,12 +213,15 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
def parse_view_response(
|
||||
self, ai_text, data, parse_prompt_response: Any = None
|
||||
) -> str:
|
||||
"""
|
||||
parse the ai response info to user view
|
||||
"""Parse the AI response info to user view.
|
||||
|
||||
Args:
|
||||
text:
|
||||
ai_text (str): The output of an LLM call.
|
||||
data (dict): The data has been handled by some scene.
|
||||
parse_prompt_response (Any): The prompt response has been parsed.
|
||||
|
||||
Returns:
|
||||
str: The parsed output of an LLM call.
|
||||
|
||||
"""
|
||||
return ai_text
|
||||
@@ -240,10 +262,14 @@ def _parse_model_response(response: ResponseTye):
|
||||
|
||||
|
||||
class SQLOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL output of an LLM call."""
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
"""Parse the output of an LLM call."""
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
return json.loads(clean_str, strict=True)
|
||||
|
Reference in New Issue
Block a user