mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
community[patch]: sambanova llm integration improvement (#23137)
- **Description:** sambanova sambaverse integration improvement: removed input parsing that was changing raw user input, and was making to use process prompt parameter as true mandatory
This commit is contained in:
parent
e162893d7f
commit
b3e53ffca0
@ -87,7 +87,6 @@
|
|||||||
" \"do_sample\": True,\n",
|
" \"do_sample\": True,\n",
|
||||||
" \"max_tokens_to_generate\": 1000,\n",
|
" \"max_tokens_to_generate\": 1000,\n",
|
||||||
" \"temperature\": 0.01,\n",
|
" \"temperature\": 0.01,\n",
|
||||||
" \"process_prompt\": True,\n",
|
|
||||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||||
" # \"repetition_penalty\": 1.0,\n",
|
" # \"repetition_penalty\": 1.0,\n",
|
||||||
@ -116,7 +115,6 @@
|
|||||||
" \"do_sample\": True,\n",
|
" \"do_sample\": True,\n",
|
||||||
" \"max_tokens_to_generate\": 1000,\n",
|
" \"max_tokens_to_generate\": 1000,\n",
|
||||||
" \"temperature\": 0.01,\n",
|
" \"temperature\": 0.01,\n",
|
||||||
" \"process_prompt\": True,\n",
|
|
||||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||||
" # \"repetition_penalty\": 1.0,\n",
|
" # \"repetition_penalty\": 1.0,\n",
|
||||||
@ -177,14 +175,16 @@
|
|||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
|
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
|
||||||
"# sambastudio_base_uri = \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
|
"sambastudio_base_uri = (\n",
|
||||||
|
" \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
|
||||||
|
")\n",
|
||||||
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
|
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
|
||||||
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
|
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
|
||||||
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
|
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Set the environment variables\n",
|
"# Set the environment variables\n",
|
||||||
"os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n",
|
"os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n",
|
||||||
"# os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
|
"os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
|
||||||
"os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n",
|
"os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n",
|
||||||
"os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
|
"os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
|
||||||
"os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key"
|
"os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key"
|
||||||
@ -247,6 +247,40 @@
|
|||||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||||
" print(chunk, end=\"\", flush=True)"
|
" print(chunk, end=\"\", flush=True)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can also call a CoE endpoint expert model "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Using a CoE endpoint\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain_community.llms.sambanova import SambaStudio\n",
|
||||||
|
"\n",
|
||||||
|
"llm = SambaStudio(\n",
|
||||||
|
" streaming=False,\n",
|
||||||
|
" model_kwargs={\n",
|
||||||
|
" \"do_sample\": True,\n",
|
||||||
|
" \"max_tokens_to_generate\": 1000,\n",
|
||||||
|
" \"temperature\": 0.01,\n",
|
||||||
|
" \"select_expert\": \"Meta-Llama-3-8B-Instruct\",\n",
|
||||||
|
" # \"repetition_penalty\": 1.0,\n",
|
||||||
|
" # \"top_k\": 50,\n",
|
||||||
|
" # \"top_logprobs\": 0,\n",
|
||||||
|
" # \"top_p\": 1.0\n",
|
||||||
|
" },\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -43,7 +43,7 @@ class SVEndpointHandler:
|
|||||||
|
|
||||||
:param requests.Response response: the response object to process
|
:param requests.Response response: the response object to process
|
||||||
:return: the response dict
|
:return: the response dict
|
||||||
:rtype: dict
|
:type: dict
|
||||||
"""
|
"""
|
||||||
result: Dict[str, Any] = {}
|
result: Dict[str, Any] = {}
|
||||||
try:
|
try:
|
||||||
@ -87,7 +87,7 @@ class SVEndpointHandler:
|
|||||||
"""
|
"""
|
||||||
Return the full API URL for a given path.
|
Return the full API URL for a given path.
|
||||||
:returns: the full API URL for the sub-path
|
:returns: the full API URL for the sub-path
|
||||||
:rtype: str
|
:type: str
|
||||||
"""
|
"""
|
||||||
return f"{self.host_url}{self.API_BASE_PATH}"
|
return f"{self.host_url}{self.API_BASE_PATH}"
|
||||||
|
|
||||||
@ -108,23 +108,12 @@ class SVEndpointHandler:
|
|||||||
:param str input_str: Input string
|
:param str input_str: Input string
|
||||||
:param str params: Input params string
|
:param str params: Input params string
|
||||||
:returns: Prediction results
|
:returns: Prediction results
|
||||||
:rtype: dict
|
:type: dict
|
||||||
"""
|
"""
|
||||||
parsed_element = {
|
|
||||||
"conversation_id": "sambaverse-conversation-id",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"message_id": 0,
|
|
||||||
"role": "user",
|
|
||||||
"content": input,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
parsed_input = json.dumps(parsed_element)
|
|
||||||
if params:
|
if params:
|
||||||
data = {"instance": parsed_input, "params": json.loads(params)}
|
data = {"instance": input, "params": json.loads(params)}
|
||||||
else:
|
else:
|
||||||
data = {"instance": parsed_input}
|
data = {"instance": input}
|
||||||
response = self.http_session.post(
|
response = self.http_session.post(
|
||||||
self._get_full_url(),
|
self._get_full_url(),
|
||||||
headers={
|
headers={
|
||||||
@ -152,23 +141,12 @@ class SVEndpointHandler:
|
|||||||
:param str input_str: Input string
|
:param str input_str: Input string
|
||||||
:param str params: Input params string
|
:param str params: Input params string
|
||||||
:returns: Prediction results
|
:returns: Prediction results
|
||||||
:rtype: dict
|
:type: dict
|
||||||
"""
|
"""
|
||||||
parsed_element = {
|
|
||||||
"conversation_id": "sambaverse-conversation-id",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"message_id": 0,
|
|
||||||
"role": "user",
|
|
||||||
"content": input,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
parsed_input = json.dumps(parsed_element)
|
|
||||||
if params:
|
if params:
|
||||||
data = {"instance": parsed_input, "params": json.loads(params)}
|
data = {"instance": input, "params": json.loads(params)}
|
||||||
else:
|
else:
|
||||||
data = {"instance": parsed_input}
|
data = {"instance": input}
|
||||||
# Streaming output
|
# Streaming output
|
||||||
response = self.http_session.post(
|
response = self.http_session.post(
|
||||||
self._get_full_url(),
|
self._get_full_url(),
|
||||||
@ -522,7 +500,7 @@ class SSEndpointHandler:
|
|||||||
|
|
||||||
:param requests.Response response: the response object to process
|
:param requests.Response response: the response object to process
|
||||||
:return: the response dict
|
:return: the response dict
|
||||||
:rtype: dict
|
:type: dict
|
||||||
"""
|
"""
|
||||||
result: Dict[str, Any] = {}
|
result: Dict[str, Any] = {}
|
||||||
try:
|
try:
|
||||||
@ -581,7 +559,7 @@ class SSEndpointHandler:
|
|||||||
|
|
||||||
:param str path: the sub-path
|
:param str path: the sub-path
|
||||||
:returns: the full API URL for the sub-path
|
:returns: the full API URL for the sub-path
|
||||||
:rtype: str
|
:type: str
|
||||||
"""
|
"""
|
||||||
return f"{self.host_url}/{self.api_base_uri}/{path}"
|
return f"{self.host_url}/{self.api_base_uri}/{path}"
|
||||||
|
|
||||||
@ -603,7 +581,7 @@ class SSEndpointHandler:
|
|||||||
:param str input_str: Input string
|
:param str input_str: Input string
|
||||||
:param str params: Input params string
|
:param str params: Input params string
|
||||||
:returns: Prediction results
|
:returns: Prediction results
|
||||||
:rtype: dict
|
:type: dict
|
||||||
"""
|
"""
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
input = [input]
|
input = [input]
|
||||||
@ -645,7 +623,7 @@ class SSEndpointHandler:
|
|||||||
:param str input_str: Input string
|
:param str input_str: Input string
|
||||||
:param str params: Input params string
|
:param str params: Input params string
|
||||||
:returns: Prediction results
|
:returns: Prediction results
|
||||||
:rtype: dict
|
:type: dict
|
||||||
"""
|
"""
|
||||||
if "nlp" in self.api_base_uri:
|
if "nlp" in self.api_base_uri:
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
|
Loading…
Reference in New Issue
Block a user