Harrison/image (#845)

Co-authored-by: Ashutosh Sanzgiri <sanzgiri@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Harrison Chase
2023-08-08 13:58:27 -07:00
committed by GitHub
parent ab193338aa
commit 7543a3d70e
5 changed files with 277 additions and 0 deletions

View File

@@ -11,6 +11,7 @@ from langchain.callbacks.manager import Callbacks
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.utilities.dalle_image_generator import DallEAPIWrapper
from langchain.utilities.requests import TextRequestsWrapper
from langchain.tools.arxiv.tool import ArxivQueryRun
from langchain.tools.golden_query.tool import GoldenQueryRun
@@ -221,6 +222,14 @@ def _get_serpapi(**kwargs: Any) -> BaseTool:
)
def _get_dalle_image_generator(**kwargs: Any) -> Tool:
return Tool(
"Dall-E Image Generator",
DallEAPIWrapper(**kwargs).run,
"A wrapper around OpenAI DALL-E API. Useful for when you need to generate images from a text description. Input should be an image description.",
)
def _get_twilio(**kwargs: Any) -> BaseTool:
return Tool(
name="Text Message",
@@ -305,6 +314,7 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
["serper_api_key", "aiosession"],
),
"serpapi": (_get_serpapi, ["serpapi_api_key", "aiosession"]),
"dalle-image-generator": (_get_dalle_image_generator, ["openai_api_key"]),
"twilio": (_get_twilio, ["account_sid", "auth_token", "from_number"]),
"searx-search": (_get_searx_search, ["searx_host", "engines", "aiosession"]),
"wikipedia": (_get_wikipedia, ["top_k_results", "lang"]),

View File

@@ -0,0 +1,61 @@
"""Util that calls OpenAI's Dall-E Image Generator."""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.utils import get_from_dict_or_env
class DallEAPIWrapper(BaseModel):
"""Wrapper for OpenAI's DALL-E Image Generator.
Docs for using:
1. pip install openai
2. save your OPENAI_API_KEY in an environment variable
"""
client: Any #: :meta private:
openai_api_key: Optional[str] = None
"""number of images to generate"""
n: int = 1
"""size of image to generate"""
size: str = "1024x1024"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def _dalle_image_url(self, prompt: str) -> str:
params = {"prompt": prompt, "n": self.n, "size": self.size}
response = self.client.create(**params)
return response["data"][0]["url"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
try:
import openai
openai.api_key = openai_api_key
values["client"] = openai.Image
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please it install it with `pip install openai`."
)
return values
def run(self, query: str) -> str:
"""Run query through OpenAI and parse result."""
image_url = self._dalle_image_url(query)
if image_url is None or image_url == "":
# We don't want to return the assumption alone if answer is empty
return "No image was generated"
else:
return image_url

View File

@@ -0,0 +1,16 @@
"""Integration test for Dall-E image generator agent."""
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.llms import OpenAI
def test_call() -> None:
"""Test that the agent runs and returns output."""
llm = OpenAI(temperature=0.9)
tools = load_tools(["dalle-image-generator"])
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
output = agent.run("Create an image of a volcano island")
assert output is not None

View File

@@ -0,0 +1,9 @@
"""Integration test for DallE API Wrapper."""
from langchain.utilities.dalle_image_generator import DallEAPIWrapper
def test_call() -> None:
"""Test that call returns a URL in the output."""
search = DallEAPIWrapper()
output = search.run("volcano island")
assert "https://oaidalleapi" in output