diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 86db80e8e22..8820897af1c 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -126,6 +126,9 @@ class ChatGroq(BaseChatModel): # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: Union[Any, None] = None """Optional httpx.Client.""" + http_async_client: Union[Any, None] = None + """Optional httpx.AsyncClient. Only used for async invocations. Must specify + http_client as well if you'd like a custom client for sync invocations.""" class Config: """Configuration for this pydantic object.""" @@ -182,17 +185,20 @@ class ChatGroq(BaseChatModel): "max_retries": values["max_retries"], "default_headers": values["default_headers"], "default_query": values["default_query"], - "http_client": values["http_client"], } try: import groq + sync_specific = {"http_client": values["http_client"]} if not values.get("client"): - values["client"] = groq.Groq(**client_params).chat.completions + values["client"] = groq.Groq( + **client_params, **sync_specific + ).chat.completions if not values.get("async_client"): + async_specific = {"http_client": values["http_async_client"]} values["async_client"] = groq.AsyncGroq( - **client_params + **client_params, **async_specific ).chat.completions except ImportError: raise ImportError(