Compare commits

...

38 Commits

Author SHA1 Message Date
Eugene Yurtsev
27b87c64a0 q 2023-06-04 08:54:41 -04:00
Eugene Yurtsev
c8f593b2c0 q 2023-06-02 14:57:15 -04:00
Eugene Yurtsev
16f41719a4 q 2023-06-02 14:45:17 -04:00
Eugene Yurtsev
c62f497082 q 2023-06-02 14:44:29 -04:00
Eugene Yurtsev
38aef8c252 q 2023-06-02 14:38:32 -04:00
Eugene Yurtsev
1911388d9d q 2023-06-02 14:37:57 -04:00
Eugene Yurtsev
37bdeb60fc q 2023-06-02 14:37:09 -04:00
Eugene Yurtsev
ece6b598c4 q 2023-06-02 14:22:59 -04:00
Eugene Yurtsev
b641bde197 q 2023-06-02 14:21:12 -04:00
Eugene Yurtsev
68e39f7f26 x 2023-06-02 14:14:19 -04:00
Eugene Yurtsev
7adbcb195d q 2023-06-02 14:12:31 -04:00
Eugene Yurtsev
307df3ebed q 2023-06-02 14:08:22 -04:00
Eugene Yurtsev
2b42b9cb82 q 2023-06-02 13:31:55 -04:00
Eugene Yurtsev
13c82e6b66 Merge branch 'master' into eugene/research_v1 2023-06-02 13:31:37 -04:00
Eugene Yurtsev
f0e78d7efd x 2023-06-02 13:07:05 -04:00
Eugene Yurtsev
ad2b777536 x 2023-06-02 13:01:56 -04:00
Eugene Yurtsev
f79582c548 Add research example 2023-06-02 12:56:26 -04:00
Eugene Yurtsev
203cb5e307 q 2023-06-02 12:49:27 -04:00
Eugene Yurtsev
9c6accfa1a q 2023-06-02 12:12:17 -04:00
Eugene Yurtsev
39a2d2511d x 2023-06-02 11:23:38 -04:00
Eugene Yurtsev
679eb9f14f x 2023-06-01 23:14:11 -04:00
Eugene Yurtsev
f24e521015 x 2023-06-01 22:57:43 -04:00
Eugene Yurtsev
7e14388ba8 q 2023-06-01 17:30:58 -04:00
Eugene Yurtsev
b7dabe8f50 x 2023-06-01 17:26:46 -04:00
Eugene Yurtsev
f449d26083 q 2023-06-01 16:39:52 -04:00
Eugene Yurtsev
1f1db7c96a Merge branch 'master' into eugene/research_v1 2023-05-31 09:52:56 -04:00
Eugene Yurtsev
6a558c72a2 q 2023-05-31 09:52:21 -04:00
Eugene Yurtsev
d381c4fad8 x 2023-05-30 15:39:17 -04:00
Eugene Yurtsev
c3d260ffdc x 2023-05-30 14:59:34 -04:00
Eugene Yurtsev
08b6c75743 x 2023-05-30 14:54:42 -04:00
Eugene Yurtsev
69cc8d7f73 x 2023-05-30 14:45:21 -04:00
Eugene Yurtsev
fe89220aac x 2023-05-30 14:40:46 -04:00
Eugene Yurtsev
0765679292 x 2023-05-30 14:37:07 -04:00
Eugene Yurtsev
95b3774479 x 2023-05-30 14:15:43 -04:00
Eugene Yurtsev
ebfd539e43 x 2023-05-30 14:09:00 -04:00
Eugene Yurtsev
de7ab6be16 q 2023-05-30 13:56:33 -04:00
Eugene Yurtsev
53a7e2b851 x 2023-05-30 13:55:29 -04:00
Eugene Yurtsev
65aad4c0bb x 2023-05-30 13:55:20 -04:00
13 changed files with 1849 additions and 0 deletions

View File

