Add tests

This commit is contained in:
Nuno Campos
2023-12-20 15:17:00 -08:00
parent 582461945f
commit 775bd6fb90
7 changed files with 111 additions and 37 deletions

View File

@@ -108,8 +108,6 @@ def _config_with_context(
raise ValueError(
f"Deadlock detected between context keys {key} and {dep}"
)
if len(getters) < 1:
raise ValueError(f"Expected at least one getter for context key {key}")
if len(setters) != 1:
raise ValueError(f"Expected exactly one setter for context key {key}")
setter_idx = setters[0][1]
@@ -118,7 +116,8 @@ def _config_with_context(
f"Context setter for key {key} must be defined after all getters."
)
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
if getters:
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
return patch_config(config, configurable=context_funcs)

View File

@@ -9,6 +9,7 @@ from pathlib import Path
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union
import yaml
from langchain_core.beta.runnables.context import Context
from langchain_core.load.dump import dumpd
from langchain_core.memory import BaseMemory
from langchain_core.outputs import RunInfo
@@ -20,7 +21,9 @@ from langchain_core.pydantic_v1 import (
validator,
)
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.configurable import ConfigurableFieldSpec
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
@@ -674,6 +677,22 @@ class RunnableChain(Chain, ABC):
def as_runnable(self) -> Runnable:
...
def as_runnable_wrapped(self, return_only_outputs: bool = False) -> Runnable:
context = Context.create_scope("runnable-chain")
def prep_outputs(all: Dict[str, Any]) -> Dict[str, Any]:
print("before outprep", all)
return self.prep_outputs(all["inputs"], all["outputs"], return_only_outputs)
return (
self.prep_inputs
| RunnableLambda(lambda i: print("after prep", i) or i.copy())
| context.setter("inputs")
| self.as_runnable()
| {"outputs": RunnablePassthrough(), "inputs": context.getter("inputs")}
| prep_outputs
)
@property
def InputType(self) -> Type[Dict[str, Any]]:
return self.as_runnable().InputType
@@ -692,39 +711,51 @@ class RunnableChain(Chain, ABC):
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.as_runnable().config_specs
return self.as_runnable_wrapped().config_specs
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
return self.as_runnable().invoke(input, config, **kwargs)
return self.as_runnable_wrapped(return_only_outputs).invoke(
input, config, **kwargs
)
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
return await self.as_runnable().ainvoke(input, config, **kwargs)
return await self.as_runnable_wrapped(return_only_outputs).ainvoke(
input, config, **kwargs
)
def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
yield from self.as_runnable().stream(input, config, **kwargs)
yield from self.as_runnable_wrapped(return_only_outputs).stream(
input, config, **kwargs
)
async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async for item in self.as_runnable().astream(input, config, **kwargs):
async for item in self.as_runnable_wrapped(return_only_outputs).astream(
input, config, **kwargs
):
yield item
def batch(
@@ -733,9 +764,10 @@ class RunnableChain(Chain, ABC):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
return_only_outputs: bool = False,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return self.as_runnable().batch(
return self.as_runnable_wrapped(return_only_outputs).batch(
inputs, config, **kwargs, return_exceptions=return_exceptions
)
@@ -745,9 +777,10 @@ class RunnableChain(Chain, ABC):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
return_only_outputs: bool = False,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return await self.as_runnable().abatch(
return await self.as_runnable_wrapped(return_only_outputs).abatch(
inputs, config, **kwargs, return_exceptions=return_exceptions
)
@@ -755,15 +788,21 @@ class RunnableChain(Chain, ABC):
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any | None,
) -> Iterator[Dict[str, Any]]:
yield from self.as_runnable().transform(input, config, **kwargs)
yield from self.as_runnable_wrapped(return_only_outputs).transform(
input, config, **kwargs
)
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any | None,
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in super().atransform(input, config, **kwargs):
async for chunk in self.as_runnable_wrapped(return_only_outputs).atransform(
input, config, **kwargs
):
yield chunk

View File

@@ -15,8 +15,13 @@ from langchain_core.messages import BaseMessage
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.base import RunnableMap
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableMap,
)
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from langchain.callbacks.manager import (
@@ -297,25 +302,33 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
get_chat_history = self.get_chat_history or _get_chat_history
def get_new_question(inputs: Dict[str, Any]):
return (
self.question_generator
if inputs["chat_history"]
else inputs["question"]
)
if inputs["chat_history"]:
return self.question_generator
else:
return inputs["question"]
def get_answer(inputs: Dict[str, Any]):
return (
self.response_if_no_docs_found
if self.response_if_no_docs_found is not None
if (
self.response_if_no_docs_found is not None
and len(inputs["input_documents"]) == 0
else self.combine_docs_chain
)
):
return self.response_if_no_docs_found
else:
return self.combine_docs_chain | itemgetter(
self.combine_docs_chain.output_key
)
output_map = {self.output_key: RunnablePassthrough()}
if self.return_source_documents:
output_map["source_documents"] = context.getter("input_documents")
if self.return_generated_question:
output_map["generated_question"] = context.getter("new_question")
return (
RunnableMap(
question=itemgetter("question") | context.setter("question"),
chat_history=itemgetter("chat_history")
| get_chat_history
| RunnableLambda(get_chat_history)
| context.setter("chat_history"),
)
| get_new_question
@@ -330,11 +343,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"new_question" if self.rephrase_question else "question"
),
}
| {
self.output_key: get_answer,
"source_documents": context.getter("input_documents"),
"generated_question": context.getter("new_question"),
}
| get_answer
| output_map
)
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:

View File

@@ -80,7 +80,7 @@ class LLMChain(RunnableChain):
return (
self.prompt
| self.llm.bind(**self.llm_kwargs)
| {self.output_key: self.output_parser}
| {self.output_key: self.prompt.output_parser or self.output_parser}
)
@property

View File

@@ -27,8 +27,10 @@ async def atest_simple() -> None:
verbose=True,
)
got = await qa_chain.acall("What is the answer?")
memory.clear()
assert got["chat_history"][1].content == fixed_resp
assert got["answer"] == fixed_resp
assert got == await qa_chain.ainvoke("What is the answer?")
@pytest.mark.asyncio
@@ -37,7 +39,7 @@ async def atest_fixed_message_response_when_docs_found() -> None:
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]]
sequential_responses=[[Document(page_content=answer)]], cycle=True
)
memory = ConversationBufferMemory(
k=1, output_key="answer", memory_key="chat_history", return_messages=True
@@ -52,8 +54,10 @@ async def atest_fixed_message_response_when_docs_found() -> None:
verbose=True,
)
got = await qa_chain.acall("What is the answer?")
memory.clear()
assert got["chat_history"][1].content == answer
assert got["answer"] == answer
assert got == await qa_chain.ainvoke("What is the answer?")
def test_fixed_message_response_when_no_docs_found() -> None:
@@ -74,8 +78,10 @@ def test_fixed_message_response_when_no_docs_found() -> None:
verbose=True,
)
got = qa_chain("What is the answer?")
memory.clear()
assert got["chat_history"][1].content == fixed_resp
assert got["answer"] == fixed_resp
assert got == qa_chain.invoke("What is the answer?")
def test_fixed_message_response_when_docs_found() -> None:
@@ -83,7 +89,7 @@ def test_fixed_message_response_when_docs_found() -> None:
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]]
sequential_responses=[[Document(page_content=answer)]], cycle=True
)
memory = ConversationBufferMemory(
k=1, output_key="answer", memory_key="chat_history", return_messages=True
@@ -98,5 +104,7 @@ def test_fixed_message_response_when_docs_found() -> None:
verbose=True,
)
got = qa_chain("What is the answer?")
memory.clear()
assert got["chat_history"][1].content == answer
assert got["answer"] == answer
assert got == qa_chain.invoke("What is the answer?")

