mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
core[patch]: fix beta, deprecated typing (#18877)
**Description:** While not technically incorrect, the TypeVar used for the `@beta` decorator prevented pyright (and thus most vscode users) from correctly seeing the types of functions/classes decorated with `@beta`. This is in part due to a small bug in pyright (https://github.com/microsoft/pyright/issues/7448 ) - however, the `Type` bound in the typevar `C = TypeVar("C", Type, Callable)` is not doing anything - classes are `Callables` by default, so by my understanding binding to `Type` does not actually provide any more safety - the modified annotation still works correctly for both functions, properties, and classes. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
263ee78886
commit
f19229c564
@ -403,7 +403,7 @@ class _RedisCacheBase(BaseCache, ABC):
|
||||
if results:
|
||||
for _, text in results.items():
|
||||
try:
|
||||
generations.append(loads(text))
|
||||
generations.append(loads(cast(str, text)))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Retrieving a cache value that could not be deserialized "
|
||||
|
@ -9,11 +9,12 @@ https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecati
|
||||
This module is for internal use only. Do not use it in your own code.
|
||||
We may change the API at any time with no warning.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Type, TypeVar
|
||||
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
|
||||
|
||||
from langchain_core._api.internal import is_caller_internal
|
||||
|
||||
@ -25,7 +26,7 @@ class LangChainBetaWarning(DeprecationWarning):
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
T = TypeVar("T", Type, Callable)
|
||||
T = TypeVar("T", bound=Union[Callable[..., Any], Type])
|
||||
|
||||
|
||||
def beta(
|
||||
@ -143,7 +144,7 @@ def beta(
|
||||
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
|
||||
warn_if_direct_instance
|
||||
)
|
||||
return obj
|
||||
return cast(T, obj)
|
||||
|
||||
elif isinstance(obj, property):
|
||||
if not _obj_type:
|
||||
@ -202,7 +203,7 @@ def beta(
|
||||
"""
|
||||
wrapper = functools.wraps(wrapped)(wrapper)
|
||||
wrapper.__doc__ = new_doc
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
|
||||
|
||||
@ -225,9 +226,10 @@ def beta(
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(obj):
|
||||
return finalize(awarning_emitting_wrapper, new_doc)
|
||||
finalized = finalize(awarning_emitting_wrapper, new_doc)
|
||||
else:
|
||||
return finalize(warning_emitting_wrapper, new_doc)
|
||||
finalized = finalize(warning_emitting_wrapper, new_doc)
|
||||
return cast(T, finalized)
|
||||
|
||||
return beta
|
||||
|
||||
|
@ -14,7 +14,7 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Type, TypeVar
|
||||
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
|
||||
|
||||
from langchain_core._api.internal import is_caller_internal
|
||||
|
||||
@ -30,7 +30,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
T = TypeVar("T", Type, Callable)
|
||||
T = TypeVar("T", bound=Union[Type, Callable[..., Any]])
|
||||
|
||||
|
||||
def deprecated(
|
||||
@ -182,7 +182,7 @@ def deprecated(
|
||||
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
|
||||
warn_if_direct_instance
|
||||
)
|
||||
return obj
|
||||
return cast(T, obj)
|
||||
|
||||
elif isinstance(obj, property):
|
||||
if not _obj_type:
|
||||
@ -241,7 +241,7 @@ def deprecated(
|
||||
"""
|
||||
wrapper = functools.wraps(wrapped)(wrapper)
|
||||
wrapper.__doc__ = new_doc
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
|
||||
|
||||
@ -267,9 +267,10 @@ def deprecated(
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(obj):
|
||||
return finalize(awarning_emitting_wrapper, new_doc)
|
||||
finalized = finalize(awarning_emitting_wrapper, new_doc)
|
||||
else:
|
||||
return finalize(warning_emitting_wrapper, new_doc)
|
||||
finalized = finalize(warning_emitting_wrapper, new_doc)
|
||||
return cast(T, finalized)
|
||||
|
||||
return deprecate
|
||||
|
||||
|
@ -308,7 +308,7 @@ def convert_to_openai_function(
|
||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||
return cast(Dict, convert_pydantic_to_openai_function(function))
|
||||
elif isinstance(function, BaseTool):
|
||||
return format_tool_to_openai_function(function)
|
||||
return cast(Dict, format_tool_to_openai_function(function))
|
||||
elif callable(function):
|
||||
return convert_python_function_to_openai_function(function)
|
||||
else:
|
||||
|
@ -23,7 +23,9 @@ def _fake_runnable(
|
||||
class FakeStructuredChatModel(FakeListChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
def with_structured_output(self, schema: Union[Dict, Type[BaseModel]]) -> Runnable:
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable:
|
||||
return RunnableLambda(partial(_fake_runnable, schema))
|
||||
|
||||
@property
|
||||
|
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Any, List, Optional, Sequence
|
||||
from typing import Any, List, Optional, Sequence, Type, cast
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_core.documents import Document
|
||||
@ -93,9 +93,14 @@ def optional_enum_field(
|
||||
return Field(..., description=description + additional_info, **field_kwargs)
|
||||
|
||||
|
||||
class _Graph(BaseModel):
|
||||
nodes: Optional[List]
|
||||
relationships: Optional[List]
|
||||
|
||||
|
||||
def create_simple_model(
|
||||
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
|
||||
) -> Any:
|
||||
) -> Type[_Graph]:
|
||||
"""
|
||||
Simple model allows to limit node and/or relationship types.
|
||||
Doesn't have any node or relationship properties.
|
||||
@ -128,7 +133,7 @@ def create_simple_model(
|
||||
rel_types, description="The type of the relationship.", is_rel=True
|
||||
)
|
||||
|
||||
class DynamicGraph(BaseModel):
|
||||
class DynamicGraph(_Graph):
|
||||
"""Represents a graph document consisting of nodes and relationships."""
|
||||
|
||||
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes")
|
||||
@ -194,7 +199,7 @@ class LLMGraphTransformer:
|
||||
llm: BaseLanguageModel,
|
||||
allowed_nodes: List[str] = [],
|
||||
allowed_relationships: List[str] = [],
|
||||
prompt: Optional[ChatPromptTemplate] = default_prompt,
|
||||
prompt: ChatPromptTemplate = default_prompt,
|
||||
strict_mode: bool = True,
|
||||
) -> None:
|
||||
if not hasattr(llm, "with_structured_output"):
|
||||
@ -217,7 +222,7 @@ class LLMGraphTransformer:
|
||||
an LLM based on the model's schema and constraints.
|
||||
"""
|
||||
text = document.page_content
|
||||
raw_schema = self.chain.invoke({"input": text})
|
||||
raw_schema = cast(_Graph, self.chain.invoke({"input": text}))
|
||||
nodes = (
|
||||
[map_to_base_node(node) for node in raw_schema.nodes]
|
||||
if raw_schema.nodes
|
||||
@ -268,7 +273,7 @@ class LLMGraphTransformer:
|
||||
graph document.
|
||||
"""
|
||||
text = document.page_content
|
||||
raw_schema = await self.chain.ainvoke({"input": text})
|
||||
raw_schema = cast(_Graph, await self.chain.ainvoke({"input": text}))
|
||||
|
||||
nodes = (
|
||||
[map_to_base_node(node) for node in raw_schema.nodes]
|
||||
|
Loading…
Reference in New Issue
Block a user