From d3499cc90b2b1650d0e360f3d40cc063d1ba3cdf Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 11 Sep 2024 11:06:30 -0400 Subject: [PATCH] langchain[patch]: Assign appropriate default for Optional/Any types (#26325) This PR was autogenerated using gritql ``` engine marzano(0.1) language python class_definition(name=$C, $body, superclasses=$S) where { $C <: ! "Config", // Does not work in this scope, but works after class_definition $body <: block($statements), $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where { or { $y <: `Field($z)`, $x <: "model_config" } }, // And has either Any or Optional fields without a default $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where { $t <: or { r"Optional.*", r"Any", r"Union[None, .*]", r"Union[.*, None, .*]", r"Union[.*, None]", }, $y <: ., // Match empty node $t => `$t = None`, }, } ``` ```shell grit apply 'class_definition(name=$C, $body, superclasses=$S) where { $C <: ! "Config", // Does not work in this scope, but works after class_definition $body <: block($statements), $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where { or { $y <: `Field($z)`, $x <: "model_config" } }, // And has either Any or Optional fields without a default $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where { $t <: or { r"Optional.*", r"Any", r"Union[None, .*]", r"Union[.*, None, .*]", r"Union[.*, None]", }, $y <: ., // Match empty node $t => `$t = None`, }, } ' --language python . ``` --- libs/langchain/langchain/chains/api/base.py | 2 +- .../langchain/chains/conversational_retrieval/base.py | 2 +- .../langchain/langchain/chains/elasticsearch_database/base.py | 2 +- libs/langchain/langchain/chains/moderation.py | 4 ++-- .../retrievers/document_compressors/embeddings_filter.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py index 4387cb623e0..0cb2dbd2855 100644 --- a/libs/langchain/langchain/chains/api/base.py +++ b/libs/langchain/langchain/chains/api/base.py @@ -198,7 +198,7 @@ try: api_docs: str question_key: str = "question" #: :meta private: output_key: str = "output" #: :meta private: - limit_to_domains: Optional[Sequence[str]] + limit_to_domains: Optional[Sequence[str]] = Field(default_factory=list) """Use to limit the domains that can be accessed by the API chain. * For example, to limit to just the domain `https://www.example.com`, set diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 3c653c43336..0c983fdc5ad 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -92,7 +92,7 @@ class BaseConversationalRetrievalChain(Chain): get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None """An optional function to get a string of the chat history. If None is provided, will use a default.""" - response_if_no_docs_found: Optional[str] + response_if_no_docs_found: Optional[str] = None """If specified, the chain will return a fixed response if no docs are found for the question. """ diff --git a/libs/langchain/langchain/chains/elasticsearch_database/base.py b/libs/langchain/langchain/chains/elasticsearch_database/base.py index 85bf7de93d2..b45b3e2c174 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/base.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/base.py @@ -40,7 +40,7 @@ class ElasticsearchDatabaseChain(Chain): """Chain for creating the ES query.""" answer_chain: Runnable """Chain for answering the user question.""" - database: Any + database: Any = None """Elasticsearch database to connect to of type elasticsearch.Elasticsearch.""" top_k: int = 10 """Number of results to return from the query""" diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index 670b4773e3d..52590a597c0 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -28,8 +28,8 @@ class OpenAIModerationChain(Chain): moderation = OpenAIModerationChain() """ - client: Any #: :meta private: - async_client: Any #: :meta private: + client: Any = None #: :meta private: + async_client: Any = None #: :meta private: model_name: Optional[str] = None """Moderation model name to use.""" error: bool = False diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 8e3f1dbf43f..d71e7f1b3c2 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -36,7 +36,7 @@ class EmbeddingsFilter(BaseDocumentCompressor): k: Optional[int] = 20 """The number of relevant documents to return. Can be set to None, in which case `similarity_threshold` must be specified. Defaults to 20.""" - similarity_threshold: Optional[float] + similarity_threshold: Optional[float] = None """Threshold for determining when two documents are similar enough to be considered redundant. Defaults to None, must be specified if `k` is set to None."""