mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
feat: Increased compatibility with new and old versions for dalle (#14222)
- **Description:** Increased compatibility with all versions openai for dalle, This pr add support for openai version from 0 ~ 1.3.
This commit is contained in:
parent
7205bfdd00
commit
20d2b4a6ba
@ -1,9 +1,17 @@
|
||||
"""Utility that calls OpenAI's Dall-E Image Generator."""
|
||||
from typing import Any, Dict, Optional
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain_core.utils import (
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils.openai import is_openai_v1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DallEAPIWrapper(BaseModel):
|
||||
@ -18,52 +26,139 @@ class DallEAPIWrapper(BaseModel):
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
openai_api_key: Optional[str] = None
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model_name: str = Field(default="dall-e-2", alias="model")
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
n: int = 1
|
||||
"""Number of images to generate"""
|
||||
size: str = "1024x1024"
|
||||
"""Size of image to generate"""
|
||||
separator: str = "\n"
|
||||
"""Separator to use when multiple URLs are returned."""
|
||||
model: Optional[str] = "dall-e-2"
|
||||
"""Model to use for image generation"""
|
||||
quality: Optional[str] = "standard"
|
||||
"""Quality of the image that will be generated"""
|
||||
max_retries: int = 2
|
||||
"""Maximum number of retries to make when generating."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key, # this is also the default, it can be omitted
|
||||
# Check OPENAI_ORGANIZATION for backwards compatibility.
|
||||
values["openai_organization"] = (
|
||||
values["openai_organization"]
|
||||
or os.getenv("OPENAI_ORG_ID")
|
||||
or os.getenv("OPENAI_ORGANIZATION")
|
||||
or None
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
|
||||
values["client"] = client
|
||||
except ImportError as e:
|
||||
try:
|
||||
import openai
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please it install it with `pip install openai`."
|
||||
) from e
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
|
||||
if is_openai_v1():
|
||||
client_params = {
|
||||
"api_key": values["openai_api_key"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
"default_query": values["default_query"],
|
||||
"http_client": values["http_client"],
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
values["client"] = openai.OpenAI(**client_params).images
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = openai.AsyncOpenAI(**client_params).images
|
||||
elif not values.get("client"):
|
||||
values["client"] = openai.Image
|
||||
else:
|
||||
pass
|
||||
return values
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Run query through OpenAI and parse result."""
|
||||
response = self.client.images.generate(
|
||||
|
||||
if is_openai_v1():
|
||||
response = self.client.generate(
|
||||
prompt=query,
|
||||
n=self.n,
|
||||
size=self.size,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
quality=self.quality,
|
||||
)
|
||||
image_urls = self.separator.join([item.url for item in response.data])
|
||||
else:
|
||||
response = self.client.create(
|
||||
prompt=query, n=self.n, size=self.size, model=self.model_name
|
||||
)
|
||||
image_urls = self.separator.join([item["url"] for item in response["data"]])
|
||||
|
||||
return image_urls if image_urls else "No image was generated"
|
||||
|
Loading…
Reference in New Issue
Block a user