mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
dalle add model parameter (#13201)
- **Description:** dalle_image_generator adding a new model parameter, - **Issue:** N/A, - **Dependencies:** - **Tag maintainer: @hwchase17 - **Twitter handle:** --------- Co-authored-by: dafu <xiangbingze@wenru.wang> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Erick Friis <erickfriis@gmail.com>
This commit is contained in:
parent
96b56a4d4f
commit
e1c020dfe1
@ -8,7 +8,10 @@ from langchain.utils import get_from_dict_or_env
|
|||||||
class DallEAPIWrapper(BaseModel):
|
class DallEAPIWrapper(BaseModel):
|
||||||
"""Wrapper for OpenAI's DALL-E Image Generator.
|
"""Wrapper for OpenAI's DALL-E Image Generator.
|
||||||
|
|
||||||
|
https://platform.openai.com/docs/guides/images/generations?context=node
|
||||||
|
|
||||||
Usage instructions:
|
Usage instructions:
|
||||||
|
|
||||||
1. `pip install openai`
|
1. `pip install openai`
|
||||||
2. save your OPENAI_API_KEY in an environment variable
|
2. save your OPENAI_API_KEY in an environment variable
|
||||||
"""
|
"""
|
||||||
@ -21,6 +24,8 @@ class DallEAPIWrapper(BaseModel):
|
|||||||
"""Size of image to generate"""
|
"""Size of image to generate"""
|
||||||
separator: str = "\n"
|
separator: str = "\n"
|
||||||
"""Separator to use when multiple URLs are returned."""
|
"""Separator to use when multiple URLs are returned."""
|
||||||
|
model: Optional[str] = None
|
||||||
|
"""Model to use for image generation."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -47,6 +52,8 @@ class DallEAPIWrapper(BaseModel):
|
|||||||
|
|
||||||
def run(self, query: str) -> str:
|
def run(self, query: str) -> str:
|
||||||
"""Run query through OpenAI and parse result."""
|
"""Run query through OpenAI and parse result."""
|
||||||
response = self.client.create(prompt=query, n=self.n, size=self.size)
|
response = self.client.create(
|
||||||
|
prompt=query, n=self.n, size=self.size, model=self.model
|
||||||
|
)
|
||||||
image_urls = self.separator.join([item["url"] for item in response["data"]])
|
image_urls = self.separator.join([item["url"] for item in response["data"]])
|
||||||
return image_urls if image_urls else "No image was generated"
|
return image_urls if image_urls else "No image was generated"
|
||||||
|
Loading…
Reference in New Issue
Block a user