langchain[patch]: deprecate various chains (#25310)

- [x] NatbotChain: move to community, deprecate langchain version.
Update to use `prompt | llm | output_parser` instead of LLMChain.
- [x] LLMMathChain: deprecate + add langgraph replacement example to API
ref
- [x] HypotheticalDocumentEmbedder (retriever): update to use `prompt |
llm | output_parser` instead of LLMChain
- [x] FlareChain: update to use `prompt | llm | output_parser` instead
of LLMChain
- [x] ConstitutionalChain: deprecate + add langgraph replacement example
to API ref
- [x] LLMChainExtractor (document compressor): update to use `prompt |
llm | output_parser` instead of LLMChain
- [x] LLMChainFilter (document compressor): update to use `prompt | llm
| output_parser` instead of LLMChain
- [x] RePhraseQueryRetriever (retriever): update to use `prompt | llm |
output_parser` instead of LLMChain
This commit is contained in:
ccurme
2024-08-15 10:49:26 -04:00
committed by GitHub
parent 66e30efa61
commit 8afbab4cf6
19 changed files with 1166 additions and 126 deletions

View File

@@ -2,6 +2,7 @@
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
@@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
from langchain.chains.llm import LLMChain
@deprecated(
since="0.2.13",
message=(
"This class is deprecated and will be removed in langchain 1.0. "
"See API reference for replacement: "
"https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
),
removal="1.0",
)
class ConstitutionalChain(Chain):
"""Chain for applying constitutional principles.
Note: this class is deprecated. See below for a replacement implementation
using LangGraph. The benefits of this implementation are:
- Uses LLM tool calling features instead of parsing string responses;
- Support for both token-by-token and step-by-step streaming;
- Support for checkpointing and memory of chat history;
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
Install LangGraph with:
.. code-block:: bash
pip install -U langgraph
.. code-block:: python
from typing import List, Optional, Tuple
from langchain.chains.constitutional_ai.prompts import (
CRITIQUE_PROMPT,
REVISION_PROMPT,
)
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from typing_extensions import Annotated, TypedDict
llm = ChatOpenAI(model="gpt-4o-mini")
class Critique(TypedDict):
\"\"\"Generate a critique, if needed.\"\"\"
critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
critique: Annotated[str, ..., "If needed, the critique."]
critique_prompt = ChatPromptTemplate.from_template(
"Critique this response according to the critique request. "
"If no critique is needed, specify that.\\n\\n"
"Query: {query}\\n\\n"
"Response: {response}\\n\\n"
"Critique request: {critique_request}"
)
revision_prompt = ChatPromptTemplate.from_template(
"Revise this response according to the critique and reivsion request.\\n\\n"
"Query: {query}\\n\\n"
"Response: {response}\\n\\n"
"Critique request: {critique_request}\\n\\n"
"Critique: {critique}\\n\\n"
"If the critique does not identify anything worth changing, ignore the "
"revision request and return 'No revisions needed'. If the critique "
"does identify something worth changing, revise the response based on "
"the revision request.\\n\\n"
"Revision Request: {revision_request}"
)
chain = llm | StrOutputParser()
critique_chain = critique_prompt | llm.with_structured_output(Critique)
revision_chain = revision_prompt | llm | StrOutputParser()
class State(TypedDict):
query: str
constitutional_principles: List[ConstitutionalPrinciple]
initial_response: str
critiques_and_revisions: List[Tuple[str, str]]
response: str
async def generate_response(state: State):
\"\"\"Generate initial response.\"\"\"
response = await chain.ainvoke(state["query"])
return {"response": response, "initial_response": response}
async def critique_and_revise(state: State):
\"\"\"Critique and revise response according to principles.\"\"\"
critiques_and_revisions = []
response = state["initial_response"]
for principle in state["constitutional_principles"]:
critique = await critique_chain.ainvoke(
{
"query": state["query"],
"response": response,
"critique_request": principle.critique_request,
}
)
if critique["critique_needed"]:
revision = await revision_chain.ainvoke(
{
"query": state["query"],
"response": response,
"critique_request": principle.critique_request,
"critique": critique["critique"],
"revision_request": principle.revision_request,
}
)
response = revision
critiques_and_revisions.append((critique["critique"], revision))
else:
critiques_and_revisions.append((critique["critique"], ""))
return {
"critiques_and_revisions": critiques_and_revisions,
"response": response,
}
graph = StateGraph(State)
graph.add_node("generate_response", generate_response)
graph.add_node("critique_and_revise", critique_and_revise)
graph.add_edge(START, "generate_response")
graph.add_edge("generate_response", "critique_and_revise")
graph.add_edge("critique_and_revise", END)
app = graph.compile()
.. code-block:: python
constitutional_principles=[
ConstitutionalPrinciple(
critique_request="Tell if this answer is good.",
revision_request="Give a better answer.",
)
]
query = "What is the meaning of life? Answer in 10 words or fewer."
async for step in app.astream(
{"query": query, "constitutional_principles": constitutional_principles},
stream_mode="values",
):
subset = ["initial_response", "critiques_and_revisions", "response"]
print({k: v for k, v in step.items() if k in subset})
Example:
.. code-block:: python
@@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
)
constitutional_chain.run(question="What is the meaning of life?")
"""
""" # noqa: E501
chain: LLMChain
constitutional_principles: List[ConstitutionalPrinciple]

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import re
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
@@ -9,10 +8,12 @@ from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import Generation
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
@@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
from langchain.chains.llm import LLMChain
class _ResponseChain(LLMChain):
"""Base class for chains that generate responses."""
prompt: BasePromptTemplate = PROMPT
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> List[str]:
return self.prompt.input_variables
def generate_tokens_and_log_probs(
self,
_input: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[Sequence[str], Sequence[float]]:
llm_result = self.generate([_input], run_manager=run_manager)
return self._extract_tokens_and_log_probs(llm_result.generations[0])
@abstractmethod
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
"""Extract tokens and log probs from response."""
class _OpenAIResponseChain(_ResponseChain):
"""Chain that generates responses from user input and context."""
llm: BaseLanguageModel
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
tokens = []
log_probs = []
for gen in generations:
if gen.generation_info is None:
raise ValueError
tokens.extend(gen.generation_info["logprobs"]["tokens"])
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
return tokens, log_probs
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
"""Extract tokens and log probabilities from chat model response."""
tokens = []
log_probs = []
for token in response.response_metadata["logprobs"]["content"]:
tokens.append(token["token"])
log_probs.append(token["logprob"])
return tokens, log_probs
class QuestionGeneratorChain(LLMChain):
@@ -111,9 +75,9 @@ class FlareChain(Chain):
"""Chain that combines a retriever, a question generator,
and a response generator."""
question_generator_chain: QuestionGeneratorChain
question_generator_chain: Runnable
"""Chain that generates questions from uncertain spans."""
response_chain: _ResponseChain
response_chain: Runnable
"""Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished."""
@@ -152,12 +116,16 @@ class FlareChain(Chain):
for question in questions:
docs.extend(self.retriever.invoke(question))
context = "\n\n".join(d.page_content for d in docs)
result = self.response_chain.predict(
user_input=user_input,
context=context,
response=response,
callbacks=callbacks,
result = self.response_chain.invoke(
{
"user_input": user_input,
"context": context,
"response": response,
},
{"callbacks": callbacks},
)
if isinstance(result, AIMessage):
result = result.content
marginal, finished = self.output_parser.parse(result)
return marginal, finished
@@ -178,13 +146,18 @@ class FlareChain(Chain):
for span in low_confidence_spans
]
callbacks = _run_manager.get_child()
question_gen_outputs = self.question_generator_chain.apply(
question_gen_inputs, callbacks=callbacks
)
questions = [
output[self.question_generator_chain.output_keys[0]]
for output in question_gen_outputs
]
if isinstance(self.question_generator_chain, LLMChain):
question_gen_outputs = self.question_generator_chain.apply(
question_gen_inputs, callbacks=callbacks
)
questions = [
output[self.question_generator_chain.output_keys[0]]
for output in question_gen_outputs
]
else:
questions = self.question_generator_chain.batch(
question_gen_inputs, config={"callbacks": callbacks}
)
_run_manager.on_text(
f"Generated Questions: {questions}", color="yellow", end="\n"
)
@@ -206,8 +179,10 @@ class FlareChain(Chain):
f"Current Response: {response}", color="blue", end="\n"
)
_input = {"user_input": user_input, "context": "", "response": response}
tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
_input, run_manager=_run_manager
tokens, log_probs = _extract_tokens_and_log_probs(
self.response_chain.invoke(
_input, {"callbacks": _run_manager.get_child()}
)
)
low_confidence_spans = _low_confidence_spans(
tokens,
@@ -251,18 +226,16 @@ class FlareChain(Chain):
FlareChain class with the given language model.
"""
try:
from langchain_openai import OpenAI
from langchain_openai import ChatOpenAI
except ImportError:
raise ImportError(
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
)
response_chain = _OpenAIResponseChain(llm=response_llm)
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
response_chain = PROMPT | llm
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
return cls(
question_generator_chain=question_gen_chain,
response_chain=response_chain,

View File

@@ -11,7 +11,9 @@ import numpy as np
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable
from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
@@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""
base_embeddings: Embeddings
llm_chain: LLMChain
llm_chain: Runnable
class Config:
arbitrary_types_allowed = True
@@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
@property
def input_keys(self) -> List[str]:
"""Input keys for Hyde's LLM chain."""
return self.llm_chain.input_keys
return self.llm_chain.input_schema.schema()["required"]
@property
def output_keys(self) -> List[str]:
"""Output keys for Hyde's LLM chain."""
return self.llm_chain.output_keys
if isinstance(self.llm_chain, LLMChain):
return self.llm_chain.output_keys
else:
return ["text"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call the base embeddings."""
@@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
def embed_query(self, text: str) -> List[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])
documents = [generation.text for generation in result.generations[0]]
var_name = self.input_keys[0]
result = self.llm_chain.invoke({var_name: text})
if isinstance(self.llm_chain, LLMChain):
documents = [result[self.output_keys[0]]]
else:
documents = [result]
embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings)
@@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
) -> Dict[str, str]:
"""Call the internal llm chain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
return self.llm_chain.invoke(
inputs, config={"callbacks": _run_manager.get_child()}
)
@classmethod
def from_llm(
@@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
f"of {list(PROMPT_MAP.keys())}."
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = prompt | llm | StrOutputParser()
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
@property

View File

@@ -7,6 +7,7 @@ import re
import warnings
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
@deprecated(
since="0.2.13",
message=(
"This class is deprecated and will be removed in langchain 1.0. "
"See API reference for replacement: "
"https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501
),
removal="1.0",
)
class LLMMathChain(Chain):
"""Chain that interprets a prompt and executes python code to do math.
Note: this class is deprecated. See below for a replacement implementation
using LangGraph. The benefits of this implementation are:
- Uses LLM tool calling features;
- Support for both token-by-token and step-by-step streaming;
- Support for checkpointing and memory of chat history;
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
Install LangGraph with:
.. code-block:: bash
pip install -U langgraph
.. code-block:: python
import math
from typing import Annotated, Sequence
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolNode
import numexpr
from typing_extensions import TypedDict
@tool
def calculator(expression: str) -> str:
\"\"\"Calculate expression using Python's numexpr library.
Expression should be a single line mathematical expression
that solves the problem.
Examples:
"37593 * 67" for "37593 times 67"
"37593**(1/5)" for "37593^(1/5)"
\"\"\"
local_dict = {"pi": math.pi, "e": math.e}
return str(
numexpr.evaluate(
expression.strip(),
global_dict={}, # restrict access to globals
local_dict=local_dict, # add common mathematical functions
)
)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
tools = [calculator]
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
class ChainState(TypedDict):
\"\"\"LangGraph state.\"\"\"
messages: Annotated[Sequence[BaseMessage], add_messages]
async def acall_chain(state: ChainState, config: RunnableConfig):
last_message = state["messages"][-1]
response = await llm_with_tools.ainvoke(state["messages"], config)
return {"messages": [response]}
async def acall_model(state: ChainState, config: RunnableConfig):
response = await llm.ainvoke(state["messages"], config)
return {"messages": [response]}
graph_builder = StateGraph(ChainState)
graph_builder.add_node("call_tool", acall_chain)
graph_builder.add_node("execute_tool", ToolNode(tools))
graph_builder.add_node("call_model", acall_model)
graph_builder.set_entry_point("call_tool")
graph_builder.add_edge("call_tool", "execute_tool")
graph_builder.add_edge("execute_tool", "call_model")
graph_builder.add_edge("call_model", END)
chain = graph_builder.compile()
.. code-block:: python
example_query = "What is 551368 divided by 82"
events = chain.astream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
async for event in events:
event["messages"][-1].pretty_print()
.. code-block:: none
================================ Human Message =================================
What is 551368 divided by 82
================================== Ai Message ==================================
Tool Calls:
calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS)
Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS
Args:
expression: 551368 / 82
================================= Tool Message =================================
Name: calculator
6724.0
================================== Ai Message ==================================
551368 divided by 82 equals 6724.
Example:
.. code-block:: python
from langchain.chains import LLMMathChain
from langchain_community.llms import OpenAI
llm_math = LLMMathChain.from_llm(OpenAI())
"""
""" # noqa: E501
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None

View File

@@ -5,15 +5,27 @@ from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import Runnable
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.natbot.prompt import PROMPT
@deprecated(
since="0.2.13",
message=(
"Importing NatBotChain from langchain is deprecated and will be removed in "
"langchain 1.0. Please import from langchain_community instead: "
"from langchain_community.chains.natbot import NatBotChain. "
"You may need to pip install -U langchain-community."
),
removal="1.0",
)
class NatBotChain(Chain):
"""Implement an LLM driven browser.
@@ -37,7 +49,7 @@ class NatBotChain(Chain):
natbot = NatBotChain.from_default("Buy me a new hat.")
"""
llm_chain: LLMChain
llm_chain: Runnable
objective: str
"""Objective that NatBot is tasked with completing."""
llm: Optional[BaseLanguageModel] = None
@@ -60,7 +72,7 @@ class NatBotChain(Chain):
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()
return values
@classmethod
@@ -77,7 +89,7 @@ class NatBotChain(Chain):
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
) -> NatBotChain:
"""Load from LLM."""
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
llm_chain = PROMPT | llm | StrOutputParser()
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
@property
@@ -104,12 +116,14 @@ class NatBotChain(Chain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
url = inputs[self.input_url_key]
browser_content = inputs[self.input_browser_content_key]
llm_cmd = self.llm_chain.predict(
objective=self.objective,
url=url[:100],
previous_command=self.previous_command,
browser_content=browser_content[:4500],
callbacks=_run_manager.get_child(),
llm_cmd = self.llm_chain.invoke(
{
"objective": self.objective,
"url": url[:100],
"previous_command": self.previous_command,
"browser_content": browser_content[:4500],
},
config={"callbacks": _run_manager.get_child()},
)
llm_cmd = llm_cmd.strip()
self.previous_command = llm_cmd

View File

@@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable
from langchain.chains.llm import LLMChain
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
@@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor):
"""Document compressor that uses an LLM chain to extract
the relevant parts of documents."""
llm_chain: LLMChain
llm_chain: Runnable
"""LLM wrapper to use for compressing documents."""
get_input: Callable[[str, Document], dict] = default_get_input
"""Callable for constructing the chain input from the query and a Document."""
class Config:
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
@@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor):
compressed_docs = []
for doc in documents:
_input = self.get_input(query, doc)
output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
output = output_dict[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
output = self.llm_chain.prompt.output_parser.parse(output)
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
output = self.llm_chain.prompt.output_parser.parse(output)
else:
output = output_
if len(output) == 0:
continue
compressed_docs.append(
@@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
"""Compress page content of raw documents asynchronously."""
outputs = await asyncio.gather(
*[
self.llm_chain.apredict_and_parse(
**self.get_input(query, doc), callbacks=callbacks
)
self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks)
for doc in documents
]
)
@@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
"""Initialize from LLM."""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
_get_input = get_input if get_input is not None else default_get_input
llm_chain = LLMChain(llm=llm, prompt=_prompt, **(llm_chain_kwargs or {}))
if _prompt.output_parser is not None:
parser = _prompt.output_parser
else:
parser = StrOutputParser()
llm_chain = _prompt | llm | parser
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]

View File

@@ -5,7 +5,9 @@ from typing import Any, Callable, Dict, Optional, Sequence
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import RunnableConfig
from langchain.chains import LLMChain
@@ -32,13 +34,16 @@ def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
class LLMChainFilter(BaseDocumentCompressor):
"""Filter that drops documents that aren't relevant to the query."""
llm_chain: LLMChain
llm_chain: Runnable
"""LLM wrapper to use for filtering documents.
The chain prompt is expected to have a BooleanOutputParser."""
get_input: Callable[[str, Document], dict] = default_get_input
"""Callable for constructing the chain input from the query and a Document."""
class Config:
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
@@ -56,11 +61,15 @@ class LLMChainFilter(BaseDocumentCompressor):
documents,
)
for output_dict, doc in outputs:
for output_, doc in outputs:
include_doc = None
output = output_dict[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
else:
if isinstance(output_, bool):
include_doc = output_
if include_doc:
filtered_docs.append(doc)
@@ -82,11 +91,15 @@ class LLMChainFilter(BaseDocumentCompressor):
),
documents,
)
for output_dict, doc in outputs:
for output_, doc in outputs:
include_doc = None
output = output_dict[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
else:
if isinstance(output_, bool):
include_doc = output_
if include_doc:
filtered_docs.append(doc)
@@ -110,5 +123,9 @@ class LLMChainFilter(BaseDocumentCompressor):
A LLMChainFilter that uses the given language model.
"""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
llm_chain = LLMChain(llm=llm, prompt=_prompt)
if _prompt.output_parser is not None:
parser = _prompt.output_parser
else:
parser = StrOutputParser()
llm_chain = _prompt | llm | parser
return cls(llm_chain=llm_chain, **kwargs)

View File

@@ -7,11 +7,11 @@ from langchain_core.callbacks import (
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain.chains.llm import LLMChain
from langchain_core.runnables import Runnable
logger = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ class RePhraseQueryRetriever(BaseRetriever):
Then, retrieve docs for the re-phrased query."""
retriever: BaseRetriever
llm_chain: LLMChain
llm_chain: Runnable
@classmethod
def from_llm(
@@ -51,8 +51,7 @@ class RePhraseQueryRetriever(BaseRetriever):
Returns:
RePhraseQueryRetriever
"""
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = prompt | llm | StrOutputParser()
return cls(
retriever=retriever,
llm_chain=llm_chain,
@@ -72,8 +71,9 @@ class RePhraseQueryRetriever(BaseRetriever):
Returns:
Relevant documents for re-phrased question
"""
response = self.llm_chain(query, callbacks=run_manager.get_child())
re_phrased_question = response["text"]
re_phrased_question = self.llm_chain.invoke(
query, {"callbacks": run_manager.get_child()}
)
logger.info(f"Re-phrased question: {re_phrased_question}")
docs = self.retriever.invoke(
re_phrased_question, config={"callbacks": run_manager.get_child()}

View File

@@ -1,59 +0,0 @@
"""Test functionality related to natbot."""
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain.chains.natbot.base import NatBotChain
class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000:
return "foo"
else:
return "bar"
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
def get_num_tokens(self, text: str) -> int:
return len(text.split())
@property
def _identifying_params(self) -> Dict[str, Any]:
return {}
def test_proper_inputs() -> None:
"""Test that natbot shortens inputs correctly."""
nat_bot_chain = NatBotChain.from_llm(FakeLLM(), objective="testing")
url = "foo" * 10000
browser_content = "foo" * 10000
output = nat_bot_chain.execute(url, browser_content)
assert output == "bar"
def test_variable_key_naming() -> None:
"""Test that natbot handles variable key naming correctly."""
nat_bot_chain = NatBotChain.from_llm(
FakeLLM(),
objective="testing",
input_url_key="u",
input_browser_content_key="b",
output_key="c",
)
output = nat_bot_chain.execute("foo", "foo")
assert output == "bar"

View File

@@ -0,0 +1,84 @@
from langchain_core.documents import Document
from langchain_core.language_models import FakeListChatModel
from langchain.retrievers.document_compressors import LLMChainExtractor
def test_llm_chain_extractor() -> None:
documents = [
Document(
page_content=(
"The sky is blue. Candlepin bowling is popular in New England."
),
metadata={"a": 1},
),
Document(
page_content=(
"Mercury is the closest planet to the Sun. "
"Candlepin bowling balls are smaller."
),
metadata={"b": 2},
),
Document(page_content="The moon is round.", metadata={"c": 3}),
]
llm = FakeListChatModel(
responses=[
"Candlepin bowling is popular in New England.",
"Candlepin bowling balls are smaller.",
"NO_OUTPUT",
]
)
doc_compressor = LLMChainExtractor.from_llm(llm)
output = doc_compressor.compress_documents(
documents, "Tell me about Candlepin bowling."
)
expected = documents = [
Document(
page_content="Candlepin bowling is popular in New England.",
metadata={"a": 1},
),
Document(
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
),
]
assert output == expected
async def test_llm_chain_extractor_async() -> None:
documents = [
Document(
page_content=(
"The sky is blue. Candlepin bowling is popular in New England."
),
metadata={"a": 1},
),
Document(
page_content=(
"Mercury is the closest planet to the Sun. "
"Candlepin bowling balls are smaller."
),
metadata={"b": 2},
),
Document(page_content="The moon is round.", metadata={"c": 3}),
]
llm = FakeListChatModel(
responses=[
"Candlepin bowling is popular in New England.",
"Candlepin bowling balls are smaller.",
"NO_OUTPUT",
]
)
doc_compressor = LLMChainExtractor.from_llm(llm)
output = await doc_compressor.acompress_documents(
documents, "Tell me about Candlepin bowling."
)
expected = documents = [
Document(
page_content="Candlepin bowling is popular in New England.",
metadata={"a": 1},
),
Document(
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
),
]
assert output == expected

View File

@@ -0,0 +1,46 @@
from langchain_core.documents import Document
from langchain_core.language_models import FakeListChatModel
from langchain.retrievers.document_compressors import LLMChainFilter
def test_llm_chain_filter() -> None:
documents = [
Document(
page_content="Candlepin bowling is popular in New England.",
metadata={"a": 1},
),
Document(
page_content="Candlepin bowling balls are smaller.",
metadata={"b": 2},
),
Document(page_content="The moon is round.", metadata={"c": 3}),
]
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
doc_compressor = LLMChainFilter.from_llm(llm)
output = doc_compressor.compress_documents(
documents, "Tell me about Candlepin bowling."
)
expected = documents[:2]
assert output == expected
async def test_llm_chain_extractor_async() -> None:
documents = [
Document(
page_content="Candlepin bowling is popular in New England.",
metadata={"a": 1},
),
Document(
page_content="Candlepin bowling balls are smaller.",
metadata={"b": 2},
),
Document(page_content="The moon is round.", metadata={"c": 3}),
]
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
doc_compressor = LLMChainFilter.from_llm(llm)
output = await doc_compressor.acompress_documents(
documents, "Tell me about Candlepin bowling."
)
expected = documents[:2]
assert output == expected