core[patch]: fix beta, deprecated typing (#18877)

**Description:** 

While not technically incorrect, the TypeVar used for the `@beta`
decorator prevented pyright (and thus most vscode users) from correctly
seeing the types of functions/classes decorated with `@beta`.

This is in part due to a small bug in pyright
(https://github.com/microsoft/pyright/issues/7448 ) - however, the
`Type` bound in the typevar `C = TypeVar("C", Type, Callable)` is not
doing anything - classes are `Callables` by default, so by my
understanding binding to `Type` does not actually provide any more
safety - the modified annotation still works correctly for both
functions, properties, and classes.

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Luca Dorigo 2024-03-28 23:33:43 +01:00 committed by GitHub
parent 263ee78886
commit f19229c564
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 31 additions and 21 deletions

View File

@ -403,7 +403,7 @@ class _RedisCacheBase(BaseCache, ABC):
if results: if results:
for _, text in results.items(): for _, text in results.items():
try: try:
generations.append(loads(text)) generations.append(loads(cast(str, text)))
except Exception: except Exception:
logger.warning( logger.warning(
"Retrieving a cache value that could not be deserialized " "Retrieving a cache value that could not be deserialized "

View File

@ -9,11 +9,12 @@ https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecati
This module is for internal use only. Do not use it in your own code. This module is for internal use only. Do not use it in your own code.
We may change the API at any time with no warning. We may change the API at any time with no warning.
""" """
import contextlib import contextlib
import functools import functools
import inspect import inspect
import warnings import warnings
from typing import Any, Callable, Generator, Type, TypeVar from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
from langchain_core._api.internal import is_caller_internal from langchain_core._api.internal import is_caller_internal
@ -25,7 +26,7 @@ class LangChainBetaWarning(DeprecationWarning):
# PUBLIC API # PUBLIC API
T = TypeVar("T", Type, Callable) T = TypeVar("T", bound=Union[Callable[..., Any], Type])
def beta( def beta(
@ -143,7 +144,7 @@ def beta(
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc] obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
warn_if_direct_instance warn_if_direct_instance
) )
return obj return cast(T, obj)
elif isinstance(obj, property): elif isinstance(obj, property):
if not _obj_type: if not _obj_type:
@ -202,7 +203,7 @@ def beta(
""" """
wrapper = functools.wraps(wrapped)(wrapper) wrapper = functools.wraps(wrapped)(wrapper)
wrapper.__doc__ = new_doc wrapper.__doc__ = new_doc
return wrapper return cast(T, wrapper)
old_doc = inspect.cleandoc(old_doc or "").strip("\n") old_doc = inspect.cleandoc(old_doc or "").strip("\n")
@ -225,9 +226,10 @@ def beta(
) )
if inspect.iscoroutinefunction(obj): if inspect.iscoroutinefunction(obj):
return finalize(awarning_emitting_wrapper, new_doc) finalized = finalize(awarning_emitting_wrapper, new_doc)
else: else:
return finalize(warning_emitting_wrapper, new_doc) finalized = finalize(warning_emitting_wrapper, new_doc)
return cast(T, finalized)
return beta return beta

View File

@ -14,7 +14,7 @@ import contextlib
import functools import functools
import inspect import inspect
import warnings import warnings
from typing import Any, Callable, Generator, Type, TypeVar from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
from langchain_core._api.internal import is_caller_internal from langchain_core._api.internal import is_caller_internal
@ -30,7 +30,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
# PUBLIC API # PUBLIC API
T = TypeVar("T", Type, Callable) T = TypeVar("T", bound=Union[Type, Callable[..., Any]])
def deprecated( def deprecated(
@ -182,7 +182,7 @@ def deprecated(
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc] obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
warn_if_direct_instance warn_if_direct_instance
) )
return obj return cast(T, obj)
elif isinstance(obj, property): elif isinstance(obj, property):
if not _obj_type: if not _obj_type:
@ -241,7 +241,7 @@ def deprecated(
""" """
wrapper = functools.wraps(wrapped)(wrapper) wrapper = functools.wraps(wrapped)(wrapper)
wrapper.__doc__ = new_doc wrapper.__doc__ = new_doc
return wrapper return cast(T, wrapper)
old_doc = inspect.cleandoc(old_doc or "").strip("\n") old_doc = inspect.cleandoc(old_doc or "").strip("\n")
@ -267,9 +267,10 @@ def deprecated(
) )
if inspect.iscoroutinefunction(obj): if inspect.iscoroutinefunction(obj):
return finalize(awarning_emitting_wrapper, new_doc) finalized = finalize(awarning_emitting_wrapper, new_doc)
else: else:
return finalize(warning_emitting_wrapper, new_doc) finalized = finalize(warning_emitting_wrapper, new_doc)
return cast(T, finalized)
return deprecate return deprecate

View File

@ -308,7 +308,7 @@ def convert_to_openai_function(
elif isinstance(function, type) and issubclass(function, BaseModel): elif isinstance(function, type) and issubclass(function, BaseModel):
return cast(Dict, convert_pydantic_to_openai_function(function)) return cast(Dict, convert_pydantic_to_openai_function(function))
elif isinstance(function, BaseTool): elif isinstance(function, BaseTool):
return format_tool_to_openai_function(function) return cast(Dict, format_tool_to_openai_function(function))
elif callable(function): elif callable(function):
return convert_python_function_to_openai_function(function) return convert_python_function_to_openai_function(function)
else: else:

View File

@ -23,7 +23,9 @@ def _fake_runnable(
class FakeStructuredChatModel(FakeListChatModel): class FakeStructuredChatModel(FakeListChatModel):
"""Fake ChatModel for testing purposes.""" """Fake ChatModel for testing purposes."""
def with_structured_output(self, schema: Union[Dict, Type[BaseModel]]) -> Runnable: def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable:
return RunnableLambda(partial(_fake_runnable, schema)) return RunnableLambda(partial(_fake_runnable, schema))
@property @property

View File

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import Any, List, Optional, Sequence from typing import Any, List, Optional, Sequence, Type, cast
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document from langchain_core.documents import Document
@ -93,9 +93,14 @@ def optional_enum_field(
return Field(..., description=description + additional_info, **field_kwargs) return Field(..., description=description + additional_info, **field_kwargs)
class _Graph(BaseModel):
nodes: Optional[List]
relationships: Optional[List]
def create_simple_model( def create_simple_model(
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
) -> Any: ) -> Type[_Graph]:
""" """
Simple model allows to limit node and/or relationship types. Simple model allows to limit node and/or relationship types.
Doesn't have any node or relationship properties. Doesn't have any node or relationship properties.
@ -128,7 +133,7 @@ def create_simple_model(
rel_types, description="The type of the relationship.", is_rel=True rel_types, description="The type of the relationship.", is_rel=True
) )
class DynamicGraph(BaseModel): class DynamicGraph(_Graph):
"""Represents a graph document consisting of nodes and relationships.""" """Represents a graph document consisting of nodes and relationships."""
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") nodes: Optional[List[SimpleNode]] = Field(description="List of nodes")
@ -194,7 +199,7 @@ class LLMGraphTransformer:
llm: BaseLanguageModel, llm: BaseLanguageModel,
allowed_nodes: List[str] = [], allowed_nodes: List[str] = [],
allowed_relationships: List[str] = [], allowed_relationships: List[str] = [],
prompt: Optional[ChatPromptTemplate] = default_prompt, prompt: ChatPromptTemplate = default_prompt,
strict_mode: bool = True, strict_mode: bool = True,
) -> None: ) -> None:
if not hasattr(llm, "with_structured_output"): if not hasattr(llm, "with_structured_output"):
@ -217,7 +222,7 @@ class LLMGraphTransformer:
an LLM based on the model's schema and constraints. an LLM based on the model's schema and constraints.
""" """
text = document.page_content text = document.page_content
raw_schema = self.chain.invoke({"input": text}) raw_schema = cast(_Graph, self.chain.invoke({"input": text}))
nodes = ( nodes = (
[map_to_base_node(node) for node in raw_schema.nodes] [map_to_base_node(node) for node in raw_schema.nodes]
if raw_schema.nodes if raw_schema.nodes
@ -268,7 +273,7 @@ class LLMGraphTransformer:
graph document. graph document.
""" """
text = document.page_content text = document.page_content
raw_schema = await self.chain.ainvoke({"input": text}) raw_schema = cast(_Graph, await self.chain.ainvoke({"input": text}))
nodes = ( nodes = (
[map_to_base_node(node) for node in raw_schema.nodes] [map_to_base_node(node) for node in raw_schema.nodes]