mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 21:11:43 +00:00
EdenAI LLM update. Add models name option (#8963)
This PR follows the **Eden AI (LLM + embeddings) integration**. #8633 We added an optional parameter to choose different AI models for providers (like 'text-bison' for provider 'google', 'text-davinci-003' for provider 'openai', etc.). Usage: ```python llm = EdenAI( feature="text", provider="google", params={ "model": "text-bison", # new "temperature": 0.2, "max_tokens": 250, }, ) ``` You can also change the provider + model after initialization ```python llm = EdenAI( feature="text", provider="google", params={ "temperature": 0.2, "max_tokens": 250, }, ) prompt = """ hi """ llm(prompt, providers='openai', model='text-davinci-003') # change provider & model ``` The jupyter notebook as been updated with an example well. Ping: @hwchase17, @baskaryan --------- Co-authored-by: RedhaWassim <rwasssim@gmail.com> Co-authored-by: sam <melaine.samy@gmail.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
@@ -14,9 +14,15 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
edenai_api_key: Optional[str] = Field(None, description="EdenAI API Token")
|
||||
|
||||
provider: Optional[str] = "openai"
|
||||
provider: str = "openai"
|
||||
"""embedding provider to use (eg: openai,google etc.)"""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""
|
||||
model name for above provider (eg: 'text-davinci-003' for openai)
|
||||
available models are shown on https://docs.edenai.co/ under 'available providers'
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@@ -30,6 +36,12 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain import __version__
|
||||
|
||||
return f"langchain/{__version__}"
|
||||
|
||||
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute embeddings using EdenAi api."""
|
||||
url = "https://api.edenai.run/v2/text/embeddings"
|
||||
@@ -38,9 +50,14 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.edenai_api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
|
||||
payload = {"texts": texts, "providers": self.provider}
|
||||
payload: Dict[str, Any] = {"texts": texts, "providers": self.provider}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload)
|
||||
if response.status_code >= 500:
|
||||
@@ -55,6 +72,11 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
temp = response.json()
|
||||
|
||||
provider_response = temp[self.provider]
|
||||
if provider_response.get("status") == "fail":
|
||||
err_msg = provider_response.get("error", {}).get("message")
|
||||
raise Exception(err_msg)
|
||||
|
||||
embeddings = []
|
||||
for embed_item in temp[self.provider]["items"]:
|
||||
embedding = embed_item["embedding"]
|
||||
|
@@ -41,13 +41,24 @@ class EdenAI(LLM):
|
||||
"""Subfeature of above feature, use generation by default"""
|
||||
|
||||
provider: str
|
||||
"""Geneerative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
|
||||
"""Generative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
|
||||
|
||||
params: Dict[str, Any]
|
||||
model: Optional[str] = None
|
||||
"""
|
||||
Parameters to pass to above subfeature (excluding 'providers' & 'text')
|
||||
ref text: https://docs.edenai.co/reference/text_generation_create
|
||||
ref image: https://docs.edenai.co/reference/text_generation_create
|
||||
model name for above provider (eg: 'text-davinci-003' for openai)
|
||||
available models are shown on https://docs.edenai.co/ under 'available providers'
|
||||
"""
|
||||
|
||||
# Optional parameters to add depending of chosen feature
|
||||
# see api reference for more infos
|
||||
temperature: Optional[float] = Field(default=None, ge=0, le=1) # for text
|
||||
max_tokens: Optional[int] = Field(default=None, ge=0) # for text
|
||||
resolution: Optional[Literal["256x256", "512x512", "1024x1024"]] = None # for image
|
||||
|
||||
params: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""
|
||||
DEPRECATED: use temperature, max_tokens, resolution directly
|
||||
optional parameters to pass to api
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -98,6 +109,12 @@ class EdenAI(LLM):
|
||||
else:
|
||||
return output[self.provider]["items"][0]["image"]
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain import __version__
|
||||
|
||||
return f"langchain/{__version__}"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -112,7 +129,6 @@ class EdenAI(LLM):
|
||||
|
||||
Returns:
|
||||
json formatted str response.
|
||||
|
||||
"""
|
||||
stops = None
|
||||
if self.stop_sequences is not None and stop is not None:
|
||||
@@ -125,16 +141,28 @@ class EdenAI(LLM):
|
||||
stops = stop
|
||||
|
||||
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
||||
headers = {"Authorization": f"Bearer {self.edenai_api_key}"}
|
||||
payload = {
|
||||
**self.params,
|
||||
"providers": self.provider,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
"text": prompt,
|
||||
**kwargs,
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.edenai_api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
payload: Dict[str, Any] = {
|
||||
"providers": self.provider,
|
||||
"text": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"resolution": self.resolution,
|
||||
**self.params,
|
||||
**kwargs,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
}
|
||||
request = Requests(headers=headers)
|
||||
|
||||
# filter None values to not pass them to the http payload
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload)
|
||||
|
||||
if response.status_code >= 500:
|
||||
@@ -147,7 +175,13 @@ class EdenAI(LLM):
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
output = self._format_output(response.json())
|
||||
data = response.json()
|
||||
provider_response = data[self.provider]
|
||||
if provider_response.get("status") == "fail":
|
||||
err_msg = provider_response.get("error", {}).get("message")
|
||||
raise Exception(err_msg)
|
||||
|
||||
output = self._format_output(data)
|
||||
|
||||
if stops is not None:
|
||||
output = enforce_stop_tokens(output, stops)
|
||||
@@ -182,19 +216,29 @@ class EdenAI(LLM):
|
||||
else:
|
||||
stops = stop
|
||||
|
||||
print("Running the acall")
|
||||
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
||||
headers = {"Authorization": f"Bearer {self.edenai_api_key}"}
|
||||
payload = {
|
||||
**self.params,
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.edenai_api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
payload: Dict[str, Any] = {
|
||||
"providers": self.provider,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
"text": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"resolution": self.resolution,
|
||||
**self.params,
|
||||
**kwargs,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
}
|
||||
|
||||
# filter `None` values to not pass them to the http payload as null
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
async with ClientSession() as session:
|
||||
print("Requesting")
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status}")
|
||||
@@ -209,6 +253,10 @@ class EdenAI(LLM):
|
||||
)
|
||||
|
||||
response_json = await response.json()
|
||||
provider_response = response_json[self.provider]
|
||||
if provider_response.get("status") == "fail":
|
||||
err_msg = provider_response.get("error", {}).get("message")
|
||||
raise Exception(err_msg)
|
||||
|
||||
output = self._format_output(response_json)
|
||||
if stops is not None:
|
||||
|
@@ -13,7 +13,7 @@ from langchain.llms import EdenAI
|
||||
|
||||
def test_edenai_call() -> None:
|
||||
"""Test simple call to edenai."""
|
||||
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
|
||||
llm = EdenAI(provider="openai", temperature=0.2, max_tokens=250)
|
||||
output = llm("Say foo:")
|
||||
|
||||
assert llm._llm_type == "edenai"
|
||||
@@ -24,9 +24,23 @@ def test_edenai_call() -> None:
|
||||
|
||||
async def test_edenai_acall() -> None:
|
||||
"""Test simple call to edenai."""
|
||||
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
|
||||
llm = EdenAI(provider="openai", temperature=0.2, max_tokens=250)
|
||||
output = await llm.agenerate(["Say foo:"])
|
||||
assert llm._llm_type == "edenai"
|
||||
assert llm.feature == "text"
|
||||
assert llm.subfeature == "generation"
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_edenai_call_with_old_params() -> None:
|
||||
"""
|
||||
Test simple call to edenai with using `params`
|
||||
to pass optional parameters to api
|
||||
"""
|
||||
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
|
||||
output = llm("Say foo:")
|
||||
|
||||
assert llm._llm_type == "edenai"
|
||||
assert llm.feature == "text"
|
||||
assert llm.subfeature == "generation"
|
||||
assert isinstance(output, str)
|
||||
|
Reference in New Issue
Block a user