mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
x
This commit is contained in:
@@ -4,15 +4,26 @@ from __future__ import annotations
|
||||
import csv
|
||||
from bs4 import BeautifulSoup
|
||||
from io import StringIO
|
||||
from typing import Sequence, Mapping, Any, Optional, Dict, List, cast, Set
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Sequence,
|
||||
Mapping,
|
||||
Any,
|
||||
Optional,
|
||||
Dict,
|
||||
List,
|
||||
cast,
|
||||
Set,
|
||||
TypeVar,
|
||||
Iterable,
|
||||
Iterator,
|
||||
)
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseOutputParser
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
MULTI_SELECT_TEMPLATE = """\
|
||||
Here is a table in CSV format:
|
||||
@@ -114,25 +125,11 @@ def _write_records_to_string(
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
class MultiSelectionInput(TypedDict):
|
||||
"""Input for the multi-selection chain."""
|
||||
|
||||
question: str
|
||||
records: Sequence[Mapping[str, Any]]
|
||||
delimiter: str
|
||||
columns: Optional[Sequence[str]]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MultiSelectionOutput(TypedDict):
|
||||
"""Output for the multi-selection chain."""
|
||||
|
||||
records: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
from itertools import islice
|
||||
|
||||
|
||||
def batch(iterable, size):
|
||||
def batch(iterable: Iterable[T], size) -> Iterator[List[T]]:
|
||||
"""Batch an iterable into chunks of size `size`."""
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
batch = list(islice(iterator, size))
|
||||
@@ -160,7 +157,7 @@ class MultiSelectChain(Chain):
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> MultiSelectionOutput:
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain."""
|
||||
choices = inputs["choices"]
|
||||
question = inputs["question"]
|
||||
|
||||
@@ -13,8 +13,6 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.classification.multiselection import (
|
||||
IDParser,
|
||||
_extract_content_from_tag,
|
||||
MultiSelectChain,
|
||||
)
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
@@ -23,43 +21,14 @@ from langchain.document_loaders.parsers.html.markdownify import MarkdownifyHTMLP
|
||||
from langchain.schema import BaseOutputParser, Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
Parser = MarkdownifyHTMLParser(tags_to_remove=("svg", "img", "script", "style", "a"))
|
||||
|
||||
|
||||
class AHrefExtractor(BaseOutputParser[List[str]]):
|
||||
"""An output parser that extracts all a-href links."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
return _extract_href_tags(text)
|
||||
|
||||
|
||||
URL_CRAWLING_PROMPT = PromptTemplate.from_template(
|
||||
"""\
|
||||
Here is a list of URLs extracted from a page titled: `{title}`.
|
||||
|
||||
```csv
|
||||
{urls}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Here is a question:
|
||||
|
||||
{question}
|
||||
|
||||
---
|
||||
|
||||
|
||||
Please output the ids of the URLs that may contain content relevant to answer the question. \
|
||||
Use only the information csv table of URLs to determine relevancy.
|
||||
|
||||
Format your answer inside of an <ids> tags, separating the ids by a comma.
|
||||
|
||||
For example, if the 132 and 133 URLs are relevant, you would write: <ids>132,133</ids>
|
||||
|
||||
Begin:""",
|
||||
output_parser=IDParser(),
|
||||
)
|
||||
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
|
||||
"""Extract content from the given tag."""
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
queries = []
|
||||
for query in soup.find_all(tag):
|
||||
queries.append(query.text)
|
||||
return queries
|
||||
|
||||
|
||||
def _get_surrounding_text(tag: PageElement, n: int, *, is_before: bool = True) -> str:
|
||||
@@ -94,7 +63,7 @@ def _get_surrounding_text(tag: PageElement, n: int, *, is_before: bool = True) -
|
||||
return text
|
||||
|
||||
|
||||
def get_ahref_snippets(html: str, num_chars: int = 0) -> Dict[str, Any]:
|
||||
def _get_ahref_snippets(html: str, num_chars: int = 0) -> Dict[str, Any]:
|
||||
"""Get a list of <a> tags as snippets from the given html.
|
||||
|
||||
Args:
|
||||
@@ -155,11 +124,12 @@ class QueryExtractor(BaseOutputParser[List[str]]):
|
||||
# TODO(Eugene): add a version that works for chat models as well w/ human system message?
|
||||
QUERY_GENERATION_PROMPT = PromptTemplate.from_template(
|
||||
"""\
|
||||
Suggest a few different search queries that could be used to identify web-pages that could answer \
|
||||
the following question.
|
||||
Suggest a few different search queries that could be used to identify web-pages \
|
||||
that could answer the following question.
|
||||
|
||||
If the question is about a named entity start by listing very general searches (e.g., just the named entity) \
|
||||
and then suggest searches more scoped to the question.
|
||||
If the question is about a named entity start by listing very general \
|
||||
searches (e.g., just the named entity) and then suggest searches more \
|
||||
scoped to the question.
|
||||
|
||||
Input: ```Where did John Snow from Cambridge, UK work?```
|
||||
Output: <query>John Snow</query>
|
||||
@@ -209,7 +179,17 @@ def generate_queries(llm: BaseLanguageModel, question: str) -> List[str]:
|
||||
def _deduplicate_objects(
|
||||
dicts: Sequence[Mapping[str, Any]], key: str
|
||||
) -> List[Mapping[str, Any]]:
|
||||
"""Deduplicate objects by the given key."""
|
||||
"""Deduplicate objects by the given key.
|
||||
|
||||
TODO(Eugene): add a way to add weights to the objects.
|
||||
|
||||
Args:
|
||||
dicts: a list of dictionaries to deduplicate.
|
||||
key: the key to deduplicate by.
|
||||
|
||||
Returns:
|
||||
a list of deduplicated dictionaries.
|
||||
"""
|
||||
unique_values = set()
|
||||
deduped: List[Mapping[str, Any]] = []
|
||||
|
||||
@@ -250,10 +230,10 @@ class BlobCrawler(abc.ABC):
|
||||
"""Explore the blob and identify links to related content that is relevant to the query."""
|
||||
|
||||
|
||||
def _extract_records(blob: Blob) -> Tuple[List[Mapping[str, Any]], Tuple[str, ...]]:
|
||||
def _extract_records(blob: Blob) -> Tuple[List[Dict[str, Any]], Tuple[str, ...]]:
|
||||
"""Extract records from a blob."""
|
||||
if blob.mimetype == "text/html":
|
||||
info = get_ahref_snippets(blob.as_string(), num_chars=100)
|
||||
info = _get_ahref_snippets(blob.as_string(), num_chars=100)
|
||||
return (
|
||||
[
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user