@@ -0,0 +1,469 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d0187afb-f460-431f-aee3-50d68bc33446",
"metadata": {},
"source": [
"# Research Chain\n",
"\n",
"This is an experimental research chain that tries to answer \"researchy\" questions using information on the web.\n",
"\n",
"\n",
"For example, \n",
"\n",
"```\n",
"Compile information about Albert Einstein.\n",
"Ignore if it's a different Albert Einstein. \n",
"Only include information you're certain about.\n",
"\n",
"Include:\n",
"* education history\n",
"* major contributions\n",
"* names of spouse \n",
"* date of birth\n",
"* place of birth\n",
"* a 3 sentence short biography\n",
"\n",
"Format your answer in a bullet point format for each sub-question.\n",
"```\n",
"\n",
"Or replace `Albert Einstein` with another person of interest (e.g., John Smith of Boston).\n",
"\n",
"\n",
"The chain is composed of the following components:\n",
"\n",
"1. A searcher that searches for documents using a search engine.\n",
" - The searcher is responsible to return a list of URLs of documents that\n",
" may be relevant to read to be able to answer the question.\n",
"2. A downloader that downloads the documents.\n",
"3. An HTML to markdown parser (hard coded) that converts the HTML to markdown.\n",
" * Conversion to markdown is lossy\n",
" * However, it can significantly reduce the token count of the document\n",
" * Markdown helps to preserve some styling information\n",
" (e.g., bold, italics, links, headers) which is expected to help the reader\n",
" to answer certain kinds of questions correctly.\n",
"4. A reader that reads the documents and produces an answer.\n",
"\n",
"## Limitations\n",
"\n",
"* Quality of results depends on LLM used, and can be improved by providing more specialized parsers (e.g., parse only the body of articles).\n",
"* If asking about people, provide enough information to disambiguate the person.\n",
"* Content downloader may get blocked (e.g., if attempting to download from linkedin) -- may need to read terms of service / user agents appropriately.\n",
"* Chain can be potentially long running (use initialization parameters to control how many options are explored) -- use async implementation as it uses more concurrency.\n",
"* This research chain only implements a single hop at the moment; i.e.,\n",
" it goes from the questions to a list of URLs to documents to compiling answers.\n",
" Without continuing the crawl, web-sites that require pagnation will not be explored fully.\n",
"* The reader chain must match the type of question. For example, the QA refine chain \n",
" isn't good at extracting a list of entries from a long document.\n",
" \n",
"## Extending\n",
"\n",
"* Continue crawling documents to discover more relevant pages that were not surfaced by the search engine.\n",
"* Adapt reading strategy based on nature of question.\n",
"* Analyze the query and determine whether the query is a multi-hop query and change search/crawling strategy based on that.\n",
"* Break components into tools that can be exposed to an agent. :)\n",
"* Add cheaper strategies for selecting which links should be explored further (e.g., based on tf-idf similarity instead of gpt-4)\n",
"* Add a summarization chain on top of the individually collected answers.\n",
"* Improve strategy to ignore irrelevant information."
]
},
{
"cell_type": "markdown",
"id": "4d937a38-66c6-4aa2-87bb-337101cfb112",
"metadata": {},
"source": [
"# Requirements\n",
"\n",
"Please install: \n",
"\n",
"* `playwright` for fetching content from the web (or use the RequestsDownloadHandler)\n",
"* `lxml` and `markdownify` for parsing HTMLs"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7eb466b8-24fa-4acc-b0ce-06fcfa2fa9c4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chains.research.api import Research\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.llms import OpenAI\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.chains.research.download import PlaywrightDownloadHandler\n",
"# If you don't have playwright installed, can experiment with requests\n",
"# Be aware that some web-pages won't download properly as javascript won't be executed\n",
"from langchain.chains.research.download import RequestsDownloadHandler "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "70474885-0acd-41b2-8050-15dd54f44f1e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"question = \"\"\"\\\n",
"Compile information about Albert Einstein.\n",
"Ignore if it's a different Albert Einstein. \n",
"Only include information you're certain about.\n",
"\n",
"Include:\n",
"* education history\n",
"* major contributions\n",
"* names of spouse \n",
"* date of birth\n",
"* place of birth\n",
"* a 3 sentence short biography\n",
"\n",
"Format your answer in a bullet point format for each sub-question.\n",
"\"\"\".strip()"
]
},
{
"cell_type": "markdown",
"id": "6613da1c-3349-45f4-9770-19986750d548",
"metadata": {},
"source": [
"Instantiate LLMs"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e74a44b4-2075-4cc6-933e-c769bf3f6002",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(\n",
" temperature=0, model=\"text-davinci-003\"\n",
") # Used for the readers and the query generator\n",
"selector_llm = ChatOpenAI(\n",
" temperature=0, model=\"gpt-4\"\n",
") # Used for selecting which links to explore"
]
},
{
"cell_type": "markdown",
"id": "b243364a-e79b-432d-8035-3de8caf554a8",
"metadata": {},
"source": [
"Create a chain that can be used to extract the answer to the question above from a given document.\n",
"\n",
"This chain must be tailored to the task."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2f538062-14e3-49ab-9b25-bc470eb5869c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"qa_chain = load_qa_chain(llm, chain_type=\"refine\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a96f3ed3-10de-4a85-9e93-a8b78d8bfbb6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"research = Research.from_llms(\n",
" query_generation_llm=llm,\n",
" link_selection_llm=selector_llm,\n",
" underlying_reader_chain=qa_chain,\n",
" top_k_per_search=1,\n",
" max_num_pages_per_doc=3,\n",
" download_handler=PlaywrightDownloadHandler(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3207c696-a72c-4378-b427-7d285f5fdd1c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"results = await research.acall(inputs={\"question\": question})"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "843616b3-32d7-49c7-a42b-b0272d71f3ed",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
"https://en.wikipedia.org/wiki/Albert_Einstein\n",
"\n",
"\n",
"Albert Einstein:\n",
"* Education history: Attended elementary school in Munich, Germany, and later attended the Swiss Federal Polytechnic School in Zurich, Switzerland.\n",
"* Major contributions: Developed the theory of relativity, made major contributions to quantum theory, and won the Nobel Prize in Physics in 1921. He also published more than 300 scientific papers and 150 non-scientific works. He was also the first to propose the existence of black holes and gravitational waves. He was also a polyglot, speaking over 15 languages, including Afrikaans, Alemannisch, Amharic, Anarâškielâ, Angika, Old English, Abkhazian, Arabic, Aragonese, Western Armenian, Aromanian, Arpitan, Assamese, Asturian, Guarani, Aymara, Azerbaijani, South Azerbaijani, Balinese, Bambara, Bangla, Min Nan Chinese, Basa Banyumasan, Bashkir, Belarusian, Belarusian (Taraškievica orthography), Bhojpuri, Central Bikol, and Bulgarian.\n",
"* Names of spouse: Married Mileva Marić in 1903 and Elsa Löwenthal in 1919\n",
"----------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
"https://www.advergize.com/edu/7-albert-einstein-inventions-contributions/\n",
"\n",
"\n",
"Education History:\n",
"* Attended Aargau Cantonal School in Switzerland from 1895-1896\n",
"* Attended ETH Zurich from 1896-1900\n",
"* Received a PhD from the University of Zurich in 1905\n",
"\n",
"Major Contributions:\n",
"* Theory of Relativity\n",
"* Photoelectric Effect\n",
"* Brownian Motion\n",
"* Bose-Einstein Condensate\n",
"* Unified Field Theory\n",
"* Quantum Theory of Light\n",
"* E=mc2\n",
"* Manhattan Project\n",
"* Einsteins Refrigerator\n",
"* Sky is Blue\n",
"* Quantum Theory of Light\n",
"* Photoelectric Effect\n",
"* Brownian Movement\n",
"* Special Theory of Relativity\n",
"* General Theory of Relativity\n",
"* Manhattan Project\n",
"* Einsteins Refrigerator\n",
"\n",
"Names of Spouse:\n",
"* Mileva Maric (1903-1919)\n",
"* Elsa Löwenthal (1919-1936)\n",
"\n",
"Date of Birth: March 14, 1879\n",
"\n",
"Place of Birth: Ulm, Germany\n",
"\n",
"Short Biography:\n",
"Albert Einstein was a German-born physicist who developed the theory of relativity. He is widely considered one of the most influential scientists of the 20th century and is known for his mass-energy equivalence formula\n",
"----------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
"https://www.nobelprize.org/prizes/physics/1921/einstein/biographical/\n",
"\n",
"\n",
"Education History:\n",
"* Attended Aargau Cantonal School in Aarau, Switzerland from 1895-1896\n",
"* Attended ETH Zurich (Swiss Federal Institute of Technology) from 1896-1900\n",
"* Obtained his doctorate degree from Swiss Federal Polytechnic School in Zurich in 1901\n",
"\n",
"Major Contributions:\n",
"* Developed the theory of relativity\n",
"* Developed the mass-energy equivalence formula (E=mc2)\n",
"* Developed the law of the photoelectric effect\n",
"* Postulated that the correct interpretation of the special theory of relativity must also furnish a theory of gravitation\n",
"* Contributed to the problems of the theory of radiation and statistical mechanics\n",
"* Investigated the thermal properties of light with a low radiation density and his observations laid the foundation of the photon theory of light\n",
"* Contributed to statistical mechanics by his development of the quantum theory of a monatomic gas\n",
"* Worked towards the unification of the basic concepts of physics, taking the opposite approach, geometrisation, to the majority of physicists\n",
"\n",
"Names of Spouse:\n",
"* Mileva Marić (1903-1919)\n",
"* Elsa Löwenthal (1919-1936)\n",
"\n",
"Date of Birth:\n",
"\n"
]
}
],
"source": [
"for doc in results[\"docs\"]:\n",
" print(\"--\" * 80)\n",
" print(doc.metadata[\"source\"])\n",
" print(doc.page_content)"
]
},
{
"cell_type": "markdown",
"id": "edcaf496-3679-4b58-9baa-6124e1cc3435",
"metadata": {},
"source": [
"If useful we can produce another summary!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "76136bff-b7df-4539-9bcb-760fc4449390",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"qa_chain = load_qa_chain(llm, chain_type=\"stuff\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b352f3f3-6777-4795-acbb-ed26ecac137d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"summary = await qa_chain.acall(\n",
" inputs={\"input_documents\": results[\"docs\"], \"question\": question}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1c90c305-a89d-42e2-b975-dda039e816b6",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Education History:\n",
"* Attended Aargau Cantonal School in Aarau, Switzerland from 1895-1896\n",
"* Attended ETH Zurich (Swiss Federal Institute of Technology) from 1896-1900\n",
"* Obtained his doctorate degree from Swiss Federal Polytechnic School in Zurich in 1901\n",
"\n",
"Major Contributions:\n",
"* Developed the theory of relativity\n",
"* Developed the mass-energy equivalence formula (E=mc2)\n",
"* Developed the law of the photoelectric effect\n",
"* Postulated that the correct interpretation of the special theory of relativity must also furnish a theory of gravitation\n",
"* Contributed to the problems of the theory of radiation and statistical mechanics\n",
"* Investigated the thermal properties of light with a low radiation density and his observations laid the foundation of the photon theory of light\n",
"* Contributed to statistical mechanics by his development of the quantum theory of a monatomic gas\n",
"* Worked towards the unification of the basic concepts of physics, taking the opposite approach, geometrisation, to the majority of physicists\n",
"\n",
"Names of Spouse:\n",
"* Mileva Marić (1903-1919)\n",
"* Elsa Löwenthal (1919-1936)\n",
"\n",
"Date of Birth: March\n"
]
}
],
"source": [
"print(summary[\"output_text\"])"
]
},
{
"cell_type": "markdown",
"id": "af2adfee-85d3-41af-900a-c594dc01ce16",
"metadata": {},
"source": [
"## Under the hood"
]
},
{
"cell_type": "markdown",
"id": "c307aa60-7e75-48e8-ba72-b45507ed3fe0",
"metadata": {},
"source": [
"A searcher is invoked first to find URLs that are good to explore"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "4e2c369c-8763-458e-9ab8-684466395890",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'question': \"Compile information about Albert Einstein.\\nIgnore if it's a different Albert Einstein. \\nOnly include information you're certain about.\\n\\nInclude:\\n* education history\\n* major contributions\\n* names of spouse \\n* date of birth\\n* place of birth\\n* a 3 sentence short biography\\n\\nFormat your answer in a bullet point format for each sub-question.\",\n",
" 'urls': ['https://en.wikipedia.org/wiki/Albert_Einstein',\n",
" 'https://www.britannica.com/biography/Albert-Einstein',\n",
" 'https://www.advergize.com/edu/7-albert-einstein-inventions-contributions/',\n",
" 'https://www.nobelprize.org/prizes/physics/1921/einstein/biographical/']}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"suggetions = await research.searcher.acall(inputs={\"question\": question})\n",
"suggetions"
]
},
{
"cell_type": "markdown",
"id": "24099529-2d3d-4eca-8a1f-45a2539c8842",
"metadata": {},
"source": [
"The webpages are downloaded"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "c6f9d1b5-e513-4d8d-b325-32eacbee92b4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"blobs = await research.downloader.adownload(suggetions[\"urls\"])"
]
},
{
"cell_type": "markdown",
"id": "58909ce0-fd1c-4e09-9d13-b43135ee8038",
"metadata": {},
"source": [
"The blobs are parsed with an HTML parser and read by the reader chain (not shown) -- see underlying code for details."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,267 @@
"""Perform classification / selection using language models."""
from __future__ import annotations
import csv
from io import StringIO
from itertools import islice
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Set,
TypeVar,
cast,
)
from bs4 import BeautifulSoup
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.schema import BaseOutputParser
MULTI_SELECT_TEMPLATE = """\
Here is a table in CSV format:
{records}
---
question:
{question}
---
Output IDs of rows that answer the question or match the question.
For example, if row id 132 and id 133 are relevant, output: <ids>132,133</ids>
---
Begin:"""
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
"""Extract content from the given tag."""
soup = BeautifulSoup(html, "lxml")
queries = []
for query in soup.find_all(tag):
queries.append(query.text)
return queries
class IDParser(BaseOutputParser[List[int]]):
"""An output parser that extracts all IDs from the output."""
def parse(self, text: str) -> List[int]:
"""Parse the text and return a list of IDs"""
tags = _extract_content_from_tag(text, "ids")
if not tags:
return []
if len(tags) > 1:
# Fail if more than 1 tag group is identified
return []
tag = tags[0]
ids = tag.split(",")
finalized_ids = []
for idx in ids:
if idx.isdigit():
finalized_ids.append(int(idx))
return finalized_ids
def _write_records_to_string(
records: Sequence[Mapping[str, Any]],
*,
columns: Optional[Sequence[str]] = None,
delimiter: str = "|",
) -> str:
"""Write records to a CSV string.
Args:
records: a list of records, assumes that all records have all keys
columns: a list of columns to include in the CSV
delimiter: the delimiter to use
Returns:
a CSV string
"""
buffer = StringIO()
if columns is None:
existing_columns: Set[str] = set()
for record in records:
existing_columns.update(record.keys())
_columns: Sequence[str] = sorted(existing_columns)
else:
_columns = columns
# Make sure the id column is always first
_columns_with_id_first = list(_columns)
if "id" in _columns_with_id_first:
_columns_with_id_first.remove("id")
# Make sure the `id` column is always first
_columns_with_id_first.insert(0, "id")
writer = csv.DictWriter(
buffer,
fieldnames=_columns_with_id_first,
delimiter=delimiter,
)
writer.writeheader()
writer.writerows(records)
buffer.seek(0)
return buffer.getvalue()
T = TypeVar("T")
def _batch(iterable: Iterable[T], size: int) -> Iterator[List[T]]:
"""Batch an iterable into chunks of size `size`.
Args:
iterable: the iterable to batch
size: the size of each batch
Returns:
iterator over batches of size `size` except for last batch which will be up
to size `size`
"""
iterator = iter(iterable)
while True:
batch = list(islice(iterator, size))
if not batch:
return
yield batch
class MultiSelectChain(Chain):
"""A chain that performs multi-selection from a list of choices."""
llm_chain: LLMChain
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["question", "choices"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["selected"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain."""
choices = inputs["choices"]
question = inputs["question"]
columns = inputs.get("columns", None)
selected: List[Mapping[str, Any]] = []
# TODO(): Balance choices into equal batches with constraint dependent
# on context window and prompt
max_choices = 30
for choice_batch in _batch(choices, max_choices):
records_with_ids = [
{**record, "id": idx} for idx, record in enumerate(choice_batch)
]
records_str = _write_records_to_string(
records_with_ids, columns=columns, delimiter="|"
)
indexes = cast(
List[int],
self.llm_chain.predict_and_parse(
records=records_str,
question=question,
callbacks=run_manager.get_child() if run_manager else None,
),
)
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
selected.extend(choice_batch[i] for i in valid_indexes)
return {
"selected": selected,
}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
choices = inputs["choices"]
question = inputs["question"]
columns = inputs.get("columns", None)
selected: List[Mapping[str, Any]] = []
# TODO(): Balance choices into equal batches with constraint dependent
# on context window and prompt
max_choices = 30
for choice_batch in _batch(choices, max_choices):
records_with_ids = [
{**record, "id": idx} for idx, record in enumerate(choice_batch)
]
records_str = _write_records_to_string(
records_with_ids, columns=columns, delimiter="|"
)
indexes = cast(
List[int],
await self.llm_chain.apredict_and_parse(
records=records_str,
question=question,
callbacks=run_manager.get_child() if run_manager else None,
),
)
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
selected.extend(choice_batch[i] for i in valid_indexes)
return {
"selected": selected,
}
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "multilabel_binary_classifier"
@classmethod
def from_default(
cls,
llm: BaseLanguageModel,
*,
prompt: str = MULTI_SELECT_TEMPLATE,
parser: BaseOutputParser = IDParser(),
) -> MultiSelectChain:
"""Provide a multilabel binary classifier."""
prompt_template = PromptTemplate.from_template(prompt, output_parser=parser)
if set(prompt_template.input_variables) != {"question", "records"}:
raise ValueError("Prompt must contain only {question} and {records}")
return cls(
llm_chain=LLMChain(
llm=llm,
prompt=prompt_template,
)
)

View File

View File

@@ -0,0 +1,200 @@
from __future__ import annotations
import itertools
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.research.download import AutoDownloadHandler, DownloadHandler
from langchain.chains.research.readers import DocReadingChain, ParallelApplyChain
from langchain.chains.research.search import GenericSearcher
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
class Research(Chain):
"""A research chain.
A research chain is composed of the following components:
1. A searcher that searches for documents using a search engine.
- The searcher is responsible to return a list of URLs of documents that
may be relevant to read to be able to answer the question.
2. A downloader that downloads the documents.
3. An HTML to markdown parser (hard coded) that converts the HTML to markdown.
* Conversion to markdown is lossy
* However, it can significantly reduce the token count of the document
* Markdown helps to preserve some styling information
(e.g., bold, italics, links, headers) which is expected to help the reader
to answer certain kinds of questions correctly.
4. A reader that reads the documents and produces an answer.
Limitations:
* This research chain only implements a single hop at the moment; i.e.,
it goes from the questions to a list of URLs to documents to compiling
answers.
* The reader chain needs to match the task. For example, if using a QA refine
chain, a task of collecting a list of entries from a long document will
fail because the QA refine chain is not designed to handle such a task.
The chain can be extended to continue crawling the documents in attempt
to discover relevant pages that were not surfaced by the search engine.
Amongst other problems without continuing the crawl, it is impossible to
continue getting results from pages that involve pagination.
"""
searcher: GenericSearcher
"""The searcher to use to search for documents."""
reader: Chain
"""The reader to use to read documents and produce an answer."""
downloader: DownloadHandler
"""The downloader to use to download the documents.
A few different implementations of the download handler have been provided.
Keep in mind that some websites require execution of JavaScript to load
the DOM.
"""
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["question"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["docs"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain synchronously."""
question = inputs["question"]
search_results = self.searcher(
{"question": question},
callbacks=run_manager.get_child() if run_manager else None,
)
urls = search_results["urls"]
blobs = self.downloader.download(urls)
parser = MarkdownifyHTMLParser()
docs = itertools.chain.from_iterable(
parser.lazy_parse(blob) for blob in blobs if blob is not None
)
_inputs = [{"doc": doc, "question": question} for doc in docs]
results = self.reader(
_inputs, callbacks=run_manager.get_child() if run_manager else None
)
return {
"docs": [result["answer"] for result in results["inputs"]],
}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain asynchronously."""
question = inputs["question"]
search_results = await self.searcher.acall(
{"question": question},
callbacks=run_manager.get_child() if run_manager else None,
)
urls = search_results["urls"]
blobs = await self.downloader.adownload(urls)
parser = MarkdownifyHTMLParser()
docs = itertools.chain.from_iterable(
parser.lazy_parse(blob) for blob in blobs if blob is not None
)
_inputs = [{"doc": doc, "question": question} for doc in docs]
results = await self.reader.acall(
_inputs,
callbacks=run_manager.get_child() if run_manager else None,
)
return {
"docs": [result["answer"] for result in results["results"]],
}
@classmethod
def from_llms(
cls,
*,
query_generation_llm: BaseLanguageModel,
link_selection_llm: BaseLanguageModel,
underlying_reader_chain: LLMChain,
top_k_per_search: int = -1,
max_concurrency: int = 1,
max_num_pages_per_doc: int = 5,
text_splitter: Union[TextSplitter, Literal["recursive"]] = "recursive",
download_handler: Union[DownloadHandler, Literal["auto"]] = "auto",
text_splitter_kwargs: Optional[Mapping[str, Any]] = None,
) -> Research:
"""Helper to create a research chain from standard llm related components.
Args:
query_generation_llm: The language model to use for query generation.
link_selection_llm: The language model to use for link selection.
underlying_reader_chain: The chain to use to answer the question.
top_k_per_search: The number of documents to return per search.
max_concurrency: The maximum number of concurrent reads.
max_num_pages_per_doc: The maximum number of pages to read per document.
text_splitter: The text splitter to use to split the document into
smaller chunks.
download_handler: The download handler to use to download the documents.
Provide either a download handler or the name of a
download handler.
- "auto" swaps between using requests and playwright
text_splitter_kwargs: The keyword arguments to pass to the text splitter.
Only use when providing a text splitter as string.
Returns:
A research chain.
"""
if isinstance(text_splitter, str):
if text_splitter == "recursive":
_text_splitter_kwargs = text_splitter_kwargs or {}
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter(
**_text_splitter_kwargs
)
else:
raise ValueError(f"Invalid text splitter: {text_splitter}")
elif isinstance(text_splitter, TextSplitter):
_text_splitter = text_splitter
else:
raise TypeError(f"Invalid text splitter: {type(text_splitter)}")
if isinstance(download_handler, str):
if download_handler == "auto":
_download_handler: DownloadHandler = AutoDownloadHandler()
else:
raise ValueError(f"Invalid download handler: {download_handler}")
elif isinstance(download_handler, DownloadHandler):
_download_handler = download_handler
else:
raise TypeError(f"Invalid download handler: {type(download_handler)}")
searcher = GenericSearcher.from_llms(
link_selection_llm,
query_generation_llm,
top_k_per_search=top_k_per_search,
)
doc_reading_chain = DocReadingChain(
chain=underlying_reader_chain,
max_num_docs=max_num_pages_per_doc,
text_splitter=_text_splitter,
)
# Can read multiple documents in parallel
multi_reader = ParallelApplyChain(
chain=doc_reading_chain,
max_concurrency=max_concurrency,
)
return cls(searcher=searcher, reader=multi_reader, downloader=_download_handler)

View File

@@ -0,0 +1,140 @@
"""Module contains code for crawling a blob (e.g., HTML file) for links.
The main idea behind the crawling module is to identify additional links
that are worth exploring to find more documents that may be relevant for being
able to answer the question correctly.
"""
import urllib.parse
from typing import Any, Dict, List, Tuple
from bs4 import BeautifulSoup, PageElement
from langchain.base_language import BaseLanguageModel
from langchain.chains.classification.multiselection import MultiSelectChain
from langchain.chains.research.typedefs import BlobCrawler
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
def _get_ahref_snippets(html: str, num_chars: int = 0) -> Dict[str, Any]:
"""Get a list of <a> tags as snippets from the given html.
Args:
html: the html to get snippets from.
num_chars: the number of characters to get around the <a> tags.
Returns:
a list of snippets.
"""
soup = BeautifulSoup(html, "html.parser")
title = soup.title.string.strip()
snippets = []
for idx, a_tag in enumerate(soup.find_all("a")):
before_text = _get_surrounding_text(a_tag, num_chars, is_before=True)
after_text = _get_surrounding_text(a_tag, num_chars, is_before=False)
snippet = {
"id": idx,
"before": before_text.strip().replace("\n", " "),
"link": a_tag.get("href").replace("\n", " ").strip(),
"content": a_tag.text.replace("\n", " ").strip(),
"after": after_text.strip().replace("\n", " "),
}
snippets.append(snippet)
return {
"snippets": snippets,
"title": title,
}
def _extract_records(blob: Blob) -> Tuple[List[Dict[str, Any]], Tuple[str, ...]]:
"""Extract records from a blob."""
if blob.mimetype == "text/html":
info = _get_ahref_snippets(blob.as_string(), num_chars=100)
return (
[
{
"content": d["content"],
"link": d["link"],
"before": d["before"],
"after": d["after"],
}
for d in info["snippets"]
],
("link", "content", "before", "after"),
)
else:
raise ValueError(
"Can only extract records from HTML/JSON blobs. Got {blob.mimetype}"
)
class ChainCrawler(BlobCrawler):
def __init__(self, chain: MultiSelectChain, parser: BaseBlobParser) -> None:
"""Crawl the blob using an LLM."""
self.chain = chain
self.parser = parser
def crawl(self, blob: Blob, question: str) -> List[str]:
"""Explore the blob and suggest additional content to explore."""
if not blob.source:
raise NotImplementedError()
records, columns = _extract_records(blob)
result = self.chain(
inputs={"question": question, "choices": records, "columns": columns},
)
selected_records = result["selected"]
urls = [
# TODO(): handle absolute links
urllib.parse.urljoin(blob.source, record["link"])
for record in selected_records
if "mailto:" not in record["link"]
]
return urls
@classmethod
def from_default(
cls,
llm: BaseLanguageModel,
blob_parser: BaseBlobParser = MarkdownifyHTMLParser(),
) -> "ChainCrawler":
"""Create a crawler from the default LLM."""
chain = MultiSelectChain.from_default(llm)
return cls(chain=chain, parser=blob_parser)
def _get_surrounding_text(tag: PageElement, n: int, *, is_before: bool = True) -> str:
"""Get surrounding text the given tag in the given direction.
Args:
tag: the tag to get surrounding text for.
n: number of characters to get
is_before: Whether to get text before or after the tag.
Returns:
the surrounding text in the given direction.
"""
text = ""
current = tag.previous_element if is_before else tag.next_element
while current and len(text) < n:
current_text = str(current.text).strip()
current_text = (
current_text
if len(current_text) + len(text) <= n
else current_text[: n - len(text)]
)
if is_before:
text = current_text + " " + text
else:
text = text + " " + current_text
current = current.previous_element if is_before else current.next_element
return text

View File

@@ -0,0 +1,175 @@
"""Module contains code for fetching documents from the web using playwright.
This module currently re-uses the code from the `web_base` module to avoid
re-implementing rate limiting behavior.
The module contains downloading interfaces.
Sub-classing with the given interface should allow a user to add url based
user-agents and authentication if needed.
Downloading is batched by default to allow efficient parallelization.
"""
import abc
import asyncio
import mimetypes
from typing import Any, List, Optional, Sequence
from bs4 import BeautifulSoup
from langchain.document_loaders import WebBaseLoader
from langchain.document_loaders.blob_loaders import Blob
MaybeBlob = Optional[Blob]
def _is_javascript_required(html_content: str) -> bool:
"""Heuristic to determine whether javascript execution is required.
Args:
html_content (str): The HTML content to check.
Returns:
bool: True if javascript execution is required, False otherwise.
"""
# Parse the HTML content using BeautifulSoup
soup = BeautifulSoup(html_content, "lxml")
# Count the number of HTML elements
body = soup.body
if not body:
return True
num_elements = len(body.find_all())
requires_javascript = num_elements < 1
return requires_javascript
class DownloadHandler(abc.ABC):
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLs synchronously."""
raise NotImplementedError()
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLs asynchronously."""
raise NotImplementedError()
class PlaywrightDownloadHandler(DownloadHandler):
"""Download URLS using playwright.
This is an implementation of the download handler that uses playwright to download
urls. This is useful for downloading urls that require javascript to be executed.
"""
def __init__(self, timeout: int = 5) -> None:
"""Initialize the download handler.
Args:
timeout: The timeout in seconds to wait for a page to load.
"""
self.timeout = timeout
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download list of urls synchronously."""
return asyncio.run(self.adownload(urls))
async def _download(self, browser: Any, url: str) -> Optional[str]:
"""Download a url asynchronously using playwright."""
from playwright.async_api import TimeoutError
page = await browser.new_page()
try:
# Up to 5 seconds to load the page.
await page.goto(url, wait_until="networkidle")
html_content = await page.content()
except TimeoutError:
html_content = None
return html_content
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLs asynchronously using playwright.
Args:
urls: The urls to download.
Returns:
list of blobs containing the downloaded content.
"""
from playwright.async_api import async_playwright
async with async_playwright() as p:
browser = await p.chromium.launch()
tasks = [self._download(browser, url) for url in urls]
contents = await asyncio.gather(*tasks, return_exceptions=True)
await browser.close()
return _repackage_as_blobs(urls, contents)
class RequestsDownloadHandler(DownloadHandler):
def __init__(self, web_downloader: Optional[WebBaseLoader] = None) -> None:
"""Initialize the requests download handler."""
self.web_downloader = web_downloader or WebBaseLoader(web_path=[])
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLS synchronously."""
return asyncio.run(self.adownload(urls))
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of urls asynchronously using playwright."""
download = WebBaseLoader(web_path=[]) # Place holder
contents = await download.fetch_all(list(urls))
return _repackage_as_blobs(urls, contents)
def _repackage_as_blobs(
urls: Sequence[str], contents: Sequence[Optional[str]]
) -> List[MaybeBlob]:
"""Repackage the contents as blobs."""
blobs: List[MaybeBlob] = []
for url, content in zip(urls, contents):
mimetype = mimetypes.guess_type(url)[0]
if content is None:
blobs.append(None)
else:
blobs.append(Blob(data=content or "", mimetype=mimetype, path=url))
return blobs
class AutoDownloadHandler(DownloadHandler):
"""Download URLs using the requests library if possible.
Fallback to using playwright if javascript is required.
"""
def __init__(self, web_downloader: Optional[WebBaseLoader] = None) -> None:
"""Initialize the auto download handler."""
self.requests_downloader = RequestsDownloadHandler(
web_downloader or WebBaseLoader(web_path=[])
)
self.playwright_downloader = PlaywrightDownloadHandler()
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of urls asynchronously using playwright."""
# Check if javascript is required
blobs = await self.requests_downloader.adownload(urls)
# Check if javascript is required
must_redownload = [
(idx, url)
for idx, (url, blob) in enumerate(zip(urls, blobs))
if blob is not None and _is_javascript_required(blob.as_string())
]
if must_redownload:
indexes, urls_to_redownload = zip(*must_redownload)
new_blobs = await self.playwright_downloader.adownload(urls_to_redownload)
for idx, blob in zip(indexes, new_blobs):
blobs[idx] = blob
return blobs
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLs synchronously."""
return asyncio.run(self.adownload(urls))

View File

@@ -0,0 +1,151 @@
"""Module contains supporting chains for research use case."""
import asyncio
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.schema import Document
from langchain.text_splitter import TextSplitter
class DocReadingChain(Chain):
"""A reader chain should use one of the QA chains to answer a question.
This chain is also responsible for splitting the document into smaller chunks
and then passing the chunks to an underlying QA chain.
A brute force chain that reads an entire document (or the first N pages).
"""
chain: Chain
"""The chain to use to answer the question."""
text_splitter: TextSplitter
"""The text splitter to use to split the document into smaller chunks."""
max_num_docs: int
"""The maximum number of documents to split the document into.
Use -1 to denote no limit to the number of pages to read.
"""
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["doc", "question"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["answer"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Process a long document synchronously."""
source_document = inputs["doc"]
question = inputs["question"]
sub_docs = self.text_splitter.split_documents([source_document])
if self.max_num_docs > 0:
_sub_docs = sub_docs[: self.max_num_docs]
else:
_sub_docs = sub_docs
response = self.chain(
{"input_documents": _sub_docs, "question": question},
callbacks=run_manager.get_child() if run_manager else None,
)
summary_doc = Document(
page_content=response["output_text"],
metadata=source_document.metadata,
)
return {"answer": summary_doc}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Process a long document asynchronously."""
source_document = inputs["doc"]
question = inputs["question"]
sub_docs = self.text_splitter.split_documents([source_document])
if self.max_num_docs > 0:
_sub_docs = sub_docs[: self.max_num_docs]
else:
_sub_docs = sub_docs
results = await self.chain.acall(
{"input_documents": _sub_docs, "question": question},
callbacks=run_manager.get_child() if run_manager else None,
)
summary_doc = Document(
page_content=results["output_text"],
metadata=source_document.metadata,
)
return {"answer": summary_doc}
class ParallelApplyChain(Chain):
"""Utility chain to apply a given chain in parallel across input documents.
This chain needs to handle a limit on concurrency.
WARNING: Parallelization only implemented on the async path.
"""
chain: Chain
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["inputs"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["results"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain."""
# TODO(): parallelize this
chain_inputs = inputs["inputs"]
results = [
self.chain(
chain_input,
callbacks=run_manager.get_child() if run_manager else None,
)
for chain_input in chain_inputs
]
return {"results": results}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain."""
chain_inputs = inputs["inputs"]
results = await asyncio.gather(
*[
self.chain.acall(
chain_input,
callbacks=run_manager.get_child() if run_manager else None,
)
for chain_input in chain_inputs
]
)
return {"results": results}

View File

@@ -0,0 +1,329 @@
"""Module for initiating a set of searches relevant for answering the question."""
from __future__ import annotations
import asyncio
import typing
from typing import Any, Dict, List, Mapping, Optional, Sequence
from bs4 import BeautifulSoup
from langchain import LLMChain, PromptTemplate, serpapi
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.classification.multiselection import MultiSelectChain
from langchain.schema import BaseOutputParser
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
"""Extract content from the given tag."""
soup = BeautifulSoup(html, "lxml")
queries = []
for query in soup.find_all(tag):
queries.append(query.text)
return queries
def _extract_href_tags(html: str) -> List[str]:
"""Extract href tags.
Args:
html: the html to extract href tags from.
Returns:
a list of href tags.
"""
href_tags = []
soup = BeautifulSoup(html, "html.parser")
for a_tag in soup.find_all("a"):
href = a_tag.get("href")
if href:
href_tags.append(href)
return href_tags
class QueryExtractor(BaseOutputParser[List[str]]):
"""An output parser that extracts all queries."""
def parse(self, text: str) -> List[str]:
"""Extract all content of <query> from the text."""
return _extract_content_from_tag(text, "query")
QUERY_GENERATION_PROMPT = PromptTemplate.from_template(
"""\
Suggest a few different search queries that could be used to identify web-pages \
that could answer the following question.
If the question is about a named entity start by listing very general \
searches (e.g., just the named entity) and then suggest searches more \
scoped to the question.
Input: ```Where did John Snow from Cambridge, UK work?```
Output: <query>John Snow</query>
<query>John Snow Cambridge UK</query>
<query> John Snow Cambridge UK work history </query>
<query> John Snow Cambridge UK cv </query>
Input: ```How many research papers did Jane Doe publish in 2010?```
Output: <query>Jane Doe</query>
<query>Jane Doe research papers</query>
<query>Jane Doe research research</query>
<query>Jane Doe publications</query>
<query>Jane Doe publications 2010</query>
Input: ```What is the capital of France?```
Output: <query>France</query>
<query>France capital</query>
<query>France capital city</query>
<query>France capital city name</query>
Input: ```What are the symptoms of COVID-19?```
Output: <query>COVID-19</query>
<query>COVID-19 symptoms</query>
<query>COVID-19 symptoms list</query>
<query>COVID-19 symptoms list WHO</query>
Input: ```What is the revenue stream of CVS?```
Output: <query>CVS</query>
<query>CVS revenue</query>
<query>CVS revenue stream</query>
<query>CVS revenue stream business model</query>
Input: ```{question}```
Output:
""",
output_parser=QueryExtractor(),
)
def _deduplicate_objects(
dicts: Sequence[Mapping[str, Any]], key: str
) -> List[Mapping[str, Any]]:
"""Deduplicate objects by the given key.
TODO(Eugene): add a way to add weights to the objects.
Args:
dicts: a list of dictionaries to deduplicate.
key: the key to deduplicate by.
Returns:
a list of deduplicated dictionaries.
"""
unique_values = set()
deduped: List[Mapping[str, Any]] = []
for d in dicts:
value = d[key]
if value not in unique_values:
unique_values.add(value)
deduped.append(d)
return deduped
def _run_searches(queries: Sequence[str], top_k: int = -1) -> List[Mapping[str, Any]]:
"""Run the given queries and return all the search results.
This function can return duplicated results, de-duplication can take place later
and take into account the frequency of appearance.
Args:
queries: a list of queries to run
top_k: the number of results to return, if -1 return all results
Returns:
a list of unique search results
"""
wrapper = serpapi.SerpAPIWrapper()
results = []
for query in queries:
result = wrapper.results(query)
all_organic_results = result.get("organic_results", [])
if top_k <= 0:
organic_results = all_organic_results
else:
organic_results = all_organic_results[:top_k]
results.extend(organic_results)
return results
async def _arun_searches(
queries: Sequence[str], top_k: int = -1
) -> List[Mapping[str, Any]]:
"""Run the given queries and return all the search results.
This function can return duplicated results, de-duplication can take place later
and take into account the frequency of appearance.
Args:
queries: a list of queries to run
top_k: the number of results to return, if -1 return all results
Returns:
a list of unique search results
"""
wrapper = serpapi.SerpAPIWrapper()
tasks = [wrapper.aresults(query) for query in queries]
results = await asyncio.gather(*tasks)
finalized_results = []
for result in results:
all_organic_results = result.get("organic_results", [])
if top_k <= 0:
organic_results = all_organic_results
else:
organic_results = all_organic_results[:top_k]
finalized_results.extend(organic_results)
return finalized_results
# PUBLIC API
def make_query_generator(llm: BaseLanguageModel) -> LLMChain:
"""Use an LLM to break down a complex query into a list of simpler queries.
The simpler queries are used to run searches against a search engine in a goal
to increase the recall step and retrieve as many relevant results as possible.
Query:
Does Harrison Chase of Langchain like to eat pizza or play squash?
May be broken down into a list of simpler queries like:
- Harrison Chase
- Harrison Chase Langchain
- Harrison Chase Langchain pizza
- Harrison Chase Langchain squash
"""
return LLMChain(
llm=llm,
output_key="urls",
prompt=QUERY_GENERATION_PROMPT,
)
class GenericSearcher(Chain):
"""A chain that takes a complex question and identifies a list of relevant urls.
The chain works by:
1. Breaking a complex question into a series of simpler queries using an LLM.
2. Running the queries against a search engine.
3. Selecting the most relevant urls using an LLM (can be replaced with tf-idf
or other models).
This chain is not meant to be used for questions requiring multiple hops to answer.
For example, the age of leonardo dicaprio's girlfriend is a multi-hop question
This kind of question requires a slightly different approach.
This chain is meant to handle questions for which one wants
to collect information from multiple sources.
To extend implementation:
* Parameterize the search engine (to allow for non serp api search engines)
* Expose an abstract interface for query generator and link selection model
* Expose promoted answers from search engines as blobs that should be summarized
"""
query_generator: LLMChain
"""An LLM used to break down a complex question into a list of simpler queries."""
link_selection_model: Chain
"""An LLM that is used to select the most relevant urls from the search results."""
top_k_per_search: int = -1
"""The number of top urls to select from each search."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["question"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["urls"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
question = inputs["question"]
queries = typing.cast(
List[str],
self.query_generator.predict_and_parse(
callbacks=run_manager.get_child() if run_manager else None,
question=question,
),
)
results = _run_searches(queries, top_k=self.top_k_per_search)
deuped_results = _deduplicate_objects(results, "link")
records = [
{"link": result["link"], "title": result["title"]}
for result in deuped_results
]
response_ = self.link_selection_model(
{
"question": question,
"choices": records,
},
callbacks=run_manager.get_child() if run_manager else None,
)
return {"urls": [result["link"] for result in response_["selected"]]}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
question = inputs["question"]
queries = typing.cast(
List[str],
await self.query_generator.apredict_and_parse(
callbacks=run_manager.get_child() if run_manager else None,
question=question,
),
)
results = _run_searches(queries, top_k=self.top_k_per_search)
deuped_results = _deduplicate_objects(results, "link")
records = [
{"link": result["link"], "title": result["title"]}
for result in deuped_results
]
response_ = await self.link_selection_model.acall(
{
"question": question,
"choices": records,
},
callbacks=run_manager.get_child() if run_manager else None,
)
return {"urls": [result["link"] for result in response_["selected"]]}
@classmethod
def from_llms(
cls,
link_selection_llm: BaseLanguageModel,
query_generation_llm: BaseLanguageModel,
*,
top_k_per_search: int = -1,
) -> GenericSearcher:
"""Initialize the searcher from a language model."""
link_selection_model = MultiSelectChain.from_default(llm=link_selection_llm)
query_generation_model = make_query_generator(query_generation_llm)
return cls(
link_selection_model=link_selection_model,
query_generator=query_generation_model,
top_k_per_search=top_k_per_search,
)

View File

@@ -0,0 +1,12 @@
import abc
from typing import List
from langchain.document_loaders.blob_loaders import Blob
class BlobCrawler(abc.ABC):
"""Crawl a blob and identify links to related content."""
@abc.abstractmethod
def crawl(self, blob: Blob, query: str) -> List[str]:
"""Explore the blob and identify links to relevant content."""

View File

@@ -0,0 +1,74 @@
"""Load and chunk HTMLs with potential pre-processing to clean the html."""
import re
from typing import Iterator, Tuple
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema import Document
# Regular expression pattern to detect multiple new lines in a row with optional
# whitespace in between
CONSECUTIVE_NEW_LINES = re.compile(r"\n(\s*\n)+", flags=re.UNICODE)
def _get_mini_html(html: str, *, tags_to_remove: Tuple[str, ...] = tuple()) -> str:
"""Clean up HTML tags."""
# Parse the HTML document using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
# Remove all CSS stylesheets
for stylesheet in soup.find_all("link", rel="stylesheet"):
stylesheet.extract()
for tag_to_remove in tags_to_remove:
# Remove all matching tags
for tag in soup.find_all(tag_to_remove):
tag.extract()
new_html = repr(soup)
return new_html
def _clean_html(html: str, *, tags_to_remove: Tuple[str, ...] = tuple()) -> str:
"""Clean up HTML and convert to markdown using markdownify."""
try:
import markdownify
except ImportError:
raise ImportError(
"The markdownify package is required to parse HTML files. "
"Please install it with `pip install markdownify`."
)
html = _get_mini_html(html, tags_to_remove=tags_to_remove)
md = markdownify.markdownify(html)
return CONSECUTIVE_NEW_LINES.sub("\n\n", md).strip()
## PUBLIC API
class MarkdownifyHTMLParser(BaseBlobParser):
"""A blob parser to parse HTML content.."""
def __init__(
self,
tags_to_remove: Tuple[str, ...] = ("svg", "img", "script", "style"),
) -> None:
"""Initialize the preprocessor.
Args:
tags_to_remove: A tuple of tags to remove from the HTML
"""
self.tags_to_remove = tags_to_remove
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
yield Document(
page_content=_clean_html(
blob.as_string(), tags_to_remove=self.tags_to_remove
),
metadata={"source": blob.source},
)

View File

@@ -0,0 +1,32 @@
"""Tests for the downloader."""
import pytest
from langchain.chains.research.download import (
_is_javascript_required,
)
@pytest.mark.requires("lxml")
def test_is_javascript_required() -> None:
"""Check whether a given page should be re-downloaded with javascript executed."""
assert not _is_javascript_required(
"""
<html>
<body>
<p>Check whether javascript is required.</p>
</body>
</html>
"""
)
assert _is_javascript_required(
"""
<html>
<script>
console.log("Javascript is required.");
</script>
<body>
</body>
</html>
"""
)