mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 06:18:05 +00:00
groq[major]: upgrade pydantic (#26036)
This commit is contained in:
@@ -65,12 +65,6 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
@@ -83,6 +77,14 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatGroq(BaseChatModel):
|
||||
@@ -225,7 +227,7 @@ class ChatGroq(BaseChatModel):
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
'''Get the current weather in a given location'''
|
||||
@@ -256,7 +258,7 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Joke(BaseModel):
|
||||
'''Joke to tell user.'''
|
||||
@@ -343,13 +345,13 @@ class ChatGroq(BaseChatModel):
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@@ -374,38 +376,38 @@ class ChatGroq(BaseChatModel):
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
if self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
if self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
if values["temperature"] == 0:
|
||||
values["temperature"] = 1e-8
|
||||
if self.temperature == 0:
|
||||
self.temperature = 1e-8
|
||||
|
||||
client_params = {
|
||||
"api_key": values["groq_api_key"].get_secret_value()
|
||||
if values["groq_api_key"]
|
||||
client_params: Dict[str, Any] = {
|
||||
"api_key": self.groq_api_key.get_secret_value()
|
||||
if self.groq_api_key
|
||||
else None,
|
||||
"base_url": values["groq_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
"default_query": values["default_query"],
|
||||
"base_url": self.groq_api_base,
|
||||
"timeout": self.request_timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
|
||||
try:
|
||||
import groq
|
||||
|
||||
sync_specific = {"http_client": values["http_client"]}
|
||||
if not values.get("client"):
|
||||
values["client"] = groq.Groq(
|
||||
sync_specific: Dict[str, Any] = {"http_client": self.http_client}
|
||||
if not self.client:
|
||||
self.client = groq.Groq(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
if not values.get("async_client"):
|
||||
async_specific = {"http_client": values["http_async_client"]}
|
||||
values["async_client"] = groq.AsyncGroq(
|
||||
if not self.async_client:
|
||||
async_specific: Dict[str, Any] = {"http_client": self.http_async_client}
|
||||
self.async_client = groq.AsyncGroq(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
except ImportError:
|
||||
@@ -413,7 +415,7 @@ class ChatGroq(BaseChatModel):
|
||||
"Could not import groq python package. "
|
||||
"Please install it with `pip install groq`."
|
||||
)
|
||||
return values
|
||||
return self
|
||||
|
||||
#
|
||||
# Serializable class method overrides
|
||||
@@ -543,7 +545,7 @@ class ChatGroq(BaseChatModel):
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
@@ -617,7 +619,7 @@ class ChatGroq(BaseChatModel):
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
@@ -662,7 +664,7 @@ class ChatGroq(BaseChatModel):
|
||||
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
||||
generations = []
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
response = response.model_dump()
|
||||
token_usage = response.get("usage", {})
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
@@ -905,7 +907,7 @@ class ChatGroq(BaseChatModel):
|
||||
from typing import Optional
|
||||
|
||||
from langchain_groq import ChatGroq
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
@@ -936,7 +938,7 @@ class ChatGroq(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_groq import ChatGroq
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
@@ -1023,7 +1025,7 @@ class ChatGroq(BaseChatModel):
|
||||
.. code-block::
|
||||
|
||||
from langchain_groq import ChatGroq
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
answer: str
|
||||
|
11
libs/partners/groq/poetry.lock
generated
11
libs/partners/groq/poetry.lock
generated
@@ -11,9 +11,6 @@ files = [
|
||||
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.4.0"
|
||||
@@ -323,10 +320,10 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.26"
|
||||
version = "0.2.38"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
files = []
|
||||
develop = true
|
||||
|
||||
@@ -941,5 +938,5 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "1acd2a13007daf276263c7036959c1a4c973d4fd8dba924105c8fdc035e02a88"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "b38cdb706f0dbe01f2f5d9b263904695fd468476459360a0cfbde42c940ea1cb"
|
||||
|
@@ -19,9 +19,10 @@ disallow_untyped_defs = "True"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-groq%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
python = ">=3.9,<4.0"
|
||||
langchain-core = "^0.2.26"
|
||||
groq = ">=0.4.1,<1"
|
||||
pydantic = ">=2,<3"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "E", "F", "I", "W",]
|
||||
|
@@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@@ -13,8 +13,8 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_groq import ChatGroq
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
|
@@ -6,7 +6,7 @@ from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@@ -256,7 +256,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retriever_error_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
|
||||
return self
|
||||
|
||||
|
||||
@@ -390,5 +391,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
|
||||
return self
|
||||
|
Reference in New Issue
Block a user