mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community: minor changes sambanova integration (#21231)
- **Description:** fix: variable names in root validator not allowing pass credentials as named parameters in llm instancing, also added sambanova's sambaverse and sambastudio llms to __init__.py for module import
This commit is contained in:
parent
d9a61c0fa9
commit
df1c10260c
@ -510,6 +510,18 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
|
|||||||
return SagemakerEndpoint
|
return SagemakerEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
def _import_sambaverse() -> Type[BaseLLM]:
|
||||||
|
from langchain_community.llms.sambanova import Sambaverse
|
||||||
|
|
||||||
|
return Sambaverse
|
||||||
|
|
||||||
|
|
||||||
|
def _import_sambastudio() -> Type[BaseLLM]:
|
||||||
|
from langchain_community.llms.sambanova import SambaStudio
|
||||||
|
|
||||||
|
return SambaStudio
|
||||||
|
|
||||||
|
|
||||||
def _import_self_hosted() -> Type[BaseLLM]:
|
def _import_self_hosted() -> Type[BaseLLM]:
|
||||||
from langchain_community.llms.self_hosted import SelfHostedPipeline
|
from langchain_community.llms.self_hosted import SelfHostedPipeline
|
||||||
|
|
||||||
@ -793,6 +805,10 @@ def __getattr__(name: str) -> Any:
|
|||||||
return _import_rwkv()
|
return _import_rwkv()
|
||||||
elif name == "SagemakerEndpoint":
|
elif name == "SagemakerEndpoint":
|
||||||
return _import_sagemaker_endpoint()
|
return _import_sagemaker_endpoint()
|
||||||
|
elif name == "Sambaverse":
|
||||||
|
return _import_sambaverse()
|
||||||
|
elif name == "SambaStudio":
|
||||||
|
return _import_sambastudio()
|
||||||
elif name == "SelfHostedPipeline":
|
elif name == "SelfHostedPipeline":
|
||||||
return _import_self_hosted()
|
return _import_self_hosted()
|
||||||
elif name == "SelfHostedHuggingFaceLLM":
|
elif name == "SelfHostedHuggingFaceLLM":
|
||||||
@ -922,6 +938,8 @@ __all__ = [
|
|||||||
"RWKV",
|
"RWKV",
|
||||||
"Replicate",
|
"Replicate",
|
||||||
"SagemakerEndpoint",
|
"SagemakerEndpoint",
|
||||||
|
"Sambaverse",
|
||||||
|
"SambaStudio",
|
||||||
"SelfHostedHuggingFaceLLM",
|
"SelfHostedHuggingFaceLLM",
|
||||||
"SelfHostedPipeline",
|
"SelfHostedPipeline",
|
||||||
"SparkLLM",
|
"SparkLLM",
|
||||||
@ -1015,6 +1033,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
|||||||
"replicate": _import_replicate,
|
"replicate": _import_replicate,
|
||||||
"rwkv": _import_rwkv,
|
"rwkv": _import_rwkv,
|
||||||
"sagemaker_endpoint": _import_sagemaker_endpoint,
|
"sagemaker_endpoint": _import_sagemaker_endpoint,
|
||||||
|
"sambaverse": _import_sambaverse,
|
||||||
|
"sambastudio": _import_sambastudio,
|
||||||
"self_hosted": _import_self_hosted,
|
"self_hosted": _import_self_hosted,
|
||||||
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
|
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
|
||||||
"stochasticai": _import_stochasticai,
|
"stochasticai": _import_stochasticai,
|
||||||
|
@ -618,10 +618,10 @@ class SambaStudio(LLM):
|
|||||||
|
|
||||||
from langchain_community.llms.sambanova import Sambaverse
|
from langchain_community.llms.sambanova import Sambaverse
|
||||||
SambaStudio(
|
SambaStudio(
|
||||||
base_url="your SambaStudio environment URL",
|
sambastudio_base_url="your SambaStudio environment URL",
|
||||||
project_id=set with your SambaStudio project ID.,
|
sambastudio_project_id=set with your SambaStudio project ID.,
|
||||||
endpoint_id=set with your SambaStudio endpoint ID.,
|
sambastudio_endpoint_id=set with your SambaStudio endpoint ID.,
|
||||||
api_token= set with your SambaStudio endpoint API key.,
|
sambastudio_api_key= set with your SambaStudio endpoint API key.,
|
||||||
streaming=false
|
streaming=false
|
||||||
model_kwargs={
|
model_kwargs={
|
||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
@ -634,16 +634,16 @@ class SambaStudio(LLM):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_url: str = ""
|
sambastudio_base_url: str = ""
|
||||||
"""Base url to use"""
|
"""Base url to use"""
|
||||||
|
|
||||||
project_id: str = ""
|
sambastudio_project_id: str = ""
|
||||||
"""Project id on sambastudio for model"""
|
"""Project id on sambastudio for model"""
|
||||||
|
|
||||||
endpoint_id: str = ""
|
sambastudio_endpoint_id: str = ""
|
||||||
"""endpoint id on sambastudio for model"""
|
"""endpoint id on sambastudio for model"""
|
||||||
|
|
||||||
api_key: str = ""
|
sambastudio_api_key: str = ""
|
||||||
"""sambastudio api key"""
|
"""sambastudio api key"""
|
||||||
|
|
||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
@ -674,16 +674,16 @@ class SambaStudio(LLM):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
values["base_url"] = get_from_dict_or_env(
|
values["sambastudio_base_url"] = get_from_dict_or_env(
|
||||||
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
|
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
|
||||||
)
|
)
|
||||||
values["project_id"] = get_from_dict_or_env(
|
values["sambastudio_project_id"] = get_from_dict_or_env(
|
||||||
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
|
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
|
||||||
)
|
)
|
||||||
values["endpoint_id"] = get_from_dict_or_env(
|
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
|
||||||
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
|
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
|
||||||
)
|
)
|
||||||
values["api_key"] = get_from_dict_or_env(
|
values["sambastudio_api_key"] = get_from_dict_or_env(
|
||||||
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
|
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
@ -729,7 +729,11 @@ class SambaStudio(LLM):
|
|||||||
ValueError: If the prediction fails.
|
ValueError: If the prediction fails.
|
||||||
"""
|
"""
|
||||||
response = sdk.nlp_predict(
|
response = sdk.nlp_predict(
|
||||||
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
|
self.sambastudio_project_id,
|
||||||
|
self.sambastudio_endpoint_id,
|
||||||
|
self.sambastudio_api_key,
|
||||||
|
prompt,
|
||||||
|
tuning_params,
|
||||||
)
|
)
|
||||||
if response["status_code"] != 200:
|
if response["status_code"] != 200:
|
||||||
optional_detail = response["detail"]
|
optional_detail = response["detail"]
|
||||||
@ -755,7 +759,7 @@ class SambaStudio(LLM):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the prediction fails.
|
ValueError: If the prediction fails.
|
||||||
"""
|
"""
|
||||||
ss_endpoint = SSEndpointHandler(self.base_url)
|
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
|
||||||
tuning_params = self._get_tuning_params(stop)
|
tuning_params = self._get_tuning_params(stop)
|
||||||
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
||||||
|
|
||||||
@ -774,7 +778,11 @@ class SambaStudio(LLM):
|
|||||||
An iterator of GenerationChunks.
|
An iterator of GenerationChunks.
|
||||||
"""
|
"""
|
||||||
for chunk in sdk.nlp_predict_stream(
|
for chunk in sdk.nlp_predict_stream(
|
||||||
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
|
self.sambastudio_project_id,
|
||||||
|
self.sambastudio_endpoint_id,
|
||||||
|
self.sambastudio_api_key,
|
||||||
|
prompt,
|
||||||
|
tuning_params,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@ -794,7 +802,7 @@ class SambaStudio(LLM):
|
|||||||
Returns:
|
Returns:
|
||||||
The string generated by the model.
|
The string generated by the model.
|
||||||
"""
|
"""
|
||||||
ss_endpoint = SSEndpointHandler(self.base_url)
|
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
|
||||||
tuning_params = self._get_tuning_params(stop)
|
tuning_params = self._get_tuning_params(stop)
|
||||||
try:
|
try:
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
|
@ -77,6 +77,8 @@ EXPECT_ALL = [
|
|||||||
"RWKV",
|
"RWKV",
|
||||||
"Replicate",
|
"Replicate",
|
||||||
"SagemakerEndpoint",
|
"SagemakerEndpoint",
|
||||||
|
"Sambaverse",
|
||||||
|
"SambaStudio",
|
||||||
"SelfHostedHuggingFaceLLM",
|
"SelfHostedHuggingFaceLLM",
|
||||||
"SelfHostedPipeline",
|
"SelfHostedPipeline",
|
||||||
"StochasticAI",
|
"StochasticAI",
|
||||||
|
Loading…
Reference in New Issue
Block a user