mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
langchain[patch]: Cerebrium model_api_request deprecation (#12704)
- **Description:** As part of my conversation with Cerebrium team, `model_api_request` will be no longer available in cerebrium lib so it needs to be replaced. - **Issue:** #12705 12705, - **Dependencies:** Cerebrium team (agreed) - **Tag maintainer:** @eyurtsev - **Twitter handle:** No official Twitter account sorry :D --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ee94ef55ee
commit
50aee687c6
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
@ -89,24 +90,21 @@ class CerebriumAI(LLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to CerebriumAI endpoint."""
|
headers: Dict = {
|
||||||
try:
|
"Authorization": self.cerebriumai_api_key,
|
||||||
from cerebrium import model_api_request
|
"Content-Type": "application/json",
|
||||||
except ImportError:
|
}
|
||||||
raise ValueError(
|
|
||||||
"Could not import cerebrium python package. "
|
|
||||||
"Please install it with `pip install cerebrium`."
|
|
||||||
)
|
|
||||||
|
|
||||||
params = self.model_kwargs or {}
|
params = self.model_kwargs or {}
|
||||||
response = model_api_request(
|
payload = {"prompt": prompt, **params, **kwargs}
|
||||||
self.endpoint_url,
|
response = requests.post(self.endpoint_url, json=payload, headers=headers)
|
||||||
{"prompt": prompt, **params, **kwargs},
|
if response.status_code == 200:
|
||||||
self.cerebriumai_api_key,
|
data = response.json()
|
||||||
)
|
text = data["result"]
|
||||||
text = response["data"]["result"]
|
if stop is not None:
|
||||||
if stop is not None:
|
# I believe this is required since the stop tokens
|
||||||
# I believe this is required since the stop tokens
|
# are not enforced by the model parameters
|
||||||
# are not enforced by the model parameters
|
text = enforce_stop_tokens(text, stop)
|
||||||
text = enforce_stop_tokens(text, stop)
|
return text
|
||||||
return text
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
return ""
|
||||||
|
Loading…
Reference in New Issue
Block a user