Compare commits

...

4 Commits

Author SHA1 Message Date
Dev 2049
ea166b091e improve 2023-05-18 12:39:51 -07:00
Dev 2049
f8389cf4fd Merge branch 'master' into dev2049/retrieval_eval 2023-05-18 12:16:46 -07:00
Dev 2049
2798832e65 undo 2023-05-17 15:10:13 -07:00
Dev 2049
7fac949d12 wip 2023-05-17 15:08:47 -07:00
7 changed files with 161 additions and 7 deletions

View File

@@ -94,6 +94,25 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
# Call predict on the LLM.
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
def combine_docs_and_parse(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[Any, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict_and_parse(callbacks=callbacks, **inputs), {}
async def acombine_docs_and_parse(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return (
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
{},
)
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"

View File

@@ -0,0 +1,50 @@
from typing import List, Optional
from langchain import LLMChain, OpenAI, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.evaluation.retrieval.prompts import (
GRADE_DOCS_PROMPT,
GRADE_SINGLE_DOC_PROMPT,
)
from langchain.schema import Document
def grade_documents(
documents: List[Document],
question: str,
llm: Optional[BaseLanguageModel] = None,
) -> List[int]:
_llm = llm or OpenAI(temperature=0)
if len(documents) == 1:
return [grade_single_document(documents[0], question, llm=_llm)]
llm_chain = LLMChain(llm=_llm, prompt=GRADE_DOCS_PROMPT)
_documents = [
Document(page_content=d.page_content, metadata={"i": i})
for i, d in enumerate(documents)
]
document_prompt = PromptTemplate(
template="DOCUMENT {i}:\n{page_content}", input_variables=["i", "page_content"]
)
eval_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name="documents",
)
score, _ = eval_chain.combine_docs_and_parse(_documents, question=question)
return score
def grade_single_document(
document: Document,
question: str,
llm: Optional[BaseLanguageModel] = None,
) -> int:
_llm = llm or OpenAI(temperature=0)
llm_chain = LLMChain(llm=_llm, prompt=GRADE_SINGLE_DOC_PROMPT)
eval_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="document",
)
score, _ = eval_chain.combine_docs_and_parse([document], question=question)
return score[0]

View File

@@ -0,0 +1,63 @@
from typing import List
from langchain import PromptTemplate
from langchain.output_parsers.regex import ListRegexParser
class GradeOutputParser(ListRegexParser[int]):
regex = r"\D*(\d+)"
_cast = int
TEMPLATE = """\
>> INSTRUCTIONS:
Given a question and a list of documents, score how relevant each document is to the question. \
Return integer scores between 1-5, where 1 means the document is completely irrelevant to the question \
and 5 means the document answers the question exactly.
>> FORMATTING INSTRUCTIONS:
Return a comma separated list of scores, with one score for each document. Do not label
the scores or add any other text. Do not return a score outside the allowed range.
>> QUESTION:
{question}
>> CANDIDATE DOCUMENTS:
{documents}
>> RELEVANCE SCORES:
"""
GRADE_DOCS_PROMPT = PromptTemplate(
input_variables=["question", "documents"],
template=TEMPLATE,
output_parser=GradeOutputParser(),
)
SINGLE_DOC_TEMPLATE = """\
>> INSTRUCTIONS:
Given a question and a document, score how relevant the document is to the question. \
Return an integer score between 1-5, where 1 means all of the document is completely irrelevant to the question \
and 5 means that some part of the document answers the question exactly.
*Remember*, a document is considered to be relevant if *ANY* part of the document is relevant. \
>> FORMATTING INSTRUCTIONS:
Return a single integer score. Do not label
the score or add any other text. Do not return a score outside the allowed range.
>> QUESTION:
{question}
>> CANDIDATE DOCUMENT:
{document}
>> RELEVANCE SCORES:
"""
GRADE_SINGLE_DOC_PROMPT = PromptTemplate(
input_variables=["question", "document"],
template=SINGLE_DOC_TEMPLATE,
output_parser=GradeOutputParser(),
)

View File

@@ -121,7 +121,7 @@ class GenerativeAgentMemory(BaseMemory):
score = self.chain(prompt).run(memory_content=memory_content).strip()
if self.verbose:
logger.info(f"Importance score: {score}")
match = re.search(r"^\D*(\d+)", score)
match = re.search(r"\D*(\d+)", score)
if match:
return (float(match.group(1)) / 10) * self.importance_weight
else:

View File

@@ -1,9 +1,31 @@
from __future__ import annotations
import re
from typing import Dict, List, Optional
from typing import Callable, Dict, Generic, List, Optional, TypeVar
from langchain.schema import BaseOutputParser
from langchain.schema import BaseOutputParser, OutputParserException
_PARSED_T = TypeVar("_PARSED_T")
class ListRegexParser(BaseOutputParser[List[_PARSED_T]], Generic[_PARSED_T]):
"""Class to parse output using a regex."""
regex: str
_cast: Callable[[str], _PARSED_T]
def parse(self, text: str) -> List[_PARSED_T]:
"""Parse the output of an LLM call."""
matches = re.findall(self.regex, text)
if matches:
return [self._cast(m) for m in matches]
else:
raise OutputParserException(f"Could not parse output: {text}")
@property
def _type(self) -> str:
"""Return the type key."""
return "list_regex_parser"
class RegexParser(BaseOutputParser):
@@ -25,7 +47,7 @@ class RegexParser(BaseOutputParser):
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
else:
if self.default_output_key is None:
raise ValueError(f"Could not parse output: {text}")
raise OutputParserException(f"Could not parse output: {text}")
else:
return {
key: text if key == self.default_output_key else ""

View File

@@ -302,17 +302,17 @@ class BaseRetriever(ABC):
Memory = BaseMemory
T = TypeVar("T")
PARSED_T = TypeVar("PARSED_T")
class BaseOutputParser(BaseModel, ABC, Generic[T]):
class BaseOutputParser(BaseModel, ABC, Generic[PARSED_T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
@abstractmethod
def parse(self, text: str) -> T:
def parse(self, text: str) -> PARSED_T:
"""Parse the output of an LLM call.
A method which takes in a string (assumed output of a language model )