Fix: (issue #13825) Getting an error with DallEAPIWrapper (#13874)

- **Description:** As of OpenAI's Python package 1.0, the existing
DallEAPIWrapper does not work correctly, so the example in the LangChain
Documentation link below does not work either.

https://python.langchain.com/docs/integrations/tools/dalle_image_generator
Also, since OpenAI only supports DALL-E version 2 or version 3, I
modified the DallEAPIWrapper to support it.

  - **Issue:** #13825 

  - **Twitter handle:** ggeutzzang
This commit is contained in:
ggeutzzang 2023-11-29 12:31:25 +09:00 committed by GitHub
parent 74045bf5c0
commit 981f78f920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,8 +25,10 @@ class DallEAPIWrapper(BaseModel):
"""Size of image to generate"""
separator: str = "\n"
"""Separator to use when multiple URLs are returned."""
model: Optional[str] = None
"""Model to use for image generation."""
model: Optional[str] = "dall-e-2"
"""Model to use for image generation"""
quality: Optional[str] = "standard"
"""Quality of the image that will be generated"""
class Config:
"""Configuration for this pydantic object."""
@ -40,10 +42,13 @@ class DallEAPIWrapper(BaseModel):
values, "openai_api_key", "OPENAI_API_KEY"
)
try:
import openai
from openai import OpenAI
openai.api_key = openai_api_key
values["client"] = openai.Image
client = OpenAI(
api_key=openai_api_key, # this is also the default, it can be omitted
)
values["client"] = client
except ImportError as e:
raise ImportError(
"Could not import openai python package. "
@ -53,8 +58,12 @@ class DallEAPIWrapper(BaseModel):
def run(self, query: str) -> str:
"""Run query through OpenAI and parse result."""
response = self.client.create(
prompt=query, n=self.n, size=self.size, model=self.model
response = self.client.images.generate(
prompt=query,
n=self.n,
size=self.size,
model=self.model,
quality=self.quality,
)
image_urls = self.separator.join([item["url"] for item in response["data"]])
image_urls = self.separator.join([item.url for item in response.data])
return image_urls if image_urls else "No image was generated"