mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463)
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
This commit is contained in:
265
libs/community/langchain_community/llms/edenai.py
Normal file
265
libs/community/langchain_community/llms/edenai.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Wrapper around EdenAI's Generation API."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EdenAI(LLM):
|
||||
"""Wrapper around edenai models.
|
||||
|
||||
To use, you should have
|
||||
the environment variable ``EDENAI_API_KEY`` set with your API token.
|
||||
You can find your token here: https://app.edenai.run/admin/account/settings
|
||||
|
||||
`feature` and `subfeature` are required, but any other model parameters can also be
|
||||
passed in with the format params={model_param: value, ...}
|
||||
|
||||
for api reference check edenai documentation: http://docs.edenai.co.
|
||||
"""
|
||||
|
||||
base_url: str = "https://api.edenai.run/v2"
|
||||
|
||||
edenai_api_key: Optional[str] = None
|
||||
|
||||
feature: Literal["text", "image"] = "text"
|
||||
"""Which generative feature to use, use text by default"""
|
||||
|
||||
subfeature: Literal["generation"] = "generation"
|
||||
"""Subfeature of above feature, use generation by default"""
|
||||
|
||||
provider: str
|
||||
"""Generative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""
|
||||
model name for above provider (eg: 'text-davinci-003' for openai)
|
||||
available models are shown on https://docs.edenai.co/ under 'available providers'
|
||||
"""
|
||||
|
||||
# Optional parameters to add depending of chosen feature
|
||||
# see api reference for more infos
|
||||
temperature: Optional[float] = Field(default=None, ge=0, le=1) # for text
|
||||
max_tokens: Optional[int] = Field(default=None, ge=0) # for text
|
||||
resolution: Optional[Literal["256x256", "512x512", "1024x1024"]] = None # for image
|
||||
|
||||
params: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""
|
||||
DEPRECATED: use temperature, max_tokens, resolution directly
|
||||
optional parameters to pass to api
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""extra parameters"""
|
||||
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
"""Stop sequences to use."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["edenai_api_key"] = get_from_dict_or_env(
|
||||
values, "edenai_api_key", "EDENAI_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
logger.warning(
|
||||
f"""{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of model."""
|
||||
return "edenai"
|
||||
|
||||
def _format_output(self, output: dict) -> str:
|
||||
if self.feature == "text":
|
||||
return output[self.provider]["generated_text"]
|
||||
else:
|
||||
return output[self.provider]["items"][0]["image"]
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain_community import __version__
|
||||
|
||||
return f"langchain/{__version__}"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to EdenAI's text generation endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
json formatted str response.
|
||||
"""
|
||||
stops = None
|
||||
if self.stop_sequences is not None and stop is not None:
|
||||
raise ValueError(
|
||||
"stop sequences found in both the input and default params."
|
||||
)
|
||||
elif self.stop_sequences is not None:
|
||||
stops = self.stop_sequences
|
||||
else:
|
||||
stops = stop
|
||||
|
||||
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.edenai_api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
payload: Dict[str, Any] = {
|
||||
"providers": self.provider,
|
||||
"text": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"resolution": self.resolution,
|
||||
**self.params,
|
||||
**kwargs,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
}
|
||||
|
||||
# filter None values to not pass them to the http payload
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload)
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"EdenAI returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
provider_response = data[self.provider]
|
||||
if provider_response.get("status") == "fail":
|
||||
err_msg = provider_response.get("error", {}).get("message")
|
||||
raise Exception(err_msg)
|
||||
|
||||
output = self._format_output(data)
|
||||
|
||||
if stops is not None:
|
||||
output = enforce_stop_tokens(output, stops)
|
||||
|
||||
return output
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call EdenAi model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of stop words (optional).
|
||||
run_manager: A callback manager for async interaction with LLMs.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
|
||||
stops = None
|
||||
if self.stop_sequences is not None and stop is not None:
|
||||
raise ValueError(
|
||||
"stop sequences found in both the input and default params."
|
||||
)
|
||||
elif self.stop_sequences is not None:
|
||||
stops = self.stop_sequences
|
||||
else:
|
||||
stops = stop
|
||||
|
||||
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.edenai_api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
payload: Dict[str, Any] = {
|
||||
"providers": self.provider,
|
||||
"text": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"resolution": self.resolution,
|
||||
**self.params,
|
||||
**kwargs,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
}
|
||||
|
||||
# filter `None` values to not pass them to the http payload as null
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
async with ClientSession() as session:
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status}")
|
||||
elif response.status >= 400:
|
||||
raise ValueError(
|
||||
f"EdenAI received an invalid payload: {response.text}"
|
||||
)
|
||||
elif response.status != 200:
|
||||
raise Exception(
|
||||
f"EdenAI returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
|
||||
response_json = await response.json()
|
||||
provider_response = response_json[self.provider]
|
||||
if provider_response.get("status") == "fail":
|
||||
err_msg = provider_response.get("error", {}).get("message")
|
||||
raise Exception(err_msg)
|
||||
|
||||
output = self._format_output(response_json)
|
||||
if stops is not None:
|
||||
output = enforce_stop_tokens(output, stops)
|
||||
|
||||
return output
|
Reference in New Issue
Block a user