langchain[patch]: Cerebrium model_api_request deprecation (#12704)

- **Description:** As part of my conversation with Cerebrium team,
`model_api_request` will be no longer available in cerebrium lib so it
needs to be replaced.
  - **Issue:** #12705 12705,
  - **Dependencies:** Cerebrium team (agreed)
  - **Tag maintainer:** @eyurtsev 
  - **Twitter handle:** No official Twitter account sorry :D

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
geret1 2023-12-04 18:26:32 +01:00 committed by GitHub
parent ee94ef55ee
commit 50aee687c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
@ -89,24 +90,21 @@ class CerebriumAI(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Call to CerebriumAI endpoint.""" headers: Dict = {
try: "Authorization": self.cerebriumai_api_key,
from cerebrium import model_api_request "Content-Type": "application/json",
except ImportError: }
raise ValueError(
"Could not import cerebrium python package. "
"Please install it with `pip install cerebrium`."
)
params = self.model_kwargs or {} params = self.model_kwargs or {}
response = model_api_request( payload = {"prompt": prompt, **params, **kwargs}
self.endpoint_url, response = requests.post(self.endpoint_url, json=payload, headers=headers)
{"prompt": prompt, **params, **kwargs}, if response.status_code == 200:
self.cerebriumai_api_key, data = response.json()
) text = data["result"]
text = response["data"]["result"]
if stop is not None: if stop is not None:
# I believe this is required since the stop tokens # I believe this is required since the stop tokens
# are not enforced by the model parameters # are not enforced by the model parameters
text = enforce_stop_tokens(text, stop) text = enforce_stop_tokens(text, stop)
return text return text
else:
response.raise_for_status()
return ""