From df1c10260cf127277875f1a011143b2d1ebdf30a Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Mon, 6 May 2024 15:28:35 -0500 Subject: [PATCH] community: minor changes sambanova integration (#21231) - **Description:** fix: variable names in root validator not allowing pass credentials as named parameters in llm instancing, also added sambanova's sambaverse and sambastudio llms to __init__.py for module import --- .../langchain_community/llms/__init__.py | 20 ++++++++++ .../langchain_community/llms/sambanova.py | 40 +++++++++++-------- .../tests/unit_tests/llms/test_imports.py | 2 + 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index f2b115cb9b2..a303cfd1b5e 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -510,6 +510,18 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]: return SagemakerEndpoint +def _import_sambaverse() -> Type[BaseLLM]: + from langchain_community.llms.sambanova import Sambaverse + + return Sambaverse + + +def _import_sambastudio() -> Type[BaseLLM]: + from langchain_community.llms.sambanova import SambaStudio + + return SambaStudio + + def _import_self_hosted() -> Type[BaseLLM]: from langchain_community.llms.self_hosted import SelfHostedPipeline @@ -793,6 +805,10 @@ def __getattr__(name: str) -> Any: return _import_rwkv() elif name == "SagemakerEndpoint": return _import_sagemaker_endpoint() + elif name == "Sambaverse": + return _import_sambaverse() + elif name == "SambaStudio": + return _import_sambastudio() elif name == "SelfHostedPipeline": return _import_self_hosted() elif name == "SelfHostedHuggingFaceLLM": @@ -922,6 +938,8 @@ __all__ = [ "RWKV", "Replicate", "SagemakerEndpoint", + "Sambaverse", + "SambaStudio", "SelfHostedHuggingFaceLLM", "SelfHostedPipeline", "SparkLLM", @@ -1015,6 +1033,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "replicate": _import_replicate, "rwkv": _import_rwkv, "sagemaker_endpoint": _import_sagemaker_endpoint, + "sambaverse": _import_sambaverse, + "sambastudio": _import_sambastudio, "self_hosted": _import_self_hosted, "self_hosted_hugging_face": _import_self_hosted_hugging_face, "stochasticai": _import_stochasticai, diff --git a/libs/community/langchain_community/llms/sambanova.py b/libs/community/langchain_community/llms/sambanova.py index 74189b934a4..c4e8d27f8b5 100644 --- a/libs/community/langchain_community/llms/sambanova.py +++ b/libs/community/langchain_community/llms/sambanova.py @@ -618,10 +618,10 @@ class SambaStudio(LLM): from langchain_community.llms.sambanova import Sambaverse SambaStudio( - base_url="your SambaStudio environment URL", - project_id=set with your SambaStudio project ID., - endpoint_id=set with your SambaStudio endpoint ID., - api_token= set with your SambaStudio endpoint API key., + sambastudio_base_url="your SambaStudio environment URL", + sambastudio_project_id=set with your SambaStudio project ID., + sambastudio_endpoint_id=set with your SambaStudio endpoint ID., + sambastudio_api_key= set with your SambaStudio endpoint API key., streaming=false model_kwargs={ "do_sample": False, @@ -634,16 +634,16 @@ class SambaStudio(LLM): ) """ - base_url: str = "" + sambastudio_base_url: str = "" """Base url to use""" - project_id: str = "" + sambastudio_project_id: str = "" """Project id on sambastudio for model""" - endpoint_id: str = "" + sambastudio_endpoint_id: str = "" """endpoint id on sambastudio for model""" - api_key: str = "" + sambastudio_api_key: str = "" """sambastudio api key""" model_kwargs: Optional[dict] = None @@ -674,16 +674,16 @@ class SambaStudio(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["base_url"] = get_from_dict_or_env( + values["sambastudio_base_url"] = get_from_dict_or_env( values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL" ) - values["project_id"] = get_from_dict_or_env( + values["sambastudio_project_id"] = get_from_dict_or_env( values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID" ) - values["endpoint_id"] = get_from_dict_or_env( + values["sambastudio_endpoint_id"] = get_from_dict_or_env( values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID" ) - values["api_key"] = get_from_dict_or_env( + values["sambastudio_api_key"] = get_from_dict_or_env( values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY" ) return values @@ -729,7 +729,11 @@ class SambaStudio(LLM): ValueError: If the prediction fails. """ response = sdk.nlp_predict( - self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params + self.sambastudio_project_id, + self.sambastudio_endpoint_id, + self.sambastudio_api_key, + prompt, + tuning_params, ) if response["status_code"] != 200: optional_detail = response["detail"] @@ -755,7 +759,7 @@ class SambaStudio(LLM): Raises: ValueError: If the prediction fails. """ - ss_endpoint = SSEndpointHandler(self.base_url) + ss_endpoint = SSEndpointHandler(self.sambastudio_base_url) tuning_params = self._get_tuning_params(stop) return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params) @@ -774,7 +778,11 @@ class SambaStudio(LLM): An iterator of GenerationChunks. """ for chunk in sdk.nlp_predict_stream( - self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params + self.sambastudio_project_id, + self.sambastudio_endpoint_id, + self.sambastudio_api_key, + prompt, + tuning_params, ): yield chunk @@ -794,7 +802,7 @@ class SambaStudio(LLM): Returns: The string generated by the model. """ - ss_endpoint = SSEndpointHandler(self.base_url) + ss_endpoint = SSEndpointHandler(self.sambastudio_base_url) tuning_params = self._get_tuning_params(stop) try: if self.streaming: diff --git a/libs/community/tests/unit_tests/llms/test_imports.py b/libs/community/tests/unit_tests/llms/test_imports.py index 64cfaec9c50..c0eddec93d6 100644 --- a/libs/community/tests/unit_tests/llms/test_imports.py +++ b/libs/community/tests/unit_tests/llms/test_imports.py @@ -77,6 +77,8 @@ EXPECT_ALL = [ "RWKV", "Replicate", "SagemakerEndpoint", + "Sambaverse", + "SambaStudio", "SelfHostedHuggingFaceLLM", "SelfHostedPipeline", "StochasticAI",