mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
openai[patch]: enable Azure structured output, parallel_tool_calls=Fa… (#26599)
…lse, tool_choice=required response_format=json_schema, tool_choice=required, parallel_tool_calls are all supported for gpt-4o on azure.
This commit is contained in:
parent
bb40a0fb32
commit
e1e4f88b3e
@ -4,37 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, TypeVar, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import from_env, secret_from_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
@ -538,6 +514,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
default_factory=from_env("OPENAI_API_TYPE", default="azure")
|
||||
)
|
||||
"""Legacy, for openai<1.0.0 support."""
|
||||
|
||||
validate_base_url: bool = True
|
||||
"""If legacy arg openai_api_base is passed in, try to infer if it is a base_url or
|
||||
azure_endpoint and update client params accordingly.
|
||||
@ -550,6 +527,28 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
Used for tracing and token counting. Does NOT affect completion.
|
||||
"""
|
||||
|
||||
disabled_params: Optional[Dict[str, Any]] = Field(default=None)
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
disabled for the given model.
|
||||
|
||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||
parameter and the value is either None, meaning that parameter should never be
|
||||
used, or it's a list of disabled values for the parameter.
|
||||
|
||||
For example, older models may not support the 'parallel_tool_calls' parameter at
|
||||
all, in which case ``disabled_params={"parallel_tool_calls: None}`` can ben passed
|
||||
in.
|
||||
|
||||
If a parameter is disabled then it will not be used by default in any methods, e.g.
|
||||
in
|
||||
:meth:`~langchain_openai.chat_models.azure.AzureChatOpenAI.with_structured_output`.
|
||||
However this does not prevent a user from directly passed in the parameter during
|
||||
invocation.
|
||||
|
||||
By default, unless ``model_name="gpt-4o"`` is specified, then
|
||||
'parallel_tools_calls' will be disabled.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
@ -574,6 +573,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
if self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
if self.disabled_params is None:
|
||||
# As of 09-17-2024 'parallel_tool_calls' param is only supported for gpt-4o.
|
||||
if self.model_name and self.model_name == "gpt-4o":
|
||||
pass
|
||||
else:
|
||||
self.disabled_params = {"parallel_tool_calls": None}
|
||||
|
||||
# Check OPENAI_ORGANIZATION for backwards compatibility.
|
||||
self.openai_organization = (
|
||||
self.openai_organization
|
||||
@ -634,311 +640,6 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
self.async_client = self.root_async_client.chat.completions
|
||||
return self
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
# As of 05/2024 Azure OpenAI doesn't support tool_choice="required".
|
||||
# TODO: Update this condition once tool_choice="required" is supported.
|
||||
if tool_choice in ("any", "required", True):
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
f"Azure OpenAI does not currently support {tool_choice=}. Should "
|
||||
f"be one of 'auto', 'none', or the name of the tool to call."
|
||||
)
|
||||
else:
|
||||
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
|
||||
return super().bind_tools(tools, tool_choice=tool_choice, **kwargs)
|
||||
|
||||
# TODO: Fix typing.
|
||||
@overload # type: ignore[override]
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: Literal[True] = True,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _AllReturnType]: ...
|
||||
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ...
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema:
|
||||
The output schema. Can be passed in as:
|
||||
- an OpenAI function/tool schema,
|
||||
- a JSON Schema,
|
||||
- a TypedDict class,
|
||||
- or a Pydantic class.
|
||||
If ``schema`` is a Pydantic class then the model output will be a
|
||||
Pydantic instance of that class, and the model-generated fields will be
|
||||
validated by the Pydantic class. Otherwise the model output will be a
|
||||
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
|
||||
for more on how to properly specify types and descriptions of
|
||||
schema fields when specifying a Pydantic or TypedDict class.
|
||||
method:
|
||||
The method for steering model generation, either "function_calling"
|
||||
or "json_mode". If "function_calling" then the schema will be converted
|
||||
to an OpenAI function and the returned model will make use of the
|
||||
function-calling API. If "json_mode" then OpenAI's JSON mode will be
|
||||
used. Note that if using "json_mode" then you must include instructions
|
||||
for formatting the output into the desired schema into the model call.
|
||||
include_raw:
|
||||
If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
|
||||
an instance of ``schema`` (i.e., a Pydantic object).
|
||||
|
||||
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||
|
||||
If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||
- ``"raw"``: BaseMessage
|
||||
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||
- ``"parsing_error"``: Optional[BaseException]
|
||||
|
||||
Example: schema=Pydantic class, method="function_calling", include_raw=False:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
|
||||
answer: str
|
||||
# If we provide default values and/or descriptions for fields, these will be passed
|
||||
# to the model. This is an important part of improving a model's ability to
|
||||
# correctly return structured outputs.
|
||||
justification: Optional[str] = Field(
|
||||
default=None, description="A justification for the answer."
|
||||
)
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
|
||||
# -> AnswerWithJustification(
|
||||
# answer='They weigh the same',
|
||||
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
|
||||
# )
|
||||
|
||||
Example: schema=Pydantic class, method="function_calling", include_raw=True:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification, include_raw=True
|
||||
)
|
||||
|
||||
structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
|
||||
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
Example: schema=TypedDict class, method="function_calling", include_raw=False:
|
||||
.. code-block:: python
|
||||
|
||||
# IMPORTANT: If you are using Python <=3.8, you need to import Annotated
|
||||
# from typing_extensions, not from typing.
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
|
||||
class AnswerWithJustification(TypedDict):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
|
||||
answer: str
|
||||
justification: Annotated[
|
||||
Optional[str], None, "A justification for the answer."
|
||||
]
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
# -> {
|
||||
# 'answer': 'They weigh the same',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||
# }
|
||||
|
||||
Example: schema=OpenAI function schema, method="function_calling", include_raw=False:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
oai_schema = {
|
||||
'name': 'AnswerWithJustification',
|
||||
'description': 'An answer to the user question along with justification for the answer.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'answer': {'type': 'string'},
|
||||
'justification': {'description': 'A justification for the answer.', 'type': 'string'}
|
||||
},
|
||||
'required': ['answer']
|
||||
}
|
||||
}
|
||||
|
||||
llm = AzureChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = llm.with_structured_output(oai_schema)
|
||||
|
||||
structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
# -> {
|
||||
# 'answer': 'They weigh the same',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||
# }
|
||||
|
||||
Example: schema=Pydantic class, method="json_mode", include_raw=True:
|
||||
.. code-block::
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = AzureChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification,
|
||||
method="json_mode",
|
||||
include_raw=True
|
||||
)
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
Example: schema=None, method="json_mode", include_raw=True:
|
||||
.. code-block::
|
||||
|
||||
structured_llm = llm.with_structured_output(method="json_mode", include_raw=True)
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'parsed': {
|
||||
# 'answer': 'They are both the same weight.',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
|
||||
# },
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
first_tool_only=True, # type: ignore[list-item]
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=tool_name, first_tool_only=True
|
||||
)
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(response_format={"type": "json_object"})
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
@ -980,6 +681,8 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
response: Union[dict, openai.BaseModel],
|
||||
generation_info: Optional[Dict] = None,
|
||||
) -> ChatResult:
|
||||
chat_result = super()._create_chat_result(response, generation_info)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
for res in response["choices"]:
|
||||
@ -988,7 +691,6 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"Azure has not provided the response due to a content filter "
|
||||
"being triggered"
|
||||
)
|
||||
chat_result = super()._create_chat_result(response, generation_info)
|
||||
|
||||
if "model" in response:
|
||||
model = response["model"]
|
||||
|
@ -454,6 +454,23 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
making requests to OpenAI compatible APIs, such as vLLM."""
|
||||
include_response_headers: bool = False
|
||||
"""Whether to include response headers in the output message response_metadata."""
|
||||
disabled_params: Optional[Dict[str, Any]] = Field(default=None)
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
disabled for the given model.
|
||||
|
||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||
parameter and the value is either None, meaning that parameter should never be
|
||||
used, or it's a list of disabled values for the parameter.
|
||||
|
||||
For example, older models may not support the 'parallel_tool_calls' parameter at
|
||||
all, in which case ``disabled_params={"parallel_tool_calls: None}`` can ben passed
|
||||
in.
|
||||
|
||||
If a parameter is disabled then it will not be used by default in any methods, e.g.
|
||||
in :meth:`~langchain_openai.chat_models.base.ChatOpenAI.with_structured_output`.
|
||||
However this does not prevent a user from directly passed in the parameter during
|
||||
invocation.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@ -1401,12 +1418,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"Received None."
|
||||
)
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
llm = self.bind_tools(
|
||||
[schema],
|
||||
tool_choice=tool_name,
|
||||
parallel_tool_calls=False,
|
||||
strict=strict,
|
||||
bind_kwargs = self._filter_disabled_params(
|
||||
tool_choice=tool_name, parallel_tool_calls=False, strict=strict
|
||||
)
|
||||
|
||||
llm = self.bind_tools([schema], **bind_kwargs)
|
||||
if is_pydantic_schema:
|
||||
output_parser: Runnable = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
@ -1456,6 +1472,21 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
if not self.disabled_params:
|
||||
return kwargs
|
||||
filtered = {}
|
||||
for k, v in kwargs.items():
|
||||
# Skip param
|
||||
if k in self.disabled_params and (
|
||||
self.disabled_params[k] is None or v in self.disabled_params[k]
|
||||
):
|
||||
continue
|
||||
# Keep param
|
||||
else:
|
||||
filtered[k] = v
|
||||
return filtered
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatOpenAI):
|
||||
"""OpenAI chat model integration.
|
||||
@ -2114,7 +2145,7 @@ def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
|
||||
else:
|
||||
raise ValueError(
|
||||
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
||||
"field."
|
||||
"field. Received message:\n\n{ai_msg}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -9,16 +9,8 @@ from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "")
|
||||
OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
|
||||
OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "")
|
||||
DEPLOYMENT_NAME = os.environ.get(
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||
os.environ.get("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", ""),
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
class TestAzureOpenAIStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return AzureChatOpenAI
|
||||
@ -26,10 +18,30 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"deployment_name": DEPLOYMENT_NAME,
|
||||
"openai_api_version": OPENAI_API_VERSION,
|
||||
"azure_endpoint": OPENAI_API_BASE,
|
||||
"api_key": OPENAI_API_KEY,
|
||||
"deployment_name": os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
||||
|
||||
class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
||||
"""Test a legacy model."""
|
||||
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return AzureChatOpenAI
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"deployment_name": os.environ["AZURE_OPENAI_LEGACY_CHAT_DEPLOYMENT_NAME"]
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
|
@ -10,6 +10,9 @@
|
||||
'kwargs': dict({
|
||||
'azure_endpoint': 'https://test.azure.com',
|
||||
'deployment_name': 'test',
|
||||
'disabled_params': dict({
|
||||
'parallel_tool_calls': None,
|
||||
}),
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'n': 1,
|
||||
|
Loading…
Reference in New Issue
Block a user