mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
- **Description:** As of OpenAI's Python package 1.0, the existing DallEAPIWrapper does not work correctly, so the example in the LangChain Documentation link below does not work either. https://python.langchain.com/docs/integrations/tools/dalle_image_generator Also, since OpenAI only supports DALL-E version 2 or version 3, I modified the DallEAPIWrapper to support it. - **Issue:** #13825 - **Twitter handle:** ggeutzzang
This commit is contained in:
parent
74045bf5c0
commit
981f78f920
@ -25,8 +25,10 @@ class DallEAPIWrapper(BaseModel):
|
||||
"""Size of image to generate"""
|
||||
separator: str = "\n"
|
||||
"""Separator to use when multiple URLs are returned."""
|
||||
model: Optional[str] = None
|
||||
"""Model to use for image generation."""
|
||||
model: Optional[str] = "dall-e-2"
|
||||
"""Model to use for image generation"""
|
||||
quality: Optional[str] = "standard"
|
||||
"""Quality of the image that will be generated"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -40,10 +42,13 @@ class DallEAPIWrapper(BaseModel):
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
openai.api_key = openai_api_key
|
||||
values["client"] = openai.Image
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key, # this is also the default, it can be omitted
|
||||
)
|
||||
|
||||
values["client"] = client
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
@ -53,8 +58,12 @@ class DallEAPIWrapper(BaseModel):
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Run query through OpenAI and parse result."""
|
||||
response = self.client.create(
|
||||
prompt=query, n=self.n, size=self.size, model=self.model
|
||||
response = self.client.images.generate(
|
||||
prompt=query,
|
||||
n=self.n,
|
||||
size=self.size,
|
||||
model=self.model,
|
||||
quality=self.quality,
|
||||
)
|
||||
image_urls = self.separator.join([item["url"] for item in response["data"]])
|
||||
image_urls = self.separator.join([item.url for item in response.data])
|
||||
return image_urls if image_urls else "No image was generated"
|
||||
|
Loading…
Reference in New Issue
Block a user