Add simple node properties to llm graph transformer (#21369)

Add support for simple node properties in llm graph transformer.

Linter and dynamic pydantic classes aren't friends, hence I added two
ignores
This commit is contained in:
Tomaz Bratanic 2024-05-07 17:41:09 +02:00 committed by GitHub
parent 080af0ec53
commit 0bf7596839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import json import json
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document from langchain_core.documents import Document
@ -12,7 +12,7 @@ from langchain_core.prompts import (
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
PromptTemplate, PromptTemplate,
) )
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field, create_model
examples = [ examples = [
{ {
@ -122,10 +122,34 @@ default_prompt = ChatPromptTemplate.from_messages(
) )
def _get_additional_info(input_type: str) -> str:
# Check if the input_type is one of the allowed values
if input_type not in ["node", "relationship", "property"]:
raise ValueError("input_type must be 'node', 'relationship', or 'property'")
# Perform actions based on the input_type
if input_type == "node":
return (
"Ensure you use basic or elementary types for node labels.\n"
"For example, when you identify an entity representing a person, "
"always label it as **'Person'**. Avoid using more specific terms "
"like 'Mathematician' or 'Scientist'"
)
elif input_type == "relationship":
return (
"Instead of using specific and momentary types such as "
"'BECAME_PROFESSOR', use more general and timeless relationship types like "
"'PROFESSOR'. However, do not sacrifice any accuracy for generality"
)
elif input_type == "property":
return ""
return ""
def optional_enum_field( def optional_enum_field(
enum_values: Optional[List[str]] = None, enum_values: Optional[List[str]] = None,
description: str = "", description: str = "",
is_rel: bool = False, input_type: str = "node",
**field_kwargs: Any, **field_kwargs: Any,
) -> Any: ) -> Any:
"""Utility function to conditionally create a field with an enum constraint.""" """Utility function to conditionally create a field with an enum constraint."""
@ -137,18 +161,7 @@ def optional_enum_field(
**field_kwargs, **field_kwargs,
) )
else: else:
node_info = ( additional_info = _get_additional_info(input_type)
"Ensure you use basic or elementary types for node labels.\n"
"For example, when you identify an entity representing a person, "
"always label it as **'Person'**. Avoid using more specific terms "
"like 'Mathematician' or 'Scientist'"
)
rel_info = (
"Instead of using specific and momentary types such as "
"'BECAME_PROFESSOR', use more general and timeless relationship types like "
"'PROFESSOR'. However, do not sacrifice any accuracy for generality"
)
additional_info = rel_info if is_rel else node_info
return Field(..., description=description + additional_info, **field_kwargs) return Field(..., description=description + additional_info, **field_kwargs)
@ -255,21 +268,53 @@ For the following text, extract entities and relations as in the provided exampl
def create_simple_model( def create_simple_model(
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None node_labels: Optional[List[str]] = None,
rel_types: Optional[List[str]] = None,
node_properties: Union[bool, List[str]] = False,
) -> Type[_Graph]: ) -> Type[_Graph]:
""" """
Simple model allows to limit node and/or relationship types. Simple model allows to limit node and/or relationship types.
Doesn't have any node or relationship properties. Doesn't have any node or relationship properties.
""" """
class SimpleNode(BaseModel): node_fields: Dict[str, Tuple[Any, Any]] = {
"""Represents a node in a graph with associated properties.""" "id": (
str,
id: str = Field(description="Name or human-readable unique identifier.") Field(..., description="Name or human-readable unique identifier."),
type: str = optional_enum_field( ),
node_labels, description="The type or label of the node." "type": (
str,
optional_enum_field(
node_labels,
description="The type or label of the node.",
input_type="node",
),
),
}
if node_properties:
if isinstance(node_properties, list) and "id" in node_properties:
raise ValueError("The node property 'id' is reserved and cannot be used.")
# Map True to empty array
node_properties_mapped: List[str] = (
[] if node_properties is True else node_properties
) )
class Property(BaseModel):
"""A single property consisting of key and value"""
key: str = optional_enum_field(
node_properties_mapped,
description="Property key.",
input_type="property",
)
value: str = Field(..., description="value")
node_fields["properties"] = (
Optional[List[Property]],
Field(None, description="List of node properties"),
)
SimpleNode = create_model("SimpleNode", **node_fields) # type: ignore
class SimpleRelationship(BaseModel): class SimpleRelationship(BaseModel):
"""Represents a directed relationship between two nodes in a graph.""" """Represents a directed relationship between two nodes in a graph."""
@ -277,22 +322,28 @@ def create_simple_model(
description="Name or human-readable unique identifier of source node" description="Name or human-readable unique identifier of source node"
) )
source_node_type: str = optional_enum_field( source_node_type: str = optional_enum_field(
node_labels, description="The type or label of the source node." node_labels,
description="The type or label of the source node.",
input_type="node",
) )
target_node_id: str = Field( target_node_id: str = Field(
description="Name or human-readable unique identifier of target node" description="Name or human-readable unique identifier of target node"
) )
target_node_type: str = optional_enum_field( target_node_type: str = optional_enum_field(
node_labels, description="The type or label of the target node." node_labels,
description="The type or label of the target node.",
input_type="node",
) )
type: str = optional_enum_field( type: str = optional_enum_field(
rel_types, description="The type of the relationship.", is_rel=True rel_types,
description="The type of the relationship.",
input_type="relationship",
) )
class DynamicGraph(_Graph): class DynamicGraph(_Graph):
"""Represents a graph document consisting of nodes and relationships.""" """Represents a graph document consisting of nodes and relationships."""
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") # type: ignore
relationships: Optional[List[SimpleRelationship]] = Field( relationships: Optional[List[SimpleRelationship]] = Field(
description="List of relationships" description="List of relationships"
) )
@ -302,7 +353,11 @@ def create_simple_model(
def map_to_base_node(node: Any) -> Node: def map_to_base_node(node: Any) -> Node:
"""Map the SimpleNode to the base Node.""" """Map the SimpleNode to the base Node."""
return Node(id=node.id, type=node.type) properties = {}
if hasattr(node, "properties") and node.properties:
for p in node.properties:
properties[format_property_key(p.key)] = p.value
return Node(id=node.id, type=node.type, properties=properties)
def map_to_base_relationship(rel: Any) -> Relationship: def map_to_base_relationship(rel: Any) -> Relationship:
@ -378,6 +433,7 @@ def _format_nodes(nodes: List[Node]) -> List[Node]:
Node( Node(
id=el.id.title() if isinstance(el.id, str) else el.id, id=el.id.title() if isinstance(el.id, str) else el.id,
type=el.type.capitalize(), type=el.type.capitalize(),
properties=el.properties,
) )
for el in nodes for el in nodes
] ]
@ -394,6 +450,15 @@ def _format_relationships(rels: List[Relationship]) -> List[Relationship]:
] ]
def format_property_key(s: str) -> str:
words = s.split()
if not words:
return s
first_word = words[0].lower()
capitalized_words = [word.capitalize() for word in words[1:]]
return "".join([first_word] + capitalized_words)
def _convert_to_graph_document( def _convert_to_graph_document(
raw_schema: Dict[Any, Any], raw_schema: Dict[Any, Any],
) -> Tuple[List[Node], List[Relationship]]: ) -> Tuple[List[Node], List[Relationship]]:
@ -474,6 +539,7 @@ class LLMGraphTransformer:
allowed_relationships: List[str] = [], allowed_relationships: List[str] = [],
prompt: Optional[ChatPromptTemplate] = None, prompt: Optional[ChatPromptTemplate] = None,
strict_mode: bool = True, strict_mode: bool = True,
node_properties: Union[bool, List[str]] = False,
) -> None: ) -> None:
self.allowed_nodes = allowed_nodes self.allowed_nodes = allowed_nodes
self.allowed_relationships = allowed_relationships self.allowed_relationships = allowed_relationships
@ -485,6 +551,12 @@ class LLMGraphTransformer:
except NotImplementedError: except NotImplementedError:
self._function_call = False self._function_call = False
if not self._function_call: if not self._function_call:
if node_properties:
raise ValueError(
"The 'node_properties' parameter cannot be used "
"in combination with a LLM that doesn't support "
"native function calling."
)
try: try:
import json_repair import json_repair
@ -500,7 +572,9 @@ class LLMGraphTransformer:
self.chain = prompt | llm self.chain = prompt | llm
else: else:
# Define chain # Define chain
schema = create_simple_model(allowed_nodes, allowed_relationships) schema = create_simple_model(
allowed_nodes, allowed_relationships, node_properties
)
structured_llm = llm.with_structured_output(schema, include_raw=True) structured_llm = llm.with_structured_output(schema, include_raw=True)
prompt = prompt or default_prompt prompt = prompt or default_prompt
self.chain = prompt | structured_llm self.chain = prompt | structured_llm