together[minor]: Update endpoint to non deprecated version (#19649)

- **Updating Together.ai Endpoint**: "langchain_together: Updated
Deprecated endpoint for partner package"

- Description: The inference API of together is deprecates, do replaced
with completions and made corresponding changes.
- Twitter handle: @dev_yashmathur

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Yash Mathur 2024-04-01 02:51:46 +05:30 committed by GitHub
parent 5ab6b39098
commit c42ec58578
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
"""Wrapper around Together AI's Completion API.""" """Wrapper around Together AI's Completion API."""
import logging import logging
import warnings
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import requests
@ -34,13 +35,14 @@ class Together(LLM):
model = Together(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1") model = Together(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
""" """
base_url: str = "https://api.together.xyz/inference" base_url: str = "https://api.together.xyz/v1/completions"
"""Base inference API URL.""" """Base completions API URL."""
together_api_key: SecretStr together_api_key: SecretStr
"""Together AI API key. Get it here: https://api.together.xyz/settings/api-keys""" """Together AI API key. Get it here: https://api.together.xyz/settings/api-keys"""
model: str model: str
"""Model name. Available models listed here: """Model name. Available models listed here:
https://docs.together.ai/docs/inference-models Base Models: https://docs.together.ai/docs/inference-models#language-models
Chat Models: https://docs.together.ai/docs/inference-models#chat-models
""" """
temperature: Optional[float] = None temperature: Optional[float] = None
"""Model temperature.""" """Model temperature."""
@ -82,13 +84,28 @@ class Together(LLM):
) )
return values return values
@root_validator()
def validate_max_tokens(cls, values: Dict) -> Dict:
"""
The v1 completions endpoint, has max_tokens as required parameter.
Set a default value and warn if the parameter is missing.
"""
if values.get("max_tokens") is None:
warnings.warn(
"The completions endpoint, has 'max_tokens' as required argument. "
"The default value is being set to 200 "
"Consider setting this value, when initializing LLM"
)
values["max_tokens"] = 200 # Default Value
return values
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of model.""" """Return type of model."""
return "together" return "together"
def _format_output(self, output: dict) -> str: def _format_output(self, output: dict) -> str:
return output["output"]["choices"][0]["text"] return output["choices"][0]["text"]
@staticmethod @staticmethod
def get_user_agent() -> str: def get_user_agent() -> str:
@ -148,9 +165,6 @@ class Together(LLM):
) )
data = response.json() data = response.json()
if data.get("status") != "finished":
err_msg = data.get("error", "Undefined Error")
raise Exception(err_msg)
output = self._format_output(data) output = self._format_output(data)
@ -203,9 +217,5 @@ class Together(LLM):
response_json = await response.json() response_json = await response.json()
if response_json.get("status") != "finished":
err_msg = response_json.get("error", "Undefined Error")
raise Exception(err_msg)
output = self._format_output(response_json) output = self._format_output(response_json)
return output return output