From cdb93ab5ca10ebfcfc0039b26cd4b2cdb668d844 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Wed, 12 Jul 2023 22:12:41 -0700 Subject: [PATCH] Adds OpenAI functions powered document metadata tagger (#7521) Adds a new document transformer that automatically extracts metadata for a document based on an input schema. I also moved `document_transformers.py` to `document_transformers/__init__.py` to group it with this new transformer - it didn't seem to cause issues in the notebook, but let me know if I've done something wrong there. Also had a linter issue I couldn't figure out: ``` MacBook-Pro:langchain jacoblee$ make lint poetry run mypy . docs/dist/conf.py: error: Duplicate module named "conf" (also at "./docs/api_reference/conf.py") docs/dist/conf.py: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules for more info docs/dist/conf.py: note: Common resolutions include: a) using `--exclude` to avoid checking one of them, b) adding `__init__.py` somewhere, c) using `--explicit-package-bases` or adjusting MYPYPATH Found 1 error in 1 file (errors prevented further checking) make: *** [lint] Error 2 ``` @rlancemartin @baskaryan --------- Co-authored-by: Bagatur --- .../modules/callbacks/get_started.mdx | 2 +- langchain/chains/openai_functions/tagging.py | 20 ++- langchain/document_transformers/__init__.py | 3 + .../document_transformers/openai_functions.py | 141 ++++++++++++++++++ 4 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 langchain/document_transformers/openai_functions.py diff --git a/docs/snippets/modules/callbacks/get_started.mdx b/docs/snippets/modules/callbacks/get_started.mdx index bbd39c6cb33..7e4974da969 100644 --- a/docs/snippets/modules/callbacks/get_started.mdx +++ b/docs/snippets/modules/callbacks/get_started.mdx @@ -131,7 +131,7 @@ chain.run(number=2, callbacks=[handler]) The `callbacks` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) in two different places: - **Constructor callbacks**: defined in the constructor, eg. `LLMChain(callbacks=[handler], tags=['a-tag'])`, which will be used for all calls made on that object, and will be scoped to that object only, eg. if you pass a handler to the `LLMChain` constructor, it will not be used by the Model attached to that chain. -- **Request callbacks**: defined in the `call()`/`run()`/`apply()` methods used for issuing a request, eg. `chain.call(inputs, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method). +- **Request callbacks**: defined in the `run()`/`apply()` methods used for issuing a request, eg. `chain.run(input, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method). The `verbose` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) as a constructor argument, eg. `LLMChain(verbose=True)`, and it is equivalent to passing a `ConsoleCallbackHandler` to the `callbacks` argument of that object and all child objects. This is useful for debugging, as it will log all events to the console. diff --git a/langchain/chains/openai_functions/tagging.py b/langchain/chains/openai_functions/tagging.py index 4bddaabba19..d39ad36ca55 100644 --- a/langchain/chains/openai_functions/tagging.py +++ b/langchain/chains/openai_functions/tagging.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -26,7 +26,12 @@ Passage: """ -def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: +def create_tagging_chain( + schema: dict, + llm: BaseLanguageModel, + prompt: Optional[ChatPromptTemplate] = None, + **kwargs: Any +) -> Chain: """Creates a chain that extracts information from a passage. Args: @@ -37,7 +42,7 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: Chain (LLMChain) that can be used to extract information from a passage. """ function = _get_tagging_function(schema) - prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) + prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) output_parser = JsonOutputFunctionsParser() llm_kwargs = get_llm_kwargs(function) chain = LLMChain( @@ -45,12 +50,16 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: prompt=prompt, llm_kwargs=llm_kwargs, output_parser=output_parser, + **kwargs, ) return chain def create_tagging_chain_pydantic( - pydantic_schema: Any, llm: BaseLanguageModel + pydantic_schema: Any, + llm: BaseLanguageModel, + prompt: Optional[ChatPromptTemplate] = None, + **kwargs: Any ) -> Chain: """Creates a chain that extracts information from a passage. @@ -63,7 +72,7 @@ def create_tagging_chain_pydantic( """ openai_schema = pydantic_schema.schema() function = _get_tagging_function(openai_schema) - prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) + prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) llm_kwargs = get_llm_kwargs(function) chain = LLMChain( @@ -71,5 +80,6 @@ def create_tagging_chain_pydantic( prompt=prompt, llm_kwargs=llm_kwargs, output_parser=output_parser, + **kwargs, ) return chain diff --git a/langchain/document_transformers/__init__.py b/langchain/document_transformers/__init__.py index 363ea3d7bea..a71d9bd0b17 100644 --- a/langchain/document_transformers/__init__.py +++ b/langchain/document_transformers/__init__.py @@ -16,4 +16,7 @@ __all__ = [ "EmbeddingsClusteringFilter", "EmbeddingsRedundantFilter", "get_stateful_documents", + "OpenAIMetadataTagger", ] + +from langchain.document_transformers.openai_functions import OpenAIMetadataTagger diff --git a/langchain/document_transformers/openai_functions.py b/langchain/document_transformers/openai_functions.py new file mode 100644 index 00000000000..96de42a2b95 --- /dev/null +++ b/langchain/document_transformers/openai_functions.py @@ -0,0 +1,141 @@ +"""Document transformers that use OpenAI Functions models""" +from typing import Any, Dict, Optional, Sequence, Type, Union + +from pydantic import BaseModel + +from langchain.chains.llm import LLMChain +from langchain.chains.openai_functions import create_tagging_chain +from langchain.prompts import ChatPromptTemplate +from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document + + +class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel): + """Extract metadata tags from document contents using OpenAI functions. + + Example: + .. code-block:: python + + from langchain.chat_models import ChatOpenAI + from langchain.document_transformers import OpenAIMetadataTagger + from langchain.schema import Document + + schema = { + "properties": { + "movie_title": { "type": "string" }, + "critic": { "type": "string" }, + "tone": { + "type": "string", + "enum": ["positive", "negative"] + }, + "rating": { + "type": "integer", + "description": "The number of stars the critic rated the movie" + } + }, + "required": ["movie_title", "critic", "tone"] + } + + # Must be an OpenAI model that supports functions + llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613") + tagging_chain = create_tagging_chain(schema, llm) + document_transformer = OpenAIMetadataTagger(tagging_chain=tagging_chain) + original_documents = [ + Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\This is the greatest movie ever made. 4 out of 5 stars."), + Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}), + ] + + enhanced_documents = document_transformer.transform_documents(original_documents) + """ # noqa: E501 + + tagging_chain: LLMChain + """The chain used to extract metadata from each document.""" + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Automatically extract and populate metadata + for each document according to the provided schema.""" + + new_documents = [] + + for document in documents: + extracted_metadata: Dict = self.tagging_chain.run(document.page_content) # type: ignore[assignment] # noqa: E501 + new_document = Document( + page_content=document.page_content, + metadata={**extracted_metadata, **document.metadata}, + ) + new_documents.append(new_document) + return new_documents + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + + +def create_metadata_tagger( + metadata_schema: Union[Dict[str, Any], Type[BaseModel]], + llm: BaseLanguageModel, + prompt: Optional[ChatPromptTemplate] = None, + *, + tagging_chain_kwargs: Optional[Dict] = None, +) -> OpenAIMetadataTagger: + """Create a DocumentTransformer that uses an OpenAI function chain to automatically + tag documents with metadata based on their content and an input schema. + + Args: + metadata_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary + is passed in, it's assumed to already be a valid JsonSchema. + For best results, pydantic.BaseModels should have docstrings describing what + the schema represents and descriptions for the parameters. + llm: Language model to use, assumed to support the OpenAI function-calling API. + Defaults to use "gpt-3.5-turbo-0613" + prompt: BasePromptTemplate to pass to the model. + + Returns: + An LLMChain that will pass the given function to the model. + + Example: + .. code-block:: python + + from langchain.chat_models import ChatOpenAI + from langchain.document_transformers import create_metadata_tagger + from langchain.schema import Document + + schema = { + "properties": { + "movie_title": { "type": "string" }, + "critic": { "type": "string" }, + "tone": { + "type": "string", + "enum": ["positive", "negative"] + }, + "rating": { + "type": "integer", + "description": "The number of stars the critic rated the movie" + } + }, + "required": ["movie_title", "critic", "tone"] + } + + # Must be an OpenAI model that supports functions + llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613") + + document_transformer = create_metadata_tagger(schema, llm) + original_documents = [ + Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\This is the greatest movie ever made. 4 out of 5 stars."), + Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}), + ] + + enhanced_documents = document_transformer.transform_documents(original_documents) + """ # noqa: E501 + metadata_schema = ( + metadata_schema + if isinstance(metadata_schema, dict) + else metadata_schema.schema() + ) + _tagging_chain_kwargs = tagging_chain_kwargs or {} + tagging_chain = create_tagging_chain( + metadata_schema, llm, prompt=prompt, **_tagging_chain_kwargs + ) + return OpenAIMetadataTagger(tagging_chain=tagging_chain)