mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
x
This commit is contained in:
@@ -4,7 +4,6 @@ The main idea behind the crawling module is to identify additional links
|
||||
that are worth exploring to find more documents that may be relevant for being
|
||||
able to answer the question correctly.
|
||||
"""
|
||||
import abc
|
||||
import json
|
||||
import urllib.parse
|
||||
from bs4 import PageElement, BeautifulSoup
|
||||
@@ -12,6 +11,7 @@ from typing import List, Dict, Any, Tuple
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.classification.multiselection import MultiSelectChain
|
||||
from langchain.chains.research.typedefs import BlobCrawler
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLParser
|
||||
@@ -82,14 +82,6 @@ def _extract_records(blob: Blob) -> Tuple[List[Dict[str, Any]], Tuple[str, ...]]
|
||||
)
|
||||
|
||||
|
||||
class BlobCrawler(abc.ABC):
|
||||
"""Crawl a blob and identify links to related content."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def crawl(self, blob: Blob, query: str) -> List[str]:
|
||||
"""Explore the blob and identify links to related content that is relevant to the query."""
|
||||
|
||||
|
||||
class ChainCrawler(BlobCrawler):
|
||||
def __init__(self, chain: MultiSelectChain, parser: BaseBlobParser) -> None:
|
||||
"""Crawl the blob using an LLM."""
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Module for initiating a set of searches relevant for answering the question."""
|
||||
import abc
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Sequence, List, Mapping, Any
|
||||
|
||||
from langchain import PromptTemplate, LLMChain, serpapi
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.research.typedefs import AbstractQueryGenerator, AbstractSearcher
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
@@ -118,7 +119,7 @@ def _deduplicate_objects(
|
||||
return deduped
|
||||
|
||||
|
||||
def run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
def _run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return the unique results.
|
||||
|
||||
Args:
|
||||
@@ -133,31 +134,46 @@ def run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
result = wrapper.results(query)
|
||||
organic_results = result["organic_results"]
|
||||
results.extend(organic_results)
|
||||
return results
|
||||
|
||||
unique_results = _deduplicate_objects(results, "link")
|
||||
return unique_results
|
||||
|
||||
async def _arun_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return the unique results."""
|
||||
wrapper = serpapi.SerpAPIWrapper()
|
||||
tasks = [wrapper.results(query) for query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return list(
|
||||
itertools.chain.from_iterable(result["organic_results"] for result in results)
|
||||
)
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def generate_queries(llm: BaseLanguageModel, question: str) -> List[str]:
|
||||
"""Generate queries using a Chain."""
|
||||
chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=llm)
|
||||
queries = chain.predict_and_parse(question=question)
|
||||
return queries
|
||||
class GenericQueryGenerator(AbstractQueryGenerator):
|
||||
def __init__(self, llm: BaseLanguageModel) -> None:
|
||||
"""Initialize the query generator."""
|
||||
self.llm = llm
|
||||
|
||||
def generate_queries(self, question: str) -> List[str]:
|
||||
"""Generate queries using a Chain."""
|
||||
chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=self.llm)
|
||||
queries = chain.predict_and_parse(question=question)
|
||||
return queries
|
||||
|
||||
|
||||
class AbstractSearcher(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def search(self, query: str) -> List[Mapping[str, Any]]:
|
||||
"""Run a search for the given query.
|
||||
class GenericSearcher(AbstractSearcher):
|
||||
def __init__(self, search_engine: str = "serp") -> None:
|
||||
"""Initialize the searcher.
|
||||
|
||||
Args:
|
||||
query: the query to run the search for.
|
||||
search_engine: the search engine to use. Placeholder for future work.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
if search_engine != "serp":
|
||||
raise NotImplementedError("Only serp is supported at the moment.")
|
||||
|
||||
# def search_all(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
# """Run a search for all the given queries."""
|
||||
# raise NotImplementedError
|
||||
def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run a search for the given query."""
|
||||
results = await _arun_searches(queries)
|
||||
return _deduplicate_objects(results, "link")
|
||||
|
||||
37
langchain/chains/research/typedefs.py
Normal file
37
langchain/chains/research/typedefs.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import List, Sequence, Mapping, Any
|
||||
|
||||
import abc
|
||||
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
|
||||
|
||||
class AbstractQueryGenerator(abc.ABC):
|
||||
"""Abstract class for generating queries."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_queries(self, question: str) -> List[str]:
|
||||
"""Generate queries for the given question."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AbstractSearcher(abc.ABC):
|
||||
"""Abstract class for running searches."""
|
||||
|
||||
def search(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run a search for the given query.
|
||||
|
||||
Args:
|
||||
queries: the query to run the search for.
|
||||
|
||||
Returns:
|
||||
a list of search results.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BlobCrawler(abc.ABC):
|
||||
"""Crawl a blob and identify links to related content."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def crawl(self, blob: Blob, query: str) -> List[str]:
|
||||
"""Explore the blob and identify links to related content that is relevant to the query."""
|
||||
Reference in New Issue
Block a user