View File

@@ -51,6 +51,7 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None:
"""Test valid call of LLM chain."""
output = fake_llm_chain({"bar": "baz"})
assert output == {"bar": "baz", "text1": "foo"}
assert fake_llm_chain.invoke({"bar": "baz"}) == output
# Test with stop words.
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
@@ -58,6 +59,18 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None:
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
async def test_valid_acall(fake_llm_chain: LLMChain) -> None:
"""Test valid call of LLM chain."""
output = await fake_llm_chain.acall({"bar": "baz"})
assert output == {"bar": "baz", "text1": "foo"}
assert await fake_llm_chain.ainvoke({"bar": "baz"}) == output
# Test with stop words.
output = await fake_llm_chain.acall({"bar": "baz", "stop": ["foo"]})
# Response should be `bar` now.
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
def test_predict_method(fake_llm_chain: LLMChain) -> None:
"""Test predict method works."""
output = fake_llm_chain.predict(bar="baz")
@@ -73,3 +86,4 @@ def test_predict_and_parse() -> None:
chain = LLMChain(prompt=prompt, llm=llm)
output = chain.predict_and_parse(foo="foo")
assert output == ["foo", "bar"]
assert output == chain.invoke({"foo": "foo"})["text"]

View File

@@ -8,16 +8,20 @@ class SequentialRetriever(BaseRetriever):
sequential_responses: List[List[Document]]
response_index: int = 0
cycle: bool = False
def _get_relevant_documents( # type: ignore[override]
self,
query: str,
) -> List[Document]:
if self.response_index >= len(self.sequential_responses):
return []
else:
self.response_index += 1
return self.sequential_responses[self.response_index - 1]
if self.cycle:
self.response_index = 0
else:
return []
self.response_index += 1
return self.sequential_responses[self.response_index - 1]
async def _aget_relevant_documents( # type: ignore[override]
self,