EdenAI LLM update. Add models name option (#8963)

This PR follows the **Eden AI (LLM + embeddings) integration**. #8633 

We added an optional parameter to choose different AI models for
providers (like 'text-bison' for provider 'google', 'text-davinci-003'
for provider 'openai', etc.).

Usage:

```python
llm = EdenAI(
    feature="text",
    provider="google",
    params={
        "model": "text-bison",  # new
        "temperature": 0.2,
        "max_tokens": 250,
    },
)

```

You can also change the provider + model after initialization
```python
llm = EdenAI(
    feature="text",
    provider="google",
    params={
        "temperature": 0.2,
        "max_tokens": 250,
    },
)

prompt = """
hi 
"""

llm(prompt, providers='openai', model='text-davinci-003')  # change provider & model
```

The jupyter notebook as been updated with an example well.


Ping: @hwchase17, @baskaryan

---------

Co-authored-by: RedhaWassim <rwasssim@gmail.com>
Co-authored-by: sam <melaine.samy@gmail.com>
This commit is contained in:
KyrianC
2023-09-01 21:11:33 +02:00
committed by GitHub
parent b5a74fb973
commit 491089754d
5 changed files with 129 additions and 54 deletions

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
@@ -14,9 +14,15 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
edenai_api_key: Optional[str] = Field(None, description="EdenAI API Token")
provider: Optional[str] = "openai"
provider: str = "openai"
"""embedding provider to use (eg: openai,google etc.)"""
model: Optional[str] = None
"""
model name for above provider (eg: 'text-davinci-003' for openai)
available models are shown on https://docs.edenai.co/ under 'available providers'
"""
class Config:
"""Configuration for this pydantic object."""
@@ -30,6 +36,12 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
)
return values
@staticmethod
def get_user_agent() -> str:
from langchain import __version__
return f"langchain/{__version__}"
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Compute embeddings using EdenAi api."""
url = "https://api.edenai.run/v2/text/embeddings"
@@ -38,9 +50,14 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload = {"texts": texts, "providers": self.provider}
payload: Dict[str, Any] = {"texts": texts, "providers": self.provider}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload)
if response.status_code >= 500:
@@ -55,6 +72,11 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
temp = response.json()
provider_response = temp[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
embeddings = []
for embed_item in temp[self.provider]["items"]:
embedding = embed_item["embedding"]

View File

@@ -41,13 +41,24 @@ class EdenAI(LLM):
"""Subfeature of above feature, use generation by default"""
provider: str
"""Geneerative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
"""Generative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
params: Dict[str, Any]
model: Optional[str] = None
"""
Parameters to pass to above subfeature (excluding 'providers' & 'text')
ref text: https://docs.edenai.co/reference/text_generation_create
ref image: https://docs.edenai.co/reference/text_generation_create
model name for above provider (eg: 'text-davinci-003' for openai)
available models are shown on https://docs.edenai.co/ under 'available providers'
"""
# Optional parameters to add depending of chosen feature
# see api reference for more infos
temperature: Optional[float] = Field(default=None, ge=0, le=1) # for text
max_tokens: Optional[int] = Field(default=None, ge=0) # for text
resolution: Optional[Literal["256x256", "512x512", "1024x1024"]] = None # for image
params: Dict[str, Any] = Field(default_factory=dict)
"""
DEPRECATED: use temperature, max_tokens, resolution directly
optional parameters to pass to api
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -98,6 +109,12 @@ class EdenAI(LLM):
else:
return output[self.provider]["items"][0]["image"]
@staticmethod
def get_user_agent() -> str:
from langchain import __version__
return f"langchain/{__version__}"
def _call(
self,
prompt: str,
@@ -112,7 +129,6 @@ class EdenAI(LLM):
Returns:
json formatted str response.
"""
stops = None
if self.stop_sequences is not None and stop is not None:
@@ -125,16 +141,28 @@ class EdenAI(LLM):
stops = stop
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
headers = {"Authorization": f"Bearer {self.edenai_api_key}"}
payload = {
**self.params,
"providers": self.provider,
"num_images": 1, # always limit to 1 (ignored for text)
"text": prompt,
**kwargs,
headers = {
"Authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload: Dict[str, Any] = {
"providers": self.provider,
"text": prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"resolution": self.resolution,
**self.params,
**kwargs,
"num_images": 1, # always limit to 1 (ignored for text)
}
request = Requests(headers=headers)
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload)
if response.status_code >= 500:
@@ -147,7 +175,13 @@ class EdenAI(LLM):
f"{response.status_code}: {response.text}"
)
output = self._format_output(response.json())
data = response.json()
provider_response = data[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
output = self._format_output(data)
if stops is not None:
output = enforce_stop_tokens(output, stops)
@@ -182,19 +216,29 @@ class EdenAI(LLM):
else:
stops = stop
print("Running the acall")
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
headers = {"Authorization": f"Bearer {self.edenai_api_key}"}
payload = {
**self.params,
headers = {
"Authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload: Dict[str, Any] = {
"providers": self.provider,
"num_images": 1, # always limit to 1 (ignored for text)
"text": prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"resolution": self.resolution,
**self.params,
**kwargs,
"num_images": 1, # always limit to 1 (ignored for text)
}
# filter `None` values to not pass them to the http payload as null
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
async with ClientSession() as session:
print("Requesting")
async with session.post(url, json=payload, headers=headers) as response:
if response.status >= 500:
raise Exception(f"EdenAI Server: Error {response.status}")
@@ -209,6 +253,10 @@ class EdenAI(LLM):
)
response_json = await response.json()
provider_response = response_json[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
output = self._format_output(response_json)
if stops is not None:

View File

@@ -13,7 +13,7 @@ from langchain.llms import EdenAI
def test_edenai_call() -> None:
"""Test simple call to edenai."""
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
llm = EdenAI(provider="openai", temperature=0.2, max_tokens=250)
output = llm("Say foo:")
assert llm._llm_type == "edenai"
@@ -24,9 +24,23 @@ def test_edenai_call() -> None:
async def test_edenai_acall() -> None:
"""Test simple call to edenai."""
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
llm = EdenAI(provider="openai", temperature=0.2, max_tokens=250)
output = await llm.agenerate(["Say foo:"])
assert llm._llm_type == "edenai"
assert llm.feature == "text"
assert llm.subfeature == "generation"
assert isinstance(output, str)
def test_edenai_call_with_old_params() -> None:
"""
Test simple call to edenai with using `params`
to pass optional parameters to api
"""
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
output = llm("Say foo:")
assert llm._llm_type == "edenai"
assert llm.feature == "text"
assert llm.subfeature == "generation"
assert isinstance(output, str)