multiple: structured output tracing standard metadata (#29421)

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Erick Friis
2025-01-29 14:00:26 -08:00
committed by GitHub
parent 284c935b08
commit 8f95da4eb1
9 changed files with 288 additions and 28 deletions

View File

@@ -365,11 +365,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:
structured_output_format_dict = {
"structured_output_format": {
"kwargs": structured_output_format.get("kwargs", {}),
"schema": convert_to_openai_tool(
structured_output_format["schema"]
),
}
}
except ValueError:
structured_output_format_dict = {}
else:
structured_output_format_dict = {}
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
**structured_output_format_dict,
}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
@@ -441,11 +458,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:
structured_output_format_dict = {
"structured_output_format": {
"kwargs": structured_output_format.get("kwargs", {}),
"schema": convert_to_openai_tool(
structured_output_format["schema"]
),
}
}
except ValueError:
structured_output_format_dict = {}
else:
structured_output_format_dict = {}
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
**structured_output_format_dict,
}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
@@ -606,11 +641,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:
structured_output_format_dict = {
"structured_output_format": {
"kwargs": structured_output_format.get("kwargs", {}),
"schema": convert_to_openai_tool(
structured_output_format["schema"]
),
}
}
except ValueError:
structured_output_format_dict = {}
else:
structured_output_format_dict = {}
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
inheritable_metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
**structured_output_format_dict,
}
callback_manager = CallbackManager.configure(
@@ -697,11 +749,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:
structured_output_format_dict = {
"structured_output_format": {
"kwargs": structured_output_format.get("kwargs", {}),
"schema": convert_to_openai_tool(
structured_output_format["schema"]
),
}
}
except ValueError:
structured_output_format_dict = {}
else:
structured_output_format_dict = {}
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
inheritable_metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
**structured_output_format_dict,
}
callback_manager = AsyncCallbackManager.configure(
@@ -1240,7 +1309,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if self.bind_tools is BaseChatModel.bind_tools:
msg = "with_structured_output is not implemented for this model."
raise NotImplementedError(msg)
llm = self.bind_tools([schema], tool_choice="any")
llm = self.bind_tools(
[schema],
tool_choice="any",
structured_output_format={"kwargs": {}, "schema": schema},
)
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast(TypeBaseModel, schema)], first_tool_only=True