mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-18 16:16:33 +00:00
Bagatur/eden llm (#8670)
Co-authored-by: RedhaWassim <rwasssim@gmail.com> Co-authored-by: KyrianC <ckyrian@protonmail.com> Co-authored-by: sam <melaine.samy@gmail.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from langchain.embeddings.clarifai import ClarifaiEmbeddings
|
||||
from langchain.embeddings.cohere import CohereEmbeddings
|
||||
from langchain.embeddings.dashscope import DashScopeEmbeddings
|
||||
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
||||
from langchain.embeddings.edenai import EdenAiEmbeddings
|
||||
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
||||
from langchain.embeddings.embaas import EmbaasEmbeddings
|
||||
from langchain.embeddings.fake import FakeEmbeddings
|
||||
@@ -85,6 +86,7 @@ __all__ = [
|
||||
"VertexAIEmbeddings",
|
||||
"BedrockEmbeddings",
|
||||
"DeepInfraEmbeddings",
|
||||
"EdenAiEmbeddings",
|
||||
"DashScopeEmbeddings",
|
||||
"EmbaasEmbeddings",
|
||||
"OctoAIEmbeddings",
|
||||
|
88
libs/langchain/langchain/embeddings/edenai.py
Normal file
88
libs/langchain/langchain/embeddings/edenai.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.requests import Requests
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class EdenAiEmbeddings(BaseModel, Embeddings):
|
||||
"""EdenAI embedding.
|
||||
environment variable ``EDENAI_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter.
|
||||
"""
|
||||
|
||||
edenai_api_key: Optional[str] = Field(None, description="EdenAI API Token")
|
||||
|
||||
provider: Optional[str] = "openai"
|
||||
"""embedding provider to use (eg: openai,google etc.)"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["edenai_api_key"] = get_from_dict_or_env(
|
||||
values, "edenai_api_key", "EDENAI_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute embeddings using EdenAi api."""
|
||||
url = "https://api.edenai.run/v2/text/embeddings"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.edenai_api_key}",
|
||||
}
|
||||
|
||||
payload = {"texts": texts, "providers": self.provider}
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload)
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"EdenAI returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
temp = response.json()
|
||||
|
||||
embeddings = []
|
||||
for embed_item in temp[self.provider]["items"]:
|
||||
embedding = embed_item["embedding"]
|
||||
|
||||
embeddings.append(embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents using EdenAI.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
return self._generate_embeddings(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query using EdenAI.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self._generate_embeddings([text])[0]
|
@@ -38,6 +38,7 @@ from langchain.llms.cohere import Cohere
|
||||
from langchain.llms.ctransformers import CTransformers
|
||||
from langchain.llms.databricks import Databricks
|
||||
from langchain.llms.deepinfra import DeepInfra
|
||||
from langchain.llms.edenai import EdenAI
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.llms.fireworks import Fireworks, FireworksChat
|
||||
from langchain.llms.forefrontai import ForefrontAI
|
||||
@@ -98,6 +99,7 @@ __all__ = [
|
||||
"Cohere",
|
||||
"Databricks",
|
||||
"DeepInfra",
|
||||
"EdenAI",
|
||||
"FakeListLLM",
|
||||
"Fireworks",
|
||||
"FireworksChat",
|
||||
@@ -162,6 +164,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"ctransformers": CTransformers,
|
||||
"databricks": Databricks,
|
||||
"deepinfra": DeepInfra,
|
||||
"edenai": EdenAI,
|
||||
"fake-list": FakeListLLM,
|
||||
"forefrontai": ForefrontAI,
|
||||
"google_palm": GooglePalm,
|
||||
|
217
libs/langchain/langchain/llms/edenai.py
Normal file
217
libs/langchain/langchain/llms/edenai.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Wrapper around EdenAI's Generation API."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.requests import Requests
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EdenAI(LLM):
|
||||
"""Wrapper around edenai models.
|
||||
|
||||
To use, you should have
|
||||
the environment variable ``EDENAI_API_KEY`` set with your API token.
|
||||
You can find your token here: https://app.edenai.run/admin/account/settings
|
||||
|
||||
`feature` and `subfeature` are required, but any other model parameters can also be
|
||||
passed in with the format params={model_param: value, ...}
|
||||
|
||||
for api reference check edenai documentation: http://docs.edenai.co.
|
||||
"""
|
||||
|
||||
base_url = "https://api.edenai.run/v2"
|
||||
|
||||
edenai_api_key: Optional[str] = None
|
||||
|
||||
feature: Literal["text", "image"] = "text"
|
||||
"""Which generative feature to use, use text by default"""
|
||||
|
||||
subfeature: Literal["generation"] = "generation"
|
||||
"""Subfeature of above feature, use generation by default"""
|
||||
|
||||
provider: str
|
||||
"""Geneerative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
|
||||
|
||||
params: Dict[str, Any]
|
||||
"""
|
||||
Parameters to pass to above subfeature (excluding 'providers' & 'text')
|
||||
ref text: https://docs.edenai.co/reference/text_generation_create
|
||||
ref image: https://docs.edenai.co/reference/text_generation_create
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""extra parameters"""
|
||||
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
"""Stop sequences to use."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["edenai_api_key"] = get_from_dict_or_env(
|
||||
values, "edenai_api_key", "EDENAI_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
logger.warning(
|
||||
f"""{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of model."""
|
||||
return "edenai"
|
||||
|
||||
def _format_output(self, output: dict) -> str:
|
||||
if self.feature == "text":
|
||||
return output[self.provider]["generated_text"]
|
||||
else:
|
||||
return output[self.provider]["items"][0]["image"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to EdenAI's text generation endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
json formatted str response.
|
||||
|
||||
"""
|
||||
stops = None
|
||||
if self.stop_sequences is not None and stop is not None:
|
||||
raise ValueError(
|
||||
"stop sequences found in both the input and default params."
|
||||
)
|
||||
elif self.stop_sequences is not None:
|
||||
stops = self.stop_sequences
|
||||
else:
|
||||
stops = stop
|
||||
|
||||
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
||||
headers = {"Authorization": f"Bearer {self.edenai_api_key}"}
|
||||
payload = {
|
||||
**self.params,
|
||||
"providers": self.provider,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
"text": prompt,
|
||||
**kwargs,
|
||||
}
|
||||
request = Requests(headers=headers)
|
||||
|
||||
response = request.post(url=url, data=payload)
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"EdenAI returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
output = self._format_output(response.json())
|
||||
|
||||
if stops is not None:
|
||||
output = enforce_stop_tokens(output, stops)
|
||||
|
||||
return output
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call EdenAi model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of stop words (optional).
|
||||
run_manager: A callback manager for async interaction with LLMs.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
|
||||
stops = None
|
||||
if self.stop_sequences is not None and stop is not None:
|
||||
raise ValueError(
|
||||
"stop sequences found in both the input and default params."
|
||||
)
|
||||
elif self.stop_sequences is not None:
|
||||
stops = self.stop_sequences
|
||||
else:
|
||||
stops = stop
|
||||
|
||||
print("Running the acall")
|
||||
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
||||
headers = {"Authorization": f"Bearer {self.edenai_api_key}"}
|
||||
payload = {
|
||||
**self.params,
|
||||
"providers": self.provider,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
"text": prompt,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
async with ClientSession() as session:
|
||||
print("Requesting")
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status}")
|
||||
elif response.status >= 400:
|
||||
raise ValueError(
|
||||
f"EdenAI received an invalid payload: {response.text}"
|
||||
)
|
||||
elif response.status != 200:
|
||||
raise Exception(
|
||||
f"EdenAI returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
|
||||
response_json = await response.json()
|
||||
|
||||
output = self._format_output(response_json)
|
||||
if stops is not None:
|
||||
output = enforce_stop_tokens(output, stops)
|
||||
|
||||
return output
|
@@ -0,0 +1,21 @@
|
||||
"""Test edenai embeddings."""
|
||||
|
||||
from langchain.embeddings.edenai import EdenAiEmbeddings
|
||||
|
||||
|
||||
def test_edenai_embedding_documents() -> None:
|
||||
"""Test edenai embeddings with openai."""
|
||||
documents = ["foo bar", "test text"]
|
||||
embedding = EdenAiEmbeddings(provider="openai")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 1536
|
||||
assert len(output[1]) == 1536
|
||||
|
||||
|
||||
def test_edenai_embedding_query() -> None:
|
||||
"""Test eden ai embeddings with google."""
|
||||
document = "foo bar"
|
||||
embedding = EdenAiEmbeddings(provider="google")
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 768
|
32
libs/langchain/tests/integration_tests/llms/test_edenai.py
Normal file
32
libs/langchain/tests/integration_tests/llms/test_edenai.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Test EdenAi API wrapper.
|
||||
|
||||
In order to run this test, you need to have an EdenAI api key.
|
||||
You can get it by registering for free at https://app.edenai.run/user/register.
|
||||
A test key can be found at https://app.edenai.run/admin/account/settings by
|
||||
clicking on the 'sandbox' toggle.
|
||||
(calls will be free, and will return dummy results)
|
||||
|
||||
You'll then need to set EDENAI_API_KEY environment variable to your api key.
|
||||
"""
|
||||
from langchain.llms import EdenAI
|
||||
|
||||
|
||||
def test_edenai_call() -> None:
|
||||
"""Test simple call to edenai."""
|
||||
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
|
||||
output = llm("Say foo:")
|
||||
|
||||
assert llm._llm_type == "edenai"
|
||||
assert llm.feature == "text"
|
||||
assert llm.subfeature == "generation"
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
async def test_edenai_acall() -> None:
|
||||
"""Test simple call to edenai."""
|
||||
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
|
||||
output = await llm.agenerate(["Say foo:"])
|
||||
assert llm._llm_type == "edenai"
|
||||
assert llm.feature == "text"
|
||||
assert llm.subfeature == "generation"
|
||||
assert isinstance(output, str)
|
Reference in New Issue
Block a user