mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Add tests
This commit is contained in:
@@ -108,8 +108,6 @@ def _config_with_context(
|
||||
raise ValueError(
|
||||
f"Deadlock detected between context keys {key} and {dep}"
|
||||
)
|
||||
if len(getters) < 1:
|
||||
raise ValueError(f"Expected at least one getter for context key {key}")
|
||||
if len(setters) != 1:
|
||||
raise ValueError(f"Expected exactly one setter for context key {key}")
|
||||
setter_idx = setters[0][1]
|
||||
@@ -118,7 +116,8 @@ def _config_with_context(
|
||||
f"Context setter for key {key} must be defined after all getters."
|
||||
)
|
||||
|
||||
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
|
||||
if getters:
|
||||
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
|
||||
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
|
||||
|
||||
return patch_config(config, configurable=context_funcs)
|
||||
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union
|
||||
|
||||
import yaml
|
||||
from langchain_core.beta.runnables.context import Context
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.outputs import RunInfo
|
||||
@@ -20,7 +21,9 @@ from langchain_core.pydantic_v1 import (
|
||||
validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.configurable import ConfigurableFieldSpec
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
@@ -674,6 +677,22 @@ class RunnableChain(Chain, ABC):
|
||||
def as_runnable(self) -> Runnable:
|
||||
...
|
||||
|
||||
def as_runnable_wrapped(self, return_only_outputs: bool = False) -> Runnable:
|
||||
context = Context.create_scope("runnable-chain")
|
||||
|
||||
def prep_outputs(all: Dict[str, Any]) -> Dict[str, Any]:
|
||||
print("before outprep", all)
|
||||
return self.prep_outputs(all["inputs"], all["outputs"], return_only_outputs)
|
||||
|
||||
return (
|
||||
self.prep_inputs
|
||||
| RunnableLambda(lambda i: print("after prep", i) or i.copy())
|
||||
| context.setter("inputs")
|
||||
| self.as_runnable()
|
||||
| {"outputs": RunnablePassthrough(), "inputs": context.getter("inputs")}
|
||||
| prep_outputs
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Dict[str, Any]]:
|
||||
return self.as_runnable().InputType
|
||||
@@ -692,39 +711,51 @@ class RunnableChain(Chain, ABC):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.as_runnable().config_specs
|
||||
return self.as_runnable_wrapped().config_specs
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return self.as_runnable().invoke(input, config, **kwargs)
|
||||
return self.as_runnable_wrapped(return_only_outputs).invoke(
|
||||
input, config, **kwargs
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.as_runnable().ainvoke(input, config, **kwargs)
|
||||
return await self.as_runnable_wrapped(return_only_outputs).ainvoke(
|
||||
input, config, **kwargs
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self.as_runnable().stream(input, config, **kwargs)
|
||||
yield from self.as_runnable_wrapped(return_only_outputs).stream(
|
||||
input, config, **kwargs
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async for item in self.as_runnable().astream(input, config, **kwargs):
|
||||
async for item in self.as_runnable_wrapped(return_only_outputs).astream(
|
||||
input, config, **kwargs
|
||||
):
|
||||
yield item
|
||||
|
||||
def batch(
|
||||
@@ -733,9 +764,10 @@ class RunnableChain(Chain, ABC):
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.as_runnable().batch(
|
||||
return self.as_runnable_wrapped(return_only_outputs).batch(
|
||||
inputs, config, **kwargs, return_exceptions=return_exceptions
|
||||
)
|
||||
|
||||
@@ -745,9 +777,10 @@ class RunnableChain(Chain, ABC):
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return await self.as_runnable().abatch(
|
||||
return await self.as_runnable_wrapped(return_only_outputs).abatch(
|
||||
inputs, config, **kwargs, return_exceptions=return_exceptions
|
||||
)
|
||||
|
||||
@@ -755,15 +788,21 @@ class RunnableChain(Chain, ABC):
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self.as_runnable().transform(input, config, **kwargs)
|
||||
yield from self.as_runnable_wrapped(return_only_outputs).transform(
|
||||
input, config, **kwargs
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async for chunk in super().atransform(input, config, **kwargs):
|
||||
async for chunk in self.as_runnable_wrapped(return_only_outputs).atransform(
|
||||
input, config, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@@ -15,8 +15,13 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.base import RunnableMap
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
)
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@@ -297,25 +302,33 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
|
||||
def get_new_question(inputs: Dict[str, Any]):
|
||||
return (
|
||||
self.question_generator
|
||||
if inputs["chat_history"]
|
||||
else inputs["question"]
|
||||
)
|
||||
if inputs["chat_history"]:
|
||||
return self.question_generator
|
||||
else:
|
||||
return inputs["question"]
|
||||
|
||||
def get_answer(inputs: Dict[str, Any]):
|
||||
return (
|
||||
self.response_if_no_docs_found
|
||||
if self.response_if_no_docs_found is not None
|
||||
if (
|
||||
self.response_if_no_docs_found is not None
|
||||
and len(inputs["input_documents"]) == 0
|
||||
else self.combine_docs_chain
|
||||
)
|
||||
):
|
||||
return self.response_if_no_docs_found
|
||||
else:
|
||||
return self.combine_docs_chain | itemgetter(
|
||||
self.combine_docs_chain.output_key
|
||||
)
|
||||
|
||||
output_map = {self.output_key: RunnablePassthrough()}
|
||||
if self.return_source_documents:
|
||||
output_map["source_documents"] = context.getter("input_documents")
|
||||
if self.return_generated_question:
|
||||
output_map["generated_question"] = context.getter("new_question")
|
||||
|
||||
return (
|
||||
RunnableMap(
|
||||
question=itemgetter("question") | context.setter("question"),
|
||||
chat_history=itemgetter("chat_history")
|
||||
| get_chat_history
|
||||
| RunnableLambda(get_chat_history)
|
||||
| context.setter("chat_history"),
|
||||
)
|
||||
| get_new_question
|
||||
@@ -330,11 +343,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
"new_question" if self.rephrase_question else "question"
|
||||
),
|
||||
}
|
||||
| {
|
||||
self.output_key: get_answer,
|
||||
"source_documents": context.getter("input_documents"),
|
||||
"generated_question": context.getter("new_question"),
|
||||
}
|
||||
| get_answer
|
||||
| output_map
|
||||
)
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
|
||||
@@ -80,7 +80,7 @@ class LLMChain(RunnableChain):
|
||||
return (
|
||||
self.prompt
|
||||
| self.llm.bind(**self.llm_kwargs)
|
||||
| {self.output_key: self.output_parser}
|
||||
| {self.output_key: self.prompt.output_parser or self.output_parser}
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -27,8 +27,10 @@ async def atest_simple() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
got = await qa_chain.acall("What is the answer?")
|
||||
memory.clear()
|
||||
assert got["chat_history"][1].content == fixed_resp
|
||||
assert got["answer"] == fixed_resp
|
||||
assert got == await qa_chain.ainvoke("What is the answer?")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -37,7 +39,7 @@ async def atest_fixed_message_response_when_docs_found() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = SequentialRetriever(
|
||||
sequential_responses=[[Document(page_content=answer)]]
|
||||
sequential_responses=[[Document(page_content=answer)]], cycle=True
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
k=1, output_key="answer", memory_key="chat_history", return_messages=True
|
||||
@@ -52,8 +54,10 @@ async def atest_fixed_message_response_when_docs_found() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
got = await qa_chain.acall("What is the answer?")
|
||||
memory.clear()
|
||||
assert got["chat_history"][1].content == answer
|
||||
assert got["answer"] == answer
|
||||
assert got == await qa_chain.ainvoke("What is the answer?")
|
||||
|
||||
|
||||
def test_fixed_message_response_when_no_docs_found() -> None:
|
||||
@@ -74,8 +78,10 @@ def test_fixed_message_response_when_no_docs_found() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
got = qa_chain("What is the answer?")
|
||||
memory.clear()
|
||||
assert got["chat_history"][1].content == fixed_resp
|
||||
assert got["answer"] == fixed_resp
|
||||
assert got == qa_chain.invoke("What is the answer?")
|
||||
|
||||
|
||||
def test_fixed_message_response_when_docs_found() -> None:
|
||||
@@ -83,7 +89,7 @@ def test_fixed_message_response_when_docs_found() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = SequentialRetriever(
|
||||
sequential_responses=[[Document(page_content=answer)]]
|
||||
sequential_responses=[[Document(page_content=answer)]], cycle=True
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
k=1, output_key="answer", memory_key="chat_history", return_messages=True
|
||||
@@ -98,5 +104,7 @@ def test_fixed_message_response_when_docs_found() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
got = qa_chain("What is the answer?")
|
||||
memory.clear()
|
||||
assert got["chat_history"][1].content == answer
|
||||
assert got["answer"] == answer
|
||||
assert got == qa_chain.invoke("What is the answer?")
|
||||
|
||||
@@ -51,6 +51,7 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test valid call of LLM chain."""
|
||||
output = fake_llm_chain({"bar": "baz"})
|
||||
assert output == {"bar": "baz", "text1": "foo"}
|
||||
assert fake_llm_chain.invoke({"bar": "baz"}) == output
|
||||
|
||||
# Test with stop words.
|
||||
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
|
||||
@@ -58,6 +59,18 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None:
|
||||
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
|
||||
|
||||
|
||||
async def test_valid_acall(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test valid call of LLM chain."""
|
||||
output = await fake_llm_chain.acall({"bar": "baz"})
|
||||
assert output == {"bar": "baz", "text1": "foo"}
|
||||
assert await fake_llm_chain.ainvoke({"bar": "baz"}) == output
|
||||
|
||||
# Test with stop words.
|
||||
output = await fake_llm_chain.acall({"bar": "baz", "stop": ["foo"]})
|
||||
# Response should be `bar` now.
|
||||
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
|
||||
|
||||
|
||||
def test_predict_method(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test predict method works."""
|
||||
output = fake_llm_chain.predict(bar="baz")
|
||||
@@ -73,3 +86,4 @@ def test_predict_and_parse() -> None:
|
||||
chain = LLMChain(prompt=prompt, llm=llm)
|
||||
output = chain.predict_and_parse(foo="foo")
|
||||
assert output == ["foo", "bar"]
|
||||
assert output == chain.invoke({"foo": "foo"})["text"]
|
||||
|
||||
@@ -8,16 +8,20 @@ class SequentialRetriever(BaseRetriever):
|
||||
|
||||
sequential_responses: List[List[Document]]
|
||||
response_index: int = 0
|
||||
cycle: bool = False
|
||||
|
||||
def _get_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
if self.response_index >= len(self.sequential_responses):
|
||||
return []
|
||||
else:
|
||||
self.response_index += 1
|
||||
return self.sequential_responses[self.response_index - 1]
|
||||
if self.cycle:
|
||||
self.response_index = 0
|
||||
else:
|
||||
return []
|
||||
|
||||
self.response_index += 1
|
||||
return self.sequential_responses[self.response_index - 1]
|
||||
|
||||
async def _aget_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user