mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
_dalle_image_url
returns list of urls if n>1 (#11800)
- **Description:** Updated the `_dalle_image_url` method to return a list of URLs if self.n>1, - **Issue:** #10691, - **Dependencies:** unsure, - **Tag maintainer:** @eyurtsev, - **Twitter handle:** @silvhua --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
1815ea2fdb
commit
9dead1034c
@ -8,29 +8,25 @@ from langchain.utils import get_from_dict_or_env
|
||||
class DallEAPIWrapper(BaseModel):
|
||||
"""Wrapper for OpenAI's DALL-E Image Generator.
|
||||
|
||||
Docs for using:
|
||||
1. pip install openai
|
||||
Usage instructions:
|
||||
1. `pip install openai`
|
||||
2. save your OPENAI_API_KEY in an environment variable
|
||||
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
openai_api_key: Optional[str] = None
|
||||
"""number of images to generate"""
|
||||
n: int = 1
|
||||
"""size of image to generate"""
|
||||
"""Number of images to generate"""
|
||||
size: str = "1024x1024"
|
||||
"""Size of image to generate"""
|
||||
separator: str = "\n"
|
||||
"""Separator to use when multiple URLs are returned."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def _dalle_image_url(self, prompt: str) -> str:
|
||||
params = {"prompt": prompt, "n": self.n, "size": self.size}
|
||||
response = self.client.create(**params)
|
||||
return response["data"][0]["url"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@ -42,19 +38,15 @@ class DallEAPIWrapper(BaseModel):
|
||||
|
||||
openai.api_key = openai_api_key
|
||||
values["client"] = openai.Image
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please it install it with `pip install openai`."
|
||||
)
|
||||
) from e
|
||||
return values
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Run query through OpenAI and parse result."""
|
||||
image_url = self._dalle_image_url(query)
|
||||
|
||||
if image_url is None or image_url == "":
|
||||
# We don't want to return the assumption alone if answer is empty
|
||||
return "No image was generated"
|
||||
else:
|
||||
return image_url
|
||||
response = self.client.create(prompt=query, n=self.n, size=self.size)
|
||||
image_urls = self.separator.join([item["url"] for item in response["data"]])
|
||||
return image_urls if image_urls else "No image was generated"
|
||||
|
Loading…
Reference in New Issue
Block a user