mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-02 13:55:42 +00:00
support sync
This commit is contained in:
parent
d4fc734250
commit
b8c04d0ece
@ -103,6 +103,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.responses import Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -114,6 +115,10 @@ global_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||
|
||||
|
||||
def _b64str_to_bytes(base64_str: str) -> bytes:
|
||||
return base64.b64decode(base64_str.encode("utf-8"))
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
|
||||
@ -410,6 +415,52 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
||||
raise
|
||||
|
||||
|
||||
_MessageContent = Union[str, list[Union[str, dict]]]
|
||||
|
||||
|
||||
def _get_image_bytes_from_content(
|
||||
content: _MessageContent,
|
||||
) -> list[tuple[str, bytes, str]]:
|
||||
images = []
|
||||
image_count = 0
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image_url":
|
||||
if (
|
||||
(image_url := block.get("image_url"))
|
||||
and isinstance(image_url, dict)
|
||||
and (url := image_url.get("url"))
|
||||
):
|
||||
base64_regex = (
|
||||
r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
|
||||
)
|
||||
base64_match = re.match(base64_regex, url)
|
||||
if base64_match:
|
||||
images.append(
|
||||
(
|
||||
f"image-{image_count:03}",
|
||||
_b64str_to_bytes(base64_match.group("data")),
|
||||
f'image/{base64_match.group("media_type")}',
|
||||
)
|
||||
)
|
||||
image_count += 1
|
||||
elif (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "image"
|
||||
and is_data_content_block(block)
|
||||
):
|
||||
images.append(
|
||||
(
|
||||
f"image-{image_count:03}",
|
||||
_b64str_to_bytes(block["data"]),
|
||||
block["mime_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
pass
|
||||
return images
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
name: str
|
||||
|
||||
@ -833,6 +884,23 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
return source
|
||||
return self.stream_usage
|
||||
|
||||
def _should_stream(
|
||||
self,
|
||||
*,
|
||||
async_api: bool,
|
||||
run_manager: Optional[
|
||||
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Determine if a given model call should hit the streaming API."""
|
||||
if self._use_images_client:
|
||||
return False
|
||||
else:
|
||||
return super()._should_stream(
|
||||
async_api=async_api, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -910,6 +978,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self._use_images_client:
|
||||
return self._generate_images(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
@ -958,6 +1030,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else:
|
||||
return _use_responses_api(payload)
|
||||
|
||||
@property
|
||||
def _use_images_client(self) -> bool:
|
||||
return self.model_name.startswith("gpt-image")
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
@ -1196,6 +1272,28 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Return type of chat model."""
|
||||
return "openai-chat"
|
||||
|
||||
def _generate_images(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = messages[-1].text()
|
||||
images = []
|
||||
for message in messages:
|
||||
images.extend(_get_image_bytes_from_content(message.content))
|
||||
if images:
|
||||
result: ImagesResponse = self.root_client.images.edit(
|
||||
model=self.model_name, image=images, prompt=prompt, **kwargs
|
||||
)
|
||||
else:
|
||||
result = self.root_client.images.generate(
|
||||
model=self.model_name, prompt=prompt, **kwargs
|
||||
)
|
||||
output_message = AIMessage(content="", response_metadata=result.model_dump())
|
||||
return ChatResult(generations=[ChatGeneration(message=output_message)])
|
||||
|
||||
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
|
||||
if self.tiktoken_model_name is not None:
|
||||
model = self.tiktoken_model_name
|
||||
|
Loading…
Reference in New Issue
Block a user