diff --git a/libs/partners/together/langchain_together/llms.py b/libs/partners/together/langchain_together/llms.py index 00c56bc915e..91d38b90801 100644 --- a/libs/partners/together/langchain_together/llms.py +++ b/libs/partners/together/langchain_together/llms.py @@ -1,5 +1,6 @@ """Wrapper around Together AI's Completion API.""" import logging +import warnings from typing import Any, Dict, List, Optional import requests @@ -34,13 +35,14 @@ class Together(LLM): model = Together(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1") """ - base_url: str = "https://api.together.xyz/inference" - """Base inference API URL.""" + base_url: str = "https://api.together.xyz/v1/completions" + """Base completions API URL.""" together_api_key: SecretStr """Together AI API key. Get it here: https://api.together.xyz/settings/api-keys""" model: str """Model name. Available models listed here: - https://docs.together.ai/docs/inference-models + Base Models: https://docs.together.ai/docs/inference-models#language-models + Chat Models: https://docs.together.ai/docs/inference-models#chat-models """ temperature: Optional[float] = None """Model temperature.""" @@ -82,13 +84,28 @@ class Together(LLM): ) return values + @root_validator() + def validate_max_tokens(cls, values: Dict) -> Dict: + """ + The v1 completions endpoint, has max_tokens as required parameter. + Set a default value and warn if the parameter is missing. + """ + if values.get("max_tokens") is None: + warnings.warn( + "The completions endpoint, has 'max_tokens' as required argument. " + "The default value is being set to 200 " + "Consider setting this value, when initializing LLM" + ) + values["max_tokens"] = 200 # Default Value + return values + @property def _llm_type(self) -> str: """Return type of model.""" return "together" def _format_output(self, output: dict) -> str: - return output["output"]["choices"][0]["text"] + return output["choices"][0]["text"] @staticmethod def get_user_agent() -> str: @@ -148,9 +165,6 @@ class Together(LLM): ) data = response.json() - if data.get("status") != "finished": - err_msg = data.get("error", "Undefined Error") - raise Exception(err_msg) output = self._format_output(data) @@ -203,9 +217,5 @@ class Together(LLM): response_json = await response.json() - if response_json.get("status") != "finished": - err_msg = response_json.get("error", "Undefined Error") - raise Exception(err_msg) - output = self._format_output(response_json) return output