Compare commits

...

4 Commits

Author SHA1 Message Date
Dev 2049
ea166b091e improve 2023-05-18 12:39:51 -07:00
Dev 2049
f8389cf4fd Merge branch 'master' into dev2049/retrieval_eval 2023-05-18 12:16:46 -07:00
Dev 2049
2798832e65 undo 2023-05-17 15:10:13 -07:00
Dev 2049
7fac949d12 wip 2023-05-17 15:08:47 -07:00
7 changed files with 161 additions and 7 deletions

View File

@@ -94,6 +94,25 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
# Call predict on the LLM. # Call predict on the LLM.
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {} return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
def combine_docs_and_parse(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[Any, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict_and_parse(callbacks=callbacks, **inputs), {}
async def acombine_docs_and_parse(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return (
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
{},
)
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:
return "stuff_documents_chain" return "stuff_documents_chain"

View File

@@ -0,0 +1,50 @@
from typing import List, Optional
from langchain import LLMChain, OpenAI, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.evaluation.retrieval.prompts import (
GRADE_DOCS_PROMPT,
GRADE_SINGLE_DOC_PROMPT,
)
from langchain.schema import Document
def grade_documents(
documents: List[Document],
question: str,
llm: Optional[BaseLanguageModel] = None,
) -> List[int]:
_llm = llm or OpenAI(temperature=0)
if len(documents) == 1:
return [grade_single_document(documents[0], question, llm=_llm)]
llm_chain = LLMChain(llm=_llm, prompt=GRADE_DOCS_PROMPT)
_documents = [
Document(page_content=d.page_content, metadata={"i": i})
for i, d in enumerate(documents)
]
document_prompt = PromptTemplate(
template="DOCUMENT {i}:\n{page_content}", input_variables=["i", "page_content"]
)
eval_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name="documents",
)
score, _ = eval_chain.combine_docs_and_parse(_documents, question=question)
return score
def grade_single_document(
document: Document,
question: str,
llm: Optional[BaseLanguageModel] = None,
) -> int:
_llm = llm or OpenAI(temperature=0)
llm_chain = LLMChain(llm=_llm, prompt=GRADE_SINGLE_DOC_PROMPT)
eval_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="document",
)
score, _ = eval_chain.combine_docs_and_parse([document], question=question)
return score[0]

View File

@@ -0,0 +1,63 @@
from typing import List
from langchain import PromptTemplate
from langchain.output_parsers.regex import ListRegexParser
class GradeOutputParser(ListRegexParser[int]):
regex = r"\D*(\d+)"
_cast = int
TEMPLATE = """\
>> INSTRUCTIONS:
Given a question and a list of documents, score how relevant each document is to the question. \
Return integer scores between 1-5, where 1 means the document is completely irrelevant to the question \
and 5 means the document answers the question exactly.
>> FORMATTING INSTRUCTIONS:
Return a comma separated list of scores, with one score for each document. Do not label
the scores or add any other text. Do not return a score outside the allowed range.
>> QUESTION:
{question}
>> CANDIDATE DOCUMENTS:
{documents}
>> RELEVANCE SCORES:
"""
GRADE_DOCS_PROMPT = PromptTemplate(
input_variables=["question", "documents"],
template=TEMPLATE,
output_parser=GradeOutputParser(),
)
SINGLE_DOC_TEMPLATE = """\
>> INSTRUCTIONS:
Given a question and a document, score how relevant the document is to the question. \
Return an integer score between 1-5, where 1 means all of the document is completely irrelevant to the question \
and 5 means that some part of the document answers the question exactly.
*Remember*, a document is considered to be relevant if *ANY* part of the document is relevant. \
>> FORMATTING INSTRUCTIONS:
Return a single integer score. Do not label
the score or add any other text. Do not return a score outside the allowed range.
>> QUESTION:
{question}
>> CANDIDATE DOCUMENT:
{document}
>> RELEVANCE SCORES:
"""
GRADE_SINGLE_DOC_PROMPT = PromptTemplate(
input_variables=["question", "document"],
template=SINGLE_DOC_TEMPLATE,
output_parser=GradeOutputParser(),
)

View File

@@ -121,7 +121,7 @@ class GenerativeAgentMemory(BaseMemory):
score = self.chain(prompt).run(memory_content=memory_content).strip() score = self.chain(prompt).run(memory_content=memory_content).strip()
if self.verbose: if self.verbose:
logger.info(f"Importance score: {score}") logger.info(f"Importance score: {score}")
match = re.search(r"^\D*(\d+)", score) match = re.search(r"\D*(\d+)", score)
if match: if match:
return (float(match.group(1)) / 10) * self.importance_weight return (float(match.group(1)) / 10) * self.importance_weight
else: else:

View File

@@ -1,9 +1,31 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Dict, List, Optional from typing import Callable, Dict, Generic, List, Optional, TypeVar
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser, OutputParserException
_PARSED_T = TypeVar("_PARSED_T")
class ListRegexParser(BaseOutputParser[List[_PARSED_T]], Generic[_PARSED_T]):
"""Class to parse output using a regex."""
regex: str
_cast: Callable[[str], _PARSED_T]
def parse(self, text: str) -> List[_PARSED_T]:
"""Parse the output of an LLM call."""
matches = re.findall(self.regex, text)
if matches:
return [self._cast(m) for m in matches]
else:
raise OutputParserException(f"Could not parse output: {text}")
@property
def _type(self) -> str:
"""Return the type key."""
return "list_regex_parser"
class RegexParser(BaseOutputParser): class RegexParser(BaseOutputParser):
@@ -25,7 +47,7 @@ class RegexParser(BaseOutputParser):
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)} return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
else: else:
if self.default_output_key is None: if self.default_output_key is None:
raise ValueError(f"Could not parse output: {text}") raise OutputParserException(f"Could not parse output: {text}")
else: else:
return { return {
key: text if key == self.default_output_key else "" key: text if key == self.default_output_key else ""

View File

@@ -302,17 +302,17 @@ class BaseRetriever(ABC):
Memory = BaseMemory Memory = BaseMemory
T = TypeVar("T") PARSED_T = TypeVar("PARSED_T")
class BaseOutputParser(BaseModel, ABC, Generic[T]): class BaseOutputParser(BaseModel, ABC, Generic[PARSED_T]):
"""Class to parse the output of an LLM call. """Class to parse the output of an LLM call.
Output parsers help structure language model responses. Output parsers help structure language model responses.
""" """
@abstractmethod @abstractmethod
def parse(self, text: str) -> T: def parse(self, text: str) -> PARSED_T:
"""Parse the output of an LLM call. """Parse the output of an LLM call.
A method which takes in a string (assumed output of a language model ) A method which takes in a string (assumed output of a language model )