langchain/libs/community/langchain_community/llms/edenai.py
Bagatur a0c2281540
infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path

import toml
import subprocess
import re

ROOT_DIR = Path(__file__).parents[1]


def main():
    for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
        print(path)
        with open(path, "rb") as f:
            pyproject = tomllib.load(f)
        try:
            pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
                "^1.10"
            )
            pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
                "^0.5"
            )
        except KeyError:
            continue
        with open(path, "w") as f:
            toml.dump(pyproject, f)
        cwd = "/".join(path.split("/")[:-1])
        completed = subprocess.run(
            "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )
        logs = completed.stdout.split("\n")

        to_ignore = {}
        for l in logs:
            if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
                path, line_no, error_type = re.match(
                    "^(.*)\:(\d+)\: error:.*\[(.*)\]", l
                ).groups()
                if (path, line_no) in to_ignore:
                    to_ignore[(path, line_no)].append(error_type)
                else:
                    to_ignore[(path, line_no)] = [error_type]
        print(len(to_ignore))
        for (error_path, line_no), error_types in to_ignore.items():
            all_errors = ", ".join(error_types)
            full_path = f"{cwd}/{error_path}"
            try:
                with open(full_path, "r") as f:
                    file_lines = f.readlines()
            except FileNotFoundError:
                continue
            file_lines[int(line_no) - 1] = (
                file_lines[int(line_no) - 1][:-1] + f"  # type: ignore[{all_errors}]\n"
            )
            with open(full_path, "w") as f:
                f.write("".join(file_lines))

        subprocess.run(
            "poetry run ruff format .; poetry run ruff --select I --fix .",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )


if __name__ == "__main__":
    main()

```
2024-07-03 10:33:27 -07:00

267 lines
9.2 KiB
Python

"""Wrapper around EdenAI's Generation API."""
import logging
from typing import Any, Dict, List, Literal, Optional
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens
from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__)
class EdenAI(LLM):
"""EdenAI models.
To use, you should have
the environment variable ``EDENAI_API_KEY`` set with your API token.
You can find your token here: https://app.edenai.run/admin/account/settings
`feature` and `subfeature` are required, but any other model parameters can also be
passed in with the format params={model_param: value, ...}
for api reference check edenai documentation: http://docs.edenai.co.
"""
base_url: str = "https://api.edenai.run/v2"
edenai_api_key: Optional[str] = None
feature: Literal["text", "image"] = "text"
"""Which generative feature to use, use text by default"""
subfeature: Literal["generation"] = "generation"
"""Subfeature of above feature, use generation by default"""
provider: str
"""Generative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
model: Optional[str] = None
"""
model name for above provider (eg: 'gpt-3.5-turbo-instruct' for openai)
available models are shown on https://docs.edenai.co/ under 'available providers'
"""
# Optional parameters to add depending of chosen feature
# see api reference for more infos
temperature: Optional[float] = Field(default=None, ge=0, le=1) # for text
max_tokens: Optional[int] = Field(default=None, ge=0) # for text
resolution: Optional[Literal["256x256", "512x512", "1024x1024"]] = None # for image
params: Dict[str, Any] = Field(default_factory=dict)
"""
DEPRECATED: use temperature, max_tokens, resolution directly
optional parameters to pass to api
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra parameters"""
stop_sequences: Optional[List[str]] = None
"""Stop sequences to use."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["edenai_api_key"] = get_from_dict_or_env(
values, "edenai_api_key", "EDENAI_API_KEY"
)
return values
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra
return values
@property
def _llm_type(self) -> str:
"""Return type of model."""
return "edenai"
def _format_output(self, output: dict) -> str:
if self.feature == "text":
return output[self.provider]["generated_text"]
else:
return output[self.provider]["items"][0]["image"]
@staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to EdenAI's text generation endpoint.
Args:
prompt: The prompt to pass into the model.
Returns:
json formatted str response.
"""
stops = None
if self.stop_sequences is not None and stop is not None:
raise ValueError(
"stop sequences found in both the input and default params."
)
elif self.stop_sequences is not None:
stops = self.stop_sequences
else:
stops = stop
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
headers = {
"Authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload: Dict[str, Any] = {
"providers": self.provider,
"text": prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"resolution": self.resolution,
**self.params,
**kwargs,
"num_images": 1, # always limit to 1 (ignored for text)
}
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload)
if response.status_code >= 500:
raise Exception(f"EdenAI Server: Error {response.status_code}")
elif response.status_code >= 400:
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
elif response.status_code != 200:
raise Exception(
f"EdenAI returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
data = response.json()
provider_response = data[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
output = self._format_output(data)
if stops is not None:
output = enforce_stop_tokens(output, stops)
return output
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call EdenAi model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
stops = None
if self.stop_sequences is not None and stop is not None:
raise ValueError(
"stop sequences found in both the input and default params."
)
elif self.stop_sequences is not None:
stops = self.stop_sequences
else:
stops = stop
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
headers = {
"Authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload: Dict[str, Any] = {
"providers": self.provider,
"text": prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"resolution": self.resolution,
**self.params,
**kwargs,
"num_images": 1, # always limit to 1 (ignored for text)
}
# filter `None` values to not pass them to the http payload as null
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
async with ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
if response.status >= 500:
raise Exception(f"EdenAI Server: Error {response.status}")
elif response.status >= 400:
raise ValueError(
f"EdenAI received an invalid payload: {response.text}"
)
elif response.status != 200:
raise Exception(
f"EdenAI returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
provider_response = response_json[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
output = self._format_output(response_json)
if stops is not None:
output = enforce_stop_tokens(output, stops)
return output