diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 235d8f905e5..1e54f6a583f 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -108,8 +108,6 @@ def _config_with_context( raise ValueError( f"Deadlock detected between context keys {key} and {dep}" ) - if len(getters) < 1: - raise ValueError(f"Expected at least one getter for context key {key}") if len(setters) != 1: raise ValueError(f"Expected exactly one setter for context key {key}") setter_idx = setters[0][1] @@ -118,7 +116,8 @@ def _config_with_context( f"Context setter for key {key} must be defined after all getters." ) - context_funcs[getters[0][0].id] = partial(getter, events[key], values) + if getters: + context_funcs[getters[0][0].id] = partial(getter, events[key], values) context_funcs[setters[0][0].id] = partial(setter, events[key], values) return patch_config(config, configurable=context_funcs) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index ca3548a752d..dc0822fc066 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union import yaml +from langchain_core.beta.runnables.context import Context from langchain_core.load.dump import dumpd from langchain_core.memory import BaseMemory from langchain_core.outputs import RunInfo @@ -20,7 +21,9 @@ from langchain_core.pydantic_v1 import ( validator, ) from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable +from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.configurable import ConfigurableFieldSpec +from langchain_core.runnables.passthrough import RunnablePassthrough from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( @@ -674,6 +677,22 @@ class RunnableChain(Chain, ABC): def as_runnable(self) -> Runnable: ... + def as_runnable_wrapped(self, return_only_outputs: bool = False) -> Runnable: + context = Context.create_scope("runnable-chain") + + def prep_outputs(all: Dict[str, Any]) -> Dict[str, Any]: + print("before outprep", all) + return self.prep_outputs(all["inputs"], all["outputs"], return_only_outputs) + + return ( + self.prep_inputs + | RunnableLambda(lambda i: print("after prep", i) or i.copy()) + | context.setter("inputs") + | self.as_runnable() + | {"outputs": RunnablePassthrough(), "inputs": context.getter("inputs")} + | prep_outputs + ) + @property def InputType(self) -> Type[Dict[str, Any]]: return self.as_runnable().InputType @@ -692,39 +711,51 @@ class RunnableChain(Chain, ABC): @property def config_specs(self) -> List[ConfigurableFieldSpec]: - return self.as_runnable().config_specs + return self.as_runnable_wrapped().config_specs def invoke( self, input: Dict[str, Any], config: Optional[RunnableConfig] = None, + return_only_outputs: bool = False, **kwargs: Any, ) -> Dict[str, Any]: - return self.as_runnable().invoke(input, config, **kwargs) + return self.as_runnable_wrapped(return_only_outputs).invoke( + input, config, **kwargs + ) async def ainvoke( self, input: Dict[str, Any], config: Optional[RunnableConfig] = None, + return_only_outputs: bool = False, **kwargs: Any, ) -> Dict[str, Any]: - return await self.as_runnable().ainvoke(input, config, **kwargs) + return await self.as_runnable_wrapped(return_only_outputs).ainvoke( + input, config, **kwargs + ) def stream( self, input: Dict[str, Any], config: Optional[RunnableConfig] = None, + return_only_outputs: bool = False, **kwargs: Any, ) -> Iterator[Dict[str, Any]]: - yield from self.as_runnable().stream(input, config, **kwargs) + yield from self.as_runnable_wrapped(return_only_outputs).stream( + input, config, **kwargs + ) async def astream( self, input: Dict[str, Any], config: Optional[RunnableConfig] = None, + return_only_outputs: bool = False, **kwargs: Any, ) -> AsyncIterator[Dict[str, Any]]: - async for item in self.as_runnable().astream(input, config, **kwargs): + async for item in self.as_runnable_wrapped(return_only_outputs).astream( + input, config, **kwargs + ): yield item def batch( @@ -733,9 +764,10 @@ class RunnableChain(Chain, ABC): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, return_exceptions: bool = False, + return_only_outputs: bool = False, **kwargs: Any, ) -> List[Dict[str, Any]]: - return self.as_runnable().batch( + return self.as_runnable_wrapped(return_only_outputs).batch( inputs, config, **kwargs, return_exceptions=return_exceptions ) @@ -745,9 +777,10 @@ class RunnableChain(Chain, ABC): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, return_exceptions: bool = False, + return_only_outputs: bool = False, **kwargs: Any, ) -> List[Dict[str, Any]]: - return await self.as_runnable().abatch( + return await self.as_runnable_wrapped(return_only_outputs).abatch( inputs, config, **kwargs, return_exceptions=return_exceptions ) @@ -755,15 +788,21 @@ class RunnableChain(Chain, ABC): self, input: Iterator[Dict[str, Any]], config: Optional[RunnableConfig] = None, + return_only_outputs: bool = False, **kwargs: Any | None, ) -> Iterator[Dict[str, Any]]: - yield from self.as_runnable().transform(input, config, **kwargs) + yield from self.as_runnable_wrapped(return_only_outputs).transform( + input, config, **kwargs + ) async def atransform( self, input: AsyncIterator[Dict[str, Any]], config: Optional[RunnableConfig] = None, + return_only_outputs: bool = False, **kwargs: Any | None, ) -> AsyncIterator[Dict[str, Any]]: - async for chunk in super().atransform(input, config, **kwargs): + async for chunk in self.as_runnable_wrapped(return_only_outputs).atransform( + input, config, **kwargs + ): yield chunk diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 147610f9764..759835289bc 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -15,8 +15,13 @@ from langchain_core.messages import BaseMessage from langchain_core.prompts import BasePromptTemplate from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain_core.retrievers import BaseRetriever -from langchain_core.runnables import Runnable, RunnableConfig -from langchain_core.runnables.base import RunnableMap +from langchain_core.runnables import ( + Runnable, + RunnableConfig, + RunnableLambda, + RunnableMap, +) +from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.vectorstores import VectorStore from langchain.callbacks.manager import ( @@ -297,25 +302,33 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): get_chat_history = self.get_chat_history or _get_chat_history def get_new_question(inputs: Dict[str, Any]): - return ( - self.question_generator - if inputs["chat_history"] - else inputs["question"] - ) + if inputs["chat_history"]: + return self.question_generator + else: + return inputs["question"] def get_answer(inputs: Dict[str, Any]): - return ( - self.response_if_no_docs_found - if self.response_if_no_docs_found is not None + if ( + self.response_if_no_docs_found is not None and len(inputs["input_documents"]) == 0 - else self.combine_docs_chain - ) + ): + return self.response_if_no_docs_found + else: + return self.combine_docs_chain | itemgetter( + self.combine_docs_chain.output_key + ) + + output_map = {self.output_key: RunnablePassthrough()} + if self.return_source_documents: + output_map["source_documents"] = context.getter("input_documents") + if self.return_generated_question: + output_map["generated_question"] = context.getter("new_question") return ( RunnableMap( question=itemgetter("question") | context.setter("question"), chat_history=itemgetter("chat_history") - | get_chat_history + | RunnableLambda(get_chat_history) | context.setter("chat_history"), ) | get_new_question @@ -330,11 +343,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): "new_question" if self.rephrase_question else "question" ), } - | { - self.output_key: get_answer, - "source_documents": context.getter("input_documents"), - "generated_question": context.getter("new_question"), - } + | get_answer + | output_map ) def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index 3834c79c41d..3fc036d7140 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -80,7 +80,7 @@ class LLMChain(RunnableChain): return ( self.prompt | self.llm.bind(**self.llm_kwargs) - | {self.output_key: self.output_parser} + | {self.output_key: self.prompt.output_parser or self.output_parser} ) @property diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py index d7a56603bbb..a84b871c120 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py @@ -27,8 +27,10 @@ async def atest_simple() -> None: verbose=True, ) got = await qa_chain.acall("What is the answer?") + memory.clear() assert got["chat_history"][1].content == fixed_resp assert got["answer"] == fixed_resp + assert got == await qa_chain.ainvoke("What is the answer?") @pytest.mark.asyncio @@ -37,7 +39,7 @@ async def atest_fixed_message_response_when_docs_found() -> None: answer = "I know the answer!" llm = FakeListLLM(responses=[answer]) retriever = SequentialRetriever( - sequential_responses=[[Document(page_content=answer)]] + sequential_responses=[[Document(page_content=answer)]], cycle=True ) memory = ConversationBufferMemory( k=1, output_key="answer", memory_key="chat_history", return_messages=True @@ -52,8 +54,10 @@ async def atest_fixed_message_response_when_docs_found() -> None: verbose=True, ) got = await qa_chain.acall("What is the answer?") + memory.clear() assert got["chat_history"][1].content == answer assert got["answer"] == answer + assert got == await qa_chain.ainvoke("What is the answer?") def test_fixed_message_response_when_no_docs_found() -> None: @@ -74,8 +78,10 @@ def test_fixed_message_response_when_no_docs_found() -> None: verbose=True, ) got = qa_chain("What is the answer?") + memory.clear() assert got["chat_history"][1].content == fixed_resp assert got["answer"] == fixed_resp + assert got == qa_chain.invoke("What is the answer?") def test_fixed_message_response_when_docs_found() -> None: @@ -83,7 +89,7 @@ def test_fixed_message_response_when_docs_found() -> None: answer = "I know the answer!" llm = FakeListLLM(responses=[answer]) retriever = SequentialRetriever( - sequential_responses=[[Document(page_content=answer)]] + sequential_responses=[[Document(page_content=answer)]], cycle=True ) memory = ConversationBufferMemory( k=1, output_key="answer", memory_key="chat_history", return_messages=True @@ -98,5 +104,7 @@ def test_fixed_message_response_when_docs_found() -> None: verbose=True, ) got = qa_chain("What is the answer?") + memory.clear() assert got["chat_history"][1].content == answer assert got["answer"] == answer + assert got == qa_chain.invoke("What is the answer?") diff --git a/libs/langchain/tests/unit_tests/chains/test_llm.py b/libs/langchain/tests/unit_tests/chains/test_llm.py index 0179cd135f2..63b8ce3eac0 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm.py @@ -51,6 +51,7 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None: """Test valid call of LLM chain.""" output = fake_llm_chain({"bar": "baz"}) assert output == {"bar": "baz", "text1": "foo"} + assert fake_llm_chain.invoke({"bar": "baz"}) == output # Test with stop words. output = fake_llm_chain({"bar": "baz", "stop": ["foo"]}) @@ -58,6 +59,18 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None: assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"} +async def test_valid_acall(fake_llm_chain: LLMChain) -> None: + """Test valid call of LLM chain.""" + output = await fake_llm_chain.acall({"bar": "baz"}) + assert output == {"bar": "baz", "text1": "foo"} + assert await fake_llm_chain.ainvoke({"bar": "baz"}) == output + + # Test with stop words. + output = await fake_llm_chain.acall({"bar": "baz", "stop": ["foo"]}) + # Response should be `bar` now. + assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"} + + def test_predict_method(fake_llm_chain: LLMChain) -> None: """Test predict method works.""" output = fake_llm_chain.predict(bar="baz") @@ -73,3 +86,4 @@ def test_predict_and_parse() -> None: chain = LLMChain(prompt=prompt, llm=llm) output = chain.predict_and_parse(foo="foo") assert output == ["foo", "bar"] + assert output == chain.invoke({"foo": "foo"})["text"] diff --git a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py index 45f7a6934d8..3ac54d512d2 100644 --- a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py @@ -8,16 +8,20 @@ class SequentialRetriever(BaseRetriever): sequential_responses: List[List[Document]] response_index: int = 0 + cycle: bool = False def _get_relevant_documents( # type: ignore[override] self, query: str, ) -> List[Document]: if self.response_index >= len(self.sequential_responses): - return [] - else: - self.response_index += 1 - return self.sequential_responses[self.response_index - 1] + if self.cycle: + self.response_index = 0 + else: + return [] + + self.response_index += 1 + return self.sequential_responses[self.response_index - 1] async def _aget_relevant_documents( # type: ignore[override] self,