community: minor changes sambanova integration (#21231)

- **Description:** fix: variable names in root validator not allowing
pass credentials as named parameters in llm instancing, also added
sambanova's sambaverse and sambastudio llms to __init__.py for module
import
This commit is contained in:
Jorge Piedrahita Ortiz 2024-05-06 15:28:35 -05:00 committed by GitHub
parent d9a61c0fa9
commit df1c10260c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 16 deletions

View File

@ -510,6 +510,18 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
return SagemakerEndpoint return SagemakerEndpoint
def _import_sambaverse() -> Type[BaseLLM]:
from langchain_community.llms.sambanova import Sambaverse
return Sambaverse
def _import_sambastudio() -> Type[BaseLLM]:
from langchain_community.llms.sambanova import SambaStudio
return SambaStudio
def _import_self_hosted() -> Type[BaseLLM]: def _import_self_hosted() -> Type[BaseLLM]:
from langchain_community.llms.self_hosted import SelfHostedPipeline from langchain_community.llms.self_hosted import SelfHostedPipeline
@ -793,6 +805,10 @@ def __getattr__(name: str) -> Any:
return _import_rwkv() return _import_rwkv()
elif name == "SagemakerEndpoint": elif name == "SagemakerEndpoint":
return _import_sagemaker_endpoint() return _import_sagemaker_endpoint()
elif name == "Sambaverse":
return _import_sambaverse()
elif name == "SambaStudio":
return _import_sambastudio()
elif name == "SelfHostedPipeline": elif name == "SelfHostedPipeline":
return _import_self_hosted() return _import_self_hosted()
elif name == "SelfHostedHuggingFaceLLM": elif name == "SelfHostedHuggingFaceLLM":
@ -922,6 +938,8 @@ __all__ = [
"RWKV", "RWKV",
"Replicate", "Replicate",
"SagemakerEndpoint", "SagemakerEndpoint",
"Sambaverse",
"SambaStudio",
"SelfHostedHuggingFaceLLM", "SelfHostedHuggingFaceLLM",
"SelfHostedPipeline", "SelfHostedPipeline",
"SparkLLM", "SparkLLM",
@ -1015,6 +1033,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"replicate": _import_replicate, "replicate": _import_replicate,
"rwkv": _import_rwkv, "rwkv": _import_rwkv,
"sagemaker_endpoint": _import_sagemaker_endpoint, "sagemaker_endpoint": _import_sagemaker_endpoint,
"sambaverse": _import_sambaverse,
"sambastudio": _import_sambastudio,
"self_hosted": _import_self_hosted, "self_hosted": _import_self_hosted,
"self_hosted_hugging_face": _import_self_hosted_hugging_face, "self_hosted_hugging_face": _import_self_hosted_hugging_face,
"stochasticai": _import_stochasticai, "stochasticai": _import_stochasticai,

View File

@ -618,10 +618,10 @@ class SambaStudio(LLM):
from langchain_community.llms.sambanova import Sambaverse from langchain_community.llms.sambanova import Sambaverse
SambaStudio( SambaStudio(
base_url="your SambaStudio environment URL", sambastudio_base_url="your SambaStudio environment URL",
project_id=set with your SambaStudio project ID., sambastudio_project_id=set with your SambaStudio project ID.,
endpoint_id=set with your SambaStudio endpoint ID., sambastudio_endpoint_id=set with your SambaStudio endpoint ID.,
api_token= set with your SambaStudio endpoint API key., sambastudio_api_key= set with your SambaStudio endpoint API key.,
streaming=false streaming=false
model_kwargs={ model_kwargs={
"do_sample": False, "do_sample": False,
@ -634,16 +634,16 @@ class SambaStudio(LLM):
) )
""" """
base_url: str = "" sambastudio_base_url: str = ""
"""Base url to use""" """Base url to use"""
project_id: str = "" sambastudio_project_id: str = ""
"""Project id on sambastudio for model""" """Project id on sambastudio for model"""
endpoint_id: str = "" sambastudio_endpoint_id: str = ""
"""endpoint id on sambastudio for model""" """endpoint id on sambastudio for model"""
api_key: str = "" sambastudio_api_key: str = ""
"""sambastudio api key""" """sambastudio api key"""
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
@ -674,16 +674,16 @@ class SambaStudio(LLM):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["base_url"] = get_from_dict_or_env( values["sambastudio_base_url"] = get_from_dict_or_env(
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL" values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
) )
values["project_id"] = get_from_dict_or_env( values["sambastudio_project_id"] = get_from_dict_or_env(
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID" values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
) )
values["endpoint_id"] = get_from_dict_or_env( values["sambastudio_endpoint_id"] = get_from_dict_or_env(
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID" values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
) )
values["api_key"] = get_from_dict_or_env( values["sambastudio_api_key"] = get_from_dict_or_env(
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY" values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
) )
return values return values
@ -729,7 +729,11 @@ class SambaStudio(LLM):
ValueError: If the prediction fails. ValueError: If the prediction fails.
""" """
response = sdk.nlp_predict( response = sdk.nlp_predict(
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
) )
if response["status_code"] != 200: if response["status_code"] != 200:
optional_detail = response["detail"] optional_detail = response["detail"]
@ -755,7 +759,7 @@ class SambaStudio(LLM):
Raises: Raises:
ValueError: If the prediction fails. ValueError: If the prediction fails.
""" """
ss_endpoint = SSEndpointHandler(self.base_url) ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
tuning_params = self._get_tuning_params(stop) tuning_params = self._get_tuning_params(stop)
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
@ -774,7 +778,11 @@ class SambaStudio(LLM):
An iterator of GenerationChunks. An iterator of GenerationChunks.
""" """
for chunk in sdk.nlp_predict_stream( for chunk in sdk.nlp_predict_stream(
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
): ):
yield chunk yield chunk
@ -794,7 +802,7 @@ class SambaStudio(LLM):
Returns: Returns:
The string generated by the model. The string generated by the model.
""" """
ss_endpoint = SSEndpointHandler(self.base_url) ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
tuning_params = self._get_tuning_params(stop) tuning_params = self._get_tuning_params(stop)
try: try:
if self.streaming: if self.streaming:

View File

@ -77,6 +77,8 @@ EXPECT_ALL = [
"RWKV", "RWKV",
"Replicate", "Replicate",
"SagemakerEndpoint", "SagemakerEndpoint",
"Sambaverse",
"SambaStudio",
"SelfHostedHuggingFaceLLM", "SelfHostedHuggingFaceLLM",
"SelfHostedPipeline", "SelfHostedPipeline",
"StochasticAI", "StochasticAI",