diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index 6f281d29083..12116315f2e 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -1,5 +1,6 @@ import asyncio -from typing import Any, List, Optional, Sequence, Type, cast +import json +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document @@ -146,16 +147,133 @@ def create_simple_model( def map_to_base_node(node: Any) -> Node: """Map the SimpleNode to the base Node.""" - return Node(id=node.id.title(), type=node.type.capitalize()) + return Node(id=node.id, type=node.type) def map_to_base_relationship(rel: Any) -> Relationship: """Map the SimpleRelationship to the base Relationship.""" - source = Node(id=rel.source_node_id.title(), type=rel.source_node_type.capitalize()) - target = Node(id=rel.target_node_id.title(), type=rel.target_node_type.capitalize()) - return Relationship( - source=source, target=target, type=rel.type.replace(" ", "_").upper() - ) + source = Node(id=rel.source_node_id, type=rel.source_node_type) + target = Node(id=rel.target_node_id, type=rel.target_node_type) + return Relationship(source=source, target=target, type=rel.type) + + +def _parse_and_clean_json( + argument_json: Dict[str, Any], +) -> Tuple[List[Node], List[Relationship]]: + nodes = [] + for node in argument_json["nodes"]: + if not node.get("id"): # Id is mandatory, skip this node + continue + nodes.append( + Node( + id=node["id"], + type=node.get("type"), + ) + ) + relationships = [] + for rel in argument_json["relationships"]: + # Mandatory props + if ( + not rel.get("source_node_id") + or not rel.get("target_node_id") + or not rel.get("type") + ): + continue + + # Node type copying if needed from node list + if not rel.get("source_node_type"): + try: + rel["source_node_type"] = [ + el.get("type") + for el in argument_json["nodes"] + if el["id"] == rel["source_node_id"] + ][0] + except IndexError: + rel["source_node_type"] = None + if not rel.get("target_node_type"): + try: + rel["target_node_type"] = [ + el.get("type") + for el in argument_json["nodes"] + if el["id"] == rel["target_node_id"] + ][0] + except IndexError: + rel["target_node_type"] = None + + source_node = Node( + id=rel["source_node_id"], + type=rel["source_node_type"], + ) + target_node = Node( + id=rel["target_node_id"], + type=rel["target_node_type"], + ) + relationships.append( + Relationship( + source=source_node, + target=target_node, + type=rel["type"], + ) + ) + return nodes, relationships + + +def _format_nodes(nodes: List[Node]) -> List[Node]: + return [ + Node( + id=el.id.title() if isinstance(el.id, str) else el.id, + type=el.type.capitalize(), + ) + for el in nodes + ] + + +def _format_relationships(rels: List[Relationship]) -> List[Relationship]: + return [ + Relationship( + source=_format_nodes([el.source])[0], + target=_format_nodes([el.target])[0], + type=el.type.replace(" ", "_").upper(), + ) + for el in rels + ] + + +def _convert_to_graph_document( + raw_schema: Dict[Any, Any], +) -> Tuple[List[Node], List[Relationship]]: + # If there are validation errors + if not raw_schema["parsed"]: + try: + try: # OpenAI type response + argument_json = json.loads( + raw_schema["raw"].additional_kwargs["tool_calls"][0]["function"][ + "arguments" + ] + ) + except Exception: # Google type response + argument_json = json.loads( + raw_schema["raw"].additional_kwargs["function_call"]["arguments"] + ) + + nodes, relationships = _parse_and_clean_json(argument_json) + except Exception: # If we can't parse JSON + return ([], []) + else: # If there are no validation errors use parsed pydantic object + parsed_schema: _Graph = raw_schema["parsed"] + nodes = ( + [map_to_base_node(node) for node in parsed_schema.nodes] + if parsed_schema.nodes + else [] + ) + + relationships = ( + [map_to_base_relationship(rel) for rel in parsed_schema.relationships] + if parsed_schema.relationships + else [] + ) + # Title / Capitalize + return _format_nodes(nodes), _format_relationships(relationships) class LLMGraphTransformer: @@ -213,7 +331,7 @@ class LLMGraphTransformer: # Define chain schema = create_simple_model(allowed_nodes, allowed_relationships) - structured_llm = llm.with_structured_output(schema) + structured_llm = llm.with_structured_output(schema, include_raw=True) self.chain = prompt | structured_llm def process_response(self, document: Document) -> GraphDocument: @@ -222,33 +340,29 @@ class LLMGraphTransformer: an LLM based on the model's schema and constraints. """ text = document.page_content - raw_schema = cast(_Graph, self.chain.invoke({"input": text})) - nodes = ( - [map_to_base_node(node) for node in raw_schema.nodes] - if raw_schema.nodes - else [] - ) - relationships = ( - [map_to_base_relationship(rel) for rel in raw_schema.relationships] - if raw_schema.relationships - else [] - ) + raw_schema = self.chain.invoke({"input": text}) + raw_schema = cast(Dict[Any, Any], raw_schema) + nodes, relationships = _convert_to_graph_document(raw_schema) # Strict mode filtering if self.strict_mode and (self.allowed_nodes or self.allowed_relationships): if self.allowed_nodes: - nodes = [node for node in nodes if node.type in self.allowed_nodes] + lower_allowed_nodes = [el.lower() for el in self.allowed_nodes] + nodes = [ + node for node in nodes if node.type.lower() in lower_allowed_nodes + ] relationships = [ rel for rel in relationships - if rel.source.type in self.allowed_nodes - and rel.target.type in self.allowed_nodes + if rel.source.type.lower() in lower_allowed_nodes + and rel.target.type.lower() in lower_allowed_nodes ] if self.allowed_relationships: relationships = [ rel for rel in relationships - if rel.type in self.allowed_relationships + if rel.type.lower() + in [el.lower() for el in self.allowed_relationships] ] return GraphDocument(nodes=nodes, relationships=relationships, source=document) @@ -273,33 +387,28 @@ class LLMGraphTransformer: graph document. """ text = document.page_content - raw_schema = cast(_Graph, await self.chain.ainvoke({"input": text})) - - nodes = ( - [map_to_base_node(node) for node in raw_schema.nodes] - if raw_schema.nodes - else [] - ) - relationships = ( - [map_to_base_relationship(rel) for rel in raw_schema.relationships] - if raw_schema.relationships - else [] - ) + raw_schema = await self.chain.ainvoke({"input": text}) + raw_schema = cast(Dict[Any, Any], raw_schema) + nodes, relationships = _convert_to_graph_document(raw_schema) if self.strict_mode and (self.allowed_nodes or self.allowed_relationships): if self.allowed_nodes: - nodes = [node for node in nodes if node.type in self.allowed_nodes] + lower_allowed_nodes = [el.lower() for el in self.allowed_nodes] + nodes = [ + node for node in nodes if node.type.lower() in lower_allowed_nodes + ] relationships = [ rel for rel in relationships - if rel.source.type in self.allowed_nodes - and rel.target.type in self.allowed_nodes + if rel.source.type.lower() in lower_allowed_nodes + and rel.target.type.lower() in lower_allowed_nodes ] if self.allowed_relationships: relationships = [ rel for rel in relationships - if rel.type in self.allowed_relationships + if rel.type.lower() + in [el.lower() for el in self.allowed_relationships] ] return GraphDocument(nodes=nodes, relationships=relationships, source=document)