groq[major]: upgrade pydantic (#26036)

This commit is contained in:
ccurme
2024-09-04 13:41:40 -04:00
committed by GitHub
parent 163d6fe8ef
commit 51c6899850
6 changed files with 55 additions and 80 deletions

View File

@@ -65,12 +65,6 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
@@ -83,6 +77,14 @@ from langchain_core.utils.function_calling import (
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self
class ChatGroq(BaseChatModel):
@@ -225,7 +227,7 @@ class ChatGroq(BaseChatModel):
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
@@ -256,7 +258,7 @@ class ChatGroq(BaseChatModel):
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
@@ -343,13 +345,13 @@ class ChatGroq(BaseChatModel):
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(
populate_by_name=True,
)
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
@@ -374,38 +376,38 @@ class ChatGroq(BaseChatModel):
values["model_kwargs"] = extra
return values
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
if values["temperature"] == 0:
values["temperature"] = 1e-8
if self.temperature == 0:
self.temperature = 1e-8
client_params = {
"api_key": values["groq_api_key"].get_secret_value()
if values["groq_api_key"]
client_params: Dict[str, Any] = {
"api_key": self.groq_api_key.get_secret_value()
if self.groq_api_key
else None,
"base_url": values["groq_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"base_url": self.groq_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
try:
import groq
sync_specific = {"http_client": values["http_client"]}
if not values.get("client"):
values["client"] = groq.Groq(
sync_specific: Dict[str, Any] = {"http_client": self.http_client}
if not self.client:
self.client = groq.Groq(
**client_params, **sync_specific
).chat.completions
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = groq.AsyncGroq(
if not self.async_client:
async_specific: Dict[str, Any] = {"http_client": self.http_async_client}
self.async_client = groq.AsyncGroq(
**client_params, **async_specific
).chat.completions
except ImportError:
@@ -413,7 +415,7 @@ class ChatGroq(BaseChatModel):
"Could not import groq python package. "
"Please install it with `pip install groq`."
)
return values
return self
#
# Serializable class method overrides
@@ -543,7 +545,7 @@ class ChatGroq(BaseChatModel):
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.dict()
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
@@ -617,7 +619,7 @@ class ChatGroq(BaseChatModel):
messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
@@ -662,7 +664,7 @@ class ChatGroq(BaseChatModel):
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.dict()
response = response.model_dump()
token_usage = response.get("usage", {})
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
@@ -905,7 +907,7 @@ class ChatGroq(BaseChatModel):
from typing import Optional
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class AnswerWithJustification(BaseModel):
@@ -936,7 +938,7 @@ class ChatGroq(BaseChatModel):
.. code-block:: python
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
@@ -1023,7 +1025,7 @@ class ChatGroq(BaseChatModel):
.. code-block::
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
answer: str

View File

@@ -11,9 +11,6 @@ files = [
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
[[package]]
name = "anyio"
version = "4.4.0"
@@ -323,10 +320,10 @@ files = [
[[package]]
name = "langchain-core"
version = "0.2.26"
version = "0.2.38"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
python-versions = ">=3.9,<4.0"
files = []
develop = true
@@ -941,5 +938,5 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "1acd2a13007daf276263c7036959c1a4c973d4fd8dba924105c8fdc035e02a88"
python-versions = ">=3.9,<4.0"
content-hash = "b38cdb706f0dbe01f2f5d9b263904695fd468476459360a0cfbde42c940ea1cb"

View File

@@ -19,9 +19,10 @@ disallow_untyped_defs = "True"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-groq%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
python = ">=3.9,<4.0"
langchain-core = "^0.2.26"
groq = ">=0.4.1,<1"
pydantic = ">=2,<3"
[tool.ruff.lint]
select = [ "E", "F", "I", "W",]

View File

@@ -1,27 +0,0 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@@ -13,8 +13,8 @@ from langchain_core.messages import (
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from langchain_groq import ChatGroq
from tests.unit_tests.fake.callbacks import (

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class BaseFakeCallbackHandler(BaseModel):
@@ -256,7 +256,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
return self
@@ -390,5 +391,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_text_common()
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
return self