diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index 5b432a94bfc..f9f6fc87926 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -13,6 +13,7 @@ from langchain_core.prompts import ( PromptTemplate, ) from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.runnables import RunnableConfig examples = [ { @@ -710,13 +711,15 @@ class LLMGraphTransformer: prompt = prompt or default_prompt self.chain = prompt | structured_llm - def process_response(self, document: Document) -> GraphDocument: + def process_response( + self, document: Document, config: Optional[RunnableConfig] = None + ) -> GraphDocument: """ Processes a single document, transforming it into a graph document using an LLM based on the model's schema and constraints. """ text = document.page_content - raw_schema = self.chain.invoke({"input": text}) + raw_schema = self.chain.invoke({"input": text}, config=config) if self._function_call: raw_schema = cast(Dict[Any, Any], raw_schema) nodes, relationships = _convert_to_graph_document(raw_schema) @@ -765,7 +768,7 @@ class LLMGraphTransformer: return GraphDocument(nodes=nodes, relationships=relationships, source=document) def convert_to_graph_documents( - self, documents: Sequence[Document] + self, documents: Sequence[Document], config: Optional[RunnableConfig] = None ) -> List[GraphDocument]: """Convert a sequence of documents into graph documents. @@ -776,15 +779,17 @@ class LLMGraphTransformer: Returns: Sequence[GraphDocument]: The transformed documents as graphs. """ - return [self.process_response(document) for document in documents] + return [self.process_response(document, config) for document in documents] - async def aprocess_response(self, document: Document) -> GraphDocument: + async def aprocess_response( + self, document: Document, config: Optional[RunnableConfig] = None + ) -> GraphDocument: """ Asynchronously processes a single document, transforming it into a graph document. """ text = document.page_content - raw_schema = await self.chain.ainvoke({"input": text}) + raw_schema = await self.chain.ainvoke({"input": text}, config=config) raw_schema = cast(Dict[Any, Any], raw_schema) nodes, relationships = _convert_to_graph_document(raw_schema) @@ -811,13 +816,13 @@ class LLMGraphTransformer: return GraphDocument(nodes=nodes, relationships=relationships, source=document) async def aconvert_to_graph_documents( - self, documents: Sequence[Document] + self, documents: Sequence[Document], config: Optional[RunnableConfig] = None ) -> List[GraphDocument]: """ Asynchronously convert a sequence of documents into graph documents. """ tasks = [ - asyncio.create_task(self.aprocess_response(document)) + asyncio.create_task(self.aprocess_response(document, config)) for document in documents ] results = await asyncio.gather(*tasks)