mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
Harrison/image (#845)
Co-authored-by: Ashutosh Sanzgiri <sanzgiri@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.utilities.dalle_image_generator import DallEAPIWrapper
|
||||
from langchain.utilities.requests import TextRequestsWrapper
|
||||
from langchain.tools.arxiv.tool import ArxivQueryRun
|
||||
from langchain.tools.golden_query.tool import GoldenQueryRun
|
||||
@@ -221,6 +222,14 @@ def _get_serpapi(**kwargs: Any) -> BaseTool:
|
||||
)
|
||||
|
||||
|
||||
def _get_dalle_image_generator(**kwargs: Any) -> Tool:
|
||||
return Tool(
|
||||
"Dall-E Image Generator",
|
||||
DallEAPIWrapper(**kwargs).run,
|
||||
"A wrapper around OpenAI DALL-E API. Useful for when you need to generate images from a text description. Input should be an image description.",
|
||||
)
|
||||
|
||||
|
||||
def _get_twilio(**kwargs: Any) -> BaseTool:
|
||||
return Tool(
|
||||
name="Text Message",
|
||||
@@ -305,6 +314,7 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
|
||||
["serper_api_key", "aiosession"],
|
||||
),
|
||||
"serpapi": (_get_serpapi, ["serpapi_api_key", "aiosession"]),
|
||||
"dalle-image-generator": (_get_dalle_image_generator, ["openai_api_key"]),
|
||||
"twilio": (_get_twilio, ["account_sid", "auth_token", "from_number"]),
|
||||
"searx-search": (_get_searx_search, ["searx_host", "engines", "aiosession"]),
|
||||
"wikipedia": (_get_wikipedia, ["top_k_results", "lang"]),
|
||||
|
61
libs/langchain/langchain/utilities/dalle_image_generator.py
Normal file
61
libs/langchain/langchain/utilities/dalle_image_generator.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Util that calls OpenAI's Dall-E Image Generator."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class DallEAPIWrapper(BaseModel):
|
||||
"""Wrapper for OpenAI's DALL-E Image Generator.
|
||||
|
||||
Docs for using:
|
||||
1. pip install openai
|
||||
2. save your OPENAI_API_KEY in an environment variable
|
||||
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
openai_api_key: Optional[str] = None
|
||||
"""number of images to generate"""
|
||||
n: int = 1
|
||||
"""size of image to generate"""
|
||||
size: str = "1024x1024"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def _dalle_image_url(self, prompt: str) -> str:
|
||||
params = {"prompt": prompt, "n": self.n, "size": self.size}
|
||||
response = self.client.create(**params)
|
||||
return response["data"][0]["url"]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
openai.api_key = openai_api_key
|
||||
values["client"] = openai.Image
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please it install it with `pip install openai`."
|
||||
)
|
||||
return values
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Run query through OpenAI and parse result."""
|
||||
image_url = self._dalle_image_url(query)
|
||||
|
||||
if image_url is None or image_url == "":
|
||||
# We don't want to return the assumption alone if answer is empty
|
||||
return "No image was generated"
|
||||
else:
|
||||
return image_url
|
@@ -0,0 +1,16 @@
|
||||
"""Integration test for Dall-E image generator agent."""
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test that the agent runs and returns output."""
|
||||
llm = OpenAI(temperature=0.9)
|
||||
tools = load_tools(["dalle-image-generator"])
|
||||
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
|
||||
output = agent.run("Create an image of a volcano island")
|
||||
assert output is not None
|
9
libs/langchain/tests/integration_tests/test_dalle.py
Normal file
9
libs/langchain/tests/integration_tests/test_dalle.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Integration test for DallE API Wrapper."""
|
||||
from langchain.utilities.dalle_image_generator import DallEAPIWrapper
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test that call returns a URL in the output."""
|
||||
search = DallEAPIWrapper()
|
||||
output = search.run("volcano island")
|
||||
assert "https://oaidalleapi" in output
|
Reference in New Issue
Block a user