diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index 34ff03a0099..34c069b772e 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, List, Optional, Sequence from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship @@ -207,29 +208,20 @@ class LLMGraphTransformer: """ text = document.page_content raw_schema = self.chain.invoke({"input": text}) - if raw_schema.nodes: - nodes = [map_to_base_node(node) for node in raw_schema.nodes] - else: - nodes = [] - if raw_schema.relationships: - relationships = [ - map_to_base_relationship(rel) for rel in raw_schema.relationships - ] - else: - relationships = [] + nodes = ( + [map_to_base_node(node) for node in raw_schema.nodes] + if raw_schema.nodes + else [] + ) + relationships = ( + [map_to_base_relationship(rel) for rel in raw_schema.relationships] + if raw_schema.relationships + else [] + ) # Strict mode filtering if self.strict_mode and (self.allowed_nodes or self.allowed_relationships): - if self.allowed_relationships and self.allowed_nodes: - nodes = [node for node in nodes if node.type in self.allowed_nodes] - relationships = [ - rel - for rel in relationships - if rel.type in self.allowed_relationships - and rel.source.type in self.allowed_nodes - and rel.target.type in self.allowed_nodes - ] - elif self.allowed_nodes and not self.allowed_relationships: + if self.allowed_nodes: nodes = [node for node in nodes if node.type in self.allowed_nodes] relationships = [ rel @@ -237,17 +229,14 @@ class LLMGraphTransformer: if rel.source.type in self.allowed_nodes and rel.target.type in self.allowed_nodes ] - if self.allowed_relationships and not self.allowed_nodes: + if self.allowed_relationships: relationships = [ rel for rel in relationships if rel.type in self.allowed_relationships ] - graph_document = GraphDocument( - nodes=nodes, relationships=relationships, source=document - ) - return graph_document + return GraphDocument(nodes=nodes, relationships=relationships, source=document) def convert_to_graph_documents( self, documents: Sequence[Document] @@ -261,8 +250,54 @@ class LLMGraphTransformer: Returns: Sequence[GraphDocument]: The transformed documents as graphs. """ - results = [] - for document in documents: - graph_document = self.process_response(document) - results.append(graph_document) + return [self.process_response(document) for document in documents] + + async def aprocess_response(self, document: Document) -> GraphDocument: + """ + Asynchronously processes a single document, transforming it into a + graph document. + """ + text = document.page_content + raw_schema = await self.chain.ainvoke({"input": text}) + + nodes = ( + [map_to_base_node(node) for node in raw_schema.nodes] + if raw_schema.nodes + else [] + ) + relationships = ( + [map_to_base_relationship(rel) for rel in raw_schema.relationships] + if raw_schema.relationships + else [] + ) + + if self.strict_mode and (self.allowed_nodes or self.allowed_relationships): + if self.allowed_nodes: + nodes = [node for node in nodes if node.type in self.allowed_nodes] + relationships = [ + rel + for rel in relationships + if rel.source.type in self.allowed_nodes + and rel.target.type in self.allowed_nodes + ] + if self.allowed_relationships: + relationships = [ + rel + for rel in relationships + if rel.type in self.allowed_relationships + ] + + return GraphDocument(nodes=nodes, relationships=relationships, source=document) + + async def aconvert_to_graph_documents( + self, documents: Sequence[Document] + ) -> List[GraphDocument]: + """ + Asynchronously convert a sequence of documents into graph documents. + """ + tasks = [ + asyncio.create_task(self.aprocess_response(document)) + for document in documents + ] + results = await asyncio.gather(*tasks) return results