langchain: add deepseek provider to init chat model (#29449)

This commit is contained in:
Erick Friis 2025-01-27 23:13:59 -08:00 committed by GitHub
parent dced0ed3fd
commit ecdc881328
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 6 deletions

View File

@ -217,7 +217,8 @@ jobs:
# Replace all dashes in the package name with underscores, # Replace all dashes in the package name with underscores,
# since that's how Python imports packages with dashes in the name. # since that's how Python imports packages with dashes in the name.
IMPORT_NAME="$(echo "$PKG_NAME" | sed s/-/_/g)" # also remove _official suffix
IMPORT_NAME="$(echo "$PKG_NAME" | sed s/-/_/g | sed s/_official//g)"
poetry run python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))" poetry run python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))"

View File

@ -416,6 +416,11 @@ def _init_chat_model_helper(
from langchain_google_vertexai.model_garden import ChatAnthropicVertex from langchain_google_vertexai.model_garden import ChatAnthropicVertex
return ChatAnthropicVertex(model=model, **kwargs) return ChatAnthropicVertex(model=model, **kwargs)
elif model_provider == "deepseek":
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek-official")
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(model=model, **kwargs)
else: else:
supported = ", ".join(_SUPPORTED_PROVIDERS) supported = ", ".join(_SUPPORTED_PROVIDERS)
raise ValueError( raise ValueError(
@ -440,6 +445,7 @@ _SUPPORTED_PROVIDERS = {
"bedrock", "bedrock",
"bedrock_converse", "bedrock_converse",
"google_anthropic_vertex", "google_anthropic_vertex",
"deepseek",
} }
@ -480,12 +486,11 @@ def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]:
return model, model_provider return model, model_provider
def _check_pkg(pkg: str) -> None: def _check_pkg(pkg: str, *, pkg_kebab: Optional[str] = None) -> None:
if not util.find_spec(pkg): if not util.find_spec(pkg):
pkg_kebab = pkg.replace("_", "-") pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
raise ImportError( raise ImportError(
f"Unable to import {pkg_kebab}. Please install with " f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
f"`pip install -U {pkg_kebab}`"
) )

View File

@ -1073,6 +1073,20 @@ files = [
[package.dependencies] [package.dependencies]
pytest = ">=6.2.5" pytest = ">=6.2.5"
[[package]]
name = "pytest-timeout"
version = "2.3.1"
description = "pytest plugin to abort hanging tests"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-timeout-2.3.1.tar.gz", hash = "sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9"},
{file = "pytest_timeout-2.3.1-py3-none-any.whl", hash = "sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e"},
]
[package.dependencies]
pytest = ">=7.0.0"
[[package]] [[package]]
name = "pytest-watcher" name = "pytest-watcher"
version = "0.3.5" version = "0.3.5"
@ -1649,4 +1663,4 @@ cffi = ["cffi (>=1.11)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "7c03a9a824974f7e2cd363082dc2f1163488749db15f73e7b6326508c07633ae" content-hash = "9c889754b3cb5e1044c6f7eeab3a5e4857179762c18a6d378fbb1ba72ab5b08b"

View File

@ -61,6 +61,7 @@ pytest-socket = "^0.7.0"
pytest-watcher = "^0.3.4" pytest-watcher = "^0.3.4"
langchain-tests = "^0.3.5" langchain-tests = "^0.3.5"
langchain-openai = { path = "../openai" } langchain-openai = { path = "../openai" }
pytest-timeout = "^2.3.1"
[tool.poetry.group.codespell.dependencies] [tool.poetry.group.codespell.dependencies]
codespell = "^2.2.6" codespell = "^2.2.6"

View File

@ -2,6 +2,7 @@
from typing import Type from typing import Type
import pytest
from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_tests.integration_tests import ChatModelIntegrationTests
from langchain_deepseek.chat_models import ChatDeepSeek from langchain_deepseek.chat_models import ChatDeepSeek
@ -21,6 +22,7 @@ class TestChatDeepSeek(ChatModelIntegrationTests):
} }
@pytest.mark.xfail(reason="Reasoning API is down")
def test_reasoning_content() -> None: def test_reasoning_content() -> None:
"""Test reasoning content.""" """Test reasoning content."""
chat_model = ChatDeepSeek(model="deepseek-reasoner") chat_model = ChatDeepSeek(model="deepseek-reasoner")