mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
Compare commits
38 Commits
erick/anth
...
eugene/res
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27b87c64a0 | ||
|
|
c8f593b2c0 | ||
|
|
16f41719a4 | ||
|
|
c62f497082 | ||
|
|
38aef8c252 | ||
|
|
1911388d9d | ||
|
|
37bdeb60fc | ||
|
|
ece6b598c4 | ||
|
|
b641bde197 | ||
|
|
68e39f7f26 | ||
|
|
7adbcb195d | ||
|
|
307df3ebed | ||
|
|
2b42b9cb82 | ||
|
|
13c82e6b66 | ||
|
|
f0e78d7efd | ||
|
|
ad2b777536 | ||
|
|
f79582c548 | ||
|
|
203cb5e307 | ||
|
|
9c6accfa1a | ||
|
|
39a2d2511d | ||
|
|
679eb9f14f | ||
|
|
f24e521015 | ||
|
|
7e14388ba8 | ||
|
|
b7dabe8f50 | ||
|
|
f449d26083 | ||
|
|
1f1db7c96a | ||
|
|
6a558c72a2 | ||
|
|
d381c4fad8 | ||
|
|
c3d260ffdc | ||
|
|
08b6c75743 | ||
|
|
69cc8d7f73 | ||
|
|
fe89220aac | ||
|
|
0765679292 | ||
|
|
95b3774479 | ||
|
|
ebfd539e43 | ||
|
|
de7ab6be16 | ||
|
|
53a7e2b851 | ||
|
|
65aad4c0bb |
469
docs/modules/chains/examples/research.ipynb
Normal file
469
docs/modules/chains/examples/research.ipynb
Normal file
@@ -0,0 +1,469 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d0187afb-f460-431f-aee3-50d68bc33446",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Research Chain\n",
|
||||
"\n",
|
||||
"This is an experimental research chain that tries to answer \"researchy\" questions using information on the web.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"For example, \n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"Compile information about Albert Einstein.\n",
|
||||
"Ignore if it's a different Albert Einstein. \n",
|
||||
"Only include information you're certain about.\n",
|
||||
"\n",
|
||||
"Include:\n",
|
||||
"* education history\n",
|
||||
"* major contributions\n",
|
||||
"* names of spouse \n",
|
||||
"* date of birth\n",
|
||||
"* place of birth\n",
|
||||
"* a 3 sentence short biography\n",
|
||||
"\n",
|
||||
"Format your answer in a bullet point format for each sub-question.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Or replace `Albert Einstein` with another person of interest (e.g., John Smith of Boston).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The chain is composed of the following components:\n",
|
||||
"\n",
|
||||
"1. A searcher that searches for documents using a search engine.\n",
|
||||
" - The searcher is responsible to return a list of URLs of documents that\n",
|
||||
" may be relevant to read to be able to answer the question.\n",
|
||||
"2. A downloader that downloads the documents.\n",
|
||||
"3. An HTML to markdown parser (hard coded) that converts the HTML to markdown.\n",
|
||||
" * Conversion to markdown is lossy\n",
|
||||
" * However, it can significantly reduce the token count of the document\n",
|
||||
" * Markdown helps to preserve some styling information\n",
|
||||
" (e.g., bold, italics, links, headers) which is expected to help the reader\n",
|
||||
" to answer certain kinds of questions correctly.\n",
|
||||
"4. A reader that reads the documents and produces an answer.\n",
|
||||
"\n",
|
||||
"## Limitations\n",
|
||||
"\n",
|
||||
"* Quality of results depends on LLM used, and can be improved by providing more specialized parsers (e.g., parse only the body of articles).\n",
|
||||
"* If asking about people, provide enough information to disambiguate the person.\n",
|
||||
"* Content downloader may get blocked (e.g., if attempting to download from linkedin) -- may need to read terms of service / user agents appropriately.\n",
|
||||
"* Chain can be potentially long running (use initialization parameters to control how many options are explored) -- use async implementation as it uses more concurrency.\n",
|
||||
"* This research chain only implements a single hop at the moment; i.e.,\n",
|
||||
" it goes from the questions to a list of URLs to documents to compiling answers.\n",
|
||||
" Without continuing the crawl, web-sites that require pagnation will not be explored fully.\n",
|
||||
"* The reader chain must match the type of question. For example, the QA refine chain \n",
|
||||
" isn't good at extracting a list of entries from a long document.\n",
|
||||
" \n",
|
||||
"## Extending\n",
|
||||
"\n",
|
||||
"* Continue crawling documents to discover more relevant pages that were not surfaced by the search engine.\n",
|
||||
"* Adapt reading strategy based on nature of question.\n",
|
||||
"* Analyze the query and determine whether the query is a multi-hop query and change search/crawling strategy based on that.\n",
|
||||
"* Break components into tools that can be exposed to an agent. :)\n",
|
||||
"* Add cheaper strategies for selecting which links should be explored further (e.g., based on tf-idf similarity instead of gpt-4)\n",
|
||||
"* Add a summarization chain on top of the individually collected answers.\n",
|
||||
"* Improve strategy to ignore irrelevant information."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d937a38-66c6-4aa2-87bb-337101cfb112",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Requirements\n",
|
||||
"\n",
|
||||
"Please install: \n",
|
||||
"\n",
|
||||
"* `playwright` for fetching content from the web (or use the RequestsDownloadHandler)\n",
|
||||
"* `lxml` and `markdownify` for parsing HTMLs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "7eb466b8-24fa-4acc-b0ce-06fcfa2fa9c4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.research.api import Research\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
"from langchain.chains.research.download import PlaywrightDownloadHandler\n",
|
||||
"# If you don't have playwright installed, can experiment with requests\n",
|
||||
"# Be aware that some web-pages won't download properly as javascript won't be executed\n",
|
||||
"from langchain.chains.research.download import RequestsDownloadHandler "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "70474885-0acd-41b2-8050-15dd54f44f1e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"\"\"\\\n",
|
||||
"Compile information about Albert Einstein.\n",
|
||||
"Ignore if it's a different Albert Einstein. \n",
|
||||
"Only include information you're certain about.\n",
|
||||
"\n",
|
||||
"Include:\n",
|
||||
"* education history\n",
|
||||
"* major contributions\n",
|
||||
"* names of spouse \n",
|
||||
"* date of birth\n",
|
||||
"* place of birth\n",
|
||||
"* a 3 sentence short biography\n",
|
||||
"\n",
|
||||
"Format your answer in a bullet point format for each sub-question.\n",
|
||||
"\"\"\".strip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6613da1c-3349-45f4-9770-19986750d548",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Instantiate LLMs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "e74a44b4-2075-4cc6-933e-c769bf3f6002",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(\n",
|
||||
" temperature=0, model=\"text-davinci-003\"\n",
|
||||
") # Used for the readers and the query generator\n",
|
||||
"selector_llm = ChatOpenAI(\n",
|
||||
" temperature=0, model=\"gpt-4\"\n",
|
||||
") # Used for selecting which links to explore"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b243364a-e79b-432d-8035-3de8caf554a8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create a chain that can be used to extract the answer to the question above from a given document.\n",
|
||||
"\n",
|
||||
"This chain must be tailored to the task."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "2f538062-14e3-49ab-9b25-bc470eb5869c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qa_chain = load_qa_chain(llm, chain_type=\"refine\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a96f3ed3-10de-4a85-9e93-a8b78d8bfbb6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"research = Research.from_llms(\n",
|
||||
" query_generation_llm=llm,\n",
|
||||
" link_selection_llm=selector_llm,\n",
|
||||
" underlying_reader_chain=qa_chain,\n",
|
||||
" top_k_per_search=1,\n",
|
||||
" max_num_pages_per_doc=3,\n",
|
||||
" download_handler=PlaywrightDownloadHandler(),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "3207c696-a72c-4378-b427-7d285f5fdd1c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"results = await research.acall(inputs={\"question\": question})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "843616b3-32d7-49c7-a42b-b0272d71f3ed",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
|
||||
"https://en.wikipedia.org/wiki/Albert_Einstein\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Albert Einstein:\n",
|
||||
"* Education history: Attended elementary school in Munich, Germany, and later attended the Swiss Federal Polytechnic School in Zurich, Switzerland.\n",
|
||||
"* Major contributions: Developed the theory of relativity, made major contributions to quantum theory, and won the Nobel Prize in Physics in 1921. He also published more than 300 scientific papers and 150 non-scientific works. He was also the first to propose the existence of black holes and gravitational waves. He was also a polyglot, speaking over 15 languages, including Afrikaans, Alemannisch, Amharic, Anarâškielâ, Angika, Old English, Abkhazian, Arabic, Aragonese, Western Armenian, Aromanian, Arpitan, Assamese, Asturian, Guarani, Aymara, Azerbaijani, South Azerbaijani, Balinese, Bambara, Bangla, Min Nan Chinese, Basa Banyumasan, Bashkir, Belarusian, Belarusian (Taraškievica orthography), Bhojpuri, Central Bikol, and Bulgarian.\n",
|
||||
"* Names of spouse: Married Mileva Marić in 1903 and Elsa Löwenthal in 1919\n",
|
||||
"----------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
|
||||
"https://www.advergize.com/edu/7-albert-einstein-inventions-contributions/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Education History:\n",
|
||||
"* Attended Aargau Cantonal School in Switzerland from 1895-1896\n",
|
||||
"* Attended ETH Zurich from 1896-1900\n",
|
||||
"* Received a PhD from the University of Zurich in 1905\n",
|
||||
"\n",
|
||||
"Major Contributions:\n",
|
||||
"* Theory of Relativity\n",
|
||||
"* Photoelectric Effect\n",
|
||||
"* Brownian Motion\n",
|
||||
"* Bose-Einstein Condensate\n",
|
||||
"* Unified Field Theory\n",
|
||||
"* Quantum Theory of Light\n",
|
||||
"* E=mc2\n",
|
||||
"* Manhattan Project\n",
|
||||
"* Einstein’s Refrigerator\n",
|
||||
"* Sky is Blue\n",
|
||||
"* Quantum Theory of Light\n",
|
||||
"* Photoelectric Effect\n",
|
||||
"* Brownian Movement\n",
|
||||
"* Special Theory of Relativity\n",
|
||||
"* General Theory of Relativity\n",
|
||||
"* Manhattan Project\n",
|
||||
"* Einstein’s Refrigerator\n",
|
||||
"\n",
|
||||
"Names of Spouse:\n",
|
||||
"* Mileva Maric (1903-1919)\n",
|
||||
"* Elsa Löwenthal (1919-1936)\n",
|
||||
"\n",
|
||||
"Date of Birth: March 14, 1879\n",
|
||||
"\n",
|
||||
"Place of Birth: Ulm, Germany\n",
|
||||
"\n",
|
||||
"Short Biography:\n",
|
||||
"Albert Einstein was a German-born physicist who developed the theory of relativity. He is widely considered one of the most influential scientists of the 20th century and is known for his mass-energy equivalence formula\n",
|
||||
"----------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
|
||||
"https://www.nobelprize.org/prizes/physics/1921/einstein/biographical/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Education History:\n",
|
||||
"* Attended Aargau Cantonal School in Aarau, Switzerland from 1895-1896\n",
|
||||
"* Attended ETH Zurich (Swiss Federal Institute of Technology) from 1896-1900\n",
|
||||
"* Obtained his doctorate degree from Swiss Federal Polytechnic School in Zurich in 1901\n",
|
||||
"\n",
|
||||
"Major Contributions:\n",
|
||||
"* Developed the theory of relativity\n",
|
||||
"* Developed the mass-energy equivalence formula (E=mc2)\n",
|
||||
"* Developed the law of the photoelectric effect\n",
|
||||
"* Postulated that the correct interpretation of the special theory of relativity must also furnish a theory of gravitation\n",
|
||||
"* Contributed to the problems of the theory of radiation and statistical mechanics\n",
|
||||
"* Investigated the thermal properties of light with a low radiation density and his observations laid the foundation of the photon theory of light\n",
|
||||
"* Contributed to statistical mechanics by his development of the quantum theory of a monatomic gas\n",
|
||||
"* Worked towards the unification of the basic concepts of physics, taking the opposite approach, geometrisation, to the majority of physicists\n",
|
||||
"\n",
|
||||
"Names of Spouse:\n",
|
||||
"* Mileva Marić (1903-1919)\n",
|
||||
"* Elsa Löwenthal (1919-1936)\n",
|
||||
"\n",
|
||||
"Date of Birth:\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for doc in results[\"docs\"]:\n",
|
||||
" print(\"--\" * 80)\n",
|
||||
" print(doc.metadata[\"source\"])\n",
|
||||
" print(doc.page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "edcaf496-3679-4b58-9baa-6124e1cc3435",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If useful we can produce another summary!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "76136bff-b7df-4539-9bcb-760fc4449390",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qa_chain = load_qa_chain(llm, chain_type=\"stuff\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "b352f3f3-6777-4795-acbb-ed26ecac137d",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"summary = await qa_chain.acall(\n",
|
||||
" inputs={\"input_documents\": results[\"docs\"], \"question\": question}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "1c90c305-a89d-42e2-b975-dda039e816b6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Education History:\n",
|
||||
"* Attended Aargau Cantonal School in Aarau, Switzerland from 1895-1896\n",
|
||||
"* Attended ETH Zurich (Swiss Federal Institute of Technology) from 1896-1900\n",
|
||||
"* Obtained his doctorate degree from Swiss Federal Polytechnic School in Zurich in 1901\n",
|
||||
"\n",
|
||||
"Major Contributions:\n",
|
||||
"* Developed the theory of relativity\n",
|
||||
"* Developed the mass-energy equivalence formula (E=mc2)\n",
|
||||
"* Developed the law of the photoelectric effect\n",
|
||||
"* Postulated that the correct interpretation of the special theory of relativity must also furnish a theory of gravitation\n",
|
||||
"* Contributed to the problems of the theory of radiation and statistical mechanics\n",
|
||||
"* Investigated the thermal properties of light with a low radiation density and his observations laid the foundation of the photon theory of light\n",
|
||||
"* Contributed to statistical mechanics by his development of the quantum theory of a monatomic gas\n",
|
||||
"* Worked towards the unification of the basic concepts of physics, taking the opposite approach, geometrisation, to the majority of physicists\n",
|
||||
"\n",
|
||||
"Names of Spouse:\n",
|
||||
"* Mileva Marić (1903-1919)\n",
|
||||
"* Elsa Löwenthal (1919-1936)\n",
|
||||
"\n",
|
||||
"Date of Birth: March\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(summary[\"output_text\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "af2adfee-85d3-41af-900a-c594dc01ce16",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Under the hood"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c307aa60-7e75-48e8-ba72-b45507ed3fe0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A searcher is invoked first to find URLs that are good to explore"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "4e2c369c-8763-458e-9ab8-684466395890",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'question': \"Compile information about Albert Einstein.\\nIgnore if it's a different Albert Einstein. \\nOnly include information you're certain about.\\n\\nInclude:\\n* education history\\n* major contributions\\n* names of spouse \\n* date of birth\\n* place of birth\\n* a 3 sentence short biography\\n\\nFormat your answer in a bullet point format for each sub-question.\",\n",
|
||||
" 'urls': ['https://en.wikipedia.org/wiki/Albert_Einstein',\n",
|
||||
" 'https://www.britannica.com/biography/Albert-Einstein',\n",
|
||||
" 'https://www.advergize.com/edu/7-albert-einstein-inventions-contributions/',\n",
|
||||
" 'https://www.nobelprize.org/prizes/physics/1921/einstein/biographical/']}"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"suggetions = await research.searcher.acall(inputs={\"question\": question})\n",
|
||||
"suggetions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "24099529-2d3d-4eca-8a1f-45a2539c8842",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The webpages are downloaded"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "c6f9d1b5-e513-4d8d-b325-32eacbee92b4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"blobs = await research.downloader.adownload(suggetions[\"urls\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "58909ce0-fd1c-4e09-9d13-b43135ee8038",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The blobs are parsed with an HTML parser and read by the reader chain (not shown) -- see underlying code for details."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
0
langchain/chains/classification/__init__.py
Normal file
0
langchain/chains/classification/__init__.py
Normal file
267
langchain/chains/classification/multiselection.py
Normal file
267
langchain/chains/classification/multiselection.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Perform classification / selection using language models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
from io import StringIO
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
MULTI_SELECT_TEMPLATE = """\
|
||||
Here is a table in CSV format:
|
||||
|
||||
{records}
|
||||
|
||||
---
|
||||
|
||||
question:
|
||||
|
||||
{question}
|
||||
|
||||
---
|
||||
|
||||
Output IDs of rows that answer the question or match the question.
|
||||
|
||||
For example, if row id 132 and id 133 are relevant, output: <ids>132,133</ids>
|
||||
|
||||
---
|
||||
|
||||
Begin:"""
|
||||
|
||||
|
||||
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
|
||||
"""Extract content from the given tag."""
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
queries = []
|
||||
for query in soup.find_all(tag):
|
||||
queries.append(query.text)
|
||||
return queries
|
||||
|
||||
|
||||
class IDParser(BaseOutputParser[List[int]]):
|
||||
"""An output parser that extracts all IDs from the output."""
|
||||
|
||||
def parse(self, text: str) -> List[int]:
|
||||
"""Parse the text and return a list of IDs"""
|
||||
tags = _extract_content_from_tag(text, "ids")
|
||||
|
||||
if not tags:
|
||||
return []
|
||||
|
||||
if len(tags) > 1:
|
||||
# Fail if more than 1 tag group is identified
|
||||
return []
|
||||
|
||||
tag = tags[0]
|
||||
ids = tag.split(",")
|
||||
|
||||
finalized_ids = []
|
||||
for idx in ids:
|
||||
if idx.isdigit():
|
||||
finalized_ids.append(int(idx))
|
||||
return finalized_ids
|
||||
|
||||
|
||||
def _write_records_to_string(
|
||||
records: Sequence[Mapping[str, Any]],
|
||||
*,
|
||||
columns: Optional[Sequence[str]] = None,
|
||||
delimiter: str = "|",
|
||||
) -> str:
|
||||
"""Write records to a CSV string.
|
||||
|
||||
Args:
|
||||
records: a list of records, assumes that all records have all keys
|
||||
columns: a list of columns to include in the CSV
|
||||
delimiter: the delimiter to use
|
||||
|
||||
Returns:
|
||||
a CSV string
|
||||
"""
|
||||
buffer = StringIO()
|
||||
if columns is None:
|
||||
existing_columns: Set[str] = set()
|
||||
for record in records:
|
||||
existing_columns.update(record.keys())
|
||||
_columns: Sequence[str] = sorted(existing_columns)
|
||||
else:
|
||||
_columns = columns
|
||||
|
||||
# Make sure the id column is always first
|
||||
_columns_with_id_first = list(_columns)
|
||||
|
||||
if "id" in _columns_with_id_first:
|
||||
_columns_with_id_first.remove("id")
|
||||
|
||||
# Make sure the `id` column is always first
|
||||
_columns_with_id_first.insert(0, "id")
|
||||
|
||||
writer = csv.DictWriter(
|
||||
buffer,
|
||||
fieldnames=_columns_with_id_first,
|
||||
delimiter=delimiter,
|
||||
)
|
||||
writer.writeheader()
|
||||
writer.writerows(records)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _batch(iterable: Iterable[T], size: int) -> Iterator[List[T]]:
|
||||
"""Batch an iterable into chunks of size `size`.
|
||||
|
||||
Args:
|
||||
iterable: the iterable to batch
|
||||
size: the size of each batch
|
||||
|
||||
Returns:
|
||||
iterator over batches of size `size` except for last batch which will be up
|
||||
to size `size`
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
batch = list(islice(iterator, size))
|
||||
if not batch:
|
||||
return
|
||||
yield batch
|
||||
|
||||
|
||||
class MultiSelectChain(Chain):
|
||||
"""A chain that performs multi-selection from a list of choices."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["question", "choices"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["selected"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain."""
|
||||
choices = inputs["choices"]
|
||||
question = inputs["question"]
|
||||
columns = inputs.get("columns", None)
|
||||
|
||||
selected: List[Mapping[str, Any]] = []
|
||||
# TODO(): Balance choices into equal batches with constraint dependent
|
||||
# on context window and prompt
|
||||
max_choices = 30
|
||||
|
||||
for choice_batch in _batch(choices, max_choices):
|
||||
records_with_ids = [
|
||||
{**record, "id": idx} for idx, record in enumerate(choice_batch)
|
||||
]
|
||||
records_str = _write_records_to_string(
|
||||
records_with_ids, columns=columns, delimiter="|"
|
||||
)
|
||||
|
||||
indexes = cast(
|
||||
List[int],
|
||||
self.llm_chain.predict_and_parse(
|
||||
records=records_str,
|
||||
question=question,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
),
|
||||
)
|
||||
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
|
||||
selected.extend(choice_batch[i] for i in valid_indexes)
|
||||
|
||||
return {
|
||||
"selected": selected,
|
||||
}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
choices = inputs["choices"]
|
||||
question = inputs["question"]
|
||||
columns = inputs.get("columns", None)
|
||||
|
||||
selected: List[Mapping[str, Any]] = []
|
||||
# TODO(): Balance choices into equal batches with constraint dependent
|
||||
# on context window and prompt
|
||||
max_choices = 30
|
||||
|
||||
for choice_batch in _batch(choices, max_choices):
|
||||
records_with_ids = [
|
||||
{**record, "id": idx} for idx, record in enumerate(choice_batch)
|
||||
]
|
||||
records_str = _write_records_to_string(
|
||||
records_with_ids, columns=columns, delimiter="|"
|
||||
)
|
||||
|
||||
indexes = cast(
|
||||
List[int],
|
||||
await self.llm_chain.apredict_and_parse(
|
||||
records=records_str,
|
||||
question=question,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
),
|
||||
)
|
||||
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
|
||||
selected.extend(choice_batch[i] for i in valid_indexes)
|
||||
|
||||
return {
|
||||
"selected": selected,
|
||||
}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
return "multilabel_binary_classifier"
|
||||
|
||||
@classmethod
|
||||
def from_default(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
prompt: str = MULTI_SELECT_TEMPLATE,
|
||||
parser: BaseOutputParser = IDParser(),
|
||||
) -> MultiSelectChain:
|
||||
"""Provide a multilabel binary classifier."""
|
||||
prompt_template = PromptTemplate.from_template(prompt, output_parser=parser)
|
||||
if set(prompt_template.input_variables) != {"question", "records"}:
|
||||
raise ValueError("Prompt must contain only {question} and {records}")
|
||||
|
||||
return cls(
|
||||
llm_chain=LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt_template,
|
||||
)
|
||||
)
|
||||
0
langchain/chains/research/__init__.py
Normal file
0
langchain/chains/research/__init__.py
Normal file
200
langchain/chains/research/api.py
Normal file
200
langchain/chains/research/api.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.research.download import AutoDownloadHandler, DownloadHandler
|
||||
from langchain.chains.research.readers import DocReadingChain, ParallelApplyChain
|
||||
from langchain.chains.research.search import GenericSearcher
|
||||
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
class Research(Chain):
|
||||
"""A research chain.
|
||||
|
||||
A research chain is composed of the following components:
|
||||
|
||||
1. A searcher that searches for documents using a search engine.
|
||||
- The searcher is responsible to return a list of URLs of documents that
|
||||
may be relevant to read to be able to answer the question.
|
||||
2. A downloader that downloads the documents.
|
||||
3. An HTML to markdown parser (hard coded) that converts the HTML to markdown.
|
||||
* Conversion to markdown is lossy
|
||||
* However, it can significantly reduce the token count of the document
|
||||
* Markdown helps to preserve some styling information
|
||||
(e.g., bold, italics, links, headers) which is expected to help the reader
|
||||
to answer certain kinds of questions correctly.
|
||||
4. A reader that reads the documents and produces an answer.
|
||||
|
||||
Limitations:
|
||||
* This research chain only implements a single hop at the moment; i.e.,
|
||||
it goes from the questions to a list of URLs to documents to compiling
|
||||
answers.
|
||||
* The reader chain needs to match the task. For example, if using a QA refine
|
||||
chain, a task of collecting a list of entries from a long document will
|
||||
fail because the QA refine chain is not designed to handle such a task.
|
||||
|
||||
The chain can be extended to continue crawling the documents in attempt
|
||||
to discover relevant pages that were not surfaced by the search engine.
|
||||
|
||||
Amongst other problems without continuing the crawl, it is impossible to
|
||||
continue getting results from pages that involve pagination.
|
||||
"""
|
||||
|
||||
searcher: GenericSearcher
|
||||
"""The searcher to use to search for documents."""
|
||||
reader: Chain
|
||||
"""The reader to use to read documents and produce an answer."""
|
||||
downloader: DownloadHandler
|
||||
"""The downloader to use to download the documents.
|
||||
|
||||
A few different implementations of the download handler have been provided.
|
||||
|
||||
Keep in mind that some websites require execution of JavaScript to load
|
||||
the DOM.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["question"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["docs"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain synchronously."""
|
||||
question = inputs["question"]
|
||||
search_results = self.searcher(
|
||||
{"question": question},
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
urls = search_results["urls"]
|
||||
blobs = self.downloader.download(urls)
|
||||
parser = MarkdownifyHTMLParser()
|
||||
docs = itertools.chain.from_iterable(
|
||||
parser.lazy_parse(blob) for blob in blobs if blob is not None
|
||||
)
|
||||
_inputs = [{"doc": doc, "question": question} for doc in docs]
|
||||
results = self.reader(
|
||||
_inputs, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
return {
|
||||
"docs": [result["answer"] for result in results["inputs"]],
|
||||
}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain asynchronously."""
|
||||
question = inputs["question"]
|
||||
search_results = await self.searcher.acall(
|
||||
{"question": question},
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
urls = search_results["urls"]
|
||||
blobs = await self.downloader.adownload(urls)
|
||||
parser = MarkdownifyHTMLParser()
|
||||
docs = itertools.chain.from_iterable(
|
||||
parser.lazy_parse(blob) for blob in blobs if blob is not None
|
||||
)
|
||||
_inputs = [{"doc": doc, "question": question} for doc in docs]
|
||||
results = await self.reader.acall(
|
||||
_inputs,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return {
|
||||
"docs": [result["answer"] for result in results["results"]],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_llms(
|
||||
cls,
|
||||
*,
|
||||
query_generation_llm: BaseLanguageModel,
|
||||
link_selection_llm: BaseLanguageModel,
|
||||
underlying_reader_chain: LLMChain,
|
||||
top_k_per_search: int = -1,
|
||||
max_concurrency: int = 1,
|
||||
max_num_pages_per_doc: int = 5,
|
||||
text_splitter: Union[TextSplitter, Literal["recursive"]] = "recursive",
|
||||
download_handler: Union[DownloadHandler, Literal["auto"]] = "auto",
|
||||
text_splitter_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> Research:
|
||||
"""Helper to create a research chain from standard llm related components.
|
||||
|
||||
Args:
|
||||
query_generation_llm: The language model to use for query generation.
|
||||
link_selection_llm: The language model to use for link selection.
|
||||
underlying_reader_chain: The chain to use to answer the question.
|
||||
top_k_per_search: The number of documents to return per search.
|
||||
max_concurrency: The maximum number of concurrent reads.
|
||||
max_num_pages_per_doc: The maximum number of pages to read per document.
|
||||
text_splitter: The text splitter to use to split the document into
|
||||
smaller chunks.
|
||||
download_handler: The download handler to use to download the documents.
|
||||
Provide either a download handler or the name of a
|
||||
download handler.
|
||||
- "auto" swaps between using requests and playwright
|
||||
text_splitter_kwargs: The keyword arguments to pass to the text splitter.
|
||||
Only use when providing a text splitter as string.
|
||||
|
||||
Returns:
|
||||
A research chain.
|
||||
"""
|
||||
if isinstance(text_splitter, str):
|
||||
if text_splitter == "recursive":
|
||||
_text_splitter_kwargs = text_splitter_kwargs or {}
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter(
|
||||
**_text_splitter_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid text splitter: {text_splitter}")
|
||||
elif isinstance(text_splitter, TextSplitter):
|
||||
_text_splitter = text_splitter
|
||||
else:
|
||||
raise TypeError(f"Invalid text splitter: {type(text_splitter)}")
|
||||
|
||||
if isinstance(download_handler, str):
|
||||
if download_handler == "auto":
|
||||
_download_handler: DownloadHandler = AutoDownloadHandler()
|
||||
else:
|
||||
raise ValueError(f"Invalid download handler: {download_handler}")
|
||||
elif isinstance(download_handler, DownloadHandler):
|
||||
_download_handler = download_handler
|
||||
else:
|
||||
raise TypeError(f"Invalid download handler: {type(download_handler)}")
|
||||
|
||||
searcher = GenericSearcher.from_llms(
|
||||
link_selection_llm,
|
||||
query_generation_llm,
|
||||
top_k_per_search=top_k_per_search,
|
||||
)
|
||||
|
||||
doc_reading_chain = DocReadingChain(
|
||||
chain=underlying_reader_chain,
|
||||
max_num_docs=max_num_pages_per_doc,
|
||||
text_splitter=_text_splitter,
|
||||
)
|
||||
# Can read multiple documents in parallel
|
||||
multi_reader = ParallelApplyChain(
|
||||
chain=doc_reading_chain,
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
return cls(searcher=searcher, reader=multi_reader, downloader=_download_handler)
|
||||
140
langchain/chains/research/crawling.py
Normal file
140
langchain/chains/research/crawling.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Module contains code for crawling a blob (e.g., HTML file) for links.
|
||||
|
||||
The main idea behind the crawling module is to identify additional links
|
||||
that are worth exploring to find more documents that may be relevant for being
|
||||
able to answer the question correctly.
|
||||
"""
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from bs4 import BeautifulSoup, PageElement
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.classification.multiselection import MultiSelectChain
|
||||
from langchain.chains.research.typedefs import BlobCrawler
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
|
||||
|
||||
|
||||
def _get_ahref_snippets(html: str, num_chars: int = 0) -> Dict[str, Any]:
|
||||
"""Get a list of <a> tags as snippets from the given html.
|
||||
|
||||
Args:
|
||||
html: the html to get snippets from.
|
||||
num_chars: the number of characters to get around the <a> tags.
|
||||
|
||||
Returns:
|
||||
a list of snippets.
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
title = soup.title.string.strip()
|
||||
snippets = []
|
||||
|
||||
for idx, a_tag in enumerate(soup.find_all("a")):
|
||||
before_text = _get_surrounding_text(a_tag, num_chars, is_before=True)
|
||||
after_text = _get_surrounding_text(a_tag, num_chars, is_before=False)
|
||||
snippet = {
|
||||
"id": idx,
|
||||
"before": before_text.strip().replace("\n", " "),
|
||||
"link": a_tag.get("href").replace("\n", " ").strip(),
|
||||
"content": a_tag.text.replace("\n", " ").strip(),
|
||||
"after": after_text.strip().replace("\n", " "),
|
||||
}
|
||||
snippets.append(snippet)
|
||||
|
||||
return {
|
||||
"snippets": snippets,
|
||||
"title": title,
|
||||
}
|
||||
|
||||
|
||||
def _extract_records(blob: Blob) -> Tuple[List[Dict[str, Any]], Tuple[str, ...]]:
|
||||
"""Extract records from a blob."""
|
||||
if blob.mimetype == "text/html":
|
||||
info = _get_ahref_snippets(blob.as_string(), num_chars=100)
|
||||
return (
|
||||
[
|
||||
{
|
||||
"content": d["content"],
|
||||
"link": d["link"],
|
||||
"before": d["before"],
|
||||
"after": d["after"],
|
||||
}
|
||||
for d in info["snippets"]
|
||||
],
|
||||
("link", "content", "before", "after"),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Can only extract records from HTML/JSON blobs. Got {blob.mimetype}"
|
||||
)
|
||||
|
||||
|
||||
class ChainCrawler(BlobCrawler):
|
||||
def __init__(self, chain: MultiSelectChain, parser: BaseBlobParser) -> None:
|
||||
"""Crawl the blob using an LLM."""
|
||||
self.chain = chain
|
||||
self.parser = parser
|
||||
|
||||
def crawl(self, blob: Blob, question: str) -> List[str]:
|
||||
"""Explore the blob and suggest additional content to explore."""
|
||||
if not blob.source:
|
||||
raise NotImplementedError()
|
||||
records, columns = _extract_records(blob)
|
||||
|
||||
result = self.chain(
|
||||
inputs={"question": question, "choices": records, "columns": columns},
|
||||
)
|
||||
|
||||
selected_records = result["selected"]
|
||||
|
||||
urls = [
|
||||
# TODO(): handle absolute links
|
||||
urllib.parse.urljoin(blob.source, record["link"])
|
||||
for record in selected_records
|
||||
if "mailto:" not in record["link"]
|
||||
]
|
||||
return urls
|
||||
|
||||
@classmethod
|
||||
def from_default(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
blob_parser: BaseBlobParser = MarkdownifyHTMLParser(),
|
||||
) -> "ChainCrawler":
|
||||
"""Create a crawler from the default LLM."""
|
||||
chain = MultiSelectChain.from_default(llm)
|
||||
return cls(chain=chain, parser=blob_parser)
|
||||
|
||||
|
||||
def _get_surrounding_text(tag: PageElement, n: int, *, is_before: bool = True) -> str:
|
||||
"""Get surrounding text the given tag in the given direction.
|
||||
|
||||
Args:
|
||||
tag: the tag to get surrounding text for.
|
||||
n: number of characters to get
|
||||
is_before: Whether to get text before or after the tag.
|
||||
|
||||
Returns:
|
||||
the surrounding text in the given direction.
|
||||
"""
|
||||
text = ""
|
||||
current = tag.previous_element if is_before else tag.next_element
|
||||
|
||||
while current and len(text) < n:
|
||||
current_text = str(current.text).strip()
|
||||
current_text = (
|
||||
current_text
|
||||
if len(current_text) + len(text) <= n
|
||||
else current_text[: n - len(text)]
|
||||
)
|
||||
|
||||
if is_before:
|
||||
text = current_text + " " + text
|
||||
else:
|
||||
text = text + " " + current_text
|
||||
|
||||
current = current.previous_element if is_before else current.next_element
|
||||
|
||||
return text
|
||||
175
langchain/chains/research/download.py
Normal file
175
langchain/chains/research/download.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Module contains code for fetching documents from the web using playwright.
|
||||
|
||||
This module currently re-uses the code from the `web_base` module to avoid
|
||||
re-implementing rate limiting behavior.
|
||||
|
||||
The module contains downloading interfaces.
|
||||
|
||||
Sub-classing with the given interface should allow a user to add url based
|
||||
user-agents and authentication if needed.
|
||||
|
||||
Downloading is batched by default to allow efficient parallelization.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
|
||||
MaybeBlob = Optional[Blob]
|
||||
|
||||
|
||||
def _is_javascript_required(html_content: str) -> bool:
|
||||
"""Heuristic to determine whether javascript execution is required.
|
||||
|
||||
Args:
|
||||
html_content (str): The HTML content to check.
|
||||
|
||||
Returns:
|
||||
bool: True if javascript execution is required, False otherwise.
|
||||
"""
|
||||
# Parse the HTML content using BeautifulSoup
|
||||
soup = BeautifulSoup(html_content, "lxml")
|
||||
|
||||
# Count the number of HTML elements
|
||||
body = soup.body
|
||||
if not body:
|
||||
return True
|
||||
num_elements = len(body.find_all())
|
||||
requires_javascript = num_elements < 1
|
||||
return requires_javascript
|
||||
|
||||
|
||||
class DownloadHandler(abc.ABC):
|
||||
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of URLs synchronously."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of URLs asynchronously."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PlaywrightDownloadHandler(DownloadHandler):
|
||||
"""Download URLS using playwright.
|
||||
|
||||
This is an implementation of the download handler that uses playwright to download
|
||||
urls. This is useful for downloading urls that require javascript to be executed.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: int = 5) -> None:
|
||||
"""Initialize the download handler.
|
||||
|
||||
Args:
|
||||
timeout: The timeout in seconds to wait for a page to load.
|
||||
"""
|
||||
self.timeout = timeout
|
||||
|
||||
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download list of urls synchronously."""
|
||||
return asyncio.run(self.adownload(urls))
|
||||
|
||||
async def _download(self, browser: Any, url: str) -> Optional[str]:
|
||||
"""Download a url asynchronously using playwright."""
|
||||
from playwright.async_api import TimeoutError
|
||||
|
||||
page = await browser.new_page()
|
||||
try:
|
||||
# Up to 5 seconds to load the page.
|
||||
await page.goto(url, wait_until="networkidle")
|
||||
html_content = await page.content()
|
||||
except TimeoutError:
|
||||
html_content = None
|
||||
return html_content
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of URLs asynchronously using playwright.
|
||||
|
||||
Args:
|
||||
urls: The urls to download.
|
||||
|
||||
Returns:
|
||||
list of blobs containing the downloaded content.
|
||||
"""
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch()
|
||||
tasks = [self._download(browser, url) for url in urls]
|
||||
contents = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await browser.close()
|
||||
|
||||
return _repackage_as_blobs(urls, contents)
|
||||
|
||||
|
||||
class RequestsDownloadHandler(DownloadHandler):
|
||||
def __init__(self, web_downloader: Optional[WebBaseLoader] = None) -> None:
|
||||
"""Initialize the requests download handler."""
|
||||
self.web_downloader = web_downloader or WebBaseLoader(web_path=[])
|
||||
|
||||
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of URLS synchronously."""
|
||||
return asyncio.run(self.adownload(urls))
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of urls asynchronously using playwright."""
|
||||
download = WebBaseLoader(web_path=[]) # Place holder
|
||||
contents = await download.fetch_all(list(urls))
|
||||
return _repackage_as_blobs(urls, contents)
|
||||
|
||||
|
||||
def _repackage_as_blobs(
|
||||
urls: Sequence[str], contents: Sequence[Optional[str]]
|
||||
) -> List[MaybeBlob]:
|
||||
"""Repackage the contents as blobs."""
|
||||
blobs: List[MaybeBlob] = []
|
||||
for url, content in zip(urls, contents):
|
||||
mimetype = mimetypes.guess_type(url)[0]
|
||||
if content is None:
|
||||
blobs.append(None)
|
||||
else:
|
||||
blobs.append(Blob(data=content or "", mimetype=mimetype, path=url))
|
||||
|
||||
return blobs
|
||||
|
||||
|
||||
class AutoDownloadHandler(DownloadHandler):
|
||||
"""Download URLs using the requests library if possible.
|
||||
|
||||
Fallback to using playwright if javascript is required.
|
||||
"""
|
||||
|
||||
def __init__(self, web_downloader: Optional[WebBaseLoader] = None) -> None:
|
||||
"""Initialize the auto download handler."""
|
||||
self.requests_downloader = RequestsDownloadHandler(
|
||||
web_downloader or WebBaseLoader(web_path=[])
|
||||
)
|
||||
self.playwright_downloader = PlaywrightDownloadHandler()
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of urls asynchronously using playwright."""
|
||||
# Check if javascript is required
|
||||
blobs = await self.requests_downloader.adownload(urls)
|
||||
|
||||
# Check if javascript is required
|
||||
must_redownload = [
|
||||
(idx, url)
|
||||
for idx, (url, blob) in enumerate(zip(urls, blobs))
|
||||
if blob is not None and _is_javascript_required(blob.as_string())
|
||||
]
|
||||
if must_redownload:
|
||||
indexes, urls_to_redownload = zip(*must_redownload)
|
||||
new_blobs = await self.playwright_downloader.adownload(urls_to_redownload)
|
||||
|
||||
for idx, blob in zip(indexes, new_blobs):
|
||||
blobs[idx] = blob
|
||||
return blobs
|
||||
|
||||
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of URLs synchronously."""
|
||||
return asyncio.run(self.adownload(urls))
|
||||
151
langchain/chains/research/readers.py
Normal file
151
langchain/chains/research/readers.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Module contains supporting chains for research use case."""
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
|
||||
class DocReadingChain(Chain):
|
||||
"""A reader chain should use one of the QA chains to answer a question.
|
||||
|
||||
This chain is also responsible for splitting the document into smaller chunks
|
||||
and then passing the chunks to an underlying QA chain.
|
||||
|
||||
A brute force chain that reads an entire document (or the first N pages).
|
||||
"""
|
||||
|
||||
chain: Chain
|
||||
"""The chain to use to answer the question."""
|
||||
|
||||
text_splitter: TextSplitter
|
||||
"""The text splitter to use to split the document into smaller chunks."""
|
||||
|
||||
max_num_docs: int
|
||||
"""The maximum number of documents to split the document into.
|
||||
|
||||
Use -1 to denote no limit to the number of pages to read.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["doc", "question"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["answer"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a long document synchronously."""
|
||||
source_document = inputs["doc"]
|
||||
question = inputs["question"]
|
||||
|
||||
sub_docs = self.text_splitter.split_documents([source_document])
|
||||
if self.max_num_docs > 0:
|
||||
_sub_docs = sub_docs[: self.max_num_docs]
|
||||
else:
|
||||
_sub_docs = sub_docs
|
||||
|
||||
response = self.chain(
|
||||
{"input_documents": _sub_docs, "question": question},
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
summary_doc = Document(
|
||||
page_content=response["output_text"],
|
||||
metadata=source_document.metadata,
|
||||
)
|
||||
return {"answer": summary_doc}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a long document asynchronously."""
|
||||
source_document = inputs["doc"]
|
||||
question = inputs["question"]
|
||||
sub_docs = self.text_splitter.split_documents([source_document])
|
||||
if self.max_num_docs > 0:
|
||||
_sub_docs = sub_docs[: self.max_num_docs]
|
||||
else:
|
||||
_sub_docs = sub_docs
|
||||
|
||||
results = await self.chain.acall(
|
||||
{"input_documents": _sub_docs, "question": question},
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
summary_doc = Document(
|
||||
page_content=results["output_text"],
|
||||
metadata=source_document.metadata,
|
||||
)
|
||||
|
||||
return {"answer": summary_doc}
|
||||
|
||||
|
||||
class ParallelApplyChain(Chain):
|
||||
"""Utility chain to apply a given chain in parallel across input documents.
|
||||
|
||||
This chain needs to handle a limit on concurrency.
|
||||
|
||||
WARNING: Parallelization only implemented on the async path.
|
||||
"""
|
||||
|
||||
chain: Chain
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["inputs"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["results"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain."""
|
||||
# TODO(): parallelize this
|
||||
chain_inputs = inputs["inputs"]
|
||||
|
||||
results = [
|
||||
self.chain(
|
||||
chain_input,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
for chain_input in chain_inputs
|
||||
]
|
||||
return {"results": results}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain."""
|
||||
chain_inputs = inputs["inputs"]
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self.chain.acall(
|
||||
chain_input,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
for chain_input in chain_inputs
|
||||
]
|
||||
)
|
||||
return {"results": results}
|
||||
329
langchain/chains/research/search.py
Normal file
329
langchain/chains/research/search.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""Module for initiating a set of searches relevant for answering the question."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from langchain import LLMChain, PromptTemplate, serpapi
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.classification.multiselection import MultiSelectChain
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
|
||||
"""Extract content from the given tag."""
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
queries = []
|
||||
for query in soup.find_all(tag):
|
||||
queries.append(query.text)
|
||||
return queries
|
||||
|
||||
|
||||
def _extract_href_tags(html: str) -> List[str]:
|
||||
"""Extract href tags.
|
||||
|
||||
Args:
|
||||
html: the html to extract href tags from.
|
||||
|
||||
Returns:
|
||||
a list of href tags.
|
||||
"""
|
||||
href_tags = []
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for a_tag in soup.find_all("a"):
|
||||
href = a_tag.get("href")
|
||||
if href:
|
||||
href_tags.append(href)
|
||||
return href_tags
|
||||
|
||||
|
||||
class QueryExtractor(BaseOutputParser[List[str]]):
|
||||
"""An output parser that extracts all queries."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Extract all content of <query> from the text."""
|
||||
return _extract_content_from_tag(text, "query")
|
||||
|
||||
|
||||
QUERY_GENERATION_PROMPT = PromptTemplate.from_template(
|
||||
"""\
|
||||
Suggest a few different search queries that could be used to identify web-pages \
|
||||
that could answer the following question.
|
||||
|
||||
If the question is about a named entity start by listing very general \
|
||||
searches (e.g., just the named entity) and then suggest searches more \
|
||||
scoped to the question.
|
||||
|
||||
Input: ```Where did John Snow from Cambridge, UK work?```
|
||||
Output: <query>John Snow</query>
|
||||
<query>John Snow Cambridge UK</query>
|
||||
<query> John Snow Cambridge UK work history </query>
|
||||
<query> John Snow Cambridge UK cv </query>
|
||||
|
||||
Input: ```How many research papers did Jane Doe publish in 2010?```
|
||||
Output: <query>Jane Doe</query>
|
||||
<query>Jane Doe research papers</query>
|
||||
<query>Jane Doe research research</query>
|
||||
<query>Jane Doe publications</query>
|
||||
<query>Jane Doe publications 2010</query>
|
||||
|
||||
Input: ```What is the capital of France?```
|
||||
Output: <query>France</query>
|
||||
<query>France capital</query>
|
||||
<query>France capital city</query>
|
||||
<query>France capital city name</query>
|
||||
|
||||
Input: ```What are the symptoms of COVID-19?```
|
||||
Output: <query>COVID-19</query>
|
||||
<query>COVID-19 symptoms</query>
|
||||
<query>COVID-19 symptoms list</query>
|
||||
<query>COVID-19 symptoms list WHO</query>
|
||||
|
||||
Input: ```What is the revenue stream of CVS?```
|
||||
Output: <query>CVS</query>
|
||||
<query>CVS revenue</query>
|
||||
<query>CVS revenue stream</query>
|
||||
<query>CVS revenue stream business model</query>
|
||||
|
||||
Input: ```{question}```
|
||||
Output:
|
||||
""",
|
||||
output_parser=QueryExtractor(),
|
||||
)
|
||||
|
||||
|
||||
def _deduplicate_objects(
|
||||
dicts: Sequence[Mapping[str, Any]], key: str
|
||||
) -> List[Mapping[str, Any]]:
|
||||
"""Deduplicate objects by the given key.
|
||||
|
||||
TODO(Eugene): add a way to add weights to the objects.
|
||||
|
||||
Args:
|
||||
dicts: a list of dictionaries to deduplicate.
|
||||
key: the key to deduplicate by.
|
||||
|
||||
Returns:
|
||||
a list of deduplicated dictionaries.
|
||||
"""
|
||||
unique_values = set()
|
||||
deduped: List[Mapping[str, Any]] = []
|
||||
|
||||
for d in dicts:
|
||||
value = d[key]
|
||||
if value not in unique_values:
|
||||
unique_values.add(value)
|
||||
deduped.append(d)
|
||||
|
||||
return deduped
|
||||
|
||||
|
||||
def _run_searches(queries: Sequence[str], top_k: int = -1) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return all the search results.
|
||||
|
||||
This function can return duplicated results, de-duplication can take place later
|
||||
and take into account the frequency of appearance.
|
||||
|
||||
Args:
|
||||
queries: a list of queries to run
|
||||
top_k: the number of results to return, if -1 return all results
|
||||
|
||||
Returns:
|
||||
a list of unique search results
|
||||
"""
|
||||
wrapper = serpapi.SerpAPIWrapper()
|
||||
results = []
|
||||
for query in queries:
|
||||
result = wrapper.results(query)
|
||||
all_organic_results = result.get("organic_results", [])
|
||||
if top_k <= 0:
|
||||
organic_results = all_organic_results
|
||||
else:
|
||||
organic_results = all_organic_results[:top_k]
|
||||
results.extend(organic_results)
|
||||
return results
|
||||
|
||||
|
||||
async def _arun_searches(
|
||||
queries: Sequence[str], top_k: int = -1
|
||||
) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return all the search results.
|
||||
|
||||
This function can return duplicated results, de-duplication can take place later
|
||||
and take into account the frequency of appearance.
|
||||
|
||||
Args:
|
||||
queries: a list of queries to run
|
||||
top_k: the number of results to return, if -1 return all results
|
||||
|
||||
Returns:
|
||||
a list of unique search results
|
||||
"""
|
||||
wrapper = serpapi.SerpAPIWrapper()
|
||||
tasks = [wrapper.aresults(query) for query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
finalized_results = []
|
||||
|
||||
for result in results:
|
||||
all_organic_results = result.get("organic_results", [])
|
||||
if top_k <= 0:
|
||||
organic_results = all_organic_results
|
||||
else:
|
||||
organic_results = all_organic_results[:top_k]
|
||||
|
||||
finalized_results.extend(organic_results)
|
||||
|
||||
return finalized_results
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def make_query_generator(llm: BaseLanguageModel) -> LLMChain:
|
||||
"""Use an LLM to break down a complex query into a list of simpler queries.
|
||||
|
||||
The simpler queries are used to run searches against a search engine in a goal
|
||||
to increase the recall step and retrieve as many relevant results as possible.
|
||||
|
||||
|
||||
Query:
|
||||
|
||||
Does Harrison Chase of Langchain like to eat pizza or play squash?
|
||||
|
||||
May be broken down into a list of simpler queries like:
|
||||
|
||||
- Harrison Chase
|
||||
- Harrison Chase Langchain
|
||||
- Harrison Chase Langchain pizza
|
||||
- Harrison Chase Langchain squash
|
||||
"""
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
output_key="urls",
|
||||
prompt=QUERY_GENERATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class GenericSearcher(Chain):
|
||||
"""A chain that takes a complex question and identifies a list of relevant urls.
|
||||
|
||||
The chain works by:
|
||||
|
||||
1. Breaking a complex question into a series of simpler queries using an LLM.
|
||||
2. Running the queries against a search engine.
|
||||
3. Selecting the most relevant urls using an LLM (can be replaced with tf-idf
|
||||
or other models).
|
||||
|
||||
This chain is not meant to be used for questions requiring multiple hops to answer.
|
||||
|
||||
For example, the age of leonardo dicaprio's girlfriend is a multi-hop question
|
||||
|
||||
This kind of question requires a slightly different approach.
|
||||
|
||||
This chain is meant to handle questions for which one wants
|
||||
to collect information from multiple sources.
|
||||
|
||||
To extend implementation:
|
||||
* Parameterize the search engine (to allow for non serp api search engines)
|
||||
* Expose an abstract interface for query generator and link selection model
|
||||
* Expose promoted answers from search engines as blobs that should be summarized
|
||||
"""
|
||||
|
||||
query_generator: LLMChain
|
||||
"""An LLM used to break down a complex question into a list of simpler queries."""
|
||||
link_selection_model: Chain
|
||||
"""An LLM that is used to select the most relevant urls from the search results."""
|
||||
top_k_per_search: int = -1
|
||||
"""The number of top urls to select from each search."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["question"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["urls"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
queries = typing.cast(
|
||||
List[str],
|
||||
self.query_generator.predict_and_parse(
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
question=question,
|
||||
),
|
||||
)
|
||||
results = _run_searches(queries, top_k=self.top_k_per_search)
|
||||
deuped_results = _deduplicate_objects(results, "link")
|
||||
records = [
|
||||
{"link": result["link"], "title": result["title"]}
|
||||
for result in deuped_results
|
||||
]
|
||||
response_ = self.link_selection_model(
|
||||
{
|
||||
"question": question,
|
||||
"choices": records,
|
||||
},
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return {"urls": [result["link"] for result in response_["selected"]]}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
queries = typing.cast(
|
||||
List[str],
|
||||
await self.query_generator.apredict_and_parse(
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
question=question,
|
||||
),
|
||||
)
|
||||
results = _run_searches(queries, top_k=self.top_k_per_search)
|
||||
deuped_results = _deduplicate_objects(results, "link")
|
||||
records = [
|
||||
{"link": result["link"], "title": result["title"]}
|
||||
for result in deuped_results
|
||||
]
|
||||
response_ = await self.link_selection_model.acall(
|
||||
{
|
||||
"question": question,
|
||||
"choices": records,
|
||||
},
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return {"urls": [result["link"] for result in response_["selected"]]}
|
||||
|
||||
@classmethod
|
||||
def from_llms(
|
||||
cls,
|
||||
link_selection_llm: BaseLanguageModel,
|
||||
query_generation_llm: BaseLanguageModel,
|
||||
*,
|
||||
top_k_per_search: int = -1,
|
||||
) -> GenericSearcher:
|
||||
"""Initialize the searcher from a language model."""
|
||||
link_selection_model = MultiSelectChain.from_default(llm=link_selection_llm)
|
||||
query_generation_model = make_query_generator(query_generation_llm)
|
||||
return cls(
|
||||
link_selection_model=link_selection_model,
|
||||
query_generator=query_generation_model,
|
||||
top_k_per_search=top_k_per_search,
|
||||
)
|
||||
12
langchain/chains/research/typedefs.py
Normal file
12
langchain/chains/research/typedefs.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
|
||||
|
||||
class BlobCrawler(abc.ABC):
|
||||
"""Crawl a blob and identify links to related content."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def crawl(self, blob: Blob, query: str) -> List[str]:
|
||||
"""Explore the blob and identify links to relevant content."""
|
||||
74
langchain/document_loaders/parsers/html/markdownify.py
Normal file
74
langchain/document_loaders/parsers/html/markdownify.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Load and chunk HTMLs with potential pre-processing to clean the html."""
|
||||
|
||||
import re
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema import Document
|
||||
|
||||
# Regular expression pattern to detect multiple new lines in a row with optional
|
||||
# whitespace in between
|
||||
CONSECUTIVE_NEW_LINES = re.compile(r"\n(\s*\n)+", flags=re.UNICODE)
|
||||
|
||||
|
||||
def _get_mini_html(html: str, *, tags_to_remove: Tuple[str, ...] = tuple()) -> str:
|
||||
"""Clean up HTML tags."""
|
||||
# Parse the HTML document using BeautifulSoup
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Remove all CSS stylesheets
|
||||
for stylesheet in soup.find_all("link", rel="stylesheet"):
|
||||
stylesheet.extract()
|
||||
|
||||
for tag_to_remove in tags_to_remove:
|
||||
# Remove all matching tags
|
||||
for tag in soup.find_all(tag_to_remove):
|
||||
tag.extract()
|
||||
|
||||
new_html = repr(soup)
|
||||
return new_html
|
||||
|
||||
|
||||
def _clean_html(html: str, *, tags_to_remove: Tuple[str, ...] = tuple()) -> str:
|
||||
"""Clean up HTML and convert to markdown using markdownify."""
|
||||
try:
|
||||
import markdownify
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The markdownify package is required to parse HTML files. "
|
||||
"Please install it with `pip install markdownify`."
|
||||
)
|
||||
html = _get_mini_html(html, tags_to_remove=tags_to_remove)
|
||||
md = markdownify.markdownify(html)
|
||||
return CONSECUTIVE_NEW_LINES.sub("\n\n", md).strip()
|
||||
|
||||
|
||||
## PUBLIC API
|
||||
|
||||
|
||||
class MarkdownifyHTMLParser(BaseBlobParser):
|
||||
"""A blob parser to parse HTML content.."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tags_to_remove: Tuple[str, ...] = ("svg", "img", "script", "style"),
|
||||
) -> None:
|
||||
"""Initialize the preprocessor.
|
||||
|
||||
Args:
|
||||
tags_to_remove: A tuple of tags to remove from the HTML
|
||||
"""
|
||||
|
||||
self.tags_to_remove = tags_to_remove
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
yield Document(
|
||||
page_content=_clean_html(
|
||||
blob.as_string(), tags_to_remove=self.tags_to_remove
|
||||
),
|
||||
metadata={"source": blob.source},
|
||||
)
|
||||
0
tests/unit_tests/chains/research/__init__.py
Normal file
0
tests/unit_tests/chains/research/__init__.py
Normal file
32
tests/unit_tests/chains/research/test_downloader.py
Normal file
32
tests/unit_tests/chains/research/test_downloader.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Tests for the downloader."""
|
||||
import pytest
|
||||
|
||||
from langchain.chains.research.download import (
|
||||
_is_javascript_required,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("lxml")
|
||||
def test_is_javascript_required() -> None:
|
||||
"""Check whether a given page should be re-downloaded with javascript executed."""
|
||||
assert not _is_javascript_required(
|
||||
"""
|
||||
<html>
|
||||
<body>
|
||||
<p>Check whether javascript is required.</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
|
||||
assert _is_javascript_required(
|
||||
"""
|
||||
<html>
|
||||
<script>
|
||||
console.log("Javascript is required.");
|
||||
</script>
|
||||
<body>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
Reference in New Issue
Block a user