Files
langchain/libs/community/langchain_community/utilities/dalle_image_generator.py
Eugene Yurtsev 844955d6e1 community[patch]: assign missed default (#26326)
Assigning missed defaults in various classes. Most clients were being
assigned during the `model_validator(mode="before")` step, so this
change should amount to a no-op in those cases.

---

This PR was autogenerated using gritql

```shell

grit apply 'class_definition(name=$C, $body, superclasses=$S) where {    
    $C <: ! "Config", // Does not work in this scope, but works after class_definition
    $body <: block($statements),
    $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where {
        or {
            $y <: `Field($z)`,
            $x <: "model_config"
        }
    },
    // And has either Any or Optional fields without a default
    $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where {
        $t <: or {
            r"Optional.*",
            r"Any",
            r"Union[None, .*]",
            r"Union[.*, None, .*]",
            r"Union[.*, None]",
        },
        $y <: ., // Match empty node        
        $t => `$t = None`,
    },    
}
' --language python .

```
2024-09-11 11:13:11 -04:00

158 lines
5.8 KiB
Python

"""Utility that calls OpenAI's Dall-E Image Generator."""
import logging
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from langchain_core.utils import (
from_env,
get_pydantic_field_names,
secret_from_env,
)
from pydantic import BaseModel, ConfigDict, Field, Secret, model_validator
from typing_extensions import Self
from langchain_community.utils.openai import is_openai_v1
logger = logging.getLogger(__name__)
class DallEAPIWrapper(BaseModel):
"""Wrapper for OpenAI's DALL-E Image Generator.
https://platform.openai.com/docs/guides/images/generations?context=node
Usage instructions:
1. `pip install openai`
2. save your OPENAI_API_KEY in an environment variable
"""
client: Any = None #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="dall-e-2", alias="model")
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
openai_api_key: Secret[str] = Field(
alias="api_key",
default_factory=secret_from_env(
"OPENAI_API_KEY",
default=None,
),
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(
alias="organization",
default_factory=from_env(
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
),
)
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI
openai_proxy: str = Field(default_factory=from_env("OPENAI_PROXY", default=""))
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
n: int = 1
"""Number of images to generate"""
size: str = "1024x1024"
"""Size of image to generate"""
separator: str = "\n"
"""Separator to use when multiple URLs are returned."""
quality: Optional[str] = "standard"
"""Quality of the image that will be generated"""
max_retries: int = 2
"""Maximum number of retries to make when generating."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
"""Optional httpx.Client."""
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if is_openai_v1():
client_params = {
"api_key": self.openai_api_key,
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
"http_client": self.http_client,
}
if not self.client:
self.client = openai.OpenAI(**client_params).images
if not self.async_client:
self.async_client = openai.AsyncOpenAI(**client_params).images
elif not self.client:
self.client = openai.Image
else:
pass
return self
def run(self, query: str) -> str:
"""Run query through OpenAI and parse result."""
if is_openai_v1():
response = self.client.generate(
prompt=query,
n=self.n,
size=self.size,
model=self.model_name,
quality=self.quality,
)
image_urls = self.separator.join([item.url for item in response.data])
else:
response = self.client.create(
prompt=query, n=self.n, size=self.size, model=self.model_name
)
image_urls = self.separator.join([item["url"] for item in response["data"]])
return image_urls if image_urls else "No image was generated"