mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-28 15:00:23 +00:00
groq[major]: upgrade pydantic (#26036)
This commit is contained in:
@@ -65,12 +65,6 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
parse_tool_call,
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import (
|
|
||||||
BaseModel,
|
|
||||||
Field,
|
|
||||||
SecretStr,
|
|
||||||
root_validator,
|
|
||||||
)
|
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
@@ -83,6 +77,14 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
class ChatGroq(BaseChatModel):
|
class ChatGroq(BaseChatModel):
|
||||||
@@ -225,7 +227,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
Tool calling:
|
Tool calling:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
class GetWeather(BaseModel):
|
class GetWeather(BaseModel):
|
||||||
'''Get the current weather in a given location'''
|
'''Get the current weather in a given location'''
|
||||||
@@ -256,7 +258,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
class Joke(BaseModel):
|
class Joke(BaseModel):
|
||||||
'''Joke to tell user.'''
|
'''Joke to tell user.'''
|
||||||
@@ -343,13 +345,13 @@ class ChatGroq(BaseChatModel):
|
|||||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||||
http_client as well if you'd like a custom client for sync invocations."""
|
http_client as well if you'd like a custom client for sync invocations."""
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
"""Configuration for this pydantic object."""
|
populate_by_name=True,
|
||||||
|
)
|
||||||
|
|
||||||
allow_population_by_field_name = True
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
@root_validator(pre=True)
|
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = get_pydantic_field_names(cls)
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
@@ -374,38 +376,38 @@ class ChatGroq(BaseChatModel):
|
|||||||
values["model_kwargs"] = extra
|
values["model_kwargs"] = extra
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator(pre=False, skip_on_failure=True)
|
@model_validator(mode="after")
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
if values["n"] < 1:
|
if self.n < 1:
|
||||||
raise ValueError("n must be at least 1.")
|
raise ValueError("n must be at least 1.")
|
||||||
if values["n"] > 1 and values["streaming"]:
|
if self.n > 1 and self.streaming:
|
||||||
raise ValueError("n must be 1 when streaming.")
|
raise ValueError("n must be 1 when streaming.")
|
||||||
if values["temperature"] == 0:
|
if self.temperature == 0:
|
||||||
values["temperature"] = 1e-8
|
self.temperature = 1e-8
|
||||||
|
|
||||||
client_params = {
|
client_params: Dict[str, Any] = {
|
||||||
"api_key": values["groq_api_key"].get_secret_value()
|
"api_key": self.groq_api_key.get_secret_value()
|
||||||
if values["groq_api_key"]
|
if self.groq_api_key
|
||||||
else None,
|
else None,
|
||||||
"base_url": values["groq_api_base"],
|
"base_url": self.groq_api_base,
|
||||||
"timeout": values["request_timeout"],
|
"timeout": self.request_timeout,
|
||||||
"max_retries": values["max_retries"],
|
"max_retries": self.max_retries,
|
||||||
"default_headers": values["default_headers"],
|
"default_headers": self.default_headers,
|
||||||
"default_query": values["default_query"],
|
"default_query": self.default_query,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import groq
|
import groq
|
||||||
|
|
||||||
sync_specific = {"http_client": values["http_client"]}
|
sync_specific: Dict[str, Any] = {"http_client": self.http_client}
|
||||||
if not values.get("client"):
|
if not self.client:
|
||||||
values["client"] = groq.Groq(
|
self.client = groq.Groq(
|
||||||
**client_params, **sync_specific
|
**client_params, **sync_specific
|
||||||
).chat.completions
|
).chat.completions
|
||||||
if not values.get("async_client"):
|
if not self.async_client:
|
||||||
async_specific = {"http_client": values["http_async_client"]}
|
async_specific: Dict[str, Any] = {"http_client": self.http_async_client}
|
||||||
values["async_client"] = groq.AsyncGroq(
|
self.async_client = groq.AsyncGroq(
|
||||||
**client_params, **async_specific
|
**client_params, **async_specific
|
||||||
).chat.completions
|
).chat.completions
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -413,7 +415,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
"Could not import groq python package. "
|
"Could not import groq python package. "
|
||||||
"Please install it with `pip install groq`."
|
"Please install it with `pip install groq`."
|
||||||
)
|
)
|
||||||
return values
|
return self
|
||||||
|
|
||||||
#
|
#
|
||||||
# Serializable class method overrides
|
# Serializable class method overrides
|
||||||
@@ -543,7 +545,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
for chunk in self.client.create(messages=message_dicts, **params):
|
for chunk in self.client.create(messages=message_dicts, **params):
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.dict()
|
chunk = chunk.model_dump()
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
@@ -617,7 +619,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
messages=message_dicts, **params
|
messages=message_dicts, **params
|
||||||
):
|
):
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.dict()
|
chunk = chunk.model_dump()
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
@@ -662,7 +664,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
||||||
generations = []
|
generations = []
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.dict()
|
response = response.model_dump()
|
||||||
token_usage = response.get("usage", {})
|
token_usage = response.get("usage", {})
|
||||||
for res in response["choices"]:
|
for res in response["choices"]:
|
||||||
message = _convert_dict_to_message(res["message"])
|
message = _convert_dict_to_message(res["message"])
|
||||||
@@ -905,7 +907,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class AnswerWithJustification(BaseModel):
|
class AnswerWithJustification(BaseModel):
|
||||||
@@ -936,7 +938,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class AnswerWithJustification(BaseModel):
|
class AnswerWithJustification(BaseModel):
|
||||||
@@ -1023,7 +1025,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class AnswerWithJustification(BaseModel):
|
class AnswerWithJustification(BaseModel):
|
||||||
answer: str
|
answer: str
|
||||||
|
11
libs/partners/groq/poetry.lock
generated
11
libs/partners/groq/poetry.lock
generated
@@ -11,9 +11,6 @@ files = [
|
|||||||
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyio"
|
name = "anyio"
|
||||||
version = "4.4.0"
|
version = "4.4.0"
|
||||||
@@ -323,10 +320,10 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.2.26"
|
version = "0.2.38"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
files = []
|
files = []
|
||||||
develop = true
|
develop = true
|
||||||
|
|
||||||
@@ -941,5 +938,5 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
content-hash = "1acd2a13007daf276263c7036959c1a4c973d4fd8dba924105c8fdc035e02a88"
|
content-hash = "b38cdb706f0dbe01f2f5d9b263904695fd468476459360a0cfbde42c940ea1cb"
|
||||||
|
@@ -19,9 +19,10 @@ disallow_untyped_defs = "True"
|
|||||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-groq%3D%3D0%22&expanded=true"
|
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-groq%3D%3D0%22&expanded=true"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.9,<4.0"
|
||||||
langchain-core = "^0.2.26"
|
langchain-core = "^0.2.26"
|
||||||
groq = ">=0.4.1,<1"
|
groq = ">=0.4.1,<1"
|
||||||
|
pydantic = ">=2,<3"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [ "E", "F", "I", "W",]
|
select = [ "E", "F", "I", "W",]
|
||||||
|
@@ -1,27 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
|
||||||
# in tracked files within a Git repository.
|
|
||||||
#
|
|
||||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
|
||||||
|
|
||||||
# Check if a path argument is provided
|
|
||||||
if [ $# -ne 1 ]; then
|
|
||||||
echo "Usage: $0 /path/to/repository"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
repository_path="$1"
|
|
||||||
|
|
||||||
# Search for lines matching the pattern within the specified repository
|
|
||||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
|
||||||
|
|
||||||
# Check if any matching lines were found
|
|
||||||
if [ -n "$result" ]; then
|
|
||||||
echo "ERROR: The following lines need to be updated:"
|
|
||||||
echo "$result"
|
|
||||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
|
||||||
echo "For example, replace 'from pydantic import BaseModel'"
|
|
||||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
@@ -13,8 +13,8 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
from tests.unit_tests.fake.callbacks import (
|
from tests.unit_tests.fake.callbacks import (
|
||||||
|
@@ -6,7 +6,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class BaseFakeCallbackHandler(BaseModel):
|
class BaseFakeCallbackHandler(BaseModel):
|
||||||
@@ -256,7 +256,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retriever_error_common()
|
self.on_retriever_error_common()
|
||||||
|
|
||||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
# Overriding since BaseModel has __deepcopy__ method as well
|
||||||
|
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@@ -390,5 +391,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_text_common()
|
self.on_text_common()
|
||||||
|
|
||||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
|
# Overriding since BaseModel has __deepcopy__ method as well
|
||||||
|
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
|
||||||
return self
|
return self
|
||||||
|
Reference in New Issue
Block a user