mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community[patch]: sambanova llm integration improvement (#23137)
- **Description:** sambanova sambaverse integration improvement: removed input parsing that was changing raw user input, and was making to use process prompt parameter as true mandatory
This commit is contained in:
parent
e162893d7f
commit
b3e53ffca0
@ -87,7 +87,6 @@
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"process_prompt\": True,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
@ -116,7 +115,6 @@
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"process_prompt\": True,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
@ -177,14 +175,16 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
|
||||
"# sambastudio_base_uri = \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
|
||||
"sambastudio_base_uri = (\n",
|
||||
" \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
|
||||
")\n",
|
||||
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
|
||||
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
|
||||
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
|
||||
"\n",
|
||||
"# Set the environment variables\n",
|
||||
"os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n",
|
||||
"# os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
|
||||
"os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
|
||||
"os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n",
|
||||
"os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
|
||||
"os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key"
|
||||
@ -247,6 +247,40 @@
|
||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also call a CoE endpoint expert model "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Using a CoE endpoint\n",
|
||||
"\n",
|
||||
"from langchain_community.llms.sambanova import SambaStudio\n",
|
||||
"\n",
|
||||
"llm = SambaStudio(\n",
|
||||
" streaming=False,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"select_expert\": \"Meta-Llama-3-8B-Instruct\",\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_logprobs\": 0,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -43,7 +43,7 @@ class SVEndpointHandler:
|
||||
|
||||
:param requests.Response response: the response object to process
|
||||
:return: the response dict
|
||||
:rtype: dict
|
||||
:type: dict
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
try:
|
||||
@ -87,7 +87,7 @@ class SVEndpointHandler:
|
||||
"""
|
||||
Return the full API URL for a given path.
|
||||
:returns: the full API URL for the sub-path
|
||||
:rtype: str
|
||||
:type: str
|
||||
"""
|
||||
return f"{self.host_url}{self.API_BASE_PATH}"
|
||||
|
||||
@ -108,23 +108,12 @@ class SVEndpointHandler:
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:rtype: dict
|
||||
:type: dict
|
||||
"""
|
||||
parsed_element = {
|
||||
"conversation_id": "sambaverse-conversation-id",
|
||||
"messages": [
|
||||
{
|
||||
"message_id": 0,
|
||||
"role": "user",
|
||||
"content": input,
|
||||
}
|
||||
],
|
||||
}
|
||||
parsed_input = json.dumps(parsed_element)
|
||||
if params:
|
||||
data = {"instance": parsed_input, "params": json.loads(params)}
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": parsed_input}
|
||||
data = {"instance": input}
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
headers={
|
||||
@ -152,23 +141,12 @@ class SVEndpointHandler:
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:rtype: dict
|
||||
:type: dict
|
||||
"""
|
||||
parsed_element = {
|
||||
"conversation_id": "sambaverse-conversation-id",
|
||||
"messages": [
|
||||
{
|
||||
"message_id": 0,
|
||||
"role": "user",
|
||||
"content": input,
|
||||
}
|
||||
],
|
||||
}
|
||||
parsed_input = json.dumps(parsed_element)
|
||||
if params:
|
||||
data = {"instance": parsed_input, "params": json.loads(params)}
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": parsed_input}
|
||||
data = {"instance": input}
|
||||
# Streaming output
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
@ -522,7 +500,7 @@ class SSEndpointHandler:
|
||||
|
||||
:param requests.Response response: the response object to process
|
||||
:return: the response dict
|
||||
:rtype: dict
|
||||
:type: dict
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
try:
|
||||
@ -581,7 +559,7 @@ class SSEndpointHandler:
|
||||
|
||||
:param str path: the sub-path
|
||||
:returns: the full API URL for the sub-path
|
||||
:rtype: str
|
||||
:type: str
|
||||
"""
|
||||
return f"{self.host_url}/{self.api_base_uri}/{path}"
|
||||
|
||||
@ -603,7 +581,7 @@ class SSEndpointHandler:
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:rtype: dict
|
||||
:type: dict
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
@ -645,7 +623,7 @@ class SSEndpointHandler:
|
||||
:param str input_str: Input string
|
||||
:param str params: Input params string
|
||||
:returns: Prediction results
|
||||
:rtype: dict
|
||||
:type: dict
|
||||
"""
|
||||
if "nlp" in self.api_base_uri:
|
||||
if isinstance(input, str):
|
||||
|
Loading…
Reference in New Issue
Block a user