mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
community: remove sambaverse (#26265)
removing Sambaverse llm model and references given is not available after Sep/10/2024 <img width="1781" alt="image" src="https://github.com/user-attachments/assets/4dcdb5f7-5264-4a03-b8e5-95c88304e059">
This commit is contained in:
parent
3fc0ea510e
commit
37b72023fe
@ -6,129 +6,11 @@
|
||||
"source": [
|
||||
"# SambaNova\n",
|
||||
"\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [Sambaverse](https://sambaverse.sambanova.ai/) and [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) are platforms for running your own open-source models\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with SambaNova models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sambaverse"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n",
|
||||
" **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"An API key is required to access Sambaverse models. To get a key, create an account at [sambaverse.sambanova.ai](https://sambaverse.sambanova.ai/)\n",
|
||||
"\n",
|
||||
"The [sseclient-py](https://pypi.org/project/sseclient-py/) package is required to run streaming predictions "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --quiet sseclient-py==1.8.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Register your API key as an environment variable:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"sambaverse_api_key = \"<Your sambaverse API key>\"\n",
|
||||
"\n",
|
||||
"# Set the environment variables\n",
|
||||
"os.environ[\"SAMBAVERSE_API_KEY\"] = sambaverse_api_key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Call Sambaverse models directly from LangChain!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.sambanova import Sambaverse\n",
|
||||
"\n",
|
||||
"llm = Sambaverse(\n",
|
||||
" sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n",
|
||||
" streaming=False,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" \"process_prompt\": False,\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming response\n",
|
||||
"\n",
|
||||
"from langchain_community.llms.sambanova import Sambaverse\n",
|
||||
"\n",
|
||||
"llm = Sambaverse(\n",
|
||||
" sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n",
|
||||
" streaming=True,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" \"process_prompt\": False,\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -510,12 +510,6 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
|
||||
return SagemakerEndpoint
|
||||
|
||||
|
||||
def _import_sambaverse() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
|
||||
return Sambaverse
|
||||
|
||||
|
||||
def _import_sambastudio() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
|
||||
@ -817,8 +811,6 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_rwkv()
|
||||
elif name == "SagemakerEndpoint":
|
||||
return _import_sagemaker_endpoint()
|
||||
elif name == "Sambaverse":
|
||||
return _import_sambaverse()
|
||||
elif name == "SambaStudio":
|
||||
return _import_sambastudio()
|
||||
elif name == "SelfHostedPipeline":
|
||||
@ -954,7 +946,6 @@ __all__ = [
|
||||
"RWKV",
|
||||
"Replicate",
|
||||
"SagemakerEndpoint",
|
||||
"Sambaverse",
|
||||
"SambaStudio",
|
||||
"SelfHostedHuggingFaceLLM",
|
||||
"SelfHostedPipeline",
|
||||
@ -1051,7 +1042,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"replicate": _import_replicate,
|
||||
"rwkv": _import_rwkv,
|
||||
"sagemaker_endpoint": _import_sagemaker_endpoint,
|
||||
"sambaverse": _import_sambaverse,
|
||||
"sambastudio": _import_sambastudio,
|
||||
"self_hosted": _import_self_hosted,
|
||||
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
|
||||
|
@ -9,464 +9,6 @@ from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class SVEndpointHandler:
|
||||
"""
|
||||
SambaNova Systems Interface for Sambaverse endpoint.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
|
||||
API_BASE_PATH: str = "/api/predict"
|
||||
|
||||
def __init__(self, host_url: str):
|
||||
"""
|
||||
Initialize the SVEndpointHandler.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
self.host_url = host_url
|
||||
self.http_session = requests.Session()
|
||||
|
||||
@staticmethod
|
||||
def _process_response(response: requests.Response) -> Dict:
|
||||
"""
|
||||
Processes the API response and returns the resulting dict.
|
||||
|
||||
All resulting dicts, regardless of success or failure, will contain the
|
||||
`status_code` key with the API response status code.
|
||||
|
||||
If the API returned an error, the resulting dict will contain the key
|
||||
`detail` with the error message.
|
||||
|
||||
If the API call was successful, the resulting dict will contain the key
|
||||
`data` with the response data.
|
||||
|
||||
:param requests.Response response: the response object to process
|
||||
:return: the response dict
|
||||
:type: dict
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
try:
|
||||
lines_result = response.text.strip().split("\n")
|
||||
text_result = lines_result[-1]
|
||||
if response.status_code == 200 and json.loads(text_result).get("error"):
|
||||
completion = ""
|
||||
for line in lines_result[:-1]:
|
||||
completion += json.loads(line)["result"]["responses"][0][
|
||||
"stream_token"
|
||||
]
|
||||
text_result = lines_result[-2]
|
||||
result = json.loads(text_result)
|
||||
result["result"]["responses"][0]["completion"] = completion
|
||||
else:
|
||||
result = json.loads(text_result)
|
||||
except Exception as e:
|
||||
result["detail"] = str(e)
|
||||
if "status_code" not in result:
|
||||
result["status_code"] = response.status_code
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _process_streaming_response(
|
||||
response: requests.Response,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Process the streaming response"""
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
chunk = json.loads(line)
|
||||
if "status_code" not in chunk:
|
||||
chunk["status_code"] = response.status_code
|
||||
if chunk["status_code"] == 200 and chunk.get("error"):
|
||||
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
||||
return chunk
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||
|
||||
def _get_full_url(self) -> str:
|
||||
"""
|
||||
Return the full API URL for a given path.
|
||||
:returns: the full API URL for the sub-path
|
||||
:type: str
|
||||
"""
|
||||
return f"{self.host_url}{self.API_BASE_PATH}"
|
||||
|
||||
def nlp_predict(
|
||||
self,
|
||||
key: str,
|
||||
sambaverse_model_name: Optional[str],
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
stream: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
:param str project: Project ID in which the endpoint exists
|
||||
:param str endpoint: Endpoint ID
|
||||
:param str key: API Key
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if params:
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": input}
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
headers={
|
||||
"key": key,
|
||||
"Content-Type": "application/json",
|
||||
"modelName": sambaverse_model_name,
|
||||
},
|
||||
json=data,
|
||||
)
|
||||
return SVEndpointHandler._process_response(response)
|
||||
|
||||
def nlp_predict_stream(
|
||||
self,
|
||||
key: str,
|
||||
sambaverse_model_name: Optional[str],
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
) -> Iterator[Dict]:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
:param str project: Project ID in which the endpoint exists
|
||||
:param str endpoint: Endpoint ID
|
||||
:param str key: API Key
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if params:
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": input}
|
||||
# Streaming output
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
headers={
|
||||
"key": key,
|
||||
"Content-Type": "application/json",
|
||||
"modelName": sambaverse_model_name,
|
||||
},
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
for chunk in SVEndpointHandler._process_streaming_response(response):
|
||||
yield chunk
|
||||
|
||||
|
||||
class Sambaverse(LLM):
|
||||
"""
|
||||
Sambaverse large language models.
|
||||
|
||||
To use, you should have the environment variable ``SAMBAVERSE_API_KEY``
|
||||
set with your API key.
|
||||
|
||||
get one in https://sambaverse.sambanova.ai
|
||||
read extra documentation in https://docs.sambanova.ai/sambaverse/latest/index.html
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
Sambaverse(
|
||||
sambaverse_url="https://sambaverse.sambanova.ai",
|
||||
sambaverse_api_key="your-sambaverse-api-key",
|
||||
sambaverse_model_name="Meta/llama-2-7b-chat-hf",
|
||||
streaming: = False
|
||||
model_kwargs={
|
||||
"select_expert": "llama-2-7b-chat-hf",
|
||||
"do_sample": False,
|
||||
"max_tokens_to_generate": 100,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_k": 50,
|
||||
"process_prompt": False
|
||||
},
|
||||
)
|
||||
"""
|
||||
|
||||
sambaverse_url: str = ""
|
||||
"""Sambaverse url to use"""
|
||||
|
||||
sambaverse_api_key: str = ""
|
||||
"""sambaverse api key"""
|
||||
|
||||
sambaverse_model_name: Optional[str] = None
|
||||
"""sambaverse expert model to use"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Streaming flag to get streamed response."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["sambaverse_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambaverse_url",
|
||||
"SAMBAVERSE_URL",
|
||||
default="https://sambaverse.sambanova.ai",
|
||||
)
|
||||
values["sambaverse_api_key"] = get_from_dict_or_env(
|
||||
values, "sambaverse_api_key", "SAMBAVERSE_API_KEY"
|
||||
)
|
||||
values["sambaverse_model_name"] = get_from_dict_or_env(
|
||||
values, "sambaverse_model_name", "SAMBAVERSE_MODEL_NAME"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_kwargs": self.model_kwargs}}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "Sambaverse LLM"
|
||||
|
||||
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
||||
"""
|
||||
Get the tuning parameters to use when calling the LLM.
|
||||
|
||||
Args:
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
|
||||
Returns:
|
||||
The tuning parameters as a JSON string.
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
||||
_stop_sequences = stop or _kwarg_stop_sequences
|
||||
if not _kwarg_stop_sequences:
|
||||
_model_kwargs["stop_sequences"] = ",".join(
|
||||
f'"{x}"' for x in _stop_sequences
|
||||
)
|
||||
tuning_params_dict = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (_model_kwargs.items())
|
||||
}
|
||||
_model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
||||
tuning_params = json.dumps(tuning_params_dict)
|
||||
return tuning_params
|
||||
|
||||
def _handle_nlp_predict(
|
||||
self,
|
||||
sdk: SVEndpointHandler,
|
||||
prompt: Union[List[str], str],
|
||||
tuning_params: str,
|
||||
) -> str:
|
||||
"""
|
||||
Perform an NLP prediction using the Sambaverse endpoint handler.
|
||||
|
||||
Args:
|
||||
sdk: The SVEndpointHandler to use for the prediction.
|
||||
prompt: The prompt to use for the prediction.
|
||||
tuning_params: The tuning parameters to use for the prediction.
|
||||
|
||||
Returns:
|
||||
The prediction result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
response = sdk.nlp_predict(
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
)
|
||||
if response["status_code"] != 200:
|
||||
error = response.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}."
|
||||
f"{response}."
|
||||
)
|
||||
return response["result"]["responses"][0]["completion"]
|
||||
|
||||
def _handle_completion_requests(
|
||||
self, prompt: Union[List[str], str], stop: Optional[List[str]]
|
||||
) -> str:
|
||||
"""
|
||||
Perform a prediction using the Sambaverse endpoint handler.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for the prediction.
|
||||
stop: stop sequences.
|
||||
|
||||
Returns:
|
||||
The prediction result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
||||
|
||||
def _handle_nlp_predict_stream(
|
||||
self, sdk: SVEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""
|
||||
Perform a streaming request to the LLM.
|
||||
|
||||
Args:
|
||||
sdk: The SVEndpointHandler to use for the prediction.
|
||||
prompt: The prompt to use for the prediction.
|
||||
tuning_params: The tuning parameters to use for the prediction.
|
||||
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
for chunk in sdk.nlp_predict_stream(
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
):
|
||||
if chunk["status_code"] != 200:
|
||||
error = chunk.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
text = chunk["result"]["responses"][0]["stream_token"]
|
||||
generated_chunk = GenerationChunk(text=text)
|
||||
yield generated_chunk
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Stream the Sambaverse's LLM on the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
try:
|
||||
if self.streaming:
|
||||
for chunk in self._handle_nlp_predict_stream(
|
||||
ss_endpoint, prompt, tuning_params
|
||||
):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
yield chunk
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
def _handle_stream_request(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]],
|
||||
run_manager: Optional[CallbackManagerForLLMRun],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
Perform a streaming request to the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate from.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
"""
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given input.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate from.
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
"""
|
||||
try:
|
||||
if self.streaming:
|
||||
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
||||
return self._handle_completion_requests(prompt, stop)
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
|
||||
class SSEndpointHandler:
|
||||
"""
|
||||
SambaNova Systems Interface for SambaStudio model endpoints.
|
||||
@ -975,7 +517,7 @@ class SambaStudio(LLM):
|
||||
first occurrence of any of the stop substrings.
|
||||
run_manager: Callback manager for the run.
|
||||
kwargs: Additional keyword arguments. directly passed
|
||||
to the sambaverse model in API call.
|
||||
to the sambastudio model in API call.
|
||||
|
||||
Returns:
|
||||
The model output as a string.
|
||||
|
@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini
|
||||
# PRs that increase the current count will not be accepted.
|
||||
# PRs that decrease update the code in the repository
|
||||
# and allow decreasing the count of are welcome!
|
||||
current_count=129
|
||||
current_count=128
|
||||
|
||||
if [ "$count" -gt "$current_count" ]; then
|
||||
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."
|
||||
|
@ -1,28 +1,17 @@
|
||||
"""Test sambanova API wrapper.
|
||||
|
||||
In order to run this test, you need to have an sambaverse api key,
|
||||
and a sambaverse base url, project id, endpoint id, and api key.
|
||||
You'll then need to set SAMBAVERSE_API_KEY, SAMBASTUDIO_BASE_URL,
|
||||
In order to run this test, you need to have a sambastudio base url,
|
||||
project id, endpoint id, and api key.
|
||||
You'll then need to set SAMBASTUDIO_BASE_URL, SAMBASTUDIO_BASE_URI
|
||||
SAMBASTUDIO_PROJECT_ID, SAMBASTUDIO_ENDPOINT_ID, and SAMBASTUDIO_API_KEY
|
||||
environment variables.
|
||||
"""
|
||||
|
||||
from langchain_community.llms.sambanova import SambaStudio, Sambaverse
|
||||
|
||||
|
||||
def test_sambaverse_call() -> None:
|
||||
"""Test simple non-streaming call to sambaverse."""
|
||||
llm = Sambaverse(
|
||||
sambaverse_model_name="Meta/llama-2-7b-chat-hf",
|
||||
model_kwargs={"select_expert": "llama-2-7b-chat-hf"},
|
||||
)
|
||||
output = llm.invoke("What is LangChain")
|
||||
assert output
|
||||
assert isinstance(output, str)
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
|
||||
|
||||
def test_sambastudio_call() -> None:
|
||||
"""Test simple non-streaming call to sambaverse."""
|
||||
"""Test simple non-streaming call to sambastudio."""
|
||||
llm = SambaStudio()
|
||||
output = llm.invoke("What is LangChain")
|
||||
assert output
|
||||
|
@ -77,7 +77,6 @@ EXPECT_ALL = [
|
||||
"RWKV",
|
||||
"Replicate",
|
||||
"SagemakerEndpoint",
|
||||
"Sambaverse",
|
||||
"SambaStudio",
|
||||
"SelfHostedHuggingFaceLLM",
|
||||
"SelfHostedPipeline",
|
||||
|
Loading…
Reference in New Issue
Block a user