mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-03 22:28:04 +00:00
support sync
This commit is contained in:
parent
d4fc734250
commit
b8c04d0ece
@ -103,6 +103,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from openai.types.images_response import ImagesResponse
|
||||||
from openai.types.responses import Response
|
from openai.types.responses import Response
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -114,6 +115,10 @@ global_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
|||||||
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
|
||||||
|
|
||||||
|
|
||||||
|
def _b64str_to_bytes(base64_str: str) -> bytes:
|
||||||
|
return base64.b64decode(base64_str.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
"""Convert a dictionary to a LangChain message.
|
"""Convert a dictionary to a LangChain message.
|
||||||
|
|
||||||
@ -410,6 +415,52 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
_MessageContent = Union[str, list[Union[str, dict]]]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_image_bytes_from_content(
|
||||||
|
content: _MessageContent,
|
||||||
|
) -> list[tuple[str, bytes, str]]:
|
||||||
|
images = []
|
||||||
|
image_count = 0
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "image_url":
|
||||||
|
if (
|
||||||
|
(image_url := block.get("image_url"))
|
||||||
|
and isinstance(image_url, dict)
|
||||||
|
and (url := image_url.get("url"))
|
||||||
|
):
|
||||||
|
base64_regex = (
|
||||||
|
r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
|
||||||
|
)
|
||||||
|
base64_match = re.match(base64_regex, url)
|
||||||
|
if base64_match:
|
||||||
|
images.append(
|
||||||
|
(
|
||||||
|
f"image-{image_count:03}",
|
||||||
|
_b64str_to_bytes(base64_match.group("data")),
|
||||||
|
f'image/{base64_match.group("media_type")}',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
image_count += 1
|
||||||
|
elif (
|
||||||
|
isinstance(block, dict)
|
||||||
|
and block.get("type") == "image"
|
||||||
|
and is_data_content_block(block)
|
||||||
|
):
|
||||||
|
images.append(
|
||||||
|
(
|
||||||
|
f"image-{image_count:03}",
|
||||||
|
_b64str_to_bytes(block["data"]),
|
||||||
|
block["mime_type"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
class _FunctionCall(TypedDict):
|
class _FunctionCall(TypedDict):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
@ -833,6 +884,23 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
return source
|
return source
|
||||||
return self.stream_usage
|
return self.stream_usage
|
||||||
|
|
||||||
|
def _should_stream(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
async_api: bool,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> bool:
|
||||||
|
"""Determine if a given model call should hit the streaming API."""
|
||||||
|
if self._use_images_client:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return super()._should_stream(
|
||||||
|
async_api=async_api, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -910,6 +978,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
|
if self._use_images_client:
|
||||||
|
return self._generate_images(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
@ -958,6 +1030,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
else:
|
else:
|
||||||
return _use_responses_api(payload)
|
return _use_responses_api(payload)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _use_images_client(self) -> bool:
|
||||||
|
return self.model_name.startswith("gpt-image")
|
||||||
|
|
||||||
def _get_request_payload(
|
def _get_request_payload(
|
||||||
self,
|
self,
|
||||||
input_: LanguageModelInput,
|
input_: LanguageModelInput,
|
||||||
@ -1196,6 +1272,28 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
return "openai-chat"
|
return "openai-chat"
|
||||||
|
|
||||||
|
def _generate_images(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
prompt = messages[-1].text()
|
||||||
|
images = []
|
||||||
|
for message in messages:
|
||||||
|
images.extend(_get_image_bytes_from_content(message.content))
|
||||||
|
if images:
|
||||||
|
result: ImagesResponse = self.root_client.images.edit(
|
||||||
|
model=self.model_name, image=images, prompt=prompt, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self.root_client.images.generate(
|
||||||
|
model=self.model_name, prompt=prompt, **kwargs
|
||||||
|
)
|
||||||
|
output_message = AIMessage(content="", response_metadata=result.model_dump())
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=output_message)])
|
||||||
|
|
||||||
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
|
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
|
||||||
if self.tiktoken_model_name is not None:
|
if self.tiktoken_model_name is not None:
|
||||||
model = self.tiktoken_model_name
|
model = self.tiktoken_model_name
|
||||||
|
Loading…
Reference in New Issue
Block a user