diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 3c6a9ba6871..9e6d24c9b6e 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -103,6 +103,7 @@ from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self if TYPE_CHECKING: + from openai.types.images_response import ImagesResponse from openai.types.responses import Response logger = logging.getLogger(__name__) @@ -114,6 +115,10 @@ global_ssl_context = ssl.create_default_context(cafile=certifi.where()) _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__" +def _b64str_to_bytes(base64_str: str) -> bytes: + return base64.b64decode(base64_str.encode("utf-8")) + + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: """Convert a dictionary to a LangChain message. @@ -410,6 +415,52 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None: raise +_MessageContent = Union[str, list[Union[str, dict]]] + + +def _get_image_bytes_from_content( + content: _MessageContent, +) -> list[tuple[str, bytes, str]]: + images = [] + image_count = 0 + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "image_url": + if ( + (image_url := block.get("image_url")) + and isinstance(image_url, dict) + and (url := image_url.get("url")) + ): + base64_regex = ( + r"^data:(?Pimage/.+);base64,(?P.+)$" + ) + base64_match = re.match(base64_regex, url) + if base64_match: + images.append( + ( + f"image-{image_count:03}", + _b64str_to_bytes(base64_match.group("data")), + f'image/{base64_match.group("media_type")}', + ) + ) + image_count += 1 + elif ( + isinstance(block, dict) + and block.get("type") == "image" + and is_data_content_block(block) + ): + images.append( + ( + f"image-{image_count:03}", + _b64str_to_bytes(block["data"]), + block["mime_type"], + ) + ) + else: + pass + return images + + class _FunctionCall(TypedDict): name: str @@ -833,6 +884,23 @@ class BaseChatOpenAI(BaseChatModel): return source return self.stream_usage + def _should_stream( + self, + *, + async_api: bool, + run_manager: Optional[ + Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] + ] = None, + **kwargs: Any, + ) -> bool: + """Determine if a given model call should hit the streaming API.""" + if self._use_images_client: + return False + else: + return super()._should_stream( + async_api=async_api, run_manager=run_manager, **kwargs + ) + def _stream( self, messages: list[BaseMessage], @@ -910,6 +978,10 @@ class BaseChatOpenAI(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self._use_images_client: + return self._generate_images( + messages, stop=stop, run_manager=run_manager, **kwargs + ) if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs @@ -958,6 +1030,10 @@ class BaseChatOpenAI(BaseChatModel): else: return _use_responses_api(payload) + @property + def _use_images_client(self) -> bool: + return self.model_name.startswith("gpt-image") + def _get_request_payload( self, input_: LanguageModelInput, @@ -1196,6 +1272,28 @@ class BaseChatOpenAI(BaseChatModel): """Return type of chat model.""" return "openai-chat" + def _generate_images( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + prompt = messages[-1].text() + images = [] + for message in messages: + images.extend(_get_image_bytes_from_content(message.content)) + if images: + result: ImagesResponse = self.root_client.images.edit( + model=self.model_name, image=images, prompt=prompt, **kwargs + ) + else: + result = self.root_client.images.generate( + model=self.model_name, prompt=prompt, **kwargs + ) + output_message = AIMessage(content="", response_metadata=result.model_dump()) + return ChatResult(generations=[ChatGeneration(message=output_message)]) + def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]: if self.tiktoken_model_name is not None: model = self.tiktoken_model_name