mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 20:16:52 +00:00
langchain[patch]: deprecate various chains (#25310)
- [x] NatbotChain: move to community, deprecate langchain version. Update to use `prompt | llm | output_parser` instead of LLMChain. - [x] LLMMathChain: deprecate + add langgraph replacement example to API ref - [x] HypotheticalDocumentEmbedder (retriever): update to use `prompt | llm | output_parser` instead of LLMChain - [x] FlareChain: update to use `prompt | llm | output_parser` instead of LLMChain - [x] ConstitutionalChain: deprecate + add langgraph replacement example to API ref - [x] LLMChainExtractor (document compressor): update to use `prompt | llm | output_parser` instead of LLMChain - [x] LLMChainFilter (document compressor): update to use `prompt | llm | output_parser` instead of LLMChain - [x] RePhraseQueryRetriever (retriever): update to use `prompt | llm | output_parser` instead of LLMChain
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
@@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.13",
|
||||
message=(
|
||||
"This class is deprecated and will be removed in langchain 1.0. "
|
||||
"See API reference for replacement: "
|
||||
"https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
|
||||
),
|
||||
removal="1.0",
|
||||
)
|
||||
class ConstitutionalChain(Chain):
|
||||
"""Chain for applying constitutional principles.
|
||||
|
||||
Note: this class is deprecated. See below for a replacement implementation
|
||||
using LangGraph. The benefits of this implementation are:
|
||||
|
||||
- Uses LLM tool calling features instead of parsing string responses;
|
||||
- Support for both token-by-token and step-by-step streaming;
|
||||
- Support for checkpointing and memory of chat history;
|
||||
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
||||
|
||||
Install LangGraph with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langgraph
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from langchain.chains.constitutional_ai.prompts import (
|
||||
CRITIQUE_PROMPT,
|
||||
REVISION_PROMPT,
|
||||
)
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
class Critique(TypedDict):
|
||||
\"\"\"Generate a critique, if needed.\"\"\"
|
||||
critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
|
||||
critique: Annotated[str, ..., "If needed, the critique."]
|
||||
|
||||
critique_prompt = ChatPromptTemplate.from_template(
|
||||
"Critique this response according to the critique request. "
|
||||
"If no critique is needed, specify that.\\n\\n"
|
||||
"Query: {query}\\n\\n"
|
||||
"Response: {response}\\n\\n"
|
||||
"Critique request: {critique_request}"
|
||||
)
|
||||
|
||||
revision_prompt = ChatPromptTemplate.from_template(
|
||||
"Revise this response according to the critique and reivsion request.\\n\\n"
|
||||
"Query: {query}\\n\\n"
|
||||
"Response: {response}\\n\\n"
|
||||
"Critique request: {critique_request}\\n\\n"
|
||||
"Critique: {critique}\\n\\n"
|
||||
"If the critique does not identify anything worth changing, ignore the "
|
||||
"revision request and return 'No revisions needed'. If the critique "
|
||||
"does identify something worth changing, revise the response based on "
|
||||
"the revision request.\\n\\n"
|
||||
"Revision Request: {revision_request}"
|
||||
)
|
||||
|
||||
chain = llm | StrOutputParser()
|
||||
critique_chain = critique_prompt | llm.with_structured_output(Critique)
|
||||
revision_chain = revision_prompt | llm | StrOutputParser()
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
query: str
|
||||
constitutional_principles: List[ConstitutionalPrinciple]
|
||||
initial_response: str
|
||||
critiques_and_revisions: List[Tuple[str, str]]
|
||||
response: str
|
||||
|
||||
|
||||
async def generate_response(state: State):
|
||||
\"\"\"Generate initial response.\"\"\"
|
||||
response = await chain.ainvoke(state["query"])
|
||||
return {"response": response, "initial_response": response}
|
||||
|
||||
async def critique_and_revise(state: State):
|
||||
\"\"\"Critique and revise response according to principles.\"\"\"
|
||||
critiques_and_revisions = []
|
||||
response = state["initial_response"]
|
||||
for principle in state["constitutional_principles"]:
|
||||
critique = await critique_chain.ainvoke(
|
||||
{
|
||||
"query": state["query"],
|
||||
"response": response,
|
||||
"critique_request": principle.critique_request,
|
||||
}
|
||||
)
|
||||
if critique["critique_needed"]:
|
||||
revision = await revision_chain.ainvoke(
|
||||
{
|
||||
"query": state["query"],
|
||||
"response": response,
|
||||
"critique_request": principle.critique_request,
|
||||
"critique": critique["critique"],
|
||||
"revision_request": principle.revision_request,
|
||||
}
|
||||
)
|
||||
response = revision
|
||||
critiques_and_revisions.append((critique["critique"], revision))
|
||||
else:
|
||||
critiques_and_revisions.append((critique["critique"], ""))
|
||||
return {
|
||||
"critiques_and_revisions": critiques_and_revisions,
|
||||
"response": response,
|
||||
}
|
||||
|
||||
graph = StateGraph(State)
|
||||
graph.add_node("generate_response", generate_response)
|
||||
graph.add_node("critique_and_revise", critique_and_revise)
|
||||
|
||||
graph.add_edge(START, "generate_response")
|
||||
graph.add_edge("generate_response", "critique_and_revise")
|
||||
graph.add_edge("critique_and_revise", END)
|
||||
app = graph.compile()
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
constitutional_principles=[
|
||||
ConstitutionalPrinciple(
|
||||
critique_request="Tell if this answer is good.",
|
||||
revision_request="Give a better answer.",
|
||||
)
|
||||
]
|
||||
|
||||
query = "What is the meaning of life? Answer in 10 words or fewer."
|
||||
|
||||
async for step in app.astream(
|
||||
{"query": query, "constitutional_principles": constitutional_principles},
|
||||
stream_mode="values",
|
||||
):
|
||||
subset = ["initial_response", "critiques_and_revisions", "response"]
|
||||
print({k: v for k, v in step.items() if k in subset})
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
|
||||
)
|
||||
|
||||
constitutional_chain.run(question="What is the meaning of life?")
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
chain: LLMChain
|
||||
constitutional_principles: List[ConstitutionalPrinciple]
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -9,10 +8,12 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.flare.prompts import (
|
||||
@@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
class _ResponseChain(LLMChain):
|
||||
"""Base class for chains that generate responses."""
|
||||
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return self.prompt.input_variables
|
||||
|
||||
def generate_tokens_and_log_probs(
|
||||
self,
|
||||
_input: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
||||
llm_result = self.generate([_input], run_manager=run_manager)
|
||||
return self._extract_tokens_and_log_probs(llm_result.generations[0])
|
||||
|
||||
@abstractmethod
|
||||
def _extract_tokens_and_log_probs(
|
||||
self, generations: List[Generation]
|
||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
||||
"""Extract tokens and log probs from response."""
|
||||
|
||||
|
||||
class _OpenAIResponseChain(_ResponseChain):
|
||||
"""Chain that generates responses from user input and context."""
|
||||
|
||||
llm: BaseLanguageModel
|
||||
|
||||
def _extract_tokens_and_log_probs(
|
||||
self, generations: List[Generation]
|
||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
||||
tokens = []
|
||||
log_probs = []
|
||||
for gen in generations:
|
||||
if gen.generation_info is None:
|
||||
raise ValueError
|
||||
tokens.extend(gen.generation_info["logprobs"]["tokens"])
|
||||
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
|
||||
return tokens, log_probs
|
||||
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
|
||||
"""Extract tokens and log probabilities from chat model response."""
|
||||
tokens = []
|
||||
log_probs = []
|
||||
for token in response.response_metadata["logprobs"]["content"]:
|
||||
tokens.append(token["token"])
|
||||
log_probs.append(token["logprob"])
|
||||
return tokens, log_probs
|
||||
|
||||
|
||||
class QuestionGeneratorChain(LLMChain):
|
||||
@@ -111,9 +75,9 @@ class FlareChain(Chain):
|
||||
"""Chain that combines a retriever, a question generator,
|
||||
and a response generator."""
|
||||
|
||||
question_generator_chain: QuestionGeneratorChain
|
||||
question_generator_chain: Runnable
|
||||
"""Chain that generates questions from uncertain spans."""
|
||||
response_chain: _ResponseChain
|
||||
response_chain: Runnable
|
||||
"""Chain that generates responses from user input and context."""
|
||||
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
||||
"""Parser that determines whether the chain is finished."""
|
||||
@@ -152,12 +116,16 @@ class FlareChain(Chain):
|
||||
for question in questions:
|
||||
docs.extend(self.retriever.invoke(question))
|
||||
context = "\n\n".join(d.page_content for d in docs)
|
||||
result = self.response_chain.predict(
|
||||
user_input=user_input,
|
||||
context=context,
|
||||
response=response,
|
||||
callbacks=callbacks,
|
||||
result = self.response_chain.invoke(
|
||||
{
|
||||
"user_input": user_input,
|
||||
"context": context,
|
||||
"response": response,
|
||||
},
|
||||
{"callbacks": callbacks},
|
||||
)
|
||||
if isinstance(result, AIMessage):
|
||||
result = result.content
|
||||
marginal, finished = self.output_parser.parse(result)
|
||||
return marginal, finished
|
||||
|
||||
@@ -178,13 +146,18 @@ class FlareChain(Chain):
|
||||
for span in low_confidence_spans
|
||||
]
|
||||
callbacks = _run_manager.get_child()
|
||||
question_gen_outputs = self.question_generator_chain.apply(
|
||||
question_gen_inputs, callbacks=callbacks
|
||||
)
|
||||
questions = [
|
||||
output[self.question_generator_chain.output_keys[0]]
|
||||
for output in question_gen_outputs
|
||||
]
|
||||
if isinstance(self.question_generator_chain, LLMChain):
|
||||
question_gen_outputs = self.question_generator_chain.apply(
|
||||
question_gen_inputs, callbacks=callbacks
|
||||
)
|
||||
questions = [
|
||||
output[self.question_generator_chain.output_keys[0]]
|
||||
for output in question_gen_outputs
|
||||
]
|
||||
else:
|
||||
questions = self.question_generator_chain.batch(
|
||||
question_gen_inputs, config={"callbacks": callbacks}
|
||||
)
|
||||
_run_manager.on_text(
|
||||
f"Generated Questions: {questions}", color="yellow", end="\n"
|
||||
)
|
||||
@@ -206,8 +179,10 @@ class FlareChain(Chain):
|
||||
f"Current Response: {response}", color="blue", end="\n"
|
||||
)
|
||||
_input = {"user_input": user_input, "context": "", "response": response}
|
||||
tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
|
||||
_input, run_manager=_run_manager
|
||||
tokens, log_probs = _extract_tokens_and_log_probs(
|
||||
self.response_chain.invoke(
|
||||
_input, {"callbacks": _run_manager.get_child()}
|
||||
)
|
||||
)
|
||||
low_confidence_spans = _low_confidence_spans(
|
||||
tokens,
|
||||
@@ -251,18 +226,16 @@ class FlareChain(Chain):
|
||||
FlareChain class with the given language model.
|
||||
"""
|
||||
try:
|
||||
from langchain_openai import OpenAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OpenAI is required for FlareChain. "
|
||||
"Please install langchain-openai."
|
||||
"pip install langchain-openai"
|
||||
)
|
||||
question_gen_chain = QuestionGeneratorChain(llm=llm)
|
||||
response_llm = OpenAI(
|
||||
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
|
||||
)
|
||||
response_chain = _OpenAIResponseChain(llm=response_llm)
|
||||
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
|
||||
response_chain = PROMPT | llm
|
||||
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
||||
return cls(
|
||||
question_generator_chain=question_gen_chain,
|
||||
response_chain=response_chain,
|
||||
|
@@ -11,7 +11,9 @@ import numpy as np
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
@@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
"""
|
||||
|
||||
base_embeddings: Embeddings
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.input_keys
|
||||
return self.llm_chain.input_schema.schema()["required"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.output_keys
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
return self.llm_chain.output_keys
|
||||
else:
|
||||
return ["text"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call the base embeddings."""
|
||||
@@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate a hypothetical document and embedded it."""
|
||||
var_name = self.llm_chain.input_keys[0]
|
||||
result = self.llm_chain.generate([{var_name: text}])
|
||||
documents = [generation.text for generation in result.generations[0]]
|
||||
var_name = self.input_keys[0]
|
||||
result = self.llm_chain.invoke({var_name: text})
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
documents = [result[self.output_keys[0]]]
|
||||
else:
|
||||
documents = [result]
|
||||
embeddings = self.embed_documents(documents)
|
||||
return self.combine_embeddings(embeddings)
|
||||
|
||||
@@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
) -> Dict[str, str]:
|
||||
"""Call the internal llm chain."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
return self.llm_chain.invoke(
|
||||
inputs, config={"callbacks": _run_manager.get_child()}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
f"of {list(PROMPT_MAP.keys())}."
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = prompt | llm | StrOutputParser()
|
||||
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@property
|
||||
|
@@ -7,6 +7,7 @@ import re
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
@@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.13",
|
||||
message=(
|
||||
"This class is deprecated and will be removed in langchain 1.0. "
|
||||
"See API reference for replacement: "
|
||||
"https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501
|
||||
),
|
||||
removal="1.0",
|
||||
)
|
||||
class LLMMathChain(Chain):
|
||||
"""Chain that interprets a prompt and executes python code to do math.
|
||||
|
||||
Note: this class is deprecated. See below for a replacement implementation
|
||||
using LangGraph. The benefits of this implementation are:
|
||||
|
||||
- Uses LLM tool calling features;
|
||||
- Support for both token-by-token and step-by-step streaming;
|
||||
- Support for checkpointing and memory of chat history;
|
||||
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
||||
|
||||
Install LangGraph with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langgraph
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import math
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.prebuilt.tool_node import ToolNode
|
||||
import numexpr
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@tool
|
||||
def calculator(expression: str) -> str:
|
||||
\"\"\"Calculate expression using Python's numexpr library.
|
||||
|
||||
Expression should be a single line mathematical expression
|
||||
that solves the problem.
|
||||
|
||||
Examples:
|
||||
"37593 * 67" for "37593 times 67"
|
||||
"37593**(1/5)" for "37593^(1/5)"
|
||||
\"\"\"
|
||||
local_dict = {"pi": math.pi, "e": math.e}
|
||||
return str(
|
||||
numexpr.evaluate(
|
||||
expression.strip(),
|
||||
global_dict={}, # restrict access to globals
|
||||
local_dict=local_dict, # add common mathematical functions
|
||||
)
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
tools = [calculator]
|
||||
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
|
||||
|
||||
class ChainState(TypedDict):
|
||||
\"\"\"LangGraph state.\"\"\"
|
||||
|
||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
|
||||
async def acall_chain(state: ChainState, config: RunnableConfig):
|
||||
last_message = state["messages"][-1]
|
||||
response = await llm_with_tools.ainvoke(state["messages"], config)
|
||||
return {"messages": [response]}
|
||||
|
||||
async def acall_model(state: ChainState, config: RunnableConfig):
|
||||
response = await llm.ainvoke(state["messages"], config)
|
||||
return {"messages": [response]}
|
||||
|
||||
graph_builder = StateGraph(ChainState)
|
||||
graph_builder.add_node("call_tool", acall_chain)
|
||||
graph_builder.add_node("execute_tool", ToolNode(tools))
|
||||
graph_builder.add_node("call_model", acall_model)
|
||||
graph_builder.set_entry_point("call_tool")
|
||||
graph_builder.add_edge("call_tool", "execute_tool")
|
||||
graph_builder.add_edge("execute_tool", "call_model")
|
||||
graph_builder.add_edge("call_model", END)
|
||||
chain = graph_builder.compile()
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
example_query = "What is 551368 divided by 82"
|
||||
|
||||
events = chain.astream(
|
||||
{"messages": [("user", example_query)]},
|
||||
stream_mode="values",
|
||||
)
|
||||
async for event in events:
|
||||
event["messages"][-1].pretty_print()
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
================================ Human Message =================================
|
||||
|
||||
What is 551368 divided by 82
|
||||
================================== Ai Message ==================================
|
||||
Tool Calls:
|
||||
calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS)
|
||||
Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS
|
||||
Args:
|
||||
expression: 551368 / 82
|
||||
================================= Tool Message =================================
|
||||
Name: calculator
|
||||
|
||||
6724.0
|
||||
================================== Ai Message ==================================
|
||||
|
||||
551368 divided by 82 equals 6724.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import LLMMathChain
|
||||
from langchain_community.llms import OpenAI
|
||||
llm_math = LLMMathChain.from_llm(OpenAI())
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
|
@@ -5,15 +5,27 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.natbot.prompt import PROMPT
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.13",
|
||||
message=(
|
||||
"Importing NatBotChain from langchain is deprecated and will be removed in "
|
||||
"langchain 1.0. Please import from langchain_community instead: "
|
||||
"from langchain_community.chains.natbot import NatBotChain. "
|
||||
"You may need to pip install -U langchain-community."
|
||||
),
|
||||
removal="1.0",
|
||||
)
|
||||
class NatBotChain(Chain):
|
||||
"""Implement an LLM driven browser.
|
||||
|
||||
@@ -37,7 +49,7 @@ class NatBotChain(Chain):
|
||||
natbot = NatBotChain.from_default("Buy me a new hat.")
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
objective: str
|
||||
"""Objective that NatBot is tasked with completing."""
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
@@ -60,7 +72,7 @@ class NatBotChain(Chain):
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
|
||||
values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
@@ -77,7 +89,7 @@ class NatBotChain(Chain):
|
||||
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
|
||||
) -> NatBotChain:
|
||||
"""Load from LLM."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||
llm_chain = PROMPT | llm | StrOutputParser()
|
||||
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
||||
|
||||
@property
|
||||
@@ -104,12 +116,14 @@ class NatBotChain(Chain):
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
url = inputs[self.input_url_key]
|
||||
browser_content = inputs[self.input_browser_content_key]
|
||||
llm_cmd = self.llm_chain.predict(
|
||||
objective=self.objective,
|
||||
url=url[:100],
|
||||
previous_command=self.previous_command,
|
||||
browser_content=browser_content[:4500],
|
||||
callbacks=_run_manager.get_child(),
|
||||
llm_cmd = self.llm_chain.invoke(
|
||||
{
|
||||
"objective": self.objective,
|
||||
"url": url[:100],
|
||||
"previous_command": self.previous_command,
|
||||
"browser_content": browser_content[:4500],
|
||||
},
|
||||
config={"callbacks": _run_manager.get_child()},
|
||||
)
|
||||
llm_cmd = llm_cmd.strip()
|
||||
self.previous_command = llm_cmd
|
||||
|
@@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
@@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""Document compressor that uses an LLM chain to extract
|
||||
the relevant parts of documents."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
"""LLM wrapper to use for compressing documents."""
|
||||
|
||||
get_input: Callable[[str, Document], dict] = default_get_input
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
compressed_docs = []
|
||||
for doc in documents:
|
||||
_input = self.get_input(query, doc)
|
||||
output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
||||
output = output_dict[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
output = self.llm_chain.prompt.output_parser.parse(output)
|
||||
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
output = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
output = output_
|
||||
if len(output) == 0:
|
||||
continue
|
||||
compressed_docs.append(
|
||||
@@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""Compress page content of raw documents asynchronously."""
|
||||
outputs = await asyncio.gather(
|
||||
*[
|
||||
self.llm_chain.apredict_and_parse(
|
||||
**self.get_input(query, doc), callbacks=callbacks
|
||||
)
|
||||
self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks)
|
||||
for doc in documents
|
||||
]
|
||||
)
|
||||
@@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
_get_input = get_input if get_input is not None else default_get_input
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt, **(llm_chain_kwargs or {}))
|
||||
if _prompt.output_parser is not None:
|
||||
parser = _prompt.output_parser
|
||||
else:
|
||||
parser = StrOutputParser()
|
||||
llm_chain = _prompt | llm | parser
|
||||
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
|
||||
|
@@ -5,7 +5,9 @@ from typing import Any, Callable, Dict, Optional, Sequence
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
@@ -32,13 +34,16 @@ def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
|
||||
class LLMChainFilter(BaseDocumentCompressor):
|
||||
"""Filter that drops documents that aren't relevant to the query."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
"""LLM wrapper to use for filtering documents.
|
||||
The chain prompt is expected to have a BooleanOutputParser."""
|
||||
|
||||
get_input: Callable[[str, Document], dict] = default_get_input
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@@ -56,11 +61,15 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
documents,
|
||||
)
|
||||
|
||||
for output_dict, doc in outputs:
|
||||
for output_, doc in outputs:
|
||||
include_doc = None
|
||||
output = output_dict[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
if isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
@@ -82,11 +91,15 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
),
|
||||
documents,
|
||||
)
|
||||
for output_dict, doc in outputs:
|
||||
for output_, doc in outputs:
|
||||
include_doc = None
|
||||
output = output_dict[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
if isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
@@ -110,5 +123,9 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
A LLMChainFilter that uses the given language model.
|
||||
"""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
if _prompt.output_parser is not None:
|
||||
parser = _prompt.output_parser
|
||||
else:
|
||||
parser = StrOutputParser()
|
||||
llm_chain = _prompt | llm | parser
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
@@ -7,11 +7,11 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,7 +30,7 @@ class RePhraseQueryRetriever(BaseRetriever):
|
||||
Then, retrieve docs for the re-phrased query."""
|
||||
|
||||
retriever: BaseRetriever
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@@ -51,8 +51,7 @@ class RePhraseQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
RePhraseQueryRetriever
|
||||
"""
|
||||
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = prompt | llm | StrOutputParser()
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
llm_chain=llm_chain,
|
||||
@@ -72,8 +71,9 @@ class RePhraseQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
Relevant documents for re-phrased question
|
||||
"""
|
||||
response = self.llm_chain(query, callbacks=run_manager.get_child())
|
||||
re_phrased_question = response["text"]
|
||||
re_phrased_question = self.llm_chain.invoke(
|
||||
query, {"callbacks": run_manager.get_child()}
|
||||
)
|
||||
logger.info(f"Re-phrased question: {re_phrased_question}")
|
||||
docs = self.retriever.invoke(
|
||||
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
||||
|
@@ -1,59 +0,0 @@
|
||||
"""Test functionality related to natbot."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
|
||||
from langchain.chains.natbot.base import NatBotChain
|
||||
|
||||
|
||||
class FakeLLM(LLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Return `foo` if longer than 10000 words, else `bar`."""
|
||||
if len(prompt) > 10000:
|
||||
return "foo"
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fake"
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return len(text.split())
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def test_proper_inputs() -> None:
|
||||
"""Test that natbot shortens inputs correctly."""
|
||||
nat_bot_chain = NatBotChain.from_llm(FakeLLM(), objective="testing")
|
||||
url = "foo" * 10000
|
||||
browser_content = "foo" * 10000
|
||||
output = nat_bot_chain.execute(url, browser_content)
|
||||
assert output == "bar"
|
||||
|
||||
|
||||
def test_variable_key_naming() -> None:
|
||||
"""Test that natbot handles variable key naming correctly."""
|
||||
nat_bot_chain = NatBotChain.from_llm(
|
||||
FakeLLM(),
|
||||
objective="testing",
|
||||
input_url_key="u",
|
||||
input_browser_content_key="b",
|
||||
output_key="c",
|
||||
)
|
||||
output = nat_bot_chain.execute("foo", "foo")
|
||||
assert output == "bar"
|
@@ -0,0 +1,84 @@
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
|
||||
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||
|
||||
|
||||
def test_llm_chain_extractor() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
page_content=(
|
||||
"The sky is blue. Candlepin bowling is popular in New England."
|
||||
),
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content=(
|
||||
"Mercury is the closest planet to the Sun. "
|
||||
"Candlepin bowling balls are smaller."
|
||||
),
|
||||
metadata={"b": 2},
|
||||
),
|
||||
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||
]
|
||||
llm = FakeListChatModel(
|
||||
responses=[
|
||||
"Candlepin bowling is popular in New England.",
|
||||
"Candlepin bowling balls are smaller.",
|
||||
"NO_OUTPUT",
|
||||
]
|
||||
)
|
||||
doc_compressor = LLMChainExtractor.from_llm(llm)
|
||||
output = doc_compressor.compress_documents(
|
||||
documents, "Tell me about Candlepin bowling."
|
||||
)
|
||||
expected = documents = [
|
||||
Document(
|
||||
page_content="Candlepin bowling is popular in New England.",
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
|
||||
),
|
||||
]
|
||||
assert output == expected
|
||||
|
||||
|
||||
async def test_llm_chain_extractor_async() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
page_content=(
|
||||
"The sky is blue. Candlepin bowling is popular in New England."
|
||||
),
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content=(
|
||||
"Mercury is the closest planet to the Sun. "
|
||||
"Candlepin bowling balls are smaller."
|
||||
),
|
||||
metadata={"b": 2},
|
||||
),
|
||||
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||
]
|
||||
llm = FakeListChatModel(
|
||||
responses=[
|
||||
"Candlepin bowling is popular in New England.",
|
||||
"Candlepin bowling balls are smaller.",
|
||||
"NO_OUTPUT",
|
||||
]
|
||||
)
|
||||
doc_compressor = LLMChainExtractor.from_llm(llm)
|
||||
output = await doc_compressor.acompress_documents(
|
||||
documents, "Tell me about Candlepin bowling."
|
||||
)
|
||||
expected = documents = [
|
||||
Document(
|
||||
page_content="Candlepin bowling is popular in New England.",
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
|
||||
),
|
||||
]
|
||||
assert output == expected
|
@@ -0,0 +1,46 @@
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
|
||||
from langchain.retrievers.document_compressors import LLMChainFilter
|
||||
|
||||
|
||||
def test_llm_chain_filter() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
page_content="Candlepin bowling is popular in New England.",
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content="Candlepin bowling balls are smaller.",
|
||||
metadata={"b": 2},
|
||||
),
|
||||
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||
]
|
||||
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
|
||||
doc_compressor = LLMChainFilter.from_llm(llm)
|
||||
output = doc_compressor.compress_documents(
|
||||
documents, "Tell me about Candlepin bowling."
|
||||
)
|
||||
expected = documents[:2]
|
||||
assert output == expected
|
||||
|
||||
|
||||
async def test_llm_chain_extractor_async() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
page_content="Candlepin bowling is popular in New England.",
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content="Candlepin bowling balls are smaller.",
|
||||
metadata={"b": 2},
|
||||
),
|
||||
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||
]
|
||||
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
|
||||
doc_compressor = LLMChainFilter.from_llm(llm)
|
||||
output = await doc_compressor.acompress_documents(
|
||||
documents, "Tell me about Candlepin bowling."
|
||||
)
|
||||
expected = documents[:2]
|
||||
assert output == expected
|
Reference in New Issue
Block a user