mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 17:33:53 +00:00
add alias for model (#4553)
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
7642f2159c
commit
c9a362e482
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import List, Optional, Sequence, Set
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -68,3 +68,12 @@ class BaseLanguageModel(BaseModel, ABC):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the message."""
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
|
||||
@classmethod
|
||||
def all_required_field_names(cls) -> Set:
|
||||
all_required_field_names = set()
|
||||
for field in cls.__fields__.values():
|
||||
all_required_field_names.add(field.name)
|
||||
if field.has_alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
return all_required_field_names
|
||||
|
@ -112,7 +112,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "gpt-3.5-turbo"
|
||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
@ -138,12 +138,12 @@ class ChatOpenAI(BaseChatModel):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.ignore
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
|
||||
all_required_field_names = cls.all_required_field_names()
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
@ -156,8 +156,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
disallowed_model_kwargs = all_required_field_names | {"model"}
|
||||
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
|
@ -124,7 +124,7 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Wrapper around OpenAI large language models."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "text-davinci-003"
|
||||
model_name: str = Field("text-davinci-003", alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
@ -178,12 +178,12 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.ignore
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
|
||||
all_required_field_names = cls.all_required_field_names()
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
@ -196,8 +196,7 @@ class BaseOpenAI(BaseLLM):
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
disallowed_model_kwargs = all_required_field_names | {"model"}
|
||||
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
|
@ -25,6 +25,14 @@ def test_chat_openai() -> None:
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_openai_model() -> None:
|
||||
"""Test ChatOpenAI wrapper handles model_name."""
|
||||
chat = ChatOpenAI(model="foo")
|
||||
assert chat.model_name == "foo"
|
||||
chat = ChatOpenAI(model_name="bar")
|
||||
assert chat.model_name == "bar"
|
||||
|
||||
|
||||
def test_chat_openai_system_message() -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
|
@ -19,6 +19,13 @@ def test_openai_call() -> None:
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_openai_model_param() -> None:
|
||||
llm = OpenAI(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = OpenAI(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
def test_openai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
|
Loading…
Reference in New Issue
Block a user