mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
Extraction Chain - Custom Prompt (#9828)
# Description This change allows you to customize the prompt used in `create_extraction_chain` as well as `create_extraction_chain_pydantic`. It also adds the `verbose` argument to `create_extraction_chain_pydantic` - because `create_extraction_chain` had it already and `create_extraction_chain_pydantic` did not. # Issue N/A # Dependencies N/A # Twitter https://twitter.com/CamAHutchison
This commit is contained in:
parent
33f43cc1b0
commit
7d8bb78e5c
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
@ -13,6 +13,7 @@ from langchain.output_parsers.openai_functions import (
|
|||||||
)
|
)
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
from langchain.pydantic_v1 import BaseModel
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
|
|
||||||
@ -43,13 +44,17 @@ Passage:
|
|||||||
|
|
||||||
|
|
||||||
def create_extraction_chain(
|
def create_extraction_chain(
|
||||||
schema: dict, llm: BaseLanguageModel, verbose: bool = False
|
schema: dict,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
|
verbose: bool = False,
|
||||||
) -> Chain:
|
) -> Chain:
|
||||||
"""Creates a chain that extracts information from a passage.
|
"""Creates a chain that extracts information from a passage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: The schema of the entities to extract.
|
schema: The schema of the entities to extract.
|
||||||
llm: The language model to use.
|
llm: The language model to use.
|
||||||
|
prompt: The prompt to use for extraction.
|
||||||
verbose: Whether to run in verbose mode. In verbose mode, some intermediate
|
verbose: Whether to run in verbose mode. In verbose mode, some intermediate
|
||||||
logs will be printed to the console. Defaults to `langchain.verbose` value.
|
logs will be printed to the console. Defaults to `langchain.verbose` value.
|
||||||
|
|
||||||
@ -57,12 +62,12 @@ def create_extraction_chain(
|
|||||||
Chain that can be used to extract information from a passage.
|
Chain that can be used to extract information from a passage.
|
||||||
"""
|
"""
|
||||||
function = _get_extraction_function(schema)
|
function = _get_extraction_function(schema)
|
||||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||||
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=extraction_prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
@ -71,13 +76,19 @@ def create_extraction_chain(
|
|||||||
|
|
||||||
|
|
||||||
def create_extraction_chain_pydantic(
|
def create_extraction_chain_pydantic(
|
||||||
pydantic_schema: Any, llm: BaseLanguageModel
|
pydantic_schema: Any,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
|
verbose: bool = False,
|
||||||
) -> Chain:
|
) -> Chain:
|
||||||
"""Creates a chain that extracts information from a passage using pydantic schema.
|
"""Creates a chain that extracts information from a passage using pydantic schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pydantic_schema: The pydantic schema of the entities to extract.
|
pydantic_schema: The pydantic schema of the entities to extract.
|
||||||
llm: The language model to use.
|
llm: The language model to use.
|
||||||
|
prompt: The prompt to use for extraction.
|
||||||
|
verbose: Whether to run in verbose mode. In verbose mode, some intermediate
|
||||||
|
logs will be printed to the console. Defaults to `langchain.verbose` value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Chain that can be used to extract information from a passage.
|
Chain that can be used to extract information from a passage.
|
||||||
@ -92,15 +103,16 @@ def create_extraction_chain_pydantic(
|
|||||||
)
|
)
|
||||||
|
|
||||||
function = _get_extraction_function(openai_schema)
|
function = _get_extraction_function(openai_schema)
|
||||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||||
output_parser = PydanticAttrOutputFunctionsParser(
|
output_parser = PydanticAttrOutputFunctionsParser(
|
||||||
pydantic_schema=PydanticSchema, attr_name="info"
|
pydantic_schema=PydanticSchema, attr_name="info"
|
||||||
)
|
)
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=extraction_prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return chain
|
return chain
|
||||||
|
Loading…
Reference in New Issue
Block a user