docstrings chains (#7892)

Added/updated docstrings.

@baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-18 18:25:27 -07:00 committed by GitHub
parent f2ef3ff54a
commit 4a810756f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 170 additions and 40 deletions

View File

@ -11,6 +11,7 @@ from langchain.schema import (
def import_context() -> Any: def import_context() -> Any:
"""Import the `getcontext` package."""
try: try:
import getcontext # noqa: F401 import getcontext # noqa: F401
from getcontext.generated.models import ( from getcontext.generated.models import (
@ -30,7 +31,9 @@ def import_context() -> Any:
class ContextCallbackHandler(BaseCallbackHandler): class ContextCallbackHandler(BaseCallbackHandler):
"""Callback Handler that records transcripts to Context (https://getcontext.ai). """Callback Handler that records transcripts to the Context service.
(https://getcontext.ai).
Keyword Args: Keyword Args:
token (optional): The token with which to authenticate requests to Context. token (optional): The token with which to authenticate requests to Context.

View File

@ -1,4 +1,17 @@
"""Chains are easily reusable components which can be linked together.""" """Chains are easily reusable components which can be linked together.
Chains should be used to encode a sequence of calls to components like
models, document retrievers, other chains, etc., and provide a simple interface
to this sequence.
The Chain interface makes it easy to create apps that are:
- Stateful: add Memory to any Chain to give it state,
- Observable: pass Callbacks to a Chain to execute additional functionality,
like logging, outside the main sequence of component calls,
- Composable: the Chain API is flexible enough that it is easy to combine
Chains with other components, including other Chains.
"""
from langchain.chains.api.base import APIChain from langchain.chains.api.base import APIChain
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
from langchain.chains.combine_documents.base import AnalyzeDocumentChain from langchain.chains.combine_documents.base import AnalyzeDocumentChain

View File

@ -72,13 +72,13 @@ class Chain(Serializable, ABC):
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs """Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to `langchain.verbose` value.""" will be printed to the console. Defaults to `langchain.verbose` value."""
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None """Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain, These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case. You can use these to eg identify a specific instance of a chain with its use case.
""" """
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None """Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain, This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case. You can use these to eg identify a specific instance of a chain with its use case.
@ -118,12 +118,12 @@ class Chain(Serializable, ABC):
@property @property
@abstractmethod @abstractmethod
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Return the keys expected to be in the chain input.""" """Keys expected to be in the chain input."""
@property @property
@abstractmethod @abstractmethod
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Return the keys expected to be in the chain output.""" """Keys expected to be in the chain output."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None: def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present.""" """Check that all inputs are present."""
@ -391,7 +391,7 @@ class Chain(Serializable, ABC):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Convenience method for executing chain when there's a single string output. """Execute chain when there's a single string output.
The main difference between this method and `Chain.__call__` is that this method The main difference between this method and `Chain.__call__` is that this method
can only be used for chains that return a single string output. If a Chain can only be used for chains that return a single string output. If a Chain
@ -465,7 +465,7 @@ class Chain(Serializable, ABC):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Convenience method for executing chain when there's a single string output. """Execute chain when there's a single string output.
The main difference between this method and `Chain.__call__` is that this method The main difference between this method and `Chain.__call__` is that this method
can only be used for chains that return a single string output. If a Chain can only be used for chains that return a single string output. If a Chain
@ -532,7 +532,7 @@ class Chain(Serializable, ABC):
) )
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain. """Dictionary representation of chain.
Expects `Chain._chain_type` property to be implemented and for memory to be Expects `Chain._chain_type` property to be implemented and for memory to be
null. null.

View File

@ -22,7 +22,7 @@ class AsyncCombineDocsProtocol(Protocol):
"""Interface for the combine_docs method.""" """Interface for the combine_docs method."""
async def __call__(self, docs: List[Document], **kwargs: Any) -> str: async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Async nterface for the combine_docs method.""" """Async interface for the combine_docs method."""
def _split_list_of_docs( def _split_list_of_docs(
@ -78,7 +78,7 @@ async def _acollapse_docs(
class ReduceDocumentsChain(BaseCombineDocumentsChain): class ReduceDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by recursively reducing them. """Combine documents by recursively reducing them.
This involves This involves
@ -206,7 +206,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Combine multiple documents recursively. """Async combine multiple documents recursively.
Args: Args:
docs: List of documents to combine, assumed that each one is less than docs: List of documents to combine, assumed that each one is less than

View File

@ -1,4 +1,4 @@
"""Combining documents by doing a first pass and then refining on more documents.""" """Combine documents by doing a first pass and then refining on more documents."""
from __future__ import annotations from __future__ import annotations
@ -161,7 +161,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain. """Async combine by mapping a first chain over all, then stuffing
into a final chain.
Args: Args:
docs: List of documents to combine docs: List of documents to combine

View File

@ -167,7 +167,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM. """Async stuff all documents into one prompt and pass to LLM.
Args: Args:
docs: List of documents to join together into one variable docs: List of documents to join together into one variable

View File

@ -80,12 +80,12 @@ class ConstitutionalChain(Chain):
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Defines the input keys.""" """Input keys."""
return self.chain.input_keys return self.chain.input_keys
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Defines the output keys.""" """Output keys."""
if self.return_intermediate_steps: if self.return_intermediate_steps:
return ["output", "critiques_and_revisions", "initial_output"] return ["output", "critiques_and_revisions", "initial_output"]
return ["output"] return ["output"]

View File

@ -23,6 +23,8 @@ from langchain.schema.language_model import BaseLanguageModel
class _ResponseChain(LLMChain): class _ResponseChain(LLMChain):
"""Base class for chains that generate responses."""
prompt: BasePromptTemplate = PROMPT prompt: BasePromptTemplate = PROMPT
@property @property
@ -46,6 +48,8 @@ class _ResponseChain(LLMChain):
class _OpenAIResponseChain(_ResponseChain): class _OpenAIResponseChain(_ResponseChain):
"""Chain that generates responses from user input and context."""
llm: OpenAI = Field( llm: OpenAI = Field(
default_factory=lambda: OpenAI( default_factory=lambda: OpenAI(
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0 max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
@ -66,10 +70,14 @@ class _OpenAIResponseChain(_ResponseChain):
class QuestionGeneratorChain(LLMChain): class QuestionGeneratorChain(LLMChain):
"""Chain that generates questions from uncertain spans."""
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
"""Prompt template for the chain."""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Input keys for the chain."""
return ["user_input", "context", "response"] return ["user_input", "context", "response"]
@ -95,22 +103,36 @@ def _low_confidence_spans(
class FlareChain(Chain): class FlareChain(Chain):
"""Chain that combines a retriever, a question generator,
and a response generator."""
question_generator_chain: QuestionGeneratorChain question_generator_chain: QuestionGeneratorChain
"""Chain that generates questions from uncertain spans."""
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain) response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain)
"""Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished."""
retriever: BaseRetriever retriever: BaseRetriever
"""Retriever that retrieves relevant documents from a user input."""
min_prob: float = 0.2 min_prob: float = 0.2
"""Minimum probability for a token to be considered low confidence."""
min_token_gap: int = 5 min_token_gap: int = 5
"""Minimum number of tokens between two low confidence spans."""
num_pad_tokens: int = 2 num_pad_tokens: int = 2
"""Number of tokens to pad around a low confidence span."""
max_iter: int = 10 max_iter: int = 10
"""Maximum number of iterations."""
start_with_retrieval: bool = True start_with_retrieval: bool = True
"""Whether to start with retrieval."""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Input keys for the chain."""
return ["user_input"] return ["user_input"]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Output keys for the chain."""
return ["response"] return ["response"]
def _do_generation( def _do_generation(
@ -213,6 +235,16 @@ class FlareChain(Chain):
def from_llm( def from_llm(
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
) -> FlareChain: ) -> FlareChain:
"""Creates a FlareChain from a language model.
Args:
llm: Language model to use.
max_generation_len: Maximum length of the generated response.
**kwargs: Additional arguments to pass to the constructor.
Returns:
FlareChain class with the given language model.
"""
question_gen_chain = QuestionGeneratorChain(llm=llm) question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI( response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0 max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0

View File

@ -5,7 +5,10 @@ from langchain.schema import BaseOutputParser
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]): class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
"""Output parser that checks if the output is finished."""
finished_value: str = "FINISHED" finished_value: str = "FINISHED"
"""Value that indicates the output is finished."""
def parse(self, text: str) -> Tuple[str, bool]: def parse(self, text: str) -> Tuple[str, bool]:
cleaned = text.strip() cleaned = text.strip()

View File

@ -25,7 +25,7 @@ class GraphQAChain(Chain):
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Return the input keys. """Input keys.
:meta private: :meta private:
""" """
@ -33,7 +33,7 @@ class GraphQAChain(Chain):
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Return the output keys. """Output keys.
:meta private: :meta private:
""" """

View File

@ -18,8 +18,8 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str: def extract_cypher(text: str) -> str:
""" """Extract Cypher code from a text.
Extract Cypher code from a text.
Args: Args:
text: Text to extract Cypher code from. text: Text to extract Cypher code from.

View File

@ -28,7 +28,7 @@ class HugeGraphQAChain(Chain):
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Return the input keys. """Input keys.
:meta private: :meta private:
""" """
@ -36,7 +36,7 @@ class HugeGraphQAChain(Chain):
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Return the output keys. """Output keys.
:meta private: :meta private:
""" """

View File

@ -1,4 +1,4 @@
"""Chain that interprets a prompt and executes bash code to perform bash operations.""" """Chain that interprets a prompt and executes bash operations."""
from __future__ import annotations from __future__ import annotations
import logging import logging
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
class LLMBashChain(Chain): class LLMBashChain(Chain):
"""Chain that interprets a prompt and executes bash code to perform bash operations. """Chain that interprets a prompt and executes bash operations.
Example: Example:
.. code-block:: python .. code-block:: python

View File

@ -16,7 +16,7 @@ DEFAULT_HEADERS = {
class LLMRequestsChain(Chain): class LLMRequestsChain(Chain):
"""Chain that hits a URL and then uses an LLM to parse results.""" """Chain that requests a URL and then uses an LLM to parse results."""
llm_chain: LLMChain llm_chain: LLMChain
requests_wrapper: TextRequestsWrapper = Field( requests_wrapper: TextRequestsWrapper = Field(

View File

@ -1,4 +1,4 @@
"""Chain that interprets a prompt and executes python code to do math.""" """Chain that interprets a prompt and executes python code to do symbolic math."""
from __future__ import annotations from __future__ import annotations
import re import re
@ -18,7 +18,7 @@ from langchain.prompts.base import BasePromptTemplate
class LLMSymbolicMathChain(Chain): class LLMSymbolicMathChain(Chain):
"""Chain that interprets a prompt and executes python code to do math. """Chain that interprets a prompt and executes python code to do symbolic math.
Example: Example:
.. code-block:: python .. code-block:: python

View File

@ -13,7 +13,7 @@ from langchain.schema.messages import HumanMessage, SystemMessage
class FactWithEvidence(BaseModel): class FactWithEvidence(BaseModel):
"""Class representing single statement. """Class representing a single statement.
Each fact has a body and a list of sources. Each fact has a body and a list of sources.
If there are multiple facts make sure to break them apart If there are multiple facts make sure to break them apart

View File

@ -188,9 +188,14 @@ def openapi_spec_to_openai_fn(
class SimpleRequestChain(Chain): class SimpleRequestChain(Chain):
"""Chain for making a simple request to an API endpoint."""
request_method: Callable request_method: Callable
"""Method to use for making the request."""
output_key: str = "response" output_key: str = "response"
"""Key to use for the output of the request."""
input_key: str = "function" input_key: str = "function"
"""Key to use for the input of the request."""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:

View File

@ -16,7 +16,7 @@ from langchain.schema.messages import HumanMessage, SystemMessage
class AnswerWithSources(BaseModel): class AnswerWithSources(BaseModel):
"""An answer to the question being asked, with sources.""" """An answer to the question, with sources."""
answer: str = Field(..., description="Answer to the question that was asked") answer: str = Field(..., description="Answer to the question that was asked")
sources: List[str] = Field( sources: List[str] = Field(
@ -30,7 +30,8 @@ def create_qa_with_structure_chain(
output_parser: str = "base", output_parser: str = "base",
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None, prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
) -> LLMChain: ) -> LLMChain:
"""Create a question answering chain that returns an answer with sources. """Create a question answering chain that returns an answer with sources
based on schema.
Args: Args:
llm: Language model to use for the chain. llm: Language model to use for the chain.

View File

@ -34,7 +34,8 @@ def create_tagging_chain(
prompt: Optional[ChatPromptTemplate] = None, prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any **kwargs: Any
) -> Chain: ) -> Chain:
"""Creates a chain that extracts information from a passage. """Creates a chain that extracts information from a passage
based on a schema.
Args: Args:
schema: The schema of the entities to extract. schema: The schema of the entities to extract.
@ -63,7 +64,8 @@ def create_tagging_chain_pydantic(
prompt: Optional[ChatPromptTemplate] = None, prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any **kwargs: Any
) -> Chain: ) -> Chain:
"""Creates a chain that extracts information from a passage. """Creates a chain that extracts information from a passage
based on a pydantic schema.
Args: Args:
pydantic_schema: The pydantic schema of the entities to extract. pydantic_schema: The pydantic schema of the entities to extract.

View File

@ -10,6 +10,8 @@ from langchain.schema.language_model import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC): class BasePromptSelector(BaseModel, ABC):
"""Base class for prompt selectors."""
@abstractmethod @abstractmethod
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate: def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model.""" """Get default prompt for a language model."""
@ -19,11 +21,21 @@ class ConditionalPromptSelector(BasePromptSelector):
"""Prompt collection that goes through conditionals.""" """Prompt collection that goes through conditionals."""
default_prompt: BasePromptTemplate default_prompt: BasePromptTemplate
"""Default prompt to use if no conditionals match."""
conditionals: List[ conditionals: List[
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate] Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
] = Field(default_factory=list) ] = Field(default_factory=list)
"""List of conditionals and prompts to use if the conditionals match."""
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate: def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model.
Args:
llm: Language model to get prompt for.
Returns:
Prompt to use for the language model.
"""
for condition, prompt in self.conditionals: for condition, prompt in self.conditionals:
if condition(llm): if condition(llm):
return prompt return prompt

View File

@ -15,13 +15,20 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
class QAGenerationChain(Chain): class QAGenerationChain(Chain):
"""Base class for question-answer generation chains."""
llm_chain: LLMChain llm_chain: LLMChain
"""LLM Chain that generates responses from user input and context."""
text_splitter: TextSplitter = Field( text_splitter: TextSplitter = Field(
default=RecursiveCharacterTextSplitter(chunk_overlap=500) default=RecursiveCharacterTextSplitter(chunk_overlap=500)
) )
"""Text splitter that splits the input into chunks."""
input_key: str = "text" input_key: str = "text"
"""Key of the input to the chain."""
output_key: str = "questions" output_key: str = "questions"
"""Key of the output of the chain."""
k: Optional[int] = None k: Optional[int] = None
"""Number of questions to generate."""
@classmethod @classmethod
def from_llm( def from_llm(
@ -30,6 +37,17 @@ class QAGenerationChain(Chain):
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any, **kwargs: Any,
) -> QAGenerationChain: ) -> QAGenerationChain:
"""
Create a QAGenerationChain from a language model.
Args:
llm: a language model
prompt: a prompt template
**kwargs: additional arguments
Returns:
a QAGenerationChain class
"""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
chain = LLMChain(llm=llm, prompt=_prompt) chain = LLMChain(llm=llm, prompt=_prompt)
return cls(llm_chain=chain, **kwargs) return cls(llm_chain=chain, **kwargs)

View File

@ -31,7 +31,7 @@ from langchain.schema.language_model import BaseLanguageModel
class BaseQAWithSourcesChain(Chain, ABC): class BaseQAWithSourcesChain(Chain, ABC):
"""Question answering with sources over documents.""" """Question answering chain with sources over documents."""
combine_documents_chain: BaseCombineDocumentsChain combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine documents.""" """Chain to use to combine documents."""

View File

@ -155,7 +155,7 @@ def load_qa_with_sources_chain(
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load question answering with sources chain. """Load a question answering with sources chain.
Args: Args:
llm: Language Model to use in the chain. llm: Language Model to use in the chain.

View File

@ -27,6 +27,8 @@ from langchain.schema.language_model import BaseLanguageModel
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
"""Output parser that parses a structured query."""
ast_parse: Callable ast_parse: Callable
"""Callable that parses dict into internal representation of query language.""" """Callable that parses dict into internal representation of query language."""
@ -57,6 +59,16 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None, allowed_operators: Optional[Sequence[Operator]] = None,
) -> StructuredQueryOutputParser: ) -> StructuredQueryOutputParser:
"""
Create a structured query output parser from components.
Args:
allowed_comparators: allowed comparators
allowed_operators: allowed operators
Returns:
a structured query output parser
"""
ast_parser = get_parser( ast_parser = get_parser(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
) )

View File

@ -53,7 +53,17 @@ def _to_snake_case(name: str) -> str:
class Expr(BaseModel): class Expr(BaseModel):
"""Base class for all expressions."""
def accept(self, visitor: Visitor) -> Any: def accept(self, visitor: Visitor) -> Any:
"""Accept a visitor.
Args:
visitor: visitor to accept
Returns:
result of visiting
"""
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")( return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
self self
) )
@ -99,6 +109,11 @@ class Operation(FilterDirective):
class StructuredQuery(Expr): class StructuredQuery(Expr):
"""A structured query."""
query: str query: str
"""Query string."""
filter: Optional[FilterDirective] filter: Optional[FilterDirective]
"""Filtering expression."""
limit: Optional[int] limit: Optional[int]
"""Limit on the number of results."""

View File

@ -54,7 +54,8 @@ GRAMMAR = """
@v_args(inline=True) @v_args(inline=True)
class QueryTransformer(Transformer): class QueryTransformer(Transformer):
"""Transforms a query string into an IR representation """Transforms a query string into an IR representation
(intermediate representation).""" (intermediate representation).
"""
def __init__( def __init__(
self, self,

View File

@ -25,12 +25,14 @@ from langchain.vectorstores.base import VectorStore
class BaseRetrievalQA(Chain): class BaseRetrievalQA(Chain):
"""Base class for question-answering chains."""
combine_documents_chain: BaseCombineDocumentsChain combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents.""" """Chain to use to combine the documents."""
input_key: str = "query" #: :meta private: input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private: output_key: str = "result" #: :meta private:
return_source_documents: bool = False return_source_documents: bool = False
"""Return the source documents.""" """Return the source documents or not."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -41,7 +43,7 @@ class BaseRetrievalQA(Chain):
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Return the input keys. """Input keys.
:meta private: :meta private:
""" """
@ -49,7 +51,7 @@ class BaseRetrievalQA(Chain):
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Return the output keys. """Output keys.
:meta private: :meta private:
""" """

View File

@ -27,6 +27,16 @@ class RouterChain(Chain, ABC):
return ["destination", "next_inputs"] return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route: def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
"""
Route inputs to a destination chain.
Args:
inputs: inputs to the chain
callbacks: callbacks to use for the chain
Returns:
a Route object
"""
result = self(inputs, callbacks=callbacks) result = self(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"]) return Route(result["destination"], result["next_inputs"])

View File

@ -12,7 +12,7 @@ from langchain.vectorstores.base import VectorStore
class EmbeddingRouterChain(RouterChain): class EmbeddingRouterChain(RouterChain):
"""Class that uses embeddings to route between options.""" """Chain that uses embeddings to route between options."""
vectorstore: VectorStore vectorstore: VectorStore
routing_keys: List[str] = ["query"] routing_keys: List[str] = ["query"]