From 1ce1a10f2bc7e5a26499c80c9ba14f4d520ba8fe Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 1 May 2024 10:04:30 -0400 Subject: [PATCH] langchain[patch],community[minor]: Move graph index creator (#20795) Move graph index creator to community --- .../graphs/index_creator.py | 99 +++++++++++++++++++ libs/langchain/langchain/indexes/__init__.py | 2 +- libs/langchain/langchain/indexes/graph.py | 48 +-------- .../langchain/indexes/prompts/__init__.py | 11 +++ .../tests/unit_tests/indexes/test_api.py | 4 +- 5 files changed, 116 insertions(+), 48 deletions(-) create mode 100644 libs/community/langchain_community/graphs/index_creator.py diff --git a/libs/community/langchain_community/graphs/index_creator.py b/libs/community/langchain_community/graphs/index_creator.py new file mode 100644 index 00000000000..9394da264e2 --- /dev/null +++ b/libs/community/langchain_community/graphs/index_creator.py @@ -0,0 +1,99 @@ +from typing import Optional, Type + + +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate + +from langchain_community.graphs import NetworkxEntityGraph +from langchain_community.graphs.networkx_graph import KG_TRIPLE_DELIMITER +from langchain_community.graphs.networkx_graph import parse_triples + +# flake8: noqa + +_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = ( + "You are a networked intelligence helping a human track knowledge triples" + " about all relevant people, things, concepts, etc. and integrating" + " them with your knowledge stored within your weights" + " as well as that stored in a knowledge graph." + " Extract all of the knowledge triples from the text." + " A knowledge triple is a clause that contains a subject, a predicate," + " and an object. The subject is the entity being described," + " the predicate is the property of the subject that is being" + " described, and the object is the value of the property.\n\n" + "EXAMPLE\n" + "It's a state in the US. It's also the number 1 producer of gold in the US.\n\n" + f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)" + f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n" + "END OF EXAMPLE\n\n" + "EXAMPLE\n" + "I'm going to the store.\n\n" + "Output: NONE\n" + "END OF EXAMPLE\n\n" + "EXAMPLE\n" + "Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n" + f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n" + "END OF EXAMPLE\n\n" + "EXAMPLE\n" + "{text}" + "Output:" +) + +KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate( + input_variables=["text"], + template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE, +) + + +class GraphIndexCreator(BaseModel): + """Functionality to create graph index.""" + + llm: Optional[BaseLanguageModel] = None + graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph + + def from_text( + self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT + ) -> NetworkxEntityGraph: + """Create graph index from text.""" + if self.llm is None: + raise ValueError("llm should not be None") + graph = self.graph_type() + # Temporary local scoped import while community does not depend on + # langchain explicitly + try: + from langchain.chains import LLMChain + except ImportError: + raise ImportError( + "Please install langchain to use this functionality. " + "You can install it with `pip install langchain`." + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + output = chain.predict(text=text) + knowledge = parse_triples(output) + for triple in knowledge: + graph.add_triple(triple) + return graph + + async def afrom_text( + self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT + ) -> NetworkxEntityGraph: + """Create graph index from text asynchronously.""" + if self.llm is None: + raise ValueError("llm should not be None") + graph = self.graph_type() + # Temporary local scoped import while community does not depend on + # langchain explicitly + try: + from langchain.chains import LLMChain + except ImportError: + raise ImportError( + "Please install langchain to use this functionality. " + "You can install it with `pip install langchain`." + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + output = await chain.apredict(text=text) + knowledge = parse_triples(output) + for triple in knowledge: + graph.add_triple(triple) + return graph diff --git a/libs/langchain/langchain/indexes/__init__.py b/libs/langchain/langchain/indexes/__init__.py index f6c9c5d535a..7e928112d01 100644 --- a/libs/langchain/langchain/indexes/__init__.py +++ b/libs/langchain/langchain/indexes/__init__.py @@ -11,10 +11,10 @@ Importantly, Index keeps on working even if the content being written is derived via a set of transformations from some source content (e.g., indexing children documents that were derived from parent documents by chunking.) """ +from langchain_community.graphs.index_creator import GraphIndexCreator from langchain_core.indexing.api import IndexingResult, aindex, index from langchain.indexes._sql_record_manager import SQLRecordManager -from langchain.indexes.graph import GraphIndexCreator from langchain.indexes.vectorstore import VectorstoreIndexCreator __all__ = [ diff --git a/libs/langchain/langchain/indexes/graph.py b/libs/langchain/langchain/indexes/graph.py index dc8e2ab38ae..aeaa1c21e27 100644 --- a/libs/langchain/langchain/indexes/graph.py +++ b/libs/langchain/langchain/indexes/graph.py @@ -1,47 +1,5 @@ """Graph Index Creator.""" -from typing import Optional, Type +from langchain_community.graphs.index_creator import GraphIndexCreator +from langchain_community.graphs.networkx_graph import NetworkxEntityGraph -from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, parse_triples -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel - -from langchain.chains.llm import LLMChain -from langchain.indexes.prompts.knowledge_triplet_extraction import ( - KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, -) - - -class GraphIndexCreator(BaseModel): - """Functionality to create graph index.""" - - llm: Optional[BaseLanguageModel] = None - graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph - - def from_text( - self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT - ) -> NetworkxEntityGraph: - """Create graph index from text.""" - if self.llm is None: - raise ValueError("llm should not be None") - graph = self.graph_type() - chain = LLMChain(llm=self.llm, prompt=prompt) - output = chain.predict(text=text) - knowledge = parse_triples(output) - for triple in knowledge: - graph.add_triple(triple) - return graph - - async def afrom_text( - self, text: str, prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT - ) -> NetworkxEntityGraph: - """Create graph index from text asynchronously.""" - if self.llm is None: - raise ValueError("llm should not be None") - graph = self.graph_type() - chain = LLMChain(llm=self.llm, prompt=prompt) - output = await chain.apredict(text=text) - knowledge = parse_triples(output) - for triple in knowledge: - graph.add_triple(triple) - return graph +__all__ = ["GraphIndexCreator", "NetworkxEntityGraph"] diff --git a/libs/langchain/langchain/indexes/prompts/__init__.py b/libs/langchain/langchain/indexes/prompts/__init__.py index 1a5833cd2a5..55f9b194788 100644 --- a/libs/langchain/langchain/indexes/prompts/__init__.py +++ b/libs/langchain/langchain/indexes/prompts/__init__.py @@ -1 +1,12 @@ """Relevant prompts for constructing indexes.""" +from langchain_core._api import warn_deprecated + +warn_deprecated( + since="0.1.47", + message=( + "langchain.indexes.prompts will be removed in the future." + "If you're relying on these prompts, please open an issue on " + "GitHub to explain your use case." + ), + pending=True, +) diff --git a/libs/langchain/tests/unit_tests/indexes/test_api.py b/libs/langchain/tests/unit_tests/indexes/test_api.py index fa59c71b5bc..0a60c2b5369 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_api.py +++ b/libs/langchain/tests/unit_tests/indexes/test_api.py @@ -3,8 +3,7 @@ from langchain.indexes import __all__ def test_all() -> None: """Use to catch obvious breaking changes.""" - assert __all__ == sorted(__all__, key=str.lower) - assert __all__ == [ + expected = [ "aindex", "GraphIndexCreator", "index", @@ -12,3 +11,4 @@ def test_all() -> None: "SQLRecordManager", "VectorstoreIndexCreator", ] + assert __all__ == sorted(expected, key=lambda x: x.lower())