diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 139ab643e7a..22ef3ff09a3 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -12,6 +12,7 @@ import sys import warnings from functools import partial from io import BytesIO +from json import JSONDecodeError from math import ceil from operator import itemgetter from typing import ( @@ -114,6 +115,8 @@ logger = logging.getLogger(__name__) # https://www.python-httpx.org/advanced/ssl/#configuring-client-instances global_ssl_context = ssl.create_default_context(cafile=certifi.where()) +_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__" + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: """Convert a dictionary to a LangChain message. @@ -452,18 +455,6 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None: raise -def _is_builtin_tool(tool: dict) -> bool: - return "type" in tool and tool["type"] != "function" - - -def _transform_payload_for_responses(payload: dict) -> dict: - updated_payload = payload.copy() - if messages := updated_payload.pop("messages"): - updated_payload["input"] = messages - - return updated_payload - - class _FunctionCall(TypedDict): name: str @@ -793,8 +784,7 @@ class BaseChatOpenAI(BaseChatModel): ) -> Iterator[ChatGenerationChunk]: kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) - responses_payload = _transform_payload_for_responses(payload) - context_manager = self.root_client.responses.create(**responses_payload) + context_manager = self.root_client.responses.create(**payload) with context_manager as response: for chunk in response: @@ -816,10 +806,7 @@ class BaseChatOpenAI(BaseChatModel): ) -> AsyncIterator[ChatGenerationChunk]: kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) - responses_payload = _transform_payload_for_responses(payload) - context_manager = await self.root_async_client.responses.create( - **responses_payload - ) + context_manager = await self.root_async_client.responses.create(**payload) async with context_manager as response: async for chunk in response: @@ -926,15 +913,11 @@ class BaseChatOpenAI(BaseChatModel): raw_response = self.client.with_raw_response.create(**payload) response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} + elif _use_response_api(payload): + response = self.root_client.responses.create(**payload) + return _construct_lc_result_from_response_api(response) else: - if "tools" in payload and any( - _is_builtin_tool(tool) for tool in payload["tools"] - ): - responses_payload = _transform_payload_for_responses(payload) - response = self.root_client.responses.create(**responses_payload) - return self._create_chat_result_responses(response, generation_info) - else: - response = self.client.create(**payload) + response = self.client.create(**payload) return self._create_chat_result(response, generation_info) def _get_request_payload( @@ -948,11 +931,12 @@ class BaseChatOpenAI(BaseChatModel): if stop is not None: kwargs["stop"] = stop - return { - "messages": [_convert_message_to_dict(m) for m in messages], - **self._default_params, - **kwargs, - } + payload = {**self._default_params, **kwargs} + if _use_response_api(payload): + payload["input"] = _construct_response_api_input(messages) + else: + payload["messages"] = [_convert_message_to_dict(m) for m in messages] + return payload def _create_chat_result( self, @@ -1003,50 +987,6 @@ class BaseChatOpenAI(BaseChatModel): return ChatResult(generations=generations, llm_output=llm_output) - def _create_chat_result_responses( - self, response: Response, generation_info: Optional[Dict] = None - ) -> ChatResult: - generations = [] - - if error := response.error: - raise ValueError(error) - - token_usage = response.usage.model_dump() if response.usage else {} - generation_info = {} - content_blocks = [] - for output in response.output: - if output.type == "message": - for content in output.content: - if content.type == "output_text": - block = { - "type": "text", - "text": content.text, - "annotations": [ - annotation.model_dump() - for annotation in content.annotations - ], - } - content_blocks.append(block) - usage_metadata = _create_usage_metadata_responses(token_usage) - message = AIMessage( - content=content_blocks, # type: ignore[arg-type] - id=output.id, - usage_metadata=usage_metadata, - ) - if output.status: - generation_info["status"] = output.status - gen = ChatGeneration(message=message, generation_info=generation_info) - generations.append(gen) - else: - tool_output = output.model_dump() - if "tool_outputs" in generation_info: - generation_info["tool_outputs"].append(tool_output) - else: - generation_info["tool_outputs"] = [tool_output] - llm_output = {"model_name": response.model} - - return ChatResult(generations=generations, llm_output=llm_output) - async def _astream( self, messages: List[BaseMessage], @@ -1147,17 +1087,11 @@ class BaseChatOpenAI(BaseChatModel): raw_response = await self.async_client.with_raw_response.create(**payload) response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} + elif _use_response_api(payload): + response = await self.root_async_client.responses.create(**payload) + return _construct_lc_result_from_response_api(response) else: - if "tools" in payload and any( - _is_builtin_tool(tool) for tool in payload["tools"] - ): - responses_payload = _transform_payload_for_responses(payload) - response = await self.root_async_client.responses.create( - **responses_payload - ) - return self._create_chat_result_responses(response, generation_info) - else: - response = await self.async_client.create(**payload) + response = await self.async_client.create(**payload) return await run_in_executor( None, self._create_chat_result, response, generation_info ) @@ -2249,9 +2183,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> Iterator[ChatGenerationChunk]: """Set default stream_options.""" - if "tools" in kwargs and any( - _is_builtin_tool(tool) for tool in kwargs["tools"] - ): + if _use_response_api(kwargs): return super()._stream_responses(*args, **kwargs) else: stream_usage = self._should_stream_usage(stream_usage, **kwargs) @@ -2269,9 +2201,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> AsyncIterator[ChatGenerationChunk]: """Set default stream_options.""" - if "tools" in kwargs and any( - _is_builtin_tool(tool) for tool in kwargs["tools"] - ): + if _use_response_api(kwargs): async for chunk in super()._astream_responses(*args, **kwargs): yield chunk else: @@ -2818,3 +2748,162 @@ def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata: **{k: v for k, v in output_token_details.items() if v is not None} ), ) + + +def _is_builtin_tool(tool: dict) -> bool: + return "type" in tool and tool["type"] != "function" + + +def _use_response_api(payload: dict) -> bool: + return "tools" in payload and any( + _is_builtin_tool(tool) for tool in payload["tools"] + ) + + +def _construct_response_api_input(messages: list[BaseMessage]) -> list: + input_ = [] + for lc_msg in messages: + msg = _convert_message_to_dict(lc_msg) + if msg["role"] == "tool": + function_call_output = { + "type": "function_call_output", + "output": msg["content"], + "call_id": msg["tool_call_id"], + } + input_.append(function_call_output) + elif msg["role"] == "assistant": + if tool_calls := msg.pop("tool_calls", None): + if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY): + raise ValueError(...) + function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] + for tool_call in tool_calls: + function_call = { + "type": "function_call", + "name": tool_call["name"], + "arguments": tool_call["arguments"], + "call_id": tool_call["id"], + "id": function_call_ids[tool_call["id"]], + } + input_.append(function_call) + if msg.get("content"): + input_.append(msg) + elif msg["role"] == "user": + if isinstance(msg["content"], list): + for block in msg["content"]: + # chat api: {"type": "text", "text": "..."} + # response api: {"type": "input_text", "text": "..."} + if block["type"] == "text": + block["type"] = "input_text" + # chat api: {"type": "image_url", "image_url": {"url": "...", "detail": "..."}} # noqa: E501 + # response api: {"type": "image_url", "image_url": "...", "detail": "...", "file_id": "..."} # noqa: E501 + elif block["type"] == "image_url": + block["type"] = "input_image" + if isinstance(block.get("image_url"), dict): + image_url = block.pop("image_url") + block["image_url"] = image_url["url"] + if image_url.get("detail"): + block["detail"] = image_url["detail"] + else: + pass + input_.append(msg) + else: + input_.append(msg) + + return input_ + + +def _construct_lc_result_from_response_api(response: Response) -> ChatResult: + """Construct ChatResponse from OpenAI Response API response.""" + if response.error: + raise ValueError(response.error) + + response_metadata = { + k: v + for k, v in response.model_dump(exclude_none=True, mode="json").items() + if k + in ( + "created_at", + "id", + "incomplete_details", + "metadata", + "object", + "status", + "user", + "model", + ) + } + # for compatibility with chat completion calls. + response_metadata["model_name"] = response.get("model") + if response.usage: + usage_metadata = _create_usage_metadata_responses(response.usage.model_dump()) + else: + usage_metadata = None + + content_blocks = [] + tool_calls = [] + invalid_tool_calls = [] + additional_kwargs: dict = {} + msg_id = None + for output in response.output: + if output.type == "message": + for content in output.content: + if content.type == "output_text": + block = { + "type": "text", + "text": content.text, + "annotations": [ + annotation.model_dump() + for annotation in content.annotations + ], + } + content_blocks.append(block) + if content.type == "refusal": + additional_kwargs["refusal"] = content.refusal + msg_id = output.id + elif output.type == "function_call": + try: + args = json.loads(output.arguments, strict=False) + error = None + except JSONDecodeError as e: + args = output.arguments + error = str(e) + if error is None: + tool_call = { + "type": "tool_call", + "name": output.name, + "args": args, + "id": output.call_id, + } + tool_calls.append(tool_call) + else: + tool_call = { + "type": "invalid_tool_call", + "name": output.name, + "args": args, + "id": output.call_id, + "error": error, + } + invalid_tool_calls.append(tool_call) + if _FUNCTION_CALL_IDS_MAP_KEY not in additional_kwargs: + additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] = {} + additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][output.call_id] = output.id + elif output.type == "reasoning": + additional_kwargs["reasoning"] = output.model_dump( + exclude_none=True, mode="json" + ) + else: + tool_output = output.model_dump(exclude_none=True, mode="json") + if "tool_outputs" in additional_kwargs: + additional_kwargs["tool_outputs"].append(tool_output) + else: + additional_kwargs["tool_outputs"] = [tool_output] + message = AIMessage( + content=content_blocks or None, # type: ignore[arg-type] + id=msg_id, + usage_metadata=usage_metadata, + response_metadata=response_metadata, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) + return ChatResult(generations=[ChatGeneration(message=message)]) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 465e37f4fc6..a3418ed761a 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1258,14 +1258,14 @@ def _check_response(response: Optional[BaseMessage]) -> None: assert response.usage_metadata["output_tokens"] > 0 assert response.usage_metadata["total_tokens"] > 0 assert response.response_metadata["model_name"] - for tool_output in response.response_metadata["tool_outputs"]: + for tool_output in response.additional_kwargs["tool_outputs"]: assert tool_output["id"] assert tool_output["status"] assert tool_output["type"] def test_web_search() -> None: - llm = ChatOpenAI(model="gpt-4o") + llm = ChatOpenAI(model="gpt-4o-mini") response = llm.invoke( "What was a positive news story from today?", tools=[{"type": "web_search_preview"}], @@ -1283,9 +1283,16 @@ def test_web_search() -> None: full = chunk if full is None else full + chunk _check_response(full) + llm.invoke( + "what about a negative one", + tools=[{"type": "web_search_preview"}], + response_id=response.response_metadata["id"] + ) + _check_response(response) + async def test_web_search_async() -> None: - llm = ChatOpenAI(model="gpt-4o") + llm = ChatOpenAI(model="gpt-4o-mini") response = await llm.ainvoke( "What was a positive news story from today?", tools=[{"type": "web_search_preview"}], @@ -1307,7 +1314,7 @@ async def test_web_search_async() -> None: def test_file_search() -> None: pytest.skip() # TODO: set up infra - llm = ChatOpenAI(model="gpt-4o") + llm = ChatOpenAI(model="gpt-4o-mini") tool = { "type": "file_search", "vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],