mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
x
This commit is contained in:
92
langchain/chains/research/api.py
Normal file
92
langchain/chains/research/api.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Any, Union, Literal
|
||||
|
||||
from plistlib import Dict
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.research.readers import DocReadingChain, ParallelApplyChain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.chains.research.search import GenericSearcher
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
|
||||
class Research(Chain):
|
||||
"""A simple research chain."""
|
||||
searcher: GenericSearcher
|
||||
"""The searcher to use to search for documents."""
|
||||
reader: Chain
|
||||
"""The reader to use to read documents and produce an answer."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain asynchronously."""
|
||||
searcher = self.searcher
|
||||
reader = self.reader
|
||||
question = inputs["question"]
|
||||
search_results = searcher({"question": question})
|
||||
documents = [result["document"] for result in search_results["results"]]
|
||||
if not documents:
|
||||
return {"answer": None}
|
||||
reader_inputs = {"doc": documents, "question": question}
|
||||
return await reader(reader_inputs, callbacks=run_manager.get_child())
|
||||
|
||||
@classmethod
|
||||
def from_llms(
|
||||
cls,
|
||||
link_selection_llm: BaseLanguageModel,
|
||||
query_generation_llm: BaseLanguageModel,
|
||||
qa_chain: LLMChain,
|
||||
*,
|
||||
top_k_per_search: int = -1,
|
||||
max_concurrency: int = 1,
|
||||
max_num_pages_per_doc: int = 100,
|
||||
text_splitter: Union[TextSplitter, Literal["recursive"]] = "recursive",
|
||||
) -> Research:
|
||||
"""Helper to create a research chain from standard llm related components.
|
||||
|
||||
Args:
|
||||
link_selection_llm: The language model to use for link selection.
|
||||
query_generation_llm: The language model to use for query generation.
|
||||
qa_chain: The chain to use to answer the question.
|
||||
top_k_per_search: The number of documents to return per search.
|
||||
max_concurrency: The maximum number of concurrent reads.
|
||||
max_num_pages_per_doc: The maximum number of pages to read per document.
|
||||
text_splitter: The text splitter to use to split the document into smaller chunks.
|
||||
|
||||
Returns:
|
||||
A research chain.
|
||||
"""
|
||||
searcher = GenericSearcher.from_llms(
|
||||
link_selection_llm,
|
||||
query_generation_llm,
|
||||
top_k_per_search=top_k_per_search,
|
||||
)
|
||||
if isinstance(text_splitter, str):
|
||||
if text_splitter == "recursive":
|
||||
_text_splitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
raise ValueError(f"Invalid text splitter: {text_splitter}")
|
||||
elif isinstance(text_splitter, TextSplitter):
|
||||
_text_splitter = text_splitter
|
||||
else:
|
||||
raise TypeError(f"Invalid text splitter: {type(text_splitter)}")
|
||||
reader = ParallelApplyChain(
|
||||
chain=DocReadingChain(
|
||||
qa_chain, max_num_pages_per_doc=max_num_pages_per_doc,
|
||||
text_splitter=text_splitter
|
||||
),
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
return cls(searcher=searcher, reader=reader)
|
||||
@@ -4,7 +4,6 @@ The main idea behind the crawling module is to identify additional links
|
||||
that are worth exploring to find more documents that may be relevant for being
|
||||
able to answer the question correctly.
|
||||
"""
|
||||
import json
|
||||
import urllib.parse
|
||||
from bs4 import PageElement, BeautifulSoup
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
@@ -89,11 +89,13 @@ class DocReadingChain(Chain):
|
||||
return {"document": summary_doc}
|
||||
|
||||
|
||||
class ParallelApply(Chain):
|
||||
"""Utility chain to apply a given chain in parallel across input documents."""
|
||||
class ParallelApplyChain(Chain):
|
||||
"""Utility chain to apply a given chain in parallel across input documents.
|
||||
|
||||
This chain needs to handle a limit on concurrency.
|
||||
"""
|
||||
|
||||
chain: Chain
|
||||
max_concurrency: int
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Module for initiating a set of searches relevant for answering the question."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Sequence, List, Mapping, Any, Optional, Union, Dict
|
||||
from typing import Sequence, List, Mapping, Any, Optional, Dict
|
||||
|
||||
from langchain import PromptTemplate, LLMChain, serpapi
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
@@ -11,9 +11,9 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.classification.multiselection import MultiSelectChain
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
|
||||
@@ -204,8 +204,8 @@ def make_query_generator(llm: BaseLanguageModel) -> LLMChain:
|
||||
"""
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
output_keys=["urls"],
|
||||
promp=QUERY_GENERATION_PROMPT,
|
||||
output_key="urls",
|
||||
prompt=QUERY_GENERATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
@@ -230,11 +230,12 @@ class GenericSearcher(Chain):
|
||||
To extend implementation:
|
||||
* Parameterize the search engine (to allow for non serp api search engines)
|
||||
* Expose an abstract interface for query generator and link selection model
|
||||
* Expose promoted answers from search engines as blobs that should be summarized
|
||||
"""
|
||||
|
||||
query_generator: LLMChain
|
||||
"""An LLM that is used to break down a complex question into a list of simpler queries."""
|
||||
link_selection_model: LLMChain
|
||||
link_selection_model: Chain
|
||||
"""An LLM that is used to select the most relevant urls from the search results."""
|
||||
top_k_per_search: int = -1
|
||||
"""The number of top urls to select from each search."""
|
||||
@@ -255,20 +256,21 @@ class GenericSearcher(Chain):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
response = self.query_generator.predict_and_parse(
|
||||
queries = self.query_generator.predict_and_parse(
|
||||
callbacks=run_manager.get_child(), question=question
|
||||
)
|
||||
queries = response["queries"]
|
||||
results = _run_searches(queries, top_k=self.top_k_per_search)
|
||||
deuped_results = _deduplicate_objects(results, "link")
|
||||
records = [
|
||||
{"link": result["link"], "title": result["title"]}
|
||||
for result in deuped_results
|
||||
]
|
||||
response_ = self.link_selection_model.predict_and_parse(
|
||||
response_ = self.link_selection_model(
|
||||
{
|
||||
"question": question,
|
||||
"choices": records,
|
||||
},
|
||||
callbacks=run_manager.get_child(),
|
||||
question=question,
|
||||
choices=records,
|
||||
)
|
||||
return {"urls": [result["link"] for result in response_["selected"]]}
|
||||
|
||||
@@ -278,33 +280,37 @@ class GenericSearcher(Chain):
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
response = await self.query_generator.apredict_and_parse(
|
||||
queries = await self.query_generator.apredict_and_parse(
|
||||
callbacks=run_manager.get_child(), question=question
|
||||
)
|
||||
queries = response["queries"]
|
||||
results = await _arun_searches(queries)
|
||||
results = _run_searches(queries, top_k=self.top_k_per_search)
|
||||
deuped_results = _deduplicate_objects(results, "link")
|
||||
records = [
|
||||
{"link": result["link"], "title": result["title"]}
|
||||
for result in deuped_results
|
||||
]
|
||||
response_ = self.link_selection_model.predict_and_parse(
|
||||
response_ = await self.link_selection_model.acall(
|
||||
{
|
||||
"question": question,
|
||||
"choices": records,
|
||||
},
|
||||
callbacks=run_manager.get_child(),
|
||||
question=question,
|
||||
choices=records,
|
||||
)
|
||||
return {"urls": [result["link"] for result in response_["selected"]]}
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
def from_llms(
|
||||
cls,
|
||||
link_selection_llm: BaseLanguageModel,
|
||||
query_generation_llm: BaseLanguageModel,
|
||||
*,
|
||||
top_k_per_search: int = -1,
|
||||
) -> GenericSearcher:
|
||||
"""Initialize the searcher from a language model."""
|
||||
link_selection_model = MultiSelectChain.from_default(llm=link_selection_llm)
|
||||
query_generation_model = make_query_generator(llm=query_generation_llm)
|
||||
query_generation_model = make_query_generator(query_generation_llm)
|
||||
return cls(
|
||||
link_selection_model=link_selection_model,
|
||||
query_generator=query_generation_model,
|
||||
top_k_per_search=top_k_per_search,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user