mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
x
This commit is contained in:
0
langchain/chains/research/__init__.py
Normal file
0
langchain/chains/research/__init__.py
Normal file
451
langchain/chains/research/api.py
Normal file
451
langchain/chains/research/api.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import json
|
||||
import urllib.parse
|
||||
from bs4 import BeautifulSoup, PageElement
|
||||
from typing import Sequence, List, Mapping, Any, Dict, Tuple, Optional
|
||||
|
||||
from langchain import PromptTemplate, LLMChain, serpapi
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.research.multiselection import (
|
||||
IDParser,
|
||||
_extract_content_from_tag,
|
||||
MultiSelectChain,
|
||||
)
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
|
||||
from langchain.schema import BaseOutputParser, Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
Parser = MarkdownifyHTMLParser(tags_to_remove=("svg", "img", "script", "style", "a"))
|
||||
|
||||
|
||||
class AHrefExtractor(BaseOutputParser[List[str]]):
|
||||
"""An output parser that extracts all a-href links."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
return _extract_href_tags(text)
|
||||
|
||||
|
||||
URL_CRAWLING_PROMPT = PromptTemplate.from_template(
|
||||
"""\
|
||||
Here is a list of URLs extracted from a page titled: `{title}`.
|
||||
|
||||
```csv
|
||||
{urls}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Here is a question:
|
||||
|
||||
{question}
|
||||
|
||||
---
|
||||
|
||||
|
||||
Please output the ids of the URLs that may contain content relevant to answer the question. \
|
||||
Use only the information csv table of URLs to determine relevancy.
|
||||
|
||||
Format your answer inside of an <ids> tags, separating the ids by a comma.
|
||||
|
||||
For example, if the 132 and 133 URLs are relevant, you would write: <ids>132,133</ids>
|
||||
|
||||
Begin:""",
|
||||
output_parser=IDParser(),
|
||||
)
|
||||
|
||||
|
||||
def _get_surrounding_text(tag: PageElement, n: int, *, is_before: bool = True) -> str:
|
||||
"""Get surrounding text the given tag in the given direction.
|
||||
|
||||
Args:
|
||||
tag: the tag to get surrounding text for.
|
||||
n: number of characters to get
|
||||
is_before: Whether to get text before or after the tag.
|
||||
|
||||
Returns:
|
||||
the surrounding text in the given direction.
|
||||
"""
|
||||
text = ""
|
||||
current = tag.previous_element if is_before else tag.next_element
|
||||
|
||||
while current and len(text) < n:
|
||||
current_text = str(current.text).strip()
|
||||
current_text = (
|
||||
current_text
|
||||
if len(current_text) + len(text) <= n
|
||||
else current_text[: n - len(text)]
|
||||
)
|
||||
|
||||
if is_before:
|
||||
text = current_text + " " + text
|
||||
else:
|
||||
text = text + " " + current_text
|
||||
|
||||
current = current.previous_element if is_before else current.next_element
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def get_ahref_snippets(html: str, num_chars: int = 0) -> Dict[str, Any]:
|
||||
"""Get a list of <a> tags as snippets from the given html.
|
||||
|
||||
Args:
|
||||
html: the html to get snippets from.
|
||||
num_chars: the number of characters to get around the <a> tags.
|
||||
|
||||
Returns:
|
||||
a list of snippets.
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
title = soup.title.string.strip()
|
||||
snippets = []
|
||||
|
||||
for idx, a_tag in enumerate(soup.find_all("a")):
|
||||
before_text = _get_surrounding_text(a_tag, num_chars, is_before=True)
|
||||
after_text = _get_surrounding_text(a_tag, num_chars, is_before=False)
|
||||
snippet = {
|
||||
"id": idx,
|
||||
"before": before_text.strip().replace("\n", " "),
|
||||
"link": a_tag.get("href").replace("\n", " ").strip(),
|
||||
"content": a_tag.text.replace("\n", " ").strip(),
|
||||
"after": after_text.strip().replace("\n", " "),
|
||||
}
|
||||
snippets.append(snippet)
|
||||
|
||||
return {
|
||||
"snippets": snippets,
|
||||
"title": title,
|
||||
}
|
||||
|
||||
|
||||
def _extract_href_tags(html: str) -> List[str]:
|
||||
"""Extract href tags.
|
||||
|
||||
Args:
|
||||
html: the html to extract href tags from.
|
||||
|
||||
Returns:
|
||||
a list of href tags.
|
||||
"""
|
||||
href_tags = []
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for a_tag in soup.find_all("a"):
|
||||
href = a_tag.get("href")
|
||||
if href:
|
||||
href_tags.append(href)
|
||||
return href_tags
|
||||
|
||||
|
||||
class QueryExtractor(BaseOutputParser[List[str]]):
|
||||
"""An output parser that extracts all queries."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Extract all content of <query> from the text."""
|
||||
return _extract_content_from_tag(text, "query")
|
||||
|
||||
|
||||
# TODO(Eugene): add a version that works for chat models as well w/ human system message?
|
||||
QUERY_GENERATION_PROMPT = PromptTemplate.from_template(
|
||||
"""\
|
||||
Suggest a few different search queries that could be used to identify web-pages that could answer \
|
||||
the following question.
|
||||
|
||||
If the question is about a named entity start by listing very general searches (e.g., just the named entity) \
|
||||
and then suggest searches more scoped to the question.
|
||||
|
||||
Input: ```Where did John Snow from Cambridge, UK work?```
|
||||
Output: <query>John Snow</query>
|
||||
<query>John Snow Cambridge UK</query>
|
||||
<query> John Snow Cambridge UK work history </query>
|
||||
<query> John Snow Cambridge UK cv </query>
|
||||
|
||||
Input: ```How many research papers did Jane Doe publish in 2010?```
|
||||
Output: <query>Jane Doe</query>
|
||||
<query>Jane Doe research papers</query>
|
||||
<query>Jane Doe research research</query>
|
||||
<query>Jane Doe publications</query>
|
||||
<query>Jane Doe publications 2010</query>
|
||||
|
||||
Input: ```What is the capital of France?```
|
||||
Output: <query>France</query>
|
||||
<query>France capital</query>
|
||||
<query>France capital city</query>
|
||||
<query>France capital city name</query>
|
||||
|
||||
Input: ```What are the symptoms of COVID-19?```
|
||||
Output: <query>COVID-19</query>
|
||||
<query>COVID-19 symptoms</query>
|
||||
<query>COVID-19 symptoms list</query>
|
||||
<query>COVID-19 symptoms list WHO</query>
|
||||
|
||||
Input: ```What is the revenue stream of CVS?```
|
||||
Output: <query>CVS</query>
|
||||
<query>CVS revenue</query>
|
||||
<query>CVS revenue stream</query>
|
||||
<query>CVS revenue stream business model</query>
|
||||
|
||||
Input: ```{question}```
|
||||
Output:
|
||||
""",
|
||||
output_parser=QueryExtractor(),
|
||||
)
|
||||
|
||||
|
||||
def generate_queries(llm: BaseLanguageModel, question: str) -> List[str]:
|
||||
"""Generate queries using a Chain."""
|
||||
chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=llm)
|
||||
queries = chain.predict_and_parse(question=question)
|
||||
return queries
|
||||
|
||||
|
||||
def _deduplicate_objects(
|
||||
dicts: Sequence[Mapping[str, Any]], key: str
|
||||
) -> List[Mapping[str, Any]]:
|
||||
"""Deduplicate objects by the given key."""
|
||||
unique_values = set()
|
||||
deduped: List[Mapping[str, Any]] = []
|
||||
|
||||
for d in dicts:
|
||||
value = d[key]
|
||||
if value not in unique_values:
|
||||
unique_values.add(value)
|
||||
deduped.append(d)
|
||||
|
||||
return deduped
|
||||
|
||||
|
||||
def run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return the unique results.
|
||||
|
||||
Args:
|
||||
queries: a list of queries to run
|
||||
|
||||
Returns:
|
||||
a list of unique search results
|
||||
"""
|
||||
wrapper = serpapi.SerpAPIWrapper()
|
||||
results = []
|
||||
for query in queries:
|
||||
result = wrapper.results(query)
|
||||
organic_results = result["organic_results"]
|
||||
results.extend(organic_results)
|
||||
|
||||
unique_results = _deduplicate_objects(results, "link")
|
||||
return unique_results
|
||||
|
||||
|
||||
class BlobCrawler(abc.ABC):
|
||||
"""Crawl a blob and identify links to related content."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def crawl(self, blob: Blob, query: str) -> List[str]:
|
||||
"""Explore the blob and identify links to related content that is relevant to the query."""
|
||||
|
||||
|
||||
def _extract_records(blob: Blob) -> Tuple[List[Mapping[str, Any]], Tuple[str, ...]]:
|
||||
"""Extract records from a blob."""
|
||||
if blob.mimetype == "text/html":
|
||||
info = get_ahref_snippets(blob.as_string(), num_chars=100)
|
||||
return (
|
||||
[
|
||||
{
|
||||
"content": d["content"],
|
||||
"link": d["link"],
|
||||
"before": d["before"],
|
||||
"after": d["after"],
|
||||
}
|
||||
for d in info["snippets"]
|
||||
],
|
||||
("link", "content", "before", "after"),
|
||||
)
|
||||
elif blob.mimetype == "application/json": # Represent search results
|
||||
data = json.loads(blob.as_string())
|
||||
results = data["results"]
|
||||
return [
|
||||
{
|
||||
"title": result["title"],
|
||||
"snippet": result["snippet"],
|
||||
"link": result["link"],
|
||||
}
|
||||
for result in results
|
||||
], ("link", "title", "snippet")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Can only extract records from HTML/JSON blobs. Got {blob.mimetype}"
|
||||
)
|
||||
|
||||
|
||||
class ChainCrawler(BlobCrawler):
|
||||
def __init__(self, chain: MultiSelectChain, parser: BaseBlobParser) -> None:
|
||||
"""Crawl the blob using an LLM."""
|
||||
self.chain = chain
|
||||
self.parser = parser
|
||||
|
||||
def crawl(self, blob: Blob, question: str) -> List[str]:
|
||||
"""Explore the blob and suggest additional content to explore."""
|
||||
records, columns = _extract_records(blob)
|
||||
|
||||
result = self.chain(
|
||||
inputs={"question": question, "choices": records, "columns": columns}
|
||||
)
|
||||
|
||||
selected_records = result["selected"]
|
||||
|
||||
urls = [
|
||||
# TODO(): handle absolute links
|
||||
urllib.parse.urljoin(blob.source, record["link"])
|
||||
for record in selected_records
|
||||
if "mailto:" not in record["link"]
|
||||
]
|
||||
return urls
|
||||
|
||||
@classmethod
|
||||
def from_default(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
blob_parser: BaseBlobParser = MarkdownifyHTMLParser(),
|
||||
) -> "ChainCrawler":
|
||||
"""Create a crawler from the default LLM."""
|
||||
chain = MultiSelectChain.from_default(llm)
|
||||
return cls(chain=chain, parser=blob_parser)
|
||||
|
||||
|
||||
class DocumentProcessor(abc.ABC):
|
||||
def transform(self, document: Document) -> List[Document]:
|
||||
"""Transform the document."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ReadEntireDocChain(Chain):
|
||||
"""Read entire document chain.
|
||||
|
||||
This chain implements a brute force approach to reading an entire document.
|
||||
"""
|
||||
|
||||
chain: Chain
|
||||
text_splitter: TextSplitter
|
||||
max_num_docs: int = -1
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["doc", "question"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["document"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a long document."""
|
||||
source_document = inputs["doc"]
|
||||
|
||||
if not isinstance(source_document, Document):
|
||||
raise TypeError(f"Expected a Document, got {type(source_document)}")
|
||||
|
||||
question = inputs["question"]
|
||||
sub_docs = self.text_splitter.split_documents([source_document])
|
||||
if self.max_num_docs > 0:
|
||||
_sub_docs = sub_docs[: self.max_num_docs]
|
||||
else:
|
||||
_sub_docs = sub_docs
|
||||
|
||||
response = self.chain(
|
||||
{"input_documents": _sub_docs, "question": question},
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
summary_doc = Document(
|
||||
page_content=response["output_text"],
|
||||
metadata=source_document.metadata,
|
||||
)
|
||||
return {"document": summary_doc}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a long document."""
|
||||
doc = inputs["doc"]
|
||||
question = inputs["question"]
|
||||
sub_docs = self.text_splitter.split_documents([doc])
|
||||
if self.max_num_docs > 0:
|
||||
_sub_docs = sub_docs[: self.max_num_docs]
|
||||
else:
|
||||
_sub_docs = sub_docs
|
||||
results = await self.chain.acall(
|
||||
{"input_documents": _sub_docs, "question": question},
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
summary_doc = Document(
|
||||
page_content=results["output_text"],
|
||||
metadata=doc.metadata,
|
||||
)
|
||||
|
||||
return {"document": summary_doc}
|
||||
|
||||
|
||||
class ParallelApply(Chain):
|
||||
"""Apply a chain in parallel."""
|
||||
|
||||
chain: Chain
|
||||
max_concurrency: int
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["inputs"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["results"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain."""
|
||||
# TODO(): parallelize this
|
||||
chain_inputs = inputs["inputs"]
|
||||
|
||||
results = [
|
||||
self.chain(
|
||||
chain_input,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
for chain_input in chain_inputs
|
||||
]
|
||||
return {"results": results}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain."""
|
||||
chain_inputs = inputs["inputs"]
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self.chain.acall(
|
||||
chain_input,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
for chain_input in chain_inputs
|
||||
]
|
||||
)
|
||||
return {"results": results}
|
||||
138
langchain/chains/research/fetch.py
Normal file
138
langchain/chains/research/fetch.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Module contains code for fetching documents from the web using playwright.
|
||||
|
||||
This module currently re-uses the code from the `web_base` module to avoid
|
||||
re-implementing rate limiting behavior.
|
||||
|
||||
The module contains downloading interfaces.
|
||||
|
||||
Sub-classing with the given interface should allow a user to add url based
|
||||
user-agents and authentication if needed.
|
||||
|
||||
Downloading is batched by default to allow parallelizing efficiently.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Sequence, List, Any
|
||||
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
|
||||
|
||||
def _is_javascript_required(html_content: str) -> bool:
|
||||
"""Heuristic to determine whether javascript execution is required.
|
||||
|
||||
Args:
|
||||
html_content (str): The HTML content to check.
|
||||
|
||||
Returns:
|
||||
bool: True if javascript execution is required, False otherwise.
|
||||
"""
|
||||
# Parse the HTML content using BeautifulSoup
|
||||
soup = BeautifulSoup(html_content, "lxml")
|
||||
|
||||
# Count the number of HTML elements
|
||||
body = soup.body
|
||||
if not body:
|
||||
return True
|
||||
num_elements = len(body.find_all())
|
||||
requires_javascript = num_elements < 1
|
||||
return requires_javascript
|
||||
|
||||
|
||||
class DownloadHandler(abc.ABC):
|
||||
def download(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download a url synchronously."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download a url asynchronously."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PlaywrightDownloadHandler(DownloadHandler):
|
||||
"""Download URLS using playwright.
|
||||
|
||||
This is an implementation of the download handler that uses playwright to download
|
||||
urls. This is useful for downloading urls that require javascript to be executed.
|
||||
"""
|
||||
|
||||
def download(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download list of urls synchronously."""
|
||||
# Implement using a threadpool or using playwright API if it supports it
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _download(self, browser: Any, url: str) -> str:
|
||||
"""Download a url asynchronously using playwright."""
|
||||
page = await browser.new_page()
|
||||
await page.goto(url, wait_until="networkidle")
|
||||
html_content = await page.content()
|
||||
return html_content
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download a url asynchronously using playwright.
|
||||
|
||||
Args:
|
||||
url: The url to download.
|
||||
|
||||
Returns:
|
||||
The html content of the url.
|
||||
"""
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch()
|
||||
tasks = [self._download(browser, url) for url in urls]
|
||||
contents = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await browser.close()
|
||||
|
||||
return _repackage_as_blobs(urls, contents)
|
||||
|
||||
|
||||
class RequestsDownloadHandler(DownloadHandler):
|
||||
def __init__(self, web_downloader: WebBaseLoader):
|
||||
self.web_downloader = web_downloader
|
||||
|
||||
def download(self, url: str) -> str:
|
||||
"""Download a url synchronously."""
|
||||
# Implement with threadpool.
|
||||
raise NotImplementedError()
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download a batch of urls asynchronously using playwright."""
|
||||
download = WebBaseLoader(web_path=[]) # Place holder
|
||||
contents = await download.fetch_all(list(urls))
|
||||
return _repackage_as_blobs(urls, contents)
|
||||
|
||||
|
||||
def _repackage_as_blobs(urls: Sequence[str], contents: Sequence[str]) -> List[Blob]:
|
||||
"""Repackage the contents as blobs."""
|
||||
return [
|
||||
Blob(data=content, mimetype=mimetypes.guess_type(url))
|
||||
for url, content in zip(urls, contents)
|
||||
]
|
||||
|
||||
|
||||
class AutoDownloadHandler(DownloadHandler):
|
||||
def __init__(self, web_downloader: WebBaseLoader) -> None:
|
||||
"""Initialize the auto download handler."""
|
||||
self.requests_downloader = RequestsDownloadHandler(web_downloader)
|
||||
self.playwright_downloader = PlaywrightDownloadHandler()
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download a batch of urls asynchronously using playwright."""
|
||||
# Check if javascript is required
|
||||
blobs = await self.requests_downloader.adownload(urls)
|
||||
|
||||
# Check if javascript is required
|
||||
must_redownload = [
|
||||
(idx, url)
|
||||
for idx, (url, blob) in enumerate(zip(urls, blobs))
|
||||
if _is_javascript_required(blob.data)
|
||||
]
|
||||
|
||||
indexes, urls_to_redownload = zip(*must_redownload)
|
||||
contents = await self.playwright_downloader.adownload(urls)
|
||||
return contents
|
||||
218
langchain/chains/research/multiselection.py
Normal file
218
langchain/chains/research/multiselection.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Perform classification / selection using language models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
from bs4 import BeautifulSoup
|
||||
from io import StringIO
|
||||
from typing import Sequence, Mapping, Any, Optional, Dict, List, cast, Set
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseOutputParser
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
MULTI_SELECT_TEMPLATE = """\
|
||||
Here is a table in CSV format:
|
||||
|
||||
{records}
|
||||
|
||||
---
|
||||
|
||||
question:
|
||||
|
||||
{question}
|
||||
|
||||
---
|
||||
|
||||
Output IDs of rows that answer the question or match the question.
|
||||
|
||||
For example, if row id 132 and id 133 are relevant, output: <ids>132,133</ids>
|
||||
|
||||
---
|
||||
|
||||
Begin:"""
|
||||
|
||||
|
||||
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
|
||||
"""Extract content from the given tag."""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
queries = []
|
||||
for query in soup.find_all(tag):
|
||||
queries.append(query.text)
|
||||
return queries
|
||||
|
||||
|
||||
class IDParser(BaseOutputParser[List[int]]):
|
||||
"""An output parser that extracts all IDs from the output."""
|
||||
|
||||
def parse(self, text: str) -> List[int]:
|
||||
"""Parse the text and return a list of IDs"""
|
||||
tags = _extract_content_from_tag(text, "ids")
|
||||
|
||||
if not tags:
|
||||
return []
|
||||
|
||||
if len(tags) > 1:
|
||||
# Fail if more than 1 tag group is identified
|
||||
return []
|
||||
|
||||
tag = tags[0]
|
||||
ids = tag.split(",")
|
||||
|
||||
finalized_ids = []
|
||||
for idx in ids:
|
||||
if idx.isdigit():
|
||||
finalized_ids.append(int(idx))
|
||||
return finalized_ids
|
||||
|
||||
|
||||
def _write_records_to_string(
|
||||
records: Sequence[Mapping[str, Any]],
|
||||
*,
|
||||
columns: Optional[Sequence[str]] = None,
|
||||
delimiter: str = "|",
|
||||
) -> str:
|
||||
"""Write records to a CSV string.
|
||||
|
||||
Args:
|
||||
records: a list of records, assumes that all records have all keys
|
||||
columns: a list of columns to include in the CSV
|
||||
delimiter: the delimiter to use
|
||||
|
||||
Returns:
|
||||
a CSV string
|
||||
"""
|
||||
buffer = StringIO()
|
||||
if columns is None:
|
||||
existing_columns: Set[str] = set()
|
||||
for record in records:
|
||||
existing_columns.update(record.keys())
|
||||
_columns: Sequence[str] = sorted(existing_columns)
|
||||
else:
|
||||
_columns = columns
|
||||
|
||||
# Make sure the id column is always first
|
||||
_columns_with_id_first = list(_columns)
|
||||
|
||||
if "id" in _columns_with_id_first:
|
||||
_columns_with_id_first.remove("id")
|
||||
|
||||
# Make sure the `id` column is always first
|
||||
_columns_with_id_first.insert(0, "id")
|
||||
|
||||
writer = csv.DictWriter(
|
||||
buffer,
|
||||
fieldnames=_columns_with_id_first,
|
||||
delimiter=delimiter,
|
||||
)
|
||||
writer.writeheader()
|
||||
writer.writerows(records)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
class MultiSelectionInput(TypedDict):
|
||||
"""Input for the multi-selection chain."""
|
||||
|
||||
question: str
|
||||
records: Sequence[Mapping[str, Any]]
|
||||
delimiter: str
|
||||
columns: Optional[Sequence[str]]
|
||||
|
||||
|
||||
class MultiSelectionOutput(TypedDict):
|
||||
"""Output for the multi-selection chain."""
|
||||
|
||||
records: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
from itertools import islice
|
||||
|
||||
|
||||
def batch(iterable, size):
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
batch = list(islice(iterator, size))
|
||||
if not batch:
|
||||
return
|
||||
yield batch
|
||||
|
||||
|
||||
class MultiSelectChain(Chain):
|
||||
"""A chain that performs multi-selection from a list of choices."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["question", "choices"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["selected"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> MultiSelectionOutput:
|
||||
"""Run the chain."""
|
||||
choices = inputs["choices"]
|
||||
question = inputs["question"]
|
||||
columns = inputs.get("columns", None)
|
||||
|
||||
selected = []
|
||||
max_choices = 30
|
||||
|
||||
for choice_batch in batch(choices, max_choices):
|
||||
records_with_ids = [
|
||||
{**record, "id": idx} for idx, record in enumerate(choice_batch)
|
||||
]
|
||||
records_str = _write_records_to_string(
|
||||
records_with_ids, columns=columns, delimiter="|"
|
||||
)
|
||||
|
||||
indexes = cast(
|
||||
List[int],
|
||||
self.llm_chain.predict_and_parse(
|
||||
records=records_str,
|
||||
question=question,
|
||||
callbacks=run_manager.get_child(),
|
||||
),
|
||||
)
|
||||
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
|
||||
selected.extend(choice_batch[i] for i in valid_indexes)
|
||||
|
||||
return {
|
||||
"selected": selected,
|
||||
}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
return "multilabel_binary_classifier"
|
||||
|
||||
@classmethod
|
||||
def from_default(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
prompt: str = MULTI_SELECT_TEMPLATE,
|
||||
parser: BaseOutputParser = IDParser(),
|
||||
) -> MultiSelectChain:
|
||||
"""Provide a multilabel binary classifier."""
|
||||
prompt_template = PromptTemplate.from_template(prompt, output_parser=parser)
|
||||
if set(prompt_template.input_variables) != {"question", "records"}:
|
||||
raise ValueError("Prompt must contain only {question} and {records}")
|
||||
|
||||
return cls(
|
||||
llm_chain=LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt_template,
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user