From f19229c564144fbdd6ee9c4e140a23b0df2d7f6d Mon Sep 17 00:00:00 2001 From: Luca Dorigo Date: Thu, 28 Mar 2024 23:33:43 +0100 Subject: [PATCH] core[patch]: fix beta, deprecated typing (#18877) **Description:** While not technically incorrect, the TypeVar used for the `@beta` decorator prevented pyright (and thus most vscode users) from correctly seeing the types of functions/classes decorated with `@beta`. This is in part due to a small bug in pyright (https://github.com/microsoft/pyright/issues/7448 ) - however, the `Type` bound in the typevar `C = TypeVar("C", Type, Callable)` is not doing anything - classes are `Callables` by default, so by my understanding binding to `Type` does not actually provide any more safety - the modified annotation still works correctly for both functions, properties, and classes. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- libs/community/langchain_community/cache.py | 2 +- libs/core/langchain_core/_api/beta_decorator.py | 14 ++++++++------ libs/core/langchain_core/_api/deprecation.py | 13 +++++++------ .../langchain_core/utils/function_calling.py | 2 +- .../tests/unit_tests/prompts/test_structured.py | 4 +++- .../graph_transformers/llm.py | 17 +++++++++++------ 6 files changed, 31 insertions(+), 21 deletions(-) diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py index c2517270f03..113f029df79 100644 --- a/libs/community/langchain_community/cache.py +++ b/libs/community/langchain_community/cache.py @@ -403,7 +403,7 @@ class _RedisCacheBase(BaseCache, ABC): if results: for _, text in results.items(): try: - generations.append(loads(text)) + generations.append(loads(cast(str, text))) except Exception: logger.warning( "Retrieving a cache value that could not be deserialized " diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index 19f5db11df5..84c18c581e7 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -9,11 +9,12 @@ https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecati This module is for internal use only. Do not use it in your own code. We may change the API at any time with no warning. """ + import contextlib import functools import inspect import warnings -from typing import Any, Callable, Generator, Type, TypeVar +from typing import Any, Callable, Generator, Type, TypeVar, Union, cast from langchain_core._api.internal import is_caller_internal @@ -25,7 +26,7 @@ class LangChainBetaWarning(DeprecationWarning): # PUBLIC API -T = TypeVar("T", Type, Callable) +T = TypeVar("T", bound=Union[Callable[..., Any], Type]) def beta( @@ -143,7 +144,7 @@ def beta( obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc] warn_if_direct_instance ) - return obj + return cast(T, obj) elif isinstance(obj, property): if not _obj_type: @@ -202,7 +203,7 @@ def beta( """ wrapper = functools.wraps(wrapped)(wrapper) wrapper.__doc__ = new_doc - return wrapper + return cast(T, wrapper) old_doc = inspect.cleandoc(old_doc or "").strip("\n") @@ -225,9 +226,10 @@ def beta( ) if inspect.iscoroutinefunction(obj): - return finalize(awarning_emitting_wrapper, new_doc) + finalized = finalize(awarning_emitting_wrapper, new_doc) else: - return finalize(warning_emitting_wrapper, new_doc) + finalized = finalize(warning_emitting_wrapper, new_doc) + return cast(T, finalized) return beta diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 52c4e9c8dfa..484d591f1ba 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -14,7 +14,7 @@ import contextlib import functools import inspect import warnings -from typing import Any, Callable, Generator, Type, TypeVar +from typing import Any, Callable, Generator, Type, TypeVar, Union, cast from langchain_core._api.internal import is_caller_internal @@ -30,7 +30,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning): # PUBLIC API -T = TypeVar("T", Type, Callable) +T = TypeVar("T", bound=Union[Type, Callable[..., Any]]) def deprecated( @@ -182,7 +182,7 @@ def deprecated( obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc] warn_if_direct_instance ) - return obj + return cast(T, obj) elif isinstance(obj, property): if not _obj_type: @@ -241,7 +241,7 @@ def deprecated( """ wrapper = functools.wraps(wrapped)(wrapper) wrapper.__doc__ = new_doc - return wrapper + return cast(T, wrapper) old_doc = inspect.cleandoc(old_doc or "").strip("\n") @@ -267,9 +267,10 @@ def deprecated( ) if inspect.iscoroutinefunction(obj): - return finalize(awarning_emitting_wrapper, new_doc) + finalized = finalize(awarning_emitting_wrapper, new_doc) else: - return finalize(warning_emitting_wrapper, new_doc) + finalized = finalize(warning_emitting_wrapper, new_doc) + return cast(T, finalized) return deprecate diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index a406b87097e..860259a93e3 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -308,7 +308,7 @@ def convert_to_openai_function( elif isinstance(function, type) and issubclass(function, BaseModel): return cast(Dict, convert_pydantic_to_openai_function(function)) elif isinstance(function, BaseTool): - return format_tool_to_openai_function(function) + return cast(Dict, format_tool_to_openai_function(function)) elif callable(function): return convert_python_function_to_openai_function(function) else: diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index a8a352be04d..92fba1249c9 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -23,7 +23,9 @@ def _fake_runnable( class FakeStructuredChatModel(FakeListChatModel): """Fake ChatModel for testing purposes.""" - def with_structured_output(self, schema: Union[Dict, Type[BaseModel]]) -> Runnable: + def with_structured_output( + self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any + ) -> Runnable: return RunnableLambda(partial(_fake_runnable, schema)) @property diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index e671bc24754..6f281d29083 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Type, cast from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document @@ -93,9 +93,14 @@ def optional_enum_field( return Field(..., description=description + additional_info, **field_kwargs) +class _Graph(BaseModel): + nodes: Optional[List] + relationships: Optional[List] + + def create_simple_model( node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None -) -> Any: +) -> Type[_Graph]: """ Simple model allows to limit node and/or relationship types. Doesn't have any node or relationship properties. @@ -128,7 +133,7 @@ def create_simple_model( rel_types, description="The type of the relationship.", is_rel=True ) - class DynamicGraph(BaseModel): + class DynamicGraph(_Graph): """Represents a graph document consisting of nodes and relationships.""" nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") @@ -194,7 +199,7 @@ class LLMGraphTransformer: llm: BaseLanguageModel, allowed_nodes: List[str] = [], allowed_relationships: List[str] = [], - prompt: Optional[ChatPromptTemplate] = default_prompt, + prompt: ChatPromptTemplate = default_prompt, strict_mode: bool = True, ) -> None: if not hasattr(llm, "with_structured_output"): @@ -217,7 +222,7 @@ class LLMGraphTransformer: an LLM based on the model's schema and constraints. """ text = document.page_content - raw_schema = self.chain.invoke({"input": text}) + raw_schema = cast(_Graph, self.chain.invoke({"input": text})) nodes = ( [map_to_base_node(node) for node in raw_schema.nodes] if raw_schema.nodes @@ -268,7 +273,7 @@ class LLMGraphTransformer: graph document. """ text = document.page_content - raw_schema = await self.chain.ainvoke({"input": text}) + raw_schema = cast(_Graph, await self.chain.ainvoke({"input": text})) nodes = ( [map_to_base_node(node) for node in raw_schema.nodes]