langchain[patch]: Cerebrium model_api_request deprecation (#12704)

- **Description:** As part of my conversation with Cerebrium team,
`model_api_request` will be no longer available in cerebrium lib so it
needs to be replaced.
  - **Issue:** #12705 12705,
  - **Dependencies:** Cerebrium team (agreed)
  - **Tag maintainer:** @eyurtsev 
  - **Twitter handle:** No official Twitter account sorry :D

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
geret1 2023-12-04 18:26:32 +01:00 committed by GitHub
parent ee94ef55ee
commit 50aee687c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
import logging
from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
@ -89,24 +90,21 @@ class CerebriumAI(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to CerebriumAI endpoint."""
try:
from cerebrium import model_api_request
except ImportError:
raise ValueError(
"Could not import cerebrium python package. "
"Please install it with `pip install cerebrium`."
)
headers: Dict = {
"Authorization": self.cerebriumai_api_key,
"Content-Type": "application/json",
}
params = self.model_kwargs or {}
response = model_api_request(
self.endpoint_url,
{"prompt": prompt, **params, **kwargs},
self.cerebriumai_api_key,
)
text = response["data"]["result"]
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text
payload = {"prompt": prompt, **params, **kwargs}
response = requests.post(self.endpoint_url, json=payload, headers=headers)
if response.status_code == 200:
data = response.json()
text = data["result"]
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text
else:
response.raise_for_status()
return ""