langchain[patch]: structured output chain nits (#17291)

This commit is contained in:
Bagatur 2024-02-13 16:45:29 -08:00 committed by GitHub
parent 8a3b74fe1f
commit 50de7a31f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 36 deletions

View File

@ -82,6 +82,7 @@ from langchain.chains.router import (
) )
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.chains.sql_database.query import create_sql_query_chain from langchain.chains.sql_database.query import create_sql_query_chain
from langchain.chains.structured_output import create_structured_output_runnable
from langchain.chains.summarize import load_summarize_chain from langchain.chains.summarize import load_summarize_chain
from langchain.chains.transform import TransformChain from langchain.chains.transform import TransformChain
@ -145,5 +146,6 @@ __all__ = [
"create_sql_query_chain", "create_sql_query_chain",
"create_retrieval_chain", "create_retrieval_chain",
"create_history_aware_retriever", "create_history_aware_retriever",
"create_structured_output_runnable",
"load_summarize_chain", "load_summarize_chain",
] ]

View File

@ -22,7 +22,7 @@ from langchain.output_parsers.openai_functions import (
def create_openai_fn_runnable( def create_openai_fn_runnable(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
llm: Runnable, llm: Runnable,
prompt: BasePromptTemplate, prompt: Optional[BasePromptTemplate] = None,
*, *,
enforce_single_function_usage: bool = True, enforce_single_function_usage: bool = True,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
@ -64,7 +64,6 @@ def create_openai_fn_runnable(
from langchain.chains.structured_output import create_openai_fn_runnable from langchain.chains.structured_output import create_openai_fn_runnable
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
@ -85,15 +84,8 @@ def create_openai_fn_runnable(
llm = ChatOpenAI(model="gpt-4", temperature=0) llm = ChatOpenAI(model="gpt-4", temperature=0)
prompt = ChatPromptTemplate.from_messages( structured_llm = create_openai_fn_runnable([RecordPerson, RecordDog], llm)
[ structured_llm.invoke("Harry was a chubby brown beagle who loved chicken)
("system", "You are a world class algorithm for recording entities."),
("human", "Make calls to the relevant function to record the entities in the following input: {input}"),
("human", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_openai_fn_runnable([RecordPerson, RecordDog], llm, prompt)
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> RecordDog(name="Harry", color="brown", fav_food="chicken") # -> RecordDog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501 """ # noqa: E501
if not functions: if not functions:
@ -103,14 +95,17 @@ def create_openai_fn_runnable(
if len(openai_functions) == 1 and enforce_single_function_usage: if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]} llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
output_parser = output_parser or get_openai_output_parser(functions) output_parser = output_parser or get_openai_output_parser(functions)
return prompt | llm.bind(**llm_kwargs) | output_parser if prompt:
return prompt | llm.bind(**llm_kwargs) | output_parser
else:
return llm.bind(**llm_kwargs) | output_parser
# TODO: implement mode='openai-tools'. # TODO: implement mode='openai-tools'.
def create_structured_output_runnable( def create_structured_output_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]], output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: Runnable, llm: Runnable,
prompt: BasePromptTemplate, prompt: Optional[BasePromptTemplate] = None,
*, *,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
mode: Literal["openai-functions", "openai-json"] = "openai-functions", mode: Literal["openai-functions", "openai-json"] = "openai-functions",
@ -152,7 +147,28 @@ def create_structured_output_runnable(
from typing import Optional from typing import Optional
from langchain.chains.structured_output import create_structured_output_runnable from langchain.chains import create_structured_output_runnable
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
class Dog(BaseModel):
'''Identifying information about a dog.'''
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-functions")
structured_llm.invoke("Harry was a chubby brown beagle who loved chicken")
# -> Dog(name="Harry", color="brown", fav_food="chicken")
OpenAI functions with prompt example:
.. code-block:: python
from typing import Optional
from langchain.chains import create_structured_output_runnable
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
@ -165,23 +181,20 @@ def create_structured_output_runnable(
fav_food: Optional[str] = Field(None, description="The dog's favorite food") fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-functions")
system = '''Extract information about any dogs mentioned in the user input.'''
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
[ [("system", system), ("human", "{input}"),]
("system", "You are a world class algorithm for extracting information in structured formats."),
("human", "Use the given format to extract information from the following input: {input}"),
("human", "Tip: Make sure to answer in the correct format"),
]
) )
chain = create_structured_output_runnable(Dog, llm, prompt, mode="openai-functions") chain = prompt | structured_llm
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> Dog(name="Harry", color="brown", fav_food="chicken") # -> Dog(name="Harry", color="brown", fav_food="chicken")
OpenAI json response format example: OpenAI json response format example:
.. code-block:: python .. code-block:: python
from typing import Optional from typing import Optional
from langchain.chains.structured_output import create_structured_output_runnable from langchain.chains import create_structured_output_runnable
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
@ -194,32 +207,30 @@ def create_structured_output_runnable(
fav_food: Optional[str] = Field(None, description="The dog's favorite food") fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-json")
system = '''You are a world class assistant for extracting information in structured JSON formats. \ system = '''You are a world class assistant for extracting information in structured JSON formats. \
Extract a valid JSON blob from the user input that matches the following JSON Schema: Extract a valid JSON blob from the user input that matches the following JSON Schema:
{output_schema}''' {output_schema}'''
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
[ [("system", system), ("human", "{input}"),]
("system", system),
("human", "{input}"),
]
) )
chain = create_structured_output_runnable(Dog, llm, prompt, mode="openai-json") chain = prompt | structured_llm
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
""" # noqa: E501 """ # noqa: E501
if mode == "openai-functions": if mode == "openai-functions":
return _create_openai_functions_structured_output_runnable( return _create_openai_functions_structured_output_runnable(
output_schema, output_schema,
llm, llm,
prompt, prompt=prompt,
output_parser=output_parser, output_parser=output_parser,
enforce_single_function_usage=enforce_single_function_usage, enforce_single_function_usage=enforce_single_function_usage,
**kwargs, **kwargs,
) )
elif mode == "openai-json": elif mode == "openai-json":
return _create_openai_json_runnable( return _create_openai_json_runnable(
output_schema, llm, prompt, output_parser=output_parser, **kwargs output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs
) )
else: else:
raise ValueError( raise ValueError(
@ -263,7 +274,7 @@ def get_openai_output_parser(
def _create_openai_json_runnable( def _create_openai_json_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]], output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: Runnable, llm: Runnable,
prompt: BasePromptTemplate, prompt: Optional[BasePromptTemplate] = None,
*, *,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
) -> Runnable: ) -> Runnable:
@ -277,17 +288,20 @@ def _create_openai_json_runnable(
output_parser = output_parser or JsonOutputParser() output_parser = output_parser or JsonOutputParser()
schema_as_dict = output_schema schema_as_dict = output_schema
if "output_schema" in prompt.input_variables:
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
llm = llm.bind(response_format={"type": "json_object"}) llm = llm.bind(response_format={"type": "json_object"})
return prompt | llm | output_parser if prompt:
if "output_schema" in prompt.input_variables:
prompt = prompt.partial(output_schema=json.dumps(schema_as_dict, indent=2))
return prompt | llm | output_parser
else:
return llm | output_parser
def _create_openai_functions_structured_output_runnable( def _create_openai_functions_structured_output_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]], output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: Runnable, llm: Runnable,
prompt: BasePromptTemplate, prompt: Optional[BasePromptTemplate] = None,
*, *,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
**kwargs: Any, **kwargs: Any,
@ -315,7 +329,7 @@ def _create_openai_functions_structured_output_runnable(
return create_openai_fn_runnable( return create_openai_fn_runnable(
[function], [function],
llm, llm,
prompt, prompt=prompt,
output_parser=output_parser, output_parser=output_parser,
**kwargs, **kwargs,
) )

View File

@ -62,6 +62,7 @@ EXPECTED_ALL = [
"create_history_aware_retriever", "create_history_aware_retriever",
"create_retrieval_chain", "create_retrieval_chain",
"load_summarize_chain", "load_summarize_chain",
"create_structured_output_runnable",
] ]