diff --git a/libs/langchain/langchain/utilities/dalle_image_generator.py b/libs/langchain/langchain/utilities/dalle_image_generator.py index c81027db700..5eaa9176c1b 100644 --- a/libs/langchain/langchain/utilities/dalle_image_generator.py +++ b/libs/langchain/langchain/utilities/dalle_image_generator.py @@ -1,9 +1,17 @@ """Utility that calls OpenAI's Dall-E Image Generator.""" -from typing import Any, Dict, Optional +import logging +import os +from typing import Any, Dict, Mapping, Optional, Tuple, Union -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.utils import ( + get_pydantic_field_names, +) from langchain.utils import get_from_dict_or_env +from langchain.utils.openai import is_openai_v1 + +logger = logging.getLogger(__name__) class DallEAPIWrapper(BaseModel): @@ -18,52 +26,139 @@ class DallEAPIWrapper(BaseModel): """ client: Any #: :meta private: - openai_api_key: Optional[str] = None + async_client: Any = Field(default=None, exclude=True) #: :meta private: + model_name: str = Field(default="dall-e-2", alias="model") + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + openai_api_key: Optional[str] = Field(default=None, alias="api_key") + """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" + openai_api_base: Optional[str] = Field(default=None, alias="base_url") + """Base URL path for API requests, leave blank if not using a proxy or service + emulator.""" + openai_organization: Optional[str] = Field(default=None, alias="organization") + """Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" + # to support explicit proxy for OpenAI + openai_proxy: Optional[str] = None + request_timeout: Union[float, Tuple[float, float], Any, None] = Field( + default=None, alias="timeout" + ) n: int = 1 """Number of images to generate""" size: str = "1024x1024" """Size of image to generate""" separator: str = "\n" """Separator to use when multiple URLs are returned.""" - model: Optional[str] = "dall-e-2" - """Model to use for image generation""" quality: Optional[str] = "standard" """Quality of the image that will be generated""" + max_retries: int = 2 + """Maximum number of retries to make when generating.""" + default_headers: Union[Mapping[str, str], None] = None + default_query: Union[Mapping[str, object], None] = None + # Configure a custom httpx client. See the + # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. + http_client: Union[Any, None] = None + """Optional httpx.Client.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = get_from_dict_or_env( + values["openai_api_key"] = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) + # Check OPENAI_ORGANIZATION for backwards compatibility. + values["openai_organization"] = ( + values["openai_organization"] + or os.getenv("OPENAI_ORG_ID") + or os.getenv("OPENAI_ORGANIZATION") + or None + ) + values["openai_api_base"] = values["openai_api_base"] or os.getenv( + "OPENAI_API_BASE" + ) + values["openai_proxy"] = get_from_dict_or_env( + values, + "openai_proxy", + "OPENAI_PROXY", + default="", + ) + try: - from openai import OpenAI + import openai - client = OpenAI( - api_key=openai_api_key, # this is also the default, it can be omitted - ) - - values["client"] = client - except ImportError as e: + except ImportError: raise ImportError( "Could not import openai python package. " - "Please it install it with `pip install openai`." - ) from e + "Please install it with `pip install openai`." + ) + + if is_openai_v1(): + client_params = { + "api_key": values["openai_api_key"], + "organization": values["openai_organization"], + "base_url": values["openai_api_base"], + "timeout": values["request_timeout"], + "max_retries": values["max_retries"], + "default_headers": values["default_headers"], + "default_query": values["default_query"], + "http_client": values["http_client"], + } + + if not values.get("client"): + values["client"] = openai.OpenAI(**client_params).images + if not values.get("async_client"): + values["async_client"] = openai.AsyncOpenAI(**client_params).images + elif not values.get("client"): + values["client"] = openai.Image + else: + pass return values def run(self, query: str) -> str: """Run query through OpenAI and parse result.""" - response = self.client.images.generate( - prompt=query, - n=self.n, - size=self.size, - model=self.model, - quality=self.quality, - ) - image_urls = self.separator.join([item.url for item in response.data]) + + if is_openai_v1(): + response = self.client.generate( + prompt=query, + n=self.n, + size=self.size, + model=self.model_name, + quality=self.quality, + ) + image_urls = self.separator.join([item.url for item in response.data]) + else: + response = self.client.create( + prompt=query, n=self.n, size=self.size, model=self.model_name + ) + image_urls = self.separator.join([item["url"] for item in response["data"]]) + return image_urls if image_urls else "No image was generated"