1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-02 13:55:42 +00:00

support sync

This commit is contained in:
Chester Curme 2025-04-23 15:50:19 -04:00
parent d4fc734250
commit b8c04d0ece

View File

@ -103,6 +103,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self
if TYPE_CHECKING:
from openai.types.images_response import ImagesResponse
from openai.types.responses import Response
logger = logging.getLogger(__name__)
@ -114,6 +115,10 @@ global_ssl_context = ssl.create_default_context(cafile=certifi.where())
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
def _b64str_to_bytes(base64_str: str) -> bytes:
return base64.b64decode(base64_str.encode("utf-8"))
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
@ -410,6 +415,52 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
raise
_MessageContent = Union[str, list[Union[str, dict]]]
def _get_image_bytes_from_content(
content: _MessageContent,
) -> list[tuple[str, bytes, str]]:
images = []
image_count = 0
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "image_url":
if (
(image_url := block.get("image_url"))
and isinstance(image_url, dict)
and (url := image_url.get("url"))
):
base64_regex = (
r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
)
base64_match = re.match(base64_regex, url)
if base64_match:
images.append(
(
f"image-{image_count:03}",
_b64str_to_bytes(base64_match.group("data")),
f'image/{base64_match.group("media_type")}',
)
)
image_count += 1
elif (
isinstance(block, dict)
and block.get("type") == "image"
and is_data_content_block(block)
):
images.append(
(
f"image-{image_count:03}",
_b64str_to_bytes(block["data"]),
block["mime_type"],
)
)
else:
pass
return images
class _FunctionCall(TypedDict):
name: str
@ -833,6 +884,23 @@ class BaseChatOpenAI(BaseChatModel):
return source
return self.stream_usage
def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
**kwargs: Any,
) -> bool:
"""Determine if a given model call should hit the streaming API."""
if self._use_images_client:
return False
else:
return super()._should_stream(
async_api=async_api, run_manager=run_manager, **kwargs
)
def _stream(
self,
messages: list[BaseMessage],
@ -910,6 +978,10 @@ class BaseChatOpenAI(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self._use_images_client:
return self._generate_images(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@ -958,6 +1030,10 @@ class BaseChatOpenAI(BaseChatModel):
else:
return _use_responses_api(payload)
@property
def _use_images_client(self) -> bool:
return self.model_name.startswith("gpt-image")
def _get_request_payload(
self,
input_: LanguageModelInput,
@ -1196,6 +1272,28 @@ class BaseChatOpenAI(BaseChatModel):
"""Return type of chat model."""
return "openai-chat"
def _generate_images(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = messages[-1].text()
images = []
for message in messages:
images.extend(_get_image_bytes_from_content(message.content))
if images:
result: ImagesResponse = self.root_client.images.edit(
model=self.model_name, image=images, prompt=prompt, **kwargs
)
else:
result = self.root_client.images.generate(
model=self.model_name, prompt=prompt, **kwargs
)
output_message = AIMessage(content="", response_metadata=result.model_dump())
return ChatResult(generations=[ChatGeneration(message=output_message)])
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name