mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 20:13:39 +00:00
parent
f2ef3ff54a
commit
4a810756f8
@ -11,6 +11,7 @@ from langchain.schema import (
|
|||||||
|
|
||||||
|
|
||||||
def import_context() -> Any:
|
def import_context() -> Any:
|
||||||
|
"""Import the `getcontext` package."""
|
||||||
try:
|
try:
|
||||||
import getcontext # noqa: F401
|
import getcontext # noqa: F401
|
||||||
from getcontext.generated.models import (
|
from getcontext.generated.models import (
|
||||||
@ -30,7 +31,9 @@ def import_context() -> Any:
|
|||||||
|
|
||||||
|
|
||||||
class ContextCallbackHandler(BaseCallbackHandler):
|
class ContextCallbackHandler(BaseCallbackHandler):
|
||||||
"""Callback Handler that records transcripts to Context (https://getcontext.ai).
|
"""Callback Handler that records transcripts to the Context service.
|
||||||
|
|
||||||
|
(https://getcontext.ai).
|
||||||
|
|
||||||
Keyword Args:
|
Keyword Args:
|
||||||
token (optional): The token with which to authenticate requests to Context.
|
token (optional): The token with which to authenticate requests to Context.
|
||||||
|
@ -1,4 +1,17 @@
|
|||||||
"""Chains are easily reusable components which can be linked together."""
|
"""Chains are easily reusable components which can be linked together.
|
||||||
|
|
||||||
|
Chains should be used to encode a sequence of calls to components like
|
||||||
|
models, document retrievers, other chains, etc., and provide a simple interface
|
||||||
|
to this sequence.
|
||||||
|
|
||||||
|
The Chain interface makes it easy to create apps that are:
|
||||||
|
- Stateful: add Memory to any Chain to give it state,
|
||||||
|
- Observable: pass Callbacks to a Chain to execute additional functionality,
|
||||||
|
like logging, outside the main sequence of component calls,
|
||||||
|
- Composable: the Chain API is flexible enough that it is easy to combine
|
||||||
|
Chains with other components, including other Chains.
|
||||||
|
"""
|
||||||
|
|
||||||
from langchain.chains.api.base import APIChain
|
from langchain.chains.api.base import APIChain
|
||||||
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
|
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
|
||||||
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
|
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
|
||||||
|
@ -72,13 +72,13 @@ class Chain(Serializable, ABC):
|
|||||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||||
will be printed to the console. Defaults to `langchain.verbose` value."""
|
will be printed to the console. Defaults to `langchain.verbose` value."""
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
"""Optional list of tags associated with the chain. Defaults to None
|
"""Optional list of tags associated with the chain. Defaults to None.
|
||||||
These tags will be associated with each call to this chain,
|
These tags will be associated with each call to this chain,
|
||||||
and passed as arguments to the handlers defined in `callbacks`.
|
and passed as arguments to the handlers defined in `callbacks`.
|
||||||
You can use these to eg identify a specific instance of a chain with its use case.
|
You can use these to eg identify a specific instance of a chain with its use case.
|
||||||
"""
|
"""
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
"""Optional metadata associated with the chain. Defaults to None
|
"""Optional metadata associated with the chain. Defaults to None.
|
||||||
This metadata will be associated with each call to this chain,
|
This metadata will be associated with each call to this chain,
|
||||||
and passed as arguments to the handlers defined in `callbacks`.
|
and passed as arguments to the handlers defined in `callbacks`.
|
||||||
You can use these to eg identify a specific instance of a chain with its use case.
|
You can use these to eg identify a specific instance of a chain with its use case.
|
||||||
@ -118,12 +118,12 @@ class Chain(Serializable, ABC):
|
|||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Return the keys expected to be in the chain input."""
|
"""Keys expected to be in the chain input."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Return the keys expected to be in the chain output."""
|
"""Keys expected to be in the chain output."""
|
||||||
|
|
||||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""Check that all inputs are present."""
|
"""Check that all inputs are present."""
|
||||||
@ -391,7 +391,7 @@ class Chain(Serializable, ABC):
|
|||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Convenience method for executing chain when there's a single string output.
|
"""Execute chain when there's a single string output.
|
||||||
|
|
||||||
The main difference between this method and `Chain.__call__` is that this method
|
The main difference between this method and `Chain.__call__` is that this method
|
||||||
can only be used for chains that return a single string output. If a Chain
|
can only be used for chains that return a single string output. If a Chain
|
||||||
@ -465,7 +465,7 @@ class Chain(Serializable, ABC):
|
|||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Convenience method for executing chain when there's a single string output.
|
"""Execute chain when there's a single string output.
|
||||||
|
|
||||||
The main difference between this method and `Chain.__call__` is that this method
|
The main difference between this method and `Chain.__call__` is that this method
|
||||||
can only be used for chains that return a single string output. If a Chain
|
can only be used for chains that return a single string output. If a Chain
|
||||||
@ -532,7 +532,7 @@ class Chain(Serializable, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return dictionary representation of chain.
|
"""Dictionary representation of chain.
|
||||||
|
|
||||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||||
null.
|
null.
|
||||||
|
@ -22,7 +22,7 @@ class AsyncCombineDocsProtocol(Protocol):
|
|||||||
"""Interface for the combine_docs method."""
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||||
"""Async nterface for the combine_docs method."""
|
"""Async interface for the combine_docs method."""
|
||||||
|
|
||||||
|
|
||||||
def _split_list_of_docs(
|
def _split_list_of_docs(
|
||||||
@ -78,7 +78,7 @@ async def _acollapse_docs(
|
|||||||
|
|
||||||
|
|
||||||
class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||||
"""Combining documents by recursively reducing them.
|
"""Combine documents by recursively reducing them.
|
||||||
|
|
||||||
This involves
|
This involves
|
||||||
|
|
||||||
@ -206,7 +206,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
"""Combine multiple documents recursively.
|
"""Async combine multiple documents recursively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
docs: List of documents to combine, assumed that each one is less than
|
docs: List of documents to combine, assumed that each one is less than
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Combining documents by doing a first pass and then refining on more documents."""
|
"""Combine documents by doing a first pass and then refining on more documents."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -161,7 +161,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
"""Combine by mapping first chain over all, then stuffing into final chain.
|
"""Async combine by mapping a first chain over all, then stuffing
|
||||||
|
into a final chain.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
docs: List of documents to combine
|
docs: List of documents to combine
|
||||||
|
@ -167,7 +167,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
"""Stuff all documents into one prompt and pass to LLM.
|
"""Async stuff all documents into one prompt and pass to LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
docs: List of documents to join together into one variable
|
docs: List of documents to join together into one variable
|
||||||
|
@ -80,12 +80,12 @@ class ConstitutionalChain(Chain):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Defines the input keys."""
|
"""Input keys."""
|
||||||
return self.chain.input_keys
|
return self.chain.input_keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Defines the output keys."""
|
"""Output keys."""
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return ["output", "critiques_and_revisions", "initial_output"]
|
return ["output", "critiques_and_revisions", "initial_output"]
|
||||||
return ["output"]
|
return ["output"]
|
||||||
|
@ -23,6 +23,8 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
|
|
||||||
|
|
||||||
class _ResponseChain(LLMChain):
|
class _ResponseChain(LLMChain):
|
||||||
|
"""Base class for chains that generate responses."""
|
||||||
|
|
||||||
prompt: BasePromptTemplate = PROMPT
|
prompt: BasePromptTemplate = PROMPT
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -46,6 +48,8 @@ class _ResponseChain(LLMChain):
|
|||||||
|
|
||||||
|
|
||||||
class _OpenAIResponseChain(_ResponseChain):
|
class _OpenAIResponseChain(_ResponseChain):
|
||||||
|
"""Chain that generates responses from user input and context."""
|
||||||
|
|
||||||
llm: OpenAI = Field(
|
llm: OpenAI = Field(
|
||||||
default_factory=lambda: OpenAI(
|
default_factory=lambda: OpenAI(
|
||||||
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
|
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
|
||||||
@ -66,10 +70,14 @@ class _OpenAIResponseChain(_ResponseChain):
|
|||||||
|
|
||||||
|
|
||||||
class QuestionGeneratorChain(LLMChain):
|
class QuestionGeneratorChain(LLMChain):
|
||||||
|
"""Chain that generates questions from uncertain spans."""
|
||||||
|
|
||||||
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
|
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
|
||||||
|
"""Prompt template for the chain."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input keys for the chain."""
|
||||||
return ["user_input", "context", "response"]
|
return ["user_input", "context", "response"]
|
||||||
|
|
||||||
|
|
||||||
@ -95,22 +103,36 @@ def _low_confidence_spans(
|
|||||||
|
|
||||||
|
|
||||||
class FlareChain(Chain):
|
class FlareChain(Chain):
|
||||||
|
"""Chain that combines a retriever, a question generator,
|
||||||
|
and a response generator."""
|
||||||
|
|
||||||
question_generator_chain: QuestionGeneratorChain
|
question_generator_chain: QuestionGeneratorChain
|
||||||
|
"""Chain that generates questions from uncertain spans."""
|
||||||
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain)
|
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain)
|
||||||
|
"""Chain that generates responses from user input and context."""
|
||||||
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
||||||
|
"""Parser that determines whether the chain is finished."""
|
||||||
retriever: BaseRetriever
|
retriever: BaseRetriever
|
||||||
|
"""Retriever that retrieves relevant documents from a user input."""
|
||||||
min_prob: float = 0.2
|
min_prob: float = 0.2
|
||||||
|
"""Minimum probability for a token to be considered low confidence."""
|
||||||
min_token_gap: int = 5
|
min_token_gap: int = 5
|
||||||
|
"""Minimum number of tokens between two low confidence spans."""
|
||||||
num_pad_tokens: int = 2
|
num_pad_tokens: int = 2
|
||||||
|
"""Number of tokens to pad around a low confidence span."""
|
||||||
max_iter: int = 10
|
max_iter: int = 10
|
||||||
|
"""Maximum number of iterations."""
|
||||||
start_with_retrieval: bool = True
|
start_with_retrieval: bool = True
|
||||||
|
"""Whether to start with retrieval."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input keys for the chain."""
|
||||||
return ["user_input"]
|
return ["user_input"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Output keys for the chain."""
|
||||||
return ["response"]
|
return ["response"]
|
||||||
|
|
||||||
def _do_generation(
|
def _do_generation(
|
||||||
@ -213,6 +235,16 @@ class FlareChain(Chain):
|
|||||||
def from_llm(
|
def from_llm(
|
||||||
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
|
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
|
||||||
) -> FlareChain:
|
) -> FlareChain:
|
||||||
|
"""Creates a FlareChain from a language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: Language model to use.
|
||||||
|
max_generation_len: Maximum length of the generated response.
|
||||||
|
**kwargs: Additional arguments to pass to the constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FlareChain class with the given language model.
|
||||||
|
"""
|
||||||
question_gen_chain = QuestionGeneratorChain(llm=llm)
|
question_gen_chain = QuestionGeneratorChain(llm=llm)
|
||||||
response_llm = OpenAI(
|
response_llm = OpenAI(
|
||||||
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
|
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
|
||||||
|
@ -5,7 +5,10 @@ from langchain.schema import BaseOutputParser
|
|||||||
|
|
||||||
|
|
||||||
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
|
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
|
||||||
|
"""Output parser that checks if the output is finished."""
|
||||||
|
|
||||||
finished_value: str = "FINISHED"
|
finished_value: str = "FINISHED"
|
||||||
|
"""Value that indicates the output is finished."""
|
||||||
|
|
||||||
def parse(self, text: str) -> Tuple[str, bool]:
|
def parse(self, text: str) -> Tuple[str, bool]:
|
||||||
cleaned = text.strip()
|
cleaned = text.strip()
|
||||||
|
@ -25,7 +25,7 @@ class GraphQAChain(Chain):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Return the input keys.
|
"""Input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
@ -33,7 +33,7 @@ class GraphQAChain(Chain):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Return the output keys.
|
"""Output keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
|
@ -18,8 +18,8 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|||||||
|
|
||||||
|
|
||||||
def extract_cypher(text: str) -> str:
|
def extract_cypher(text: str) -> str:
|
||||||
"""
|
"""Extract Cypher code from a text.
|
||||||
Extract Cypher code from a text.
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to extract Cypher code from.
|
text: Text to extract Cypher code from.
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ class HugeGraphQAChain(Chain):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Return the input keys.
|
"""Input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
@ -36,7 +36,7 @@ class HugeGraphQAChain(Chain):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Return the output keys.
|
"""Output keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
"""Chain that interprets a prompt and executes bash operations."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class LLMBashChain(Chain):
|
class LLMBashChain(Chain):
|
||||||
"""Chain that interprets a prompt and executes bash code to perform bash operations.
|
"""Chain that interprets a prompt and executes bash operations.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -16,7 +16,7 @@ DEFAULT_HEADERS = {
|
|||||||
|
|
||||||
|
|
||||||
class LLMRequestsChain(Chain):
|
class LLMRequestsChain(Chain):
|
||||||
"""Chain that hits a URL and then uses an LLM to parse results."""
|
"""Chain that requests a URL and then uses an LLM to parse results."""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
requests_wrapper: TextRequestsWrapper = Field(
|
requests_wrapper: TextRequestsWrapper = Field(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Chain that interprets a prompt and executes python code to do math."""
|
"""Chain that interprets a prompt and executes python code to do symbolic math."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
@ -18,7 +18,7 @@ from langchain.prompts.base import BasePromptTemplate
|
|||||||
|
|
||||||
|
|
||||||
class LLMSymbolicMathChain(Chain):
|
class LLMSymbolicMathChain(Chain):
|
||||||
"""Chain that interprets a prompt and executes python code to do math.
|
"""Chain that interprets a prompt and executes python code to do symbolic math.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -13,7 +13,7 @@ from langchain.schema.messages import HumanMessage, SystemMessage
|
|||||||
|
|
||||||
|
|
||||||
class FactWithEvidence(BaseModel):
|
class FactWithEvidence(BaseModel):
|
||||||
"""Class representing single statement.
|
"""Class representing a single statement.
|
||||||
|
|
||||||
Each fact has a body and a list of sources.
|
Each fact has a body and a list of sources.
|
||||||
If there are multiple facts make sure to break them apart
|
If there are multiple facts make sure to break them apart
|
||||||
|
@ -188,9 +188,14 @@ def openapi_spec_to_openai_fn(
|
|||||||
|
|
||||||
|
|
||||||
class SimpleRequestChain(Chain):
|
class SimpleRequestChain(Chain):
|
||||||
|
"""Chain for making a simple request to an API endpoint."""
|
||||||
|
|
||||||
request_method: Callable
|
request_method: Callable
|
||||||
|
"""Method to use for making the request."""
|
||||||
output_key: str = "response"
|
output_key: str = "response"
|
||||||
|
"""Key to use for the output of the request."""
|
||||||
input_key: str = "function"
|
input_key: str = "function"
|
||||||
|
"""Key to use for the input of the request."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
|
@ -16,7 +16,7 @@ from langchain.schema.messages import HumanMessage, SystemMessage
|
|||||||
|
|
||||||
|
|
||||||
class AnswerWithSources(BaseModel):
|
class AnswerWithSources(BaseModel):
|
||||||
"""An answer to the question being asked, with sources."""
|
"""An answer to the question, with sources."""
|
||||||
|
|
||||||
answer: str = Field(..., description="Answer to the question that was asked")
|
answer: str = Field(..., description="Answer to the question that was asked")
|
||||||
sources: List[str] = Field(
|
sources: List[str] = Field(
|
||||||
@ -30,7 +30,8 @@ def create_qa_with_structure_chain(
|
|||||||
output_parser: str = "base",
|
output_parser: str = "base",
|
||||||
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
|
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
|
||||||
) -> LLMChain:
|
) -> LLMChain:
|
||||||
"""Create a question answering chain that returns an answer with sources.
|
"""Create a question answering chain that returns an answer with sources
|
||||||
|
based on schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm: Language model to use for the chain.
|
llm: Language model to use for the chain.
|
||||||
|
@ -34,7 +34,8 @@ def create_tagging_chain(
|
|||||||
prompt: Optional[ChatPromptTemplate] = None,
|
prompt: Optional[ChatPromptTemplate] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Chain:
|
) -> Chain:
|
||||||
"""Creates a chain that extracts information from a passage.
|
"""Creates a chain that extracts information from a passage
|
||||||
|
based on a schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: The schema of the entities to extract.
|
schema: The schema of the entities to extract.
|
||||||
@ -63,7 +64,8 @@ def create_tagging_chain_pydantic(
|
|||||||
prompt: Optional[ChatPromptTemplate] = None,
|
prompt: Optional[ChatPromptTemplate] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Chain:
|
) -> Chain:
|
||||||
"""Creates a chain that extracts information from a passage.
|
"""Creates a chain that extracts information from a passage
|
||||||
|
based on a pydantic schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pydantic_schema: The pydantic schema of the entities to extract.
|
pydantic_schema: The pydantic schema of the entities to extract.
|
||||||
|
@ -10,6 +10,8 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
|
|
||||||
|
|
||||||
class BasePromptSelector(BaseModel, ABC):
|
class BasePromptSelector(BaseModel, ABC):
|
||||||
|
"""Base class for prompt selectors."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
||||||
"""Get default prompt for a language model."""
|
"""Get default prompt for a language model."""
|
||||||
@ -19,11 +21,21 @@ class ConditionalPromptSelector(BasePromptSelector):
|
|||||||
"""Prompt collection that goes through conditionals."""
|
"""Prompt collection that goes through conditionals."""
|
||||||
|
|
||||||
default_prompt: BasePromptTemplate
|
default_prompt: BasePromptTemplate
|
||||||
|
"""Default prompt to use if no conditionals match."""
|
||||||
conditionals: List[
|
conditionals: List[
|
||||||
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
||||||
] = Field(default_factory=list)
|
] = Field(default_factory=list)
|
||||||
|
"""List of conditionals and prompts to use if the conditionals match."""
|
||||||
|
|
||||||
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
||||||
|
"""Get default prompt for a language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: Language model to get prompt for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prompt to use for the language model.
|
||||||
|
"""
|
||||||
for condition, prompt in self.conditionals:
|
for condition, prompt in self.conditionals:
|
||||||
if condition(llm):
|
if condition(llm):
|
||||||
return prompt
|
return prompt
|
||||||
|
@ -15,13 +15,20 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
|||||||
|
|
||||||
|
|
||||||
class QAGenerationChain(Chain):
|
class QAGenerationChain(Chain):
|
||||||
|
"""Base class for question-answer generation chains."""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
|
"""LLM Chain that generates responses from user input and context."""
|
||||||
text_splitter: TextSplitter = Field(
|
text_splitter: TextSplitter = Field(
|
||||||
default=RecursiveCharacterTextSplitter(chunk_overlap=500)
|
default=RecursiveCharacterTextSplitter(chunk_overlap=500)
|
||||||
)
|
)
|
||||||
|
"""Text splitter that splits the input into chunks."""
|
||||||
input_key: str = "text"
|
input_key: str = "text"
|
||||||
|
"""Key of the input to the chain."""
|
||||||
output_key: str = "questions"
|
output_key: str = "questions"
|
||||||
|
"""Key of the output of the chain."""
|
||||||
k: Optional[int] = None
|
k: Optional[int] = None
|
||||||
|
"""Number of questions to generate."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
@ -30,6 +37,17 @@ class QAGenerationChain(Chain):
|
|||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> QAGenerationChain:
|
) -> QAGenerationChain:
|
||||||
|
"""
|
||||||
|
Create a QAGenerationChain from a language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: a language model
|
||||||
|
prompt: a prompt template
|
||||||
|
**kwargs: additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a QAGenerationChain class
|
||||||
|
"""
|
||||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||||
chain = LLMChain(llm=llm, prompt=_prompt)
|
chain = LLMChain(llm=llm, prompt=_prompt)
|
||||||
return cls(llm_chain=chain, **kwargs)
|
return cls(llm_chain=chain, **kwargs)
|
||||||
|
@ -31,7 +31,7 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
|
|
||||||
|
|
||||||
class BaseQAWithSourcesChain(Chain, ABC):
|
class BaseQAWithSourcesChain(Chain, ABC):
|
||||||
"""Question answering with sources over documents."""
|
"""Question answering chain with sources over documents."""
|
||||||
|
|
||||||
combine_documents_chain: BaseCombineDocumentsChain
|
combine_documents_chain: BaseCombineDocumentsChain
|
||||||
"""Chain to use to combine documents."""
|
"""Chain to use to combine documents."""
|
||||||
|
@ -155,7 +155,7 @@ def load_qa_with_sources_chain(
|
|||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseCombineDocumentsChain:
|
) -> BaseCombineDocumentsChain:
|
||||||
"""Load question answering with sources chain.
|
"""Load a question answering with sources chain.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm: Language Model to use in the chain.
|
llm: Language Model to use in the chain.
|
||||||
|
@ -27,6 +27,8 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
|
|
||||||
|
|
||||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||||
|
"""Output parser that parses a structured query."""
|
||||||
|
|
||||||
ast_parse: Callable
|
ast_parse: Callable
|
||||||
"""Callable that parses dict into internal representation of query language."""
|
"""Callable that parses dict into internal representation of query language."""
|
||||||
|
|
||||||
@ -57,6 +59,16 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|||||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||||
) -> StructuredQueryOutputParser:
|
) -> StructuredQueryOutputParser:
|
||||||
|
"""
|
||||||
|
Create a structured query output parser from components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_comparators: allowed comparators
|
||||||
|
allowed_operators: allowed operators
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a structured query output parser
|
||||||
|
"""
|
||||||
ast_parser = get_parser(
|
ast_parser = get_parser(
|
||||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||||
)
|
)
|
||||||
|
@ -53,7 +53,17 @@ def _to_snake_case(name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class Expr(BaseModel):
|
class Expr(BaseModel):
|
||||||
|
"""Base class for all expressions."""
|
||||||
|
|
||||||
def accept(self, visitor: Visitor) -> Any:
|
def accept(self, visitor: Visitor) -> Any:
|
||||||
|
"""Accept a visitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
visitor: visitor to accept
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result of visiting
|
||||||
|
"""
|
||||||
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
|
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
|
||||||
self
|
self
|
||||||
)
|
)
|
||||||
@ -99,6 +109,11 @@ class Operation(FilterDirective):
|
|||||||
|
|
||||||
|
|
||||||
class StructuredQuery(Expr):
|
class StructuredQuery(Expr):
|
||||||
|
"""A structured query."""
|
||||||
|
|
||||||
query: str
|
query: str
|
||||||
|
"""Query string."""
|
||||||
filter: Optional[FilterDirective]
|
filter: Optional[FilterDirective]
|
||||||
|
"""Filtering expression."""
|
||||||
limit: Optional[int]
|
limit: Optional[int]
|
||||||
|
"""Limit on the number of results."""
|
||||||
|
@ -54,7 +54,8 @@ GRAMMAR = """
|
|||||||
@v_args(inline=True)
|
@v_args(inline=True)
|
||||||
class QueryTransformer(Transformer):
|
class QueryTransformer(Transformer):
|
||||||
"""Transforms a query string into an IR representation
|
"""Transforms a query string into an IR representation
|
||||||
(intermediate representation)."""
|
(intermediate representation).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -25,12 +25,14 @@ from langchain.vectorstores.base import VectorStore
|
|||||||
|
|
||||||
|
|
||||||
class BaseRetrievalQA(Chain):
|
class BaseRetrievalQA(Chain):
|
||||||
|
"""Base class for question-answering chains."""
|
||||||
|
|
||||||
combine_documents_chain: BaseCombineDocumentsChain
|
combine_documents_chain: BaseCombineDocumentsChain
|
||||||
"""Chain to use to combine the documents."""
|
"""Chain to use to combine the documents."""
|
||||||
input_key: str = "query" #: :meta private:
|
input_key: str = "query" #: :meta private:
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
return_source_documents: bool = False
|
return_source_documents: bool = False
|
||||||
"""Return the source documents."""
|
"""Return the source documents or not."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -41,7 +43,7 @@ class BaseRetrievalQA(Chain):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Return the input keys.
|
"""Input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
@ -49,7 +51,7 @@ class BaseRetrievalQA(Chain):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Return the output keys.
|
"""Output keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
|
@ -27,6 +27,16 @@ class RouterChain(Chain, ABC):
|
|||||||
return ["destination", "next_inputs"]
|
return ["destination", "next_inputs"]
|
||||||
|
|
||||||
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
|
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
|
||||||
|
"""
|
||||||
|
Route inputs to a destination chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: inputs to the chain
|
||||||
|
callbacks: callbacks to use for the chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Route object
|
||||||
|
"""
|
||||||
result = self(inputs, callbacks=callbacks)
|
result = self(inputs, callbacks=callbacks)
|
||||||
return Route(result["destination"], result["next_inputs"])
|
return Route(result["destination"], result["next_inputs"])
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from langchain.vectorstores.base import VectorStore
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingRouterChain(RouterChain):
|
class EmbeddingRouterChain(RouterChain):
|
||||||
"""Class that uses embeddings to route between options."""
|
"""Chain that uses embeddings to route between options."""
|
||||||
|
|
||||||
vectorstore: VectorStore
|
vectorstore: VectorStore
|
||||||
routing_keys: List[str] = ["query"]
|
routing_keys: List[str] = ["query"]
|
||||||
|
Loading…
Reference in New Issue
Block a user