Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
a020da7390 fmt 2024-02-23 15:39:56 -08:00
Bagatur
088825634a fmt 2024-02-23 15:31:06 -08:00
Bagatur
6bd70375c1 rfc: with_structured_output List 2024-02-23 15:22:49 -08:00

View File

@@ -23,6 +23,8 @@ from typing import (
TypeVar,
Union,
cast,
get_args,
get_origin,
overload,
)
@@ -61,7 +63,13 @@ from langchain_core.output_parsers import (
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
create_model,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
@@ -887,12 +895,18 @@ class ChatOpenAI(BaseChatModel):
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
is_list_type = get_origin(schema) is list
tool_schema = schema if not is_list_type else get_args(schema)[0]
if method == "function_calling":
llm = self.bind_tools([schema], tool_choice=True)
if is_pydantic_schema:
llm = self.bind_tools([tool_schema], tool_choice=not is_list_type)
if is_list_type:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
tools=[tool_schema],
)
elif is_pydantic_schema:
output_parser = PydanticToolsParser(
tools=[tool_schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
@@ -901,11 +915,16 @@ class ChatOpenAI(BaseChatModel):
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema)
if is_pydantic_schema
else JsonOutputParser()
)
if is_list_type:
ListSchema = create_model("ListSchema", __root__=(schema, []))
# Handle case where array is nested under a key.
output_parser = PydanticOutputParser(
pydantic_object=ListSchema
).with_fallbacks([JsonOutputParser() | (lambda x: list(x.values())[0])])
elif is_pydantic_schema:
output_parser = PydanticOutputParser(pydantic_object=tool_schema)
else:
output_parser = JsonOutputParser()
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "