mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
regex output parser (#435)
This commit is contained in:
parent
c994ce6b7f
commit
9ec01dfc16
@ -1,24 +1,6 @@
|
|||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import re
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.prompts.base import BaseOutputParser
|
from langchain.prompts.base import RegexParser
|
||||||
|
|
||||||
|
|
||||||
class QAGenerationOutputParser(BaseOutputParser):
|
|
||||||
"""Parse output in question/answer pair."""
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Dict[str, str]:
|
|
||||||
regex = r"QUESTION: (.*?)\nANSWER: (.*)"
|
|
||||||
match = re.search(regex, text)
|
|
||||||
if match:
|
|
||||||
question = match.group(1)
|
|
||||||
answer = match.group(2)
|
|
||||||
return {"query": question, "answer": answer}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Could not parse output: {text}")
|
|
||||||
|
|
||||||
|
|
||||||
template = """You are a teacher coming up with questions to ask on a quiz.
|
template = """You are a teacher coming up with questions to ask on a quiz.
|
||||||
Given the following document, please generate a question and answer based on that document.
|
Given the following document, please generate a question and answer based on that document.
|
||||||
@ -35,6 +17,9 @@ These questions should be detailed and be based explicitly on information in the
|
|||||||
<Begin Document>
|
<Begin Document>
|
||||||
{doc}
|
{doc}
|
||||||
<End Document>"""
|
<End Document>"""
|
||||||
PROMPT = PromptTemplate(
|
output_parser = RegexParser(
|
||||||
input_variables=["doc"], template=template, output_parser=QAGenerationOutputParser()
|
regex=r"QUESTION: (.*?)\nANSWER: (.*)", output_keys=["question", "answer"]
|
||||||
|
)
|
||||||
|
PROMPT = PromptTemplate(
|
||||||
|
input_variables=["doc"], template=template, output_parser=output_parser
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""BasePrompt schema definition."""
|
"""BasePrompt schema definition."""
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
@ -55,7 +56,7 @@ class BaseOutputParser(ABC):
|
|||||||
"""Parse the output of an LLM call."""
|
"""Parse the output of an LLM call."""
|
||||||
|
|
||||||
|
|
||||||
class ListOutputParser(ABC):
|
class ListOutputParser(BaseOutputParser):
|
||||||
"""Class to parse the output of an LLM call to a list."""
|
"""Class to parse the output of an LLM call to a list."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -63,6 +64,21 @@ class ListOutputParser(ABC):
|
|||||||
"""Parse the output of an LLM call."""
|
"""Parse the output of an LLM call."""
|
||||||
|
|
||||||
|
|
||||||
|
class RegexParser(BaseOutputParser, BaseModel):
|
||||||
|
"""Class to parse the output into a dictionary."""
|
||||||
|
|
||||||
|
regex: str
|
||||||
|
output_keys: List[str]
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Dict[str, str]:
|
||||||
|
"""Parse the output of an LLM call."""
|
||||||
|
match = re.search(self.regex, text)
|
||||||
|
if match:
|
||||||
|
return {key: match.group(i) for i, key in enumerate(self.output_keys)}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not parse output: {text}")
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(BaseModel, ABC):
|
class BasePromptTemplate(BaseModel, ABC):
|
||||||
"""Base prompt should expose the format method, returning a prompt."""
|
"""Base prompt should expose the format method, returning a prompt."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user