core[patch]: fix beta, deprecated typing (#18877)

**Description:** 

While not technically incorrect, the TypeVar used for the `@beta`
decorator prevented pyright (and thus most vscode users) from correctly
seeing the types of functions/classes decorated with `@beta`.

This is in part due to a small bug in pyright
(https://github.com/microsoft/pyright/issues/7448 ) - however, the
`Type` bound in the typevar `C = TypeVar("C", Type, Callable)` is not
doing anything - classes are `Callables` by default, so by my
understanding binding to `Type` does not actually provide any more
safety - the modified annotation still works correctly for both
functions, properties, and classes.

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Luca Dorigo
2024-03-28 23:33:43 +01:00
committed by GitHub
parent 263ee78886
commit f19229c564
6 changed files with 31 additions and 21 deletions

View File

@@ -1,5 +1,5 @@
import asyncio
from typing import Any, List, Optional, Sequence
from typing import Any, List, Optional, Sequence, Type, cast
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
@@ -93,9 +93,14 @@ def optional_enum_field(
return Field(..., description=description + additional_info, **field_kwargs)
class _Graph(BaseModel):
nodes: Optional[List]
relationships: Optional[List]
def create_simple_model(
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
) -> Any:
) -> Type[_Graph]:
"""
Simple model allows to limit node and/or relationship types.
Doesn't have any node or relationship properties.
@@ -128,7 +133,7 @@ def create_simple_model(
rel_types, description="The type of the relationship.", is_rel=True
)
class DynamicGraph(BaseModel):
class DynamicGraph(_Graph):
"""Represents a graph document consisting of nodes and relationships."""
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes")
@@ -194,7 +199,7 @@ class LLMGraphTransformer:
llm: BaseLanguageModel,
allowed_nodes: List[str] = [],
allowed_relationships: List[str] = [],
prompt: Optional[ChatPromptTemplate] = default_prompt,
prompt: ChatPromptTemplate = default_prompt,
strict_mode: bool = True,
) -> None:
if not hasattr(llm, "with_structured_output"):
@@ -217,7 +222,7 @@ class LLMGraphTransformer:
an LLM based on the model's schema and constraints.
"""
text = document.page_content
raw_schema = self.chain.invoke({"input": text})
raw_schema = cast(_Graph, self.chain.invoke({"input": text}))
nodes = (
[map_to_base_node(node) for node in raw_schema.nodes]
if raw_schema.nodes
@@ -268,7 +273,7 @@ class LLMGraphTransformer:
graph document.
"""
text = document.page_content
raw_schema = await self.chain.ainvoke({"input": text})
raw_schema = cast(_Graph, await self.chain.ainvoke({"input": text}))
nodes = (
[map_to_base_node(node) for node in raw_schema.nodes]