This commit is contained in:
Eugene Yurtsev
2023-05-30 14:15:43 -04:00
parent ebfd539e43
commit 95b3774479
2 changed files with 44 additions and 67 deletions

View File

@@ -4,15 +4,26 @@ from __future__ import annotations
import csv
from bs4 import BeautifulSoup
from io import StringIO
from typing import Sequence, Mapping, Any, Optional, Dict, List, cast, Set
from itertools import islice
from typing import (
Sequence,
Mapping,
Any,
Optional,
Dict,
List,
cast,
Set,
TypeVar,
Iterable,
Iterator,
)
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseOutputParser
from typing import TypedDict
MULTI_SELECT_TEMPLATE = """\
Here is a table in CSV format:
@@ -114,25 +125,11 @@ def _write_records_to_string(
return buffer.getvalue()
class MultiSelectionInput(TypedDict):
"""Input for the multi-selection chain."""
question: str
records: Sequence[Mapping[str, Any]]
delimiter: str
columns: Optional[Sequence[str]]
T = TypeVar("T")
class MultiSelectionOutput(TypedDict):
"""Output for the multi-selection chain."""
records: Sequence[Mapping[str, Any]]
from itertools import islice
def batch(iterable, size):
def batch(iterable: Iterable[T], size) -> Iterator[List[T]]:
"""Batch an iterable into chunks of size `size`."""
iterator = iter(iterable)
while True:
batch = list(islice(iterator, size))
@@ -160,7 +157,7 @@ class MultiSelectChain(Chain):
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> MultiSelectionOutput:
) -> Dict[str, Any]:
"""Run the chain."""
choices = inputs["choices"]
question = inputs["question"]

View File

@@ -13,8 +13,6 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.chains.classification.multiselection import (
IDParser,
_extract_content_from_tag,
MultiSelectChain,
)
from langchain.document_loaders.base import BaseBlobParser
@@ -23,43 +21,14 @@ from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLP
from langchain.schema import BaseOutputParser, Document
from langchain.text_splitter import TextSplitter
Parser = MarkdownifyHTMLParser(tags_to_remove=("svg", "img", "script", "style", "a"))
class AHrefExtractor(BaseOutputParser[List[str]]):
"""An output parser that extracts all a-href links."""
def parse(self, text: str) -> List[str]:
return _extract_href_tags(text)
URL_CRAWLING_PROMPT = PromptTemplate.from_template(
"""\
Here is a list of URLs extracted from a page titled: `{title}`.
```csv
{urls}
```
---
Here is a question:
{question}
---
Please output the ids of the URLs that may contain content relevant to answer the question. \
Use only the information csv table of URLs to determine relevancy.
Format your answer inside of an <ids> tags, separating the ids by a comma.
For example, if the 132 and 133 URLs are relevant, you would write: <ids>132,133</ids>
Begin:""",
output_parser=IDParser(),
)
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
"""Extract content from the given tag."""
soup = BeautifulSoup(html, "lxml")
queries = []
for query in soup.find_all(tag):
queries.append(query.text)
return queries
def _get_surrounding_text(tag: PageElement, n: int, *, is_before: bool = True) -> str:
@@ -94,7 +63,7 @@ def _get_surrounding_text(tag: PageElement, n: int, *, is_before: bool = True) -
return text
def get_ahref_snippets(html: str, num_chars: int = 0) -> Dict[str, Any]:
def _get_ahref_snippets(html: str, num_chars: int = 0) -> Dict[str, Any]:
"""Get a list of <a> tags as snippets from the given html.
Args:
@@ -155,11 +124,12 @@ class QueryExtractor(BaseOutputParser[List[str]]):
# TODO(Eugene): add a version that works for chat models as well w/ human system message?
QUERY_GENERATION_PROMPT = PromptTemplate.from_template(
"""\
Suggest a few different search queries that could be used to identify web-pages that could answer \
the following question.
Suggest a few different search queries that could be used to identify web-pages \
that could answer the following question.
If the question is about a named entity start by listing very general searches (e.g., just the named entity) \
and then suggest searches more scoped to the question.
If the question is about a named entity start by listing very general \
searches (e.g., just the named entity) and then suggest searches more \
scoped to the question.
Input: ```Where did John Snow from Cambridge, UK work?```
Output: <query>John Snow</query>
@@ -209,7 +179,17 @@ def generate_queries(llm: BaseLanguageModel, question: str) -> List[str]:
def _deduplicate_objects(
dicts: Sequence[Mapping[str, Any]], key: str
) -> List[Mapping[str, Any]]:
"""Deduplicate objects by the given key."""
"""Deduplicate objects by the given key.
TODO(Eugene): add a way to add weights to the objects.
Args:
dicts: a list of dictionaries to deduplicate.
key: the key to deduplicate by.
Returns:
a list of deduplicated dictionaries.
"""
unique_values = set()
deduped: List[Mapping[str, Any]] = []
@@ -250,10 +230,10 @@ class BlobCrawler(abc.ABC):
"""Explore the blob and identify links to related content that is relevant to the query."""
def _extract_records(blob: Blob) -> Tuple[List[Mapping[str, Any]], Tuple[str, ...]]:
def _extract_records(blob: Blob) -> Tuple[List[Dict[str, Any]], Tuple[str, ...]]:
"""Extract records from a blob."""
if blob.mimetype == "text/html":
info = get_ahref_snippets(blob.as_string(), num_chars=100)
info = _get_ahref_snippets(blob.as_string(), num_chars=100)
return (
[
{