import json from typing import Any, Dict, Generator, Iterator, List, Optional, Union import requests from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra, root_validator from langchain_core.utils import get_from_dict_or_env class SVEndpointHandler: """ SambaNova Systems Interface for Sambaverse endpoint. :param str host_url: Base URL of the DaaS API service """ API_BASE_PATH = "/api/predict" def __init__(self, host_url: str): """ Initialize the SVEndpointHandler. :param str host_url: Base URL of the DaaS API service """ self.host_url = host_url self.http_session = requests.Session() @staticmethod def _process_response(response: requests.Response) -> Dict: """ Processes the API response and returns the resulting dict. All resulting dicts, regardless of success or failure, will contain the `status_code` key with the API response status code. If the API returned an error, the resulting dict will contain the key `detail` with the error message. If the API call was successful, the resulting dict will contain the key `data` with the response data. :param requests.Response response: the response object to process :return: the response dict :rtype: dict """ result: Dict[str, Any] = {} try: text_result = response.text.strip().split("\n")[-1] result = {"data": json.loads("".join(text_result.split("data: ")[1:]))} except Exception as e: result["detail"] = str(e) if "status_code" not in result: result["status_code"] = response.status_code return result @staticmethod def _process_streaming_response( response: requests.Response, ) -> Generator[GenerationChunk, None, None]: """Process the streaming response""" try: import sseclient except ImportError: raise ImportError( "could not import sseclient library" "Please install it with `pip install sseclient-py`." ) client = sseclient.SSEClient(response) close_conn = False for event in client.events(): if event.event == "error_event": close_conn = True text = json.dumps({"event": event.event, "data": event.data}) chunk = GenerationChunk(text=text) yield chunk if close_conn: client.close() def _get_full_url(self) -> str: """ Return the full API URL for a given path. :returns: the full API URL for the sub-path :rtype: str """ return f"{self.host_url}{self.API_BASE_PATH}" def nlp_predict( self, key: str, sambaverse_model_name: Optional[str], input: Union[List[str], str], params: Optional[str] = "", stream: bool = False, ) -> Dict: """ NLP predict using inline input string. :param str project: Project ID in which the endpoint exists :param str endpoint: Endpoint ID :param str key: API Key :param str input_str: Input string :param str params: Input params string :returns: Prediction results :rtype: dict """ if isinstance(input, str): input = [input] parsed_input = [] for element in input: parsed_element = { "conversation_id": "sambaverse-conversation-id", "messages": [ { "message_id": 0, "role": "user", "content": element, } ], } parsed_input.append(json.dumps(parsed_element)) if params: data = {"inputs": parsed_input, "params": json.loads(params)} else: data = {"inputs": parsed_input} response = self.http_session.post( self._get_full_url(), headers={ "key": key, "Content-Type": "application/json", "modelName": sambaverse_model_name, }, json=data, ) return SVEndpointHandler._process_response(response) def nlp_predict_stream( self, key: str, sambaverse_model_name: Optional[str], input: Union[List[str], str], params: Optional[str] = "", ) -> Iterator[GenerationChunk]: """ NLP predict using inline input string. :param str project: Project ID in which the endpoint exists :param str endpoint: Endpoint ID :param str key: API Key :param str input_str: Input string :param str params: Input params string :returns: Prediction results :rtype: dict """ if isinstance(input, str): input = [input] parsed_input = [] for element in input: parsed_element = { "conversation_id": "sambaverse-conversation-id", "messages": [ { "message_id": 0, "role": "user", "content": element, } ], } parsed_input.append(json.dumps(parsed_element)) if params: data = {"inputs": parsed_input, "params": json.loads(params)} else: data = {"inputs": parsed_input} # Streaming output response = self.http_session.post( self._get_full_url(), headers={ "key": key, "Content-Type": "application/json", "modelName": sambaverse_model_name, }, json=data, stream=True, ) for chunk in SVEndpointHandler._process_streaming_response(response): yield chunk class Sambaverse(LLM): """ Sambaverse large language models. To use, you should have the environment variable ``SAMBAVERSE_API_KEY`` set with your API key. get one in https://sambaverse.sambanova.ai read extra documentation in https://docs.sambanova.ai/sambaverse/latest/index.html Example: .. code-block:: python from langchain_community.llms.sambanova import Sambaverse Sambaverse( sambaverse_url="https://sambaverse.sambanova.ai", sambaverse_api_key: "your sambaverse api key", sambaverse_model_name: "Meta/llama-2-7b-chat-hf", streaming: = False model_kwargs={ "do_sample": False, "max_tokens_to_generate": 100, "temperature": 0.7, "top_p": 1.0, "repetition_penalty": 1, "top_k": 50, }, ) """ sambaverse_url: str = "https://sambaverse.sambanova.ai" """Sambaverse url to use""" sambaverse_api_key: str = "" """sambaverse api key""" sambaverse_model_name: Optional[str] = None """sambaverse expert model to use""" model_kwargs: Optional[dict] = None """Key word arguments to pass to the model.""" streaming: Optional[bool] = False """Streaming flag to get streamed response.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid @classmethod def is_lc_serializable(cls) -> bool: return True @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" values["sambaverse_url"] = get_from_dict_or_env( values, "sambaverse_url", "SAMBAVERSE_URL" ) values["sambaverse_api_key"] = get_from_dict_or_env( values, "sambaverse_api_key", "SAMBAVERSE_API_KEY" ) values["sambaverse_model_name"] = get_from_dict_or_env( values, "sambaverse_model_name", "SAMBAVERSE_MODEL_NAME" ) return values @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return {**{"model_kwargs": self.model_kwargs}} @property def _llm_type(self) -> str: """Return type of llm.""" return "Sambaverse LLM" def _get_tuning_params(self, stop: Optional[List[str]]) -> str: """ Get the tuning parameters to use when calling the LLM. Args: stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. Returns: The tuning parameters as a JSON string. """ _model_kwargs = self.model_kwargs or {} _stop_sequences = _model_kwargs.get("stop_sequences", []) _stop_sequences = stop or _stop_sequences _model_kwargs["stop_sequences"] = ",".join(f'"{x}"' for x in _stop_sequences) tuning_params_dict = { k: {"type": type(v).__name__, "value": str(v)} for k, v in (_model_kwargs.items()) } tuning_params = json.dumps(tuning_params_dict) return tuning_params def _handle_nlp_predict( self, sdk: SVEndpointHandler, prompt: Union[List[str], str], tuning_params: str, ) -> str: """ Perform an NLP prediction using the Sambaverse endpoint handler. Args: sdk: The SVEndpointHandler to use for the prediction. prompt: The prompt to use for the prediction. tuning_params: The tuning parameters to use for the prediction. Returns: The prediction result. Raises: ValueError: If the prediction fails. """ response = sdk.nlp_predict( self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params ) if response["status_code"] != 200: optional_details = response["details"] optional_message = response["message"] raise ValueError( f"Sambanova /complete call failed with status code " f"{response['status_code']}. Details: {optional_details}" f"{response['status_code']}. Message: {optional_message}" ) return response["data"]["completion"] def _handle_completion_requests( self, prompt: Union[List[str], str], stop: Optional[List[str]] ) -> str: """ Perform a prediction using the Sambaverse endpoint handler. Args: prompt: The prompt to use for the prediction. stop: stop sequences. Returns: The prediction result. Raises: ValueError: If the prediction fails. """ ss_endpoint = SVEndpointHandler(self.sambaverse_url) tuning_params = self._get_tuning_params(stop) return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) def _handle_nlp_predict_stream( self, sdk: SVEndpointHandler, prompt: Union[List[str], str], tuning_params: str ) -> Iterator[GenerationChunk]: """ Perform a streaming request to the LLM. Args: sdk: The SVEndpointHandler to use for the prediction. prompt: The prompt to use for the prediction. tuning_params: The tuning parameters to use for the prediction. Returns: An iterator of GenerationChunks. """ for chunk in sdk.nlp_predict_stream( self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params ): yield chunk def _stream( self, prompt: Union[List[str], str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """Stream the Sambaverse's LLM on the given prompt. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. run_manager: Callback manager for the run. **kwargs: Additional keyword arguments. directly passed to the sambaverse model in API call. Returns: An iterator of GenerationChunks. """ ss_endpoint = SVEndpointHandler(self.sambaverse_url) tuning_params = self._get_tuning_params(stop) try: if self.streaming: for chunk in self._handle_nlp_predict_stream( ss_endpoint, prompt, tuning_params ): if run_manager: run_manager.on_llm_new_token(chunk.text) yield chunk else: return except Exception as e: # Handle any errors raised by the inference endpoint raise ValueError(f"Error raised by the inference endpoint: {e}") from e def _handle_stream_request( self, prompt: Union[List[str], str], stop: Optional[List[str]], run_manager: Optional[CallbackManagerForLLMRun], kwargs: Dict[str, Any], ) -> str: """ Perform a streaming request to the LLM. Args: prompt: The prompt to generate from. stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. run_manager: Callback manager for the run. **kwargs: Additional keyword arguments. directly passed to the sambaverse model in API call. Returns: The model output as a string. """ completion = "" for chunk in self._stream( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ): completion += chunk.text return completion def _call( self, prompt: Union[List[str], str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Run the LLM on the given input. Args: prompt: The prompt to generate from. stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. run_manager: Callback manager for the run. **kwargs: Additional keyword arguments. directly passed to the sambaverse model in API call. Returns: The model output as a string. """ try: if self.streaming: return self._handle_stream_request(prompt, stop, run_manager, kwargs) return self._handle_completion_requests(prompt, stop) except Exception as e: # Handle any errors raised by the inference endpoint raise ValueError(f"Error raised by the inference endpoint: {e}") from e class SSEndpointHandler: """ SambaNova Systems Interface for SambaStudio model endpoints. :param str host_url: Base URL of the DaaS API service """ API_BASE_PATH = "/api" def __init__(self, host_url: str): """ Initialize the SSEndpointHandler. :param str host_url: Base URL of the DaaS API service """ self.host_url = host_url self.http_session = requests.Session() @staticmethod def _process_response(response: requests.Response) -> Dict: """ Processes the API response and returns the resulting dict. All resulting dicts, regardless of success or failure, will contain the `status_code` key with the API response status code. If the API returned an error, the resulting dict will contain the key `detail` with the error message. If the API call was successful, the resulting dict will contain the key `data` with the response data. :param requests.Response response: the response object to process :return: the response dict :rtype: dict """ result: Dict[str, Any] = {} try: result = response.json() except Exception as e: result["detail"] = str(e) if "status_code" not in result: result["status_code"] = response.status_code return result @staticmethod def _process_streaming_response( response: requests.Response, ) -> Generator[GenerationChunk, None, None]: """Process the streaming response""" try: import sseclient except ImportError: raise ImportError( "could not import sseclient library" "Please install it with `pip install sseclient-py`." ) client = sseclient.SSEClient(response) close_conn = False for event in client.events(): if event.event == "error_event": close_conn = True text = json.dumps({"event": event.event, "data": event.data}) chunk = GenerationChunk(text=text) yield chunk if close_conn: client.close() def _get_full_url(self, path: str) -> str: """ Return the full API URL for a given path. :param str path: the sub-path :returns: the full API URL for the sub-path :rtype: str """ return f"{self.host_url}{self.API_BASE_PATH}{path}" def nlp_predict( self, project: str, endpoint: str, key: str, input: Union[List[str], str], params: Optional[str] = "", stream: bool = False, ) -> Dict: """ NLP predict using inline input string. :param str project: Project ID in which the endpoint exists :param str endpoint: Endpoint ID :param str key: API Key :param str input_str: Input string :param str params: Input params string :returns: Prediction results :rtype: dict """ if isinstance(input, str): input = [input] if params: data = {"inputs": input, "params": json.loads(params)} else: data = {"inputs": input} response = self.http_session.post( self._get_full_url(f"/predict/nlp/{project}/{endpoint}"), headers={"key": key}, json=data, ) return SSEndpointHandler._process_response(response) def nlp_predict_stream( self, project: str, endpoint: str, key: str, input: Union[List[str], str], params: Optional[str] = "", ) -> Iterator[GenerationChunk]: """ NLP predict using inline input string. :param str project: Project ID in which the endpoint exists :param str endpoint: Endpoint ID :param str key: API Key :param str input_str: Input string :param str params: Input params string :returns: Prediction results :rtype: dict """ if isinstance(input, str): input = [input] if params: data = {"inputs": input, "params": json.loads(params)} else: data = {"inputs": input} # Streaming output response = self.http_session.post( self._get_full_url(f"/predict/nlp/stream/{project}/{endpoint}"), headers={"key": key}, json=data, stream=True, ) for chunk in SSEndpointHandler._process_streaming_response(response): yield chunk class SambaStudio(LLM): """ SambaStudio large language models. To use, you should have the environment variables ``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL. ``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID. ``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID. ``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key. https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html Example: .. code-block:: python from langchain_community.llms.sambanova import Sambaverse SambaStudio( sambastudio_base_url="your SambaStudio environment URL", sambastudio_project_id=set with your SambaStudio project ID., sambastudio_endpoint_id=set with your SambaStudio endpoint ID., sambastudio_api_key= set with your SambaStudio endpoint API key., streaming=false model_kwargs={ "do_sample": False, "max_tokens_to_generate": 1000, "temperature": 0.7, "top_p": 1.0, "repetition_penalty": 1, "top_k": 50, }, ) """ sambastudio_base_url: str = "" """Base url to use""" sambastudio_project_id: str = "" """Project id on sambastudio for model""" sambastudio_endpoint_id: str = "" """endpoint id on sambastudio for model""" sambastudio_api_key: str = "" """sambastudio api key""" model_kwargs: Optional[dict] = None """Key word arguments to pass to the model.""" streaming: Optional[bool] = False """Streaming flag to get streamed response.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid @classmethod def is_lc_serializable(cls) -> bool: return True @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return {**{"model_kwargs": self.model_kwargs}} @property def _llm_type(self) -> str: """Return type of llm.""" return "Sambastudio LLM" @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["sambastudio_base_url"] = get_from_dict_or_env( values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL" ) values["sambastudio_project_id"] = get_from_dict_or_env( values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID" ) values["sambastudio_endpoint_id"] = get_from_dict_or_env( values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID" ) values["sambastudio_api_key"] = get_from_dict_or_env( values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY" ) return values def _get_tuning_params(self, stop: Optional[List[str]]) -> str: """ Get the tuning parameters to use when calling the LLM. Args: stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. Returns: The tuning parameters as a JSON string. """ _model_kwargs = self.model_kwargs or {} _stop_sequences = _model_kwargs.get("stop_sequences", []) _stop_sequences = stop or _stop_sequences # _model_kwargs['stop_sequences'] = ','.join( # f"'{x}'" for x in _stop_sequences) tuning_params_dict = { k: {"type": type(v).__name__, "value": str(v)} for k, v in (_model_kwargs.items()) } tuning_params = json.dumps(tuning_params_dict) return tuning_params def _handle_nlp_predict( self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str ) -> str: """ Perform an NLP prediction using the SambaStudio endpoint handler. Args: sdk: The SSEndpointHandler to use for the prediction. prompt: The prompt to use for the prediction. tuning_params: The tuning parameters to use for the prediction. Returns: The prediction result. Raises: ValueError: If the prediction fails. """ response = sdk.nlp_predict( self.sambastudio_project_id, self.sambastudio_endpoint_id, self.sambastudio_api_key, prompt, tuning_params, ) if response["status_code"] != 200: optional_detail = response["detail"] raise ValueError( f"Sambanova /complete call failed with status code " f"{response['status_code']}. Details: {optional_detail}" ) return response["data"][0]["completion"] def _handle_completion_requests( self, prompt: Union[List[str], str], stop: Optional[List[str]] ) -> str: """ Perform a prediction using the SambaStudio endpoint handler. Args: prompt: The prompt to use for the prediction. stop: stop sequences. Returns: The prediction result. Raises: ValueError: If the prediction fails. """ ss_endpoint = SSEndpointHandler(self.sambastudio_base_url) tuning_params = self._get_tuning_params(stop) return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) def _handle_nlp_predict_stream( self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str ) -> Iterator[GenerationChunk]: """ Perform a streaming request to the LLM. Args: sdk: The SVEndpointHandler to use for the prediction. prompt: The prompt to use for the prediction. tuning_params: The tuning parameters to use for the prediction. Returns: An iterator of GenerationChunks. """ for chunk in sdk.nlp_predict_stream( self.sambastudio_project_id, self.sambastudio_endpoint_id, self.sambastudio_api_key, prompt, tuning_params, ): yield chunk def _stream( self, prompt: Union[List[str], str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """Call out to Sambanova's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. """ ss_endpoint = SSEndpointHandler(self.sambastudio_base_url) tuning_params = self._get_tuning_params(stop) try: if self.streaming: for chunk in self._handle_nlp_predict_stream( ss_endpoint, prompt, tuning_params ): if run_manager: run_manager.on_llm_new_token(chunk.text) yield chunk else: return except Exception as e: # Handle any errors raised by the inference endpoint raise ValueError(f"Error raised by the inference endpoint: {e}") from e def _handle_stream_request( self, prompt: Union[List[str], str], stop: Optional[List[str]], run_manager: Optional[CallbackManagerForLLMRun], kwargs: Dict[str, Any], ) -> str: """ Perform a streaming request to the LLM. Args: prompt: The prompt to generate from. stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. run_manager: Callback manager for the run. **kwargs: Additional keyword arguments. directly passed to the sambaverse model in API call. Returns: The model output as a string. """ completion = "" for chunk in self._stream( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ): completion += chunk.text return completion def _call( self, prompt: Union[List[str], str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call out to Sambanova's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. """ if stop is not None: raise Exception("stop not implemented") try: if self.streaming: return self._handle_stream_request(prompt, stop, run_manager, kwargs) return self._handle_completion_requests(prompt, stop) except Exception as e: # Handle any errors raised by the inference endpoint raise ValueError(f"Error raised by the inference endpoint: {e}") from e