Compare commits

...

2 Commits

Author SHA1 Message Date
Chester Curme
dffb2b1fc9 update 2025-04-23 16:48:48 -04:00
Chester Curme
b8c04d0ece support sync 2025-04-23 15:50:19 -04:00
3 changed files with 136 additions and 6 deletions

View File

@@ -32,6 +32,7 @@ from urllib.parse import urlparse
import certifi
import openai
import tiktoken
from langchain_core._api.beta_decorator import warn_beta
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@@ -103,6 +104,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self
if TYPE_CHECKING:
from openai.types.images_response import ImagesResponse
from openai.types.responses import Response
logger = logging.getLogger(__name__)
@@ -114,6 +116,10 @@ global_ssl_context = ssl.create_default_context(cafile=certifi.where())
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
def _b64str_to_bytes(base64_str: str) -> bytes:
return base64.b64decode(base64_str.encode("utf-8"))
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
@@ -410,6 +416,53 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
raise
_MessageContent = Union[str, list[Union[str, dict]]]
WARNED_IMAGE_GEN_BETA = False
def _get_image_bytes_from_content(
content: _MessageContent,
) -> list[tuple[str, bytes, str]]:
images = []
image_count = 0
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "image_url":
if (
(image_url := block.get("image_url"))
and isinstance(image_url, dict)
and (url := image_url.get("url"))
):
base64_regex = (
r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
)
base64_match = re.match(base64_regex, url)
if base64_match:
images.append(
(
f"image-{image_count:03}",
_b64str_to_bytes(base64_match.group("data")),
f'image/{base64_match.group("media_type")}',
)
)
image_count += 1
elif (
isinstance(block, dict)
and block.get("type") == "image"
and is_data_content_block(block)
):
images.append(
(
f"image-{image_count:03}",
_b64str_to_bytes(block["data"]),
block["mime_type"],
)
)
else:
pass
return images
class _FunctionCall(TypedDict):
name: str
@@ -833,6 +886,23 @@ class BaseChatOpenAI(BaseChatModel):
return source
return self.stream_usage
def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
**kwargs: Any,
) -> bool:
"""Determine if a given model call should hit the streaming API."""
if self._use_images_client:
return False
else:
return super()._should_stream(
async_api=async_api, run_manager=run_manager, **kwargs
)
def _stream(
self,
messages: list[BaseMessage],
@@ -910,6 +980,10 @@ class BaseChatOpenAI(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self._use_images_client:
return self._generate_images(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@@ -958,6 +1032,10 @@ class BaseChatOpenAI(BaseChatModel):
else:
return _use_responses_api(payload)
@property
def _use_images_client(self) -> bool:
return self.model_name.startswith("gpt-image")
def _get_request_payload(
self,
input_: LanguageModelInput,
@@ -1196,6 +1274,58 @@ class BaseChatOpenAI(BaseChatModel):
"""Return type of chat model."""
return "openai-chat"
def _generate_images(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
global WARNED_IMAGE_GEN_BETA
if not WARNED_IMAGE_GEN_BETA:
warn_beta(message="Image generation via ChatOpenAI is in beta.")
WARNED_IMAGE_GEN_BETA = True
prompt = messages[-1].text()
# Get last set of images
for message in reversed(messages):
images = _get_image_bytes_from_content(message.content)
if images:
break
else:
images = []
if images:
result: ImagesResponse = self.root_client.images.edit(
model=self.model_name, image=images, prompt=prompt, **kwargs
)
else:
result = self.root_client.images.generate(
model=self.model_name, prompt=prompt, **kwargs
)
image_blocks = []
if result.data:
for image in result.data:
if image.b64_json:
image_blocks.append(
{
"type": "image",
"source_type": "base64",
"data": image.b64_json,
"mime_type": "image/png",
}
)
if result.usage:
usage_metadata = _create_usage_metadata_responses(result.usage.model_dump())
else:
usage_metadata = None
output_message = AIMessage(
content=image_blocks or "", # type: ignore[arg-type]
response_metadata={"created": result.created},
usage_metadata=usage_metadata,
)
return ChatResult(generations=[ChatGeneration(message=output_message)])
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name

View File

@@ -8,7 +8,7 @@ license = { text = "MIT" }
requires-python = "<4.0,>=3.9"
dependencies = [
"langchain-core<1.0.0,>=0.3.53",
"openai<2.0.0,>=1.68.2",
"openai<2.0.0,>=1.76.0",
"tiktoken<1,>=0.7",
]
name = "langchain-openai"

View File

@@ -463,7 +463,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.53"
version = "0.3.55"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },
@@ -571,7 +571,7 @@ typing = [
[package.metadata]
requires-dist = [
{ name = "langchain-core", editable = "../../core" },
{ name = "openai", specifier = ">=1.68.2,<2.0.0" },
{ name = "openai", specifier = ">=1.76.0,<2.0.0" },
{ name = "tiktoken", specifier = ">=0.7,<1" },
]
@@ -829,7 +829,7 @@ wheels = [
[[package]]
name = "openai"
version = "1.68.2"
version = "1.76.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -841,9 +841,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3f/6b/6b002d5d38794645437ae3ddb42083059d556558493408d39a0fcea608bc/openai-1.68.2.tar.gz", hash = "sha256:b720f0a95a1dbe1429c0d9bb62096a0d98057bcda82516f6e8af10284bdd5b19", size = 413429 }
sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fd/34/cebce15f64eb4a3d609a83ac3568d43005cc9a1cba9d7fde5590fd415423/openai-1.68.2-py3-none-any.whl", hash = "sha256:24484cb5c9a33b58576fdc5acf0e5f92603024a4e39d0b99793dfa1eb14c2b36", size = 606073 },
{ url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201 },
]
[[package]]