From 51c6899850df421bb435cafcfcf90f13c67fa69c Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 4 Sep 2024 13:41:40 -0400 Subject: [PATCH] groq[major]: upgrade pydantic (#26036) --- .../groq/langchain_groq/chat_models.py | 84 ++++++++++--------- libs/partners/groq/poetry.lock | 11 +-- libs/partners/groq/pyproject.toml | 3 +- libs/partners/groq/scripts/check_pydantic.sh | 27 ------ .../integration_tests/test_chat_models.py | 2 +- .../groq/tests/unit_tests/fake/callbacks.py | 8 +- 6 files changed, 55 insertions(+), 80 deletions(-) delete mode 100755 libs/partners/groq/scripts/check_pydantic.sh diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 23aeae9f11b..05dc4724570 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -65,12 +65,6 @@ from langchain_core.output_parsers.openai_tools import ( parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import ( - BaseModel, - Field, - SecretStr, - root_validator, -) from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import ( @@ -83,6 +77,14 @@ from langchain_core.utils.function_calling import ( convert_to_openai_tool, ) from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic import ( + BaseModel, + ConfigDict, + Field, + SecretStr, + model_validator, +) +from typing_extensions import Self class ChatGroq(BaseChatModel): @@ -225,7 +227,7 @@ class ChatGroq(BaseChatModel): Tool calling: .. code-block:: python - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -256,7 +258,7 @@ class ChatGroq(BaseChatModel): from typing import Optional - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class Joke(BaseModel): '''Joke to tell user.''' @@ -343,13 +345,13 @@ class ChatGroq(BaseChatModel): """Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations.""" - class Config: - """Configuration for this pydantic object.""" + model_config = ConfigDict( + populate_by_name=True, + ) - allow_population_by_field_name = True - - @root_validator(pre=True) - def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) @@ -374,38 +376,38 @@ class ChatGroq(BaseChatModel): values["model_kwargs"] = extra return values - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if values["n"] < 1: + if self.n < 1: raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: + if self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") - if values["temperature"] == 0: - values["temperature"] = 1e-8 + if self.temperature == 0: + self.temperature = 1e-8 - client_params = { - "api_key": values["groq_api_key"].get_secret_value() - if values["groq_api_key"] + client_params: Dict[str, Any] = { + "api_key": self.groq_api_key.get_secret_value() + if self.groq_api_key else None, - "base_url": values["groq_api_base"], - "timeout": values["request_timeout"], - "max_retries": values["max_retries"], - "default_headers": values["default_headers"], - "default_query": values["default_query"], + "base_url": self.groq_api_base, + "timeout": self.request_timeout, + "max_retries": self.max_retries, + "default_headers": self.default_headers, + "default_query": self.default_query, } try: import groq - sync_specific = {"http_client": values["http_client"]} - if not values.get("client"): - values["client"] = groq.Groq( + sync_specific: Dict[str, Any] = {"http_client": self.http_client} + if not self.client: + self.client = groq.Groq( **client_params, **sync_specific ).chat.completions - if not values.get("async_client"): - async_specific = {"http_client": values["http_async_client"]} - values["async_client"] = groq.AsyncGroq( + if not self.async_client: + async_specific: Dict[str, Any] = {"http_client": self.http_async_client} + self.async_client = groq.AsyncGroq( **client_params, **async_specific ).chat.completions except ImportError: @@ -413,7 +415,7 @@ class ChatGroq(BaseChatModel): "Could not import groq python package. " "Please install it with `pip install groq`." ) - return values + return self # # Serializable class method overrides @@ -543,7 +545,7 @@ class ChatGroq(BaseChatModel): default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk for chunk in self.client.create(messages=message_dicts, **params): if not isinstance(chunk, dict): - chunk = chunk.dict() + chunk = chunk.model_dump() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] @@ -617,7 +619,7 @@ class ChatGroq(BaseChatModel): messages=message_dicts, **params ): if not isinstance(chunk, dict): - chunk = chunk.dict() + chunk = chunk.model_dump() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] @@ -662,7 +664,7 @@ class ChatGroq(BaseChatModel): def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: generations = [] if not isinstance(response, dict): - response = response.dict() + response = response.model_dump() token_usage = response.get("usage", {}) for res in response["choices"]: message = _convert_dict_to_message(res["message"]) @@ -905,7 +907,7 @@ class ChatGroq(BaseChatModel): from typing import Optional from langchain_groq import ChatGroq - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class AnswerWithJustification(BaseModel): @@ -936,7 +938,7 @@ class ChatGroq(BaseChatModel): .. code-block:: python from langchain_groq import ChatGroq - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel class AnswerWithJustification(BaseModel): @@ -1023,7 +1025,7 @@ class ChatGroq(BaseChatModel): .. code-block:: from langchain_groq import ChatGroq - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel class AnswerWithJustification(BaseModel): answer: str diff --git a/libs/partners/groq/poetry.lock b/libs/partners/groq/poetry.lock index ffd832ac794..81ccce88856 100644 --- a/libs/partners/groq/poetry.lock +++ b/libs/partners/groq/poetry.lock @@ -11,9 +11,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} - [[package]] name = "anyio" version = "4.4.0" @@ -323,10 +320,10 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.26" +version = "0.2.38" description = "Building applications with LLMs through composability" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = ">=3.9,<4.0" files = [] develop = true @@ -941,5 +938,5 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" -python-versions = ">=3.8.1,<4.0" -content-hash = "1acd2a13007daf276263c7036959c1a4c973d4fd8dba924105c8fdc035e02a88" +python-versions = ">=3.9,<4.0" +content-hash = "b38cdb706f0dbe01f2f5d9b263904695fd468476459360a0cfbde42c940ea1cb" diff --git a/libs/partners/groq/pyproject.toml b/libs/partners/groq/pyproject.toml index e7d262a3581..55fc74006d3 100644 --- a/libs/partners/groq/pyproject.toml +++ b/libs/partners/groq/pyproject.toml @@ -19,9 +19,10 @@ disallow_untyped_defs = "True" "Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-groq%3D%3D0%22&expanded=true" [tool.poetry.dependencies] -python = ">=3.8.1,<4.0" +python = ">=3.9,<4.0" langchain-core = "^0.2.26" groq = ">=0.4.1,<1" +pydantic = ">=2,<3" [tool.ruff.lint] select = [ "E", "F", "I", "W",] diff --git a/libs/partners/groq/scripts/check_pydantic.sh b/libs/partners/groq/scripts/check_pydantic.sh deleted file mode 100755 index 06b5bb81ae2..00000000000 --- a/libs/partners/groq/scripts/check_pydantic.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# -# This script searches for lines starting with "import pydantic" or "from pydantic" -# in tracked files within a Git repository. -# -# Usage: ./scripts/check_pydantic.sh /path/to/repository - -# Check if a path argument is provided -if [ $# -ne 1 ]; then - echo "Usage: $0 /path/to/repository" - exit 1 -fi - -repository_path="$1" - -# Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') - -# Check if any matching lines were found -if [ -n "$result" ]; then - echo "ERROR: The following lines need to be updated:" - echo "$result" - echo "Please replace the code with an import from langchain_core.pydantic_v1." - echo "For example, replace 'from pydantic import BaseModel'" - echo "with 'from langchain_core.pydantic_v1 import BaseModel'" - exit 1 -fi diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index 2e5a9620b22..1b2cf05b542 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -13,8 +13,8 @@ from langchain_core.messages import ( SystemMessage, ) from langchain_core.outputs import ChatGeneration, LLMResult -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool +from pydantic import BaseModel, Field from langchain_groq import ChatGroq from tests.unit_tests.fake.callbacks import ( diff --git a/libs/partners/groq/tests/unit_tests/fake/callbacks.py b/libs/partners/groq/tests/unit_tests/fake/callbacks.py index 71a6dea0cef..34f825e16a5 100644 --- a/libs/partners/groq/tests/unit_tests/fake/callbacks.py +++ b/libs/partners/groq/tests/unit_tests/fake/callbacks.py @@ -6,7 +6,7 @@ from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.messages import BaseMessage -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel class BaseFakeCallbackHandler(BaseModel): @@ -256,7 +256,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_retriever_error_common() - def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + # Overriding since BaseModel has __deepcopy__ method as well + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore return self @@ -390,5 +391,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_text_common() - def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": + # Overriding since BaseModel has __deepcopy__ method as well + def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore return self