diff --git a/libs/langchain/langchain/chat_models/litellm.py b/libs/langchain/langchain/chat_models/litellm.py index d23549dc51c..f429c322188 100644 --- a/libs/langchain/langchain/chat_models/litellm.py +++ b/libs/langchain/langchain/chat_models/litellm.py @@ -190,9 +190,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: class ChatLiteLLM(BaseChatModel): """`LiteLLM` Chat models API. - To use you must have the google.generativeai Python package installed and - either: - 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or 2. Pass your API key using the google_api_key kwarg to the ChatGoogle constructor. @@ -206,7 +203,8 @@ class ChatLiteLLM(BaseChatModel): """ client: Any #: :meta private: - model_name: str = "gpt-3.5-turbo" + model: str = "gpt-3.5-turbo" + model_name: Optional[str] = None """Model name to use.""" openai_api_key: Optional[str] = None azure_api_key: Optional[str] = None @@ -217,8 +215,9 @@ class ChatLiteLLM(BaseChatModel): streaming: bool = False api_base: Optional[str] = None organization: Optional[str] = None + custom_llm_provider: Optional[str] = None request_timeout: Optional[Union[float, Tuple[float, float]]] = None - temperature: Optional[float] = None + temperature: Optional[float] = 1 model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Run inference with this temperature. Must by in the closed interval [0.0, 1.0].""" @@ -238,8 +237,11 @@ class ChatLiteLLM(BaseChatModel): @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" + set_model_value = self.model + if self.model_name is not None: + set_model_value = self.model_name return { - "model": self.model_name, + "model": set_model_value, "force_timeout": self.request_timeout, "max_tokens": self.max_tokens, "stream": self.streaming, @@ -251,10 +253,13 @@ class ChatLiteLLM(BaseChatModel): @property def _client_params(self) -> Dict[str, Any]: """Get the parameters used for the openai client.""" + set_model_value = self.model + if self.model_name is not None: + set_model_value = self.model_name self.client.api_base = self.api_base self.client.organization = self.organization creds: Dict[str, Any] = { - "model": self.model_name, + "model": set_model_value, "force_timeout": self.request_timeout, } return {**self._default_params, **creds} @@ -347,7 +352,10 @@ class ChatLiteLLM(BaseChatModel): ) generations.append(gen) token_usage = response.get("usage", {}) - llm_output = {"token_usage": token_usage, "model_name": self.model_name} + set_model_value = self.model + if self.model_name is not None: + set_model_value = self.model_name + llm_output = {"token_usage": token_usage, "model": set_model_value} return ChatResult(generations=generations, llm_output=llm_output) def _create_message_dicts( @@ -437,8 +445,11 @@ class ChatLiteLLM(BaseChatModel): @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" + set_model_value = self.model + if self.model_name is not None: + set_model_value = self.model_name return { - "model_name": self.model_name, + "model": set_model_value, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k,