mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
add alias for model (#4553)
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
7642f2159c
commit
c9a362e482
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Sequence
|
from typing import List, Optional, Sequence, Set
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -68,3 +68,12 @@ class BaseLanguageModel(BaseModel, ABC):
|
|||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
"""Get the number of tokens in the message."""
|
"""Get the number of tokens in the message."""
|
||||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all_required_field_names(cls) -> Set:
|
||||||
|
all_required_field_names = set()
|
||||||
|
for field in cls.__fields__.values():
|
||||||
|
all_required_field_names.add(field.name)
|
||||||
|
if field.has_alias:
|
||||||
|
all_required_field_names.add(field.alias)
|
||||||
|
return all_required_field_names
|
||||||
|
@ -112,7 +112,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
model_name: str = "gpt-3.5-turbo"
|
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
@ -138,12 +138,12 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.ignore
|
extra = Extra.ignore
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = cls.all_required_field_names()
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
@ -156,8 +156,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
extra[field_name] = values.pop(field_name)
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
disallowed_model_kwargs = all_required_field_names | {"model"}
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
|
|
||||||
if invalid_model_kwargs:
|
if invalid_model_kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
@ -124,7 +124,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
"""Wrapper around OpenAI large language models."""
|
"""Wrapper around OpenAI large language models."""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
model_name: str = "text-davinci-003"
|
model_name: str = Field("text-davinci-003", alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
@ -178,12 +178,12 @@ class BaseOpenAI(BaseLLM):
|
|||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.ignore
|
extra = Extra.ignore
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = cls.all_required_field_names()
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
@ -196,8 +196,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
)
|
)
|
||||||
extra[field_name] = values.pop(field_name)
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
disallowed_model_kwargs = all_required_field_names | {"model"}
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
|
|
||||||
if invalid_model_kwargs:
|
if invalid_model_kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
@ -25,6 +25,14 @@ def test_chat_openai() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_openai_model() -> None:
|
||||||
|
"""Test ChatOpenAI wrapper handles model_name."""
|
||||||
|
chat = ChatOpenAI(model="foo")
|
||||||
|
assert chat.model_name == "foo"
|
||||||
|
chat = ChatOpenAI(model_name="bar")
|
||||||
|
assert chat.model_name == "bar"
|
||||||
|
|
||||||
|
|
||||||
def test_chat_openai_system_message() -> None:
|
def test_chat_openai_system_message() -> None:
|
||||||
"""Test ChatOpenAI wrapper with system message."""
|
"""Test ChatOpenAI wrapper with system message."""
|
||||||
chat = ChatOpenAI(max_tokens=10)
|
chat = ChatOpenAI(max_tokens=10)
|
||||||
|
@ -19,6 +19,13 @@ def test_openai_call() -> None:
|
|||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_model_param() -> None:
|
||||||
|
llm = OpenAI(model="foo")
|
||||||
|
assert llm.model_name == "foo"
|
||||||
|
llm = OpenAI(model_name="foo")
|
||||||
|
assert llm.model_name == "foo"
|
||||||
|
|
||||||
|
|
||||||
def test_openai_extra_kwargs() -> None:
|
def test_openai_extra_kwargs() -> None:
|
||||||
"""Test extra kwargs to openai."""
|
"""Test extra kwargs to openai."""
|
||||||
# Check that foo is saved in extra_kwargs.
|
# Check that foo is saved in extra_kwargs.
|
||||||
|
Loading…
Reference in New Issue
Block a user