Bagatur/litellm model name (#9613)

Co-authored-by: ishaan-jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
Bagatur 2023-08-22 07:44:00 -07:00 committed by GitHub
parent 1720e99397
commit e99ef12cb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -190,9 +190,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
class ChatLiteLLM(BaseChatModel):
"""`LiteLLM` Chat models API.
To use you must have the google.generativeai Python package installed and
either:
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
constructor.
@ -206,7 +203,8 @@ class ChatLiteLLM(BaseChatModel):
"""
client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo"
model: str = "gpt-3.5-turbo"
model_name: Optional[str] = None
"""Model name to use."""
openai_api_key: Optional[str] = None
azure_api_key: Optional[str] = None
@ -217,8 +215,9 @@ class ChatLiteLLM(BaseChatModel):
streaming: bool = False
api_base: Optional[str] = None
organization: Optional[str] = None
custom_llm_provider: Optional[str] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
temperature: Optional[float] = None
temperature: Optional[float] = 1
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Run inference with this temperature. Must by in the closed
interval [0.0, 1.0]."""
@ -238,8 +237,11 @@ class ChatLiteLLM(BaseChatModel):
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
return {
"model": self.model_name,
"model": set_model_value,
"force_timeout": self.request_timeout,
"max_tokens": self.max_tokens,
"stream": self.streaming,
@ -251,10 +253,13 @@ class ChatLiteLLM(BaseChatModel):
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the openai client."""
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
self.client.api_base = self.api_base
self.client.organization = self.organization
creds: Dict[str, Any] = {
"model": self.model_name,
"model": set_model_value,
"force_timeout": self.request_timeout,
}
return {**self._default_params, **creds}
@ -347,7 +352,10 @@ class ChatLiteLLM(BaseChatModel):
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
llm_output = {"token_usage": token_usage, "model": set_model_value}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
@ -437,8 +445,11 @@ class ChatLiteLLM(BaseChatModel):
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
return {
"model_name": self.model_name,
"model": set_model_value,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,