mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
community: minor changes sambanova integration (#21231)
- **Description:** fix: variable names in root validator not allowing pass credentials as named parameters in llm instancing, also added sambanova's sambaverse and sambastudio llms to __init__.py for module import
This commit is contained in:
parent
d9a61c0fa9
commit
df1c10260c
@ -510,6 +510,18 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
|
||||
return SagemakerEndpoint
|
||||
|
||||
|
||||
def _import_sambaverse() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
|
||||
return Sambaverse
|
||||
|
||||
|
||||
def _import_sambastudio() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.sambanova import SambaStudio
|
||||
|
||||
return SambaStudio
|
||||
|
||||
|
||||
def _import_self_hosted() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.self_hosted import SelfHostedPipeline
|
||||
|
||||
@ -793,6 +805,10 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_rwkv()
|
||||
elif name == "SagemakerEndpoint":
|
||||
return _import_sagemaker_endpoint()
|
||||
elif name == "Sambaverse":
|
||||
return _import_sambaverse()
|
||||
elif name == "SambaStudio":
|
||||
return _import_sambastudio()
|
||||
elif name == "SelfHostedPipeline":
|
||||
return _import_self_hosted()
|
||||
elif name == "SelfHostedHuggingFaceLLM":
|
||||
@ -922,6 +938,8 @@ __all__ = [
|
||||
"RWKV",
|
||||
"Replicate",
|
||||
"SagemakerEndpoint",
|
||||
"Sambaverse",
|
||||
"SambaStudio",
|
||||
"SelfHostedHuggingFaceLLM",
|
||||
"SelfHostedPipeline",
|
||||
"SparkLLM",
|
||||
@ -1015,6 +1033,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"replicate": _import_replicate,
|
||||
"rwkv": _import_rwkv,
|
||||
"sagemaker_endpoint": _import_sagemaker_endpoint,
|
||||
"sambaverse": _import_sambaverse,
|
||||
"sambastudio": _import_sambastudio,
|
||||
"self_hosted": _import_self_hosted,
|
||||
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
|
||||
"stochasticai": _import_stochasticai,
|
||||
|
@ -618,10 +618,10 @@ class SambaStudio(LLM):
|
||||
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
SambaStudio(
|
||||
base_url="your SambaStudio environment URL",
|
||||
project_id=set with your SambaStudio project ID.,
|
||||
endpoint_id=set with your SambaStudio endpoint ID.,
|
||||
api_token= set with your SambaStudio endpoint API key.,
|
||||
sambastudio_base_url="your SambaStudio environment URL",
|
||||
sambastudio_project_id=set with your SambaStudio project ID.,
|
||||
sambastudio_endpoint_id=set with your SambaStudio endpoint ID.,
|
||||
sambastudio_api_key= set with your SambaStudio endpoint API key.,
|
||||
streaming=false
|
||||
model_kwargs={
|
||||
"do_sample": False,
|
||||
@ -634,16 +634,16 @@ class SambaStudio(LLM):
|
||||
)
|
||||
"""
|
||||
|
||||
base_url: str = ""
|
||||
sambastudio_base_url: str = ""
|
||||
"""Base url to use"""
|
||||
|
||||
project_id: str = ""
|
||||
sambastudio_project_id: str = ""
|
||||
"""Project id on sambastudio for model"""
|
||||
|
||||
endpoint_id: str = ""
|
||||
sambastudio_endpoint_id: str = ""
|
||||
"""endpoint id on sambastudio for model"""
|
||||
|
||||
api_key: str = ""
|
||||
sambastudio_api_key: str = ""
|
||||
"""sambastudio api key"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
@ -674,16 +674,16 @@ class SambaStudio(LLM):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["base_url"] = get_from_dict_or_env(
|
||||
values["sambastudio_base_url"] = get_from_dict_or_env(
|
||||
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
|
||||
)
|
||||
values["project_id"] = get_from_dict_or_env(
|
||||
values["sambastudio_project_id"] = get_from_dict_or_env(
|
||||
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
|
||||
)
|
||||
values["endpoint_id"] = get_from_dict_or_env(
|
||||
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
|
||||
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
|
||||
)
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values["sambastudio_api_key"] = get_from_dict_or_env(
|
||||
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
|
||||
)
|
||||
return values
|
||||
@ -729,7 +729,11 @@ class SambaStudio(LLM):
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
response = sdk.nlp_predict(
|
||||
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
|
||||
self.sambastudio_project_id,
|
||||
self.sambastudio_endpoint_id,
|
||||
self.sambastudio_api_key,
|
||||
prompt,
|
||||
tuning_params,
|
||||
)
|
||||
if response["status_code"] != 200:
|
||||
optional_detail = response["detail"]
|
||||
@ -755,7 +759,7 @@ class SambaStudio(LLM):
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
ss_endpoint = SSEndpointHandler(self.base_url)
|
||||
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
||||
|
||||
@ -774,7 +778,11 @@ class SambaStudio(LLM):
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
for chunk in sdk.nlp_predict_stream(
|
||||
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
|
||||
self.sambastudio_project_id,
|
||||
self.sambastudio_endpoint_id,
|
||||
self.sambastudio_api_key,
|
||||
prompt,
|
||||
tuning_params,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@ -794,7 +802,7 @@ class SambaStudio(LLM):
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
ss_endpoint = SSEndpointHandler(self.base_url)
|
||||
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
try:
|
||||
if self.streaming:
|
||||
|
@ -77,6 +77,8 @@ EXPECT_ALL = [
|
||||
"RWKV",
|
||||
"Replicate",
|
||||
"SagemakerEndpoint",
|
||||
"Sambaverse",
|
||||
"SambaStudio",
|
||||
"SelfHostedHuggingFaceLLM",
|
||||
"SelfHostedPipeline",
|
||||
"StochasticAI",
|
||||
|
Loading…
Reference in New Issue
Block a user