diff --git a/libs/langchain/langchain/utilities/dalle_image_generator.py b/libs/langchain/langchain/utilities/dalle_image_generator.py index f5d37b514ad..e805aabe505 100644 --- a/libs/langchain/langchain/utilities/dalle_image_generator.py +++ b/libs/langchain/langchain/utilities/dalle_image_generator.py @@ -8,7 +8,10 @@ from langchain.utils import get_from_dict_or_env class DallEAPIWrapper(BaseModel): """Wrapper for OpenAI's DALL-E Image Generator. + https://platform.openai.com/docs/guides/images/generations?context=node + Usage instructions: + 1. `pip install openai` 2. save your OPENAI_API_KEY in an environment variable """ @@ -21,6 +24,8 @@ class DallEAPIWrapper(BaseModel): """Size of image to generate""" separator: str = "\n" """Separator to use when multiple URLs are returned.""" + model: Optional[str] = None + """Model to use for image generation.""" class Config: """Configuration for this pydantic object.""" @@ -47,6 +52,8 @@ class DallEAPIWrapper(BaseModel): def run(self, query: str) -> str: """Run query through OpenAI and parse result.""" - response = self.client.create(prompt=query, n=self.n, size=self.size) + response = self.client.create( + prompt=query, n=self.n, size=self.size, model=self.model + ) image_urls = self.separator.join([item["url"] for item in response["data"]]) return image_urls if image_urls else "No image was generated"