add debug mode

This commit is contained in:
Harrison Chase
2023-03-21 07:54:11 -07:00
parent 86085bc1e4
commit 81b87a6c20

View File

@@ -38,6 +38,7 @@ class LLMChain(Chain, BaseModel):
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
output_parser: Optional[BaseOutputParser] = None
debug: bool = True
@root_validator()
def validate_output_parser(cls, values: Dict) -> Dict:
@@ -77,7 +78,10 @@ class LLMChain(Chain, BaseModel):
:meta private:
"""
return [self.output_key]
if not self.debug:
return [self.output_key]
else:
return [self.output_key, "raw", "error"]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
@@ -155,12 +159,28 @@ class LLMChain(Chain, BaseModel):
def _get_final_output(
self, generations: List[Generation], prompt_value: PromptValue
) -> Any:
) -> Dict:
"""Get the final output from a list of generations for a prompt."""
completion = generations[0].text
if self.output_parser is not None:
completion = self.output_parser.parse_with_prompt(completion, prompt_value)
return completion
try:
new_completion = self.output_parser.parse_with_prompt(
completion, prompt_value
)
result = {self.output_key: new_completion}
if self.debug:
result["raw"] = completion
result["errors"] = []
except Exception as e:
if self.debug:
result = {
self.output_key: None,
"raw": completion,
"error": [repr(e)],
}
else:
result = {self.output_key: completion}
return result
def create_outputs(
self, response: LLMResult, prompts: List[PromptValue]
@@ -168,7 +188,7 @@ class LLMChain(Chain, BaseModel):
"""Create outputs from response."""
return [
# Get the text of the top generated string.
{self.output_key: self._get_final_output(generation, prompts[i])}
self._get_final_output(generation, prompts[i])
for i, generation in enumerate(response.generations)
]