diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index 12116315f2e..5a752060b02 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -5,9 +5,67 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import ChatPromptTemplate +from langchain_core.messages import SystemMessage +from langchain_core.output_parsers import JsonOutputParser +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, +) from langchain_core.pydantic_v1 import BaseModel, Field +examples = [ + { + "text": ( + "Adam is a software engineer in Microsoft since 2009, " + "and last year he got an award as the Best Talent" + ), + "head": "Adam", + "head_type": "Person", + "relation": "WORKS_FOR", + "tail": "Microsoft", + "tail_type": "Company", + }, + { + "text": ( + "Adam is a software engineer in Microsoft since 2009, " + "and last year he got an award as the Best Talent" + ), + "head": "Adam", + "head_type": "Person", + "relation": "HAS_AWARD", + "tail": "Best Talent", + "tail_type": "Award", + }, + { + "text": ( + "Microsoft is a tech company that provide " + "several products such as Microsoft Word" + ), + "head": "Microsoft Word", + "head_type": "Product", + "relation": "PRODUCED_BY", + "tail": "Microsoft", + "tail_type": "Company", + }, + { + "text": "Microsoft Word is a lightweight app that accessible offline", + "head": "Microsoft Word", + "head_type": "Product", + "relation": "HAS_CHARACTERISTIC", + "tail": "lightweight app", + "tail_type": "Characteristic", + }, + { + "text": "Microsoft Word is a lightweight app that accessible offline", + "head": "Microsoft Word", + "head_type": "Product", + "relation": "HAS_CHARACTERISTIC", + "tail": "accessible offline", + "tail_type": "Characteristic", + }, +] + system_prompt = ( "# Knowledge Graph Instructions for GPT-4\n" "## 1. Overview\n" @@ -99,6 +157,103 @@ class _Graph(BaseModel): relationships: Optional[List] +class UnstructuredRelation(BaseModel): + head: str = Field( + description=( + "extracted head entity like Microsoft, Apple, John. " + "Must use human-readable unique identifier." + ) + ) + head_type: str = Field( + description="type of the extracted head entity like Person, Company, etc" + ) + relation: str = Field(description="relation between the head and the tail entities") + tail: str = Field( + description=( + "extracted tail entity like Microsoft, Apple, John. " + "Must use human-readable unique identifier." + ) + ) + tail_type: str = Field( + description="type of the extracted tail entity like Person, Company, etc" + ) + + +def create_unstructured_prompt( + node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None +) -> ChatPromptTemplate: + node_labels_str = str(node_labels) if node_labels else "" + rel_types_str = str(rel_types) if rel_types else "" + base_string_parts = [ + "You are a top-tier algorithm designed for extracting information in " + "structured formats to build a knowledge graph. Your task is to identify " + "the entities and relations requested with the user prompt from a given " + "text. You must generate the output in a JSON format containing a list " + 'with JSON objects. Each object should have the keys: "head", ' + '"head_type", "relation", "tail", and "tail_type". The "head" ' + "key must contain the text of the extracted entity with one of the types " + "from the provided list in the user prompt.", + f'The "head_type" key must contain the type of the extracted head entity, ' + f"which must be one of the types from {node_labels_str}." + if node_labels + else "", + f'The "relation" key must contain the type of relation between the "head" ' + f'and the "tail", which must be one of the relations from {rel_types_str}.' + if rel_types + else "", + f'The "tail" key must represent the text of an extracted entity which is ' + f'the tail of the relation, and the "tail_type" key must contain the type ' + f"of the tail entity from {node_labels_str}." + if node_labels + else "", + "Attempt to extract as many entities and relations as you can. Maintain " + "Entity Consistency: When extracting entities, it's vital to ensure " + 'consistency. If an entity, such as "John Doe", is mentioned multiple ' + "times in the text but is referred to by different names or pronouns " + '(e.g., "Joe", "he"), always use the most complete identifier for ' + "that entity. The knowledge graph should be coherent and easily " + "understandable, so maintaining consistency in entity references is " + "crucial.", + "IMPORTANT NOTES:\n- Don't add any explanation and text.", + ] + system_prompt = "\n".join(filter(None, base_string_parts)) + + system_message = SystemMessage(content=system_prompt) + parser = JsonOutputParser(pydantic_object=UnstructuredRelation) + + human_prompt = PromptTemplate( + template="""Based on the following example, extract entities and +relations from the provided text.\n\n +Use the following entity types, don't use other entity that is not defined below: +# ENTITY TYPES: +{node_labels} + +Use the following relation types, don't use other relation that is not defined below: +# RELATION TYPES: +{rel_types} + +Below are a number of examples of text and their extracted entities and relationships. +{examples} + +For the following text, extract entities and relations as in the provided example. +{format_instructions}\nText: {input}""", + input_variables=["input"], + partial_variables={ + "format_instructions": parser.get_format_instructions(), + "node_labels": node_labels, + "rel_types": rel_types, + "examples": examples, + }, + ) + + human_message_prompt = HumanMessagePromptTemplate(prompt=human_prompt) + + chat_prompt = ChatPromptTemplate.from_messages( + [system_message, human_message_prompt] + ) + return chat_prompt + + def create_simple_model( node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None ) -> Type[_Graph]: @@ -317,22 +472,38 @@ class LLMGraphTransformer: llm: BaseLanguageModel, allowed_nodes: List[str] = [], allowed_relationships: List[str] = [], - prompt: ChatPromptTemplate = default_prompt, + prompt: Optional[ChatPromptTemplate] = None, strict_mode: bool = True, ) -> None: - if not hasattr(llm, "with_structured_output"): - raise ValueError( - "The specified LLM does not support the 'with_structured_output'. " - "Please ensure you are using an LLM that supports this feature." - ) self.allowed_nodes = allowed_nodes self.allowed_relationships = allowed_relationships self.strict_mode = strict_mode + self._function_call = True + # Check if the LLM really supports structured output + try: + llm.with_structured_output(_Graph) + except NotImplementedError: + self._function_call = False + if not self._function_call: + try: + import json_repair - # Define chain - schema = create_simple_model(allowed_nodes, allowed_relationships) - structured_llm = llm.with_structured_output(schema, include_raw=True) - self.chain = prompt | structured_llm + self.json_repair = json_repair + except ImportError: + raise ImportError( + "Could not import json_repair python package. " + "Please install it with `pip install json-repair`." + ) + prompt = prompt or create_unstructured_prompt( + allowed_nodes, allowed_relationships + ) + self.chain = prompt | llm + else: + # Define chain + schema = create_simple_model(allowed_nodes, allowed_relationships) + structured_llm = llm.with_structured_output(schema, include_raw=True) + prompt = prompt or default_prompt + self.chain = prompt | structured_llm def process_response(self, document: Document) -> GraphDocument: """ @@ -341,8 +512,27 @@ class LLMGraphTransformer: """ text = document.page_content raw_schema = self.chain.invoke({"input": text}) - raw_schema = cast(Dict[Any, Any], raw_schema) - nodes, relationships = _convert_to_graph_document(raw_schema) + if self._function_call: + raw_schema = cast(Dict[Any, Any], raw_schema) + nodes, relationships = _convert_to_graph_document(raw_schema) + else: + nodes_set = set() + relationships = [] + parsed_json = self.json_repair.loads(raw_schema.content) + for rel in parsed_json: + # Nodes need to be deduplicated using a set + nodes_set.add((rel["head"], rel["head_type"])) + nodes_set.add((rel["tail"], rel["tail_type"])) + + source_node = Node(id=rel["head"], type=rel["head_type"]) + target_node = Node(id=rel["tail"], type=rel["tail_type"]) + relationships.append( + Relationship( + source=source_node, target=target_node, type=rel["relation"] + ) + ) + # Create nodes list + nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)] # Strict mode filtering if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):