mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
langchain[patch]: Cerebrium model_api_request deprecation (#12704)
- **Description:** As part of my conversation with Cerebrium team, `model_api_request` will be no longer available in cerebrium lib so it needs to be replaced. - **Issue:** #12705 12705, - **Dependencies:** Cerebrium team (agreed) - **Tag maintainer:** @eyurtsev - **Twitter handle:** No official Twitter account sorry :D --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ee94ef55ee
commit
50aee687c6
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
@ -89,24 +90,21 @@ class CerebriumAI(LLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to CerebriumAI endpoint."""
|
||||
try:
|
||||
from cerebrium import model_api_request
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import cerebrium python package. "
|
||||
"Please install it with `pip install cerebrium`."
|
||||
)
|
||||
|
||||
headers: Dict = {
|
||||
"Authorization": self.cerebriumai_api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
params = self.model_kwargs or {}
|
||||
response = model_api_request(
|
||||
self.endpoint_url,
|
||||
{"prompt": prompt, **params, **kwargs},
|
||||
self.cerebriumai_api_key,
|
||||
)
|
||||
text = response["data"]["result"]
|
||||
if stop is not None:
|
||||
# I believe this is required since the stop tokens
|
||||
# are not enforced by the model parameters
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
payload = {"prompt": prompt, **params, **kwargs}
|
||||
response = requests.post(self.endpoint_url, json=payload, headers=headers)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
text = data["result"]
|
||||
if stop is not None:
|
||||
# I believe this is required since the stop tokens
|
||||
# are not enforced by the model parameters
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
else:
|
||||
response.raise_for_status()
|
||||
return ""
|
||||
|
Loading…
Reference in New Issue
Block a user