mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
Bagatur/litellm model name (#9613)
Co-authored-by: ishaan-jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
parent
1720e99397
commit
e99ef12cb1
@ -190,9 +190,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
class ChatLiteLLM(BaseChatModel):
|
||||
"""`LiteLLM` Chat models API.
|
||||
|
||||
To use you must have the google.generativeai Python package installed and
|
||||
either:
|
||||
|
||||
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
|
||||
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
||||
constructor.
|
||||
@ -206,7 +203,8 @@ class ChatLiteLLM(BaseChatModel):
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "gpt-3.5-turbo"
|
||||
model: str = "gpt-3.5-turbo"
|
||||
model_name: Optional[str] = None
|
||||
"""Model name to use."""
|
||||
openai_api_key: Optional[str] = None
|
||||
azure_api_key: Optional[str] = None
|
||||
@ -217,8 +215,9 @@ class ChatLiteLLM(BaseChatModel):
|
||||
streaming: bool = False
|
||||
api_base: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
custom_llm_provider: Optional[str] = None
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
temperature: Optional[float] = None
|
||||
temperature: Optional[float] = 1
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Run inference with this temperature. Must by in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
@ -238,8 +237,11 @@ class ChatLiteLLM(BaseChatModel):
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
set_model_value = self.model
|
||||
if self.model_name is not None:
|
||||
set_model_value = self.model_name
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"model": set_model_value,
|
||||
"force_timeout": self.request_timeout,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
@ -251,10 +253,13 @@ class ChatLiteLLM(BaseChatModel):
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the openai client."""
|
||||
set_model_value = self.model
|
||||
if self.model_name is not None:
|
||||
set_model_value = self.model_name
|
||||
self.client.api_base = self.api_base
|
||||
self.client.organization = self.organization
|
||||
creds: Dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"model": set_model_value,
|
||||
"force_timeout": self.request_timeout,
|
||||
}
|
||||
return {**self._default_params, **creds}
|
||||
@ -347,7 +352,10 @@ class ChatLiteLLM(BaseChatModel):
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
set_model_value = self.model
|
||||
if self.model_name is not None:
|
||||
set_model_value = self.model_name
|
||||
llm_output = {"token_usage": token_usage, "model": set_model_value}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
@ -437,8 +445,11 @@ class ChatLiteLLM(BaseChatModel):
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
set_model_value = self.model
|
||||
if self.model_name is not None:
|
||||
set_model_value = self.model_name
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"model": set_model_value,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
|
Loading…
Reference in New Issue
Block a user