mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-18 08:03:52 +00:00
Compare commits
9 Commits
langchain-
...
bagatur/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05838ebcce | ||
|
|
243a74735b | ||
|
|
8096fa2027 | ||
|
|
7f5faf1076 | ||
|
|
062ad47c65 | ||
|
|
e343aec97b | ||
|
|
775bd6fb90 | ||
|
|
582461945f | ||
|
|
ad1ab2b566 |
@@ -108,8 +108,6 @@ def _config_with_context(
|
||||
raise ValueError(
|
||||
f"Deadlock detected between context keys {key} and {dep}"
|
||||
)
|
||||
if len(getters) < 1:
|
||||
raise ValueError(f"Expected at least one getter for context key {key}")
|
||||
if len(setters) != 1:
|
||||
raise ValueError(f"Expected exactly one setter for context key {key}")
|
||||
setter_idx = setters[0][1]
|
||||
@@ -118,7 +116,8 @@ def _config_with_context(
|
||||
f"Context setter for key {key} must be defined after all getters."
|
||||
)
|
||||
|
||||
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
|
||||
if getters:
|
||||
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
|
||||
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
|
||||
|
||||
return patch_config(config, configurable=context_funcs)
|
||||
|
||||
@@ -6,9 +6,10 @@ import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union
|
||||
|
||||
import yaml
|
||||
from langchain_core.beta.runnables.context import Context
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.outputs import RunInfo
|
||||
@@ -19,7 +20,9 @@ from langchain_core.pydantic_v1 import (
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.configurable import ConfigurableFieldSpec
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
@@ -666,3 +669,141 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
|
||||
class RunnableChain(Chain):
|
||||
@abstractmethod
|
||||
def as_runnable(self) -> Runnable:
|
||||
...
|
||||
|
||||
def as_runnable_wrapped(self, return_only_outputs: bool = False) -> Runnable:
|
||||
context = Context.create_scope("runnable-chain")
|
||||
|
||||
def prep_outputs(all: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# print("before outprep", all)
|
||||
return self.prep_outputs(all["inputs"], all["outputs"], return_only_outputs)
|
||||
|
||||
return (
|
||||
self.prep_inputs
|
||||
# | RunnableLambda(lambda i: print("after prep", i) or i.copy())
|
||||
| context.setter("inputs")
|
||||
| self.as_runnable()
|
||||
| {"outputs": RunnablePassthrough(), "inputs": context.getter("inputs")}
|
||||
| prep_outputs
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Dict[str, Any]]:
|
||||
return self.as_runnable().InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Dict[str, Any]]:
|
||||
return self.as_runnable().OutputType
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.as_runnable().get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.as_runnable().get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.as_runnable_wrapped().config_specs
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return self.as_runnable_wrapped(return_only_outputs).invoke(
|
||||
input, config, **kwargs
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.as_runnable_wrapped(return_only_outputs).ainvoke(
|
||||
input, config, **kwargs
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self.as_runnable_wrapped(return_only_outputs).stream(
|
||||
input, config, **kwargs
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async for item in self.as_runnable_wrapped(return_only_outputs).astream(
|
||||
input, config, **kwargs
|
||||
):
|
||||
yield item
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Dict[str, Any]],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.as_runnable_wrapped(return_only_outputs).batch(
|
||||
inputs, config, **kwargs, return_exceptions=return_exceptions
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Dict[str, Any]],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return await self.as_runnable_wrapped(return_only_outputs).abatch(
|
||||
inputs, config, **kwargs, return_exceptions=return_exceptions
|
||||
)
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self.as_runnable_wrapped(return_only_outputs).transform(
|
||||
input, config, **kwargs
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
return_only_outputs: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async for chunk in self.as_runnable_wrapped(return_only_outputs).atransform(
|
||||
input, config, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@@ -1,21 +1,48 @@
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.base import Chain, RunnableChain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
def _format_docs_chain(
|
||||
input_key: str,
|
||||
document_prompt: BasePromptTemplate,
|
||||
document_variable_name: str,
|
||||
document_separator: str,
|
||||
) -> Runnable:
|
||||
def format_document_(doc: Document) -> str:
|
||||
return format_document(doc, document_prompt)
|
||||
|
||||
format_docs = (
|
||||
itemgetter(input_key)
|
||||
| RunnableLambda(format_document_).map()
|
||||
| document_separator.join
|
||||
)
|
||||
|
||||
def pop_raw_docs(input_: dict) -> dict:
|
||||
return {k: v for k, v in input_.items() if k != input_key}
|
||||
|
||||
return (
|
||||
RunnablePassthrough.assign(**{document_variable_name: format_docs})
|
||||
| pop_raw_docs
|
||||
)
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(RunnableChain, ABC):
|
||||
"""Base interface for chains combining documents.
|
||||
|
||||
Subclasses of this chain deal with combining documents in a variety of
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
@@ -133,6 +134,44 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
def format_inputs(inputs: dict) -> dict:
|
||||
doc = inputs[self.input_key][len(inputs.get("intermediate_steps", [])) - 1]
|
||||
inputs[self.document_variable_name] = format_document(
|
||||
doc, self.document_prompt
|
||||
)
|
||||
return {
|
||||
k: v
|
||||
for k, v in inputs.items()
|
||||
if k not in ("intermediate_steps", self.input_key)
|
||||
}
|
||||
|
||||
first_chain = format_inputs | self.initial_llm_chain.as_runnable()
|
||||
refine_chain = format_inputs | self.refine_llm_chain.as_runnable()
|
||||
|
||||
def loop(inputs: dict) -> Union[Runnable, dict]:
|
||||
if len(inputs.get("intermediate_steps", [])) < len(inputs[self.input_key]):
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
intermediate_steps=lambda x: x.get("intermediate_steps", [])
|
||||
+ [x[self.initial_response_name]]
|
||||
)
|
||||
| RunnablePassthrough.assign(
|
||||
**{self.initial_response_name: refine_chain}
|
||||
)
|
||||
| loop
|
||||
)
|
||||
else:
|
||||
res = {self.output_key: inputs["intermediate_steps"][-1]}
|
||||
if self.return_intermediate_steps:
|
||||
res["intermediate_steps"] = inputs["intermediate_steps"]
|
||||
return res
|
||||
|
||||
return (
|
||||
RunnablePassthrough.assign(**{self.initial_response_name: first_chain})
|
||||
| loop
|
||||
)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableParallel,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
_format_docs_chain,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
@@ -100,6 +104,18 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
chain = (
|
||||
_format_docs_chain(
|
||||
self.input_key,
|
||||
self.document_prompt,
|
||||
self.document_variable_name,
|
||||
self.document_separator,
|
||||
)
|
||||
| self.llm_chain.as_runnable()
|
||||
)
|
||||
return RunnableParallel({self.output_key: chain})
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
extra_keys = [
|
||||
|
||||
@@ -4,16 +4,24 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain_core.beta.runnables.context import Context
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
)
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@@ -21,7 +29,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.base import RunnableChain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||
@@ -63,7 +71,7 @@ class InputType(BaseModel):
|
||||
"""The chat history to use for retrieval."""
|
||||
|
||||
|
||||
class BaseConversationalRetrievalChain(Chain):
|
||||
class BaseConversationalRetrievalChain(RunnableChain):
|
||||
"""Chain for chatting with an index."""
|
||||
|
||||
combine_docs_chain: BaseCombineDocumentsChain
|
||||
@@ -289,6 +297,56 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
"""If set, enforces that the documents returned are less than this limit.
|
||||
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
context = Context.create_scope("conversational_retrieval")
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
|
||||
def get_new_question(inputs: Dict[str, Any]):
|
||||
if inputs["chat_history"]:
|
||||
return self.question_generator
|
||||
else:
|
||||
return inputs["question"]
|
||||
|
||||
def get_answer(inputs: Dict[str, Any]):
|
||||
if (
|
||||
self.response_if_no_docs_found is not None
|
||||
and len(inputs["input_documents"]) == 0
|
||||
):
|
||||
return self.response_if_no_docs_found
|
||||
else:
|
||||
return self.combine_docs_chain | itemgetter(
|
||||
self.combine_docs_chain.output_key
|
||||
)
|
||||
|
||||
output_map = {self.output_key: RunnablePassthrough()}
|
||||
if self.return_source_documents:
|
||||
output_map["source_documents"] = context.getter("input_documents")
|
||||
if self.return_generated_question:
|
||||
output_map["generated_question"] = context.getter("new_question")
|
||||
|
||||
return (
|
||||
RunnableMap(
|
||||
question=itemgetter("question") | context.setter("question"),
|
||||
chat_history=itemgetter("chat_history")
|
||||
| RunnableLambda(get_chat_history)
|
||||
| context.setter("chat_history"),
|
||||
)
|
||||
| get_new_question
|
||||
| context.setter("new_question")
|
||||
| self.retriever
|
||||
| self._reduce_tokens_below_limit
|
||||
| context.setter("input_documents")
|
||||
| {
|
||||
"input_documents": context.getter("input_documents"),
|
||||
"chat_history": context.getter("chat_history"),
|
||||
"question": context.getter(
|
||||
"new_question" if self.rephrase_question else "question"
|
||||
),
|
||||
}
|
||||
| get_answer
|
||||
| output_map
|
||||
)
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
num_docs = len(docs)
|
||||
|
||||
@@ -400,6 +458,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
top_k_docs_for_context: int = 4
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
return self
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "chat-vector-db"
|
||||
|
||||
@@ -31,10 +31,10 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.base import RunnableChain
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
class LLMChain(RunnableChain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
Example:
|
||||
@@ -76,6 +76,13 @@ class LLMChain(Chain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
return (
|
||||
self.prompt
|
||||
| self.llm.bind(**self.llm_kwargs)
|
||||
| {self.output_key: self.prompt.output_parser or self.output_parser}
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
@@ -27,8 +27,10 @@ async def atest_simple() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
got = await qa_chain.acall("What is the answer?")
|
||||
memory.clear()
|
||||
assert got["chat_history"][1].content == fixed_resp
|
||||
assert got["answer"] == fixed_resp
|
||||
assert got == await qa_chain.ainvoke("What is the answer?")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -37,7 +39,7 @@ async def atest_fixed_message_response_when_docs_found() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = SequentialRetriever(
|
||||
sequential_responses=[[Document(page_content=answer)]]
|
||||
sequential_responses=[[Document(page_content=answer)]], cycle=True
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
k=1, output_key="answer", memory_key="chat_history", return_messages=True
|
||||
@@ -52,8 +54,10 @@ async def atest_fixed_message_response_when_docs_found() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
got = await qa_chain.acall("What is the answer?")
|
||||
memory.clear()
|
||||
assert got["chat_history"][1].content == answer
|
||||
assert got["answer"] == answer
|
||||
assert got == await qa_chain.ainvoke("What is the answer?")
|
||||
|
||||
|
||||
def test_fixed_message_response_when_no_docs_found() -> None:
|
||||
@@ -74,8 +78,10 @@ def test_fixed_message_response_when_no_docs_found() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
got = qa_chain("What is the answer?")
|
||||
memory.clear()
|
||||
assert got["chat_history"][1].content == fixed_resp
|
||||
assert got["answer"] == fixed_resp
|
||||
assert got == qa_chain.invoke("What is the answer?")
|
||||
|
||||
|
||||
def test_fixed_message_response_when_docs_found() -> None:
|
||||
@@ -83,7 +89,7 @@ def test_fixed_message_response_when_docs_found() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = SequentialRetriever(
|
||||
sequential_responses=[[Document(page_content=answer)]]
|
||||
sequential_responses=[[Document(page_content=answer)]], cycle=True
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
k=1, output_key="answer", memory_key="chat_history", return_messages=True
|
||||
@@ -98,5 +104,7 @@ def test_fixed_message_response_when_docs_found() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
got = qa_chain("What is the answer?")
|
||||
memory.clear()
|
||||
assert got["chat_history"][1].content == answer
|
||||
assert got["answer"] == answer
|
||||
assert got == qa_chain.invoke("What is the answer?")
|
||||
|
||||
@@ -51,6 +51,7 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test valid call of LLM chain."""
|
||||
output = fake_llm_chain({"bar": "baz"})
|
||||
assert output == {"bar": "baz", "text1": "foo"}
|
||||
assert fake_llm_chain.invoke({"bar": "baz"}) == output
|
||||
|
||||
# Test with stop words.
|
||||
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
|
||||
@@ -58,6 +59,18 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None:
|
||||
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
|
||||
|
||||
|
||||
async def test_valid_acall(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test valid call of LLM chain."""
|
||||
output = await fake_llm_chain.acall({"bar": "baz"})
|
||||
assert output == {"bar": "baz", "text1": "foo"}
|
||||
assert await fake_llm_chain.ainvoke({"bar": "baz"}) == output
|
||||
|
||||
# Test with stop words.
|
||||
output = await fake_llm_chain.acall({"bar": "baz", "stop": ["foo"]})
|
||||
# Response should be `bar` now.
|
||||
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
|
||||
|
||||
|
||||
def test_predict_method(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test predict method works."""
|
||||
output = fake_llm_chain.predict(bar="baz")
|
||||
@@ -73,3 +86,4 @@ def test_predict_and_parse() -> None:
|
||||
chain = LLMChain(prompt=prompt, llm=llm)
|
||||
output = chain.predict_and_parse(foo="foo")
|
||||
assert output == ["foo", "bar"]
|
||||
assert output == chain.invoke({"foo": "foo"})["text"]
|
||||
|
||||
@@ -8,16 +8,20 @@ class SequentialRetriever(BaseRetriever):
|
||||
|
||||
sequential_responses: List[List[Document]]
|
||||
response_index: int = 0
|
||||
cycle: bool = False
|
||||
|
||||
def _get_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
if self.response_index >= len(self.sequential_responses):
|
||||
return []
|
||||
else:
|
||||
self.response_index += 1
|
||||
return self.sequential_responses[self.response_index - 1]
|
||||
if self.cycle:
|
||||
self.response_index = 0
|
||||
else:
|
||||
return []
|
||||
|
||||
self.response_index += 1
|
||||
return self.sequential_responses[self.response_index - 1]
|
||||
|
||||
async def _aget_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user