cli[patch]: integration template nits (#14691)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur 2024-02-09 17:59:34 -08:00 committed by GitHub
parent 99540d3d75
commit 10c10f2dea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 51 additions and 56 deletions

View File

@ -5,16 +5,11 @@ all: help
# Define a variable for the test file path. # Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/ TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE = tests/integration_tests/
integration_tests: TEST_FILE = tests/integration_tests/ test tests integration_test integration_tests:
test integration_tests:
poetry run pytest $(TEST_FILE) poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
###################### ######################
# LINTING AND FORMATTING # LINTING AND FORMATTING
###################### ######################
@ -32,7 +27,7 @@ lint lint_diff lint_package lint_tests:
poetry run ruff . poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES) poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff: format format_diff:
poetry run ruff format $(PYTHON_FILES) poetry run ruff format $(PYTHON_FILES)

View File

@ -1,3 +1,4 @@
"""__ModuleName__ chat models."""
from typing import Any, AsyncIterator, Iterator, List, Optional from typing import Any, AsyncIterator, Iterator, List, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -5,7 +6,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGenerationChunk, ChatResult
@ -15,40 +16,19 @@ class Chat__ModuleName__(BaseChatModel):
Example: Example:
.. code-block:: python .. code-block:: python
from langchain_core.messages import HumanMessage
from __module_name__ import Chat__ModuleName__ from __module_name__ import Chat__ModuleName__
model = Chat__ModuleName__() model = Chat__ModuleName__()
""" model.invoke([HumanMessage(content="Come up with 10 names for a song about parrots.")])
""" # noqa: E501
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of chat model.""" """Return type of chat model."""
return "chat-__package_name_short__" return "chat-__package_name_short__"
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
yield ChatGenerationChunk(
message=BaseMessageChunk(content="Yield chunks", type="ai"),
)
yield ChatGenerationChunk(
message=BaseMessageChunk(content=" like this!", type="ai"),
)
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -58,6 +38,29 @@ class Chat__ModuleName__(BaseChatModel):
) -> ChatResult: ) -> ChatResult:
raise NotImplementedError raise NotImplementedError
# TODO: Implement if __model_name__ supports streaming. Otherwise delete method.
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError
# TODO: Implement if __model_name__ supports async streaming. Otherwise delete
# method.
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError
# TODO: Implement if __model_name__ supports async generation. Otherwise delete
# method.
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],

View File

@ -1,5 +1,4 @@
import asyncio """__ModuleName__ large language models."""
from functools import partial
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -25,6 +24,7 @@ class __ModuleName__LLM(BaseLLM):
from __module_name__ import __ModuleName__LLM from __module_name__ import __ModuleName__LLM
model = __ModuleName__LLM() model = __ModuleName__LLM()
model.invoke("Come up with 10 names for a song about parrots")
""" """
@property @property
@ -41,6 +41,8 @@ class __ModuleName__LLM(BaseLLM):
) -> LLMResult: ) -> LLMResult:
raise NotImplementedError raise NotImplementedError
# TODO: Implement if __model_name__ supports async generation. Otherwise
# delete method.
async def _agenerate( async def _agenerate(
self, self,
prompts: List[str], prompts: List[str],
@ -48,11 +50,9 @@ class __ModuleName__LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
# Change implementation if integration natively supports async generation. raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager
)
# TODO: Implement if __model_name__ supports streaming. Otherwise delete method.
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
@ -62,6 +62,8 @@ class __ModuleName__LLM(BaseLLM):
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
raise NotImplementedError raise NotImplementedError
# TODO: Implement if __model_name__ supports async streaming. Otherwise delete
# method.
async def _astream( async def _astream(
self, self,
prompt: str, prompt: str,
@ -69,5 +71,4 @@ class __ModuleName__LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[GenerationChunk]: ) -> AsyncIterator[GenerationChunk]:
yield GenerationChunk(text="Yield chunks") raise NotImplementedError
yield GenerationChunk(text=" like this!")

View File

@ -1,3 +1,4 @@
"""__ModuleName__ vector stores."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -24,7 +25,7 @@ VST = TypeVar("VST", bound=VectorStore)
class __ModuleName__VectorStore(VectorStore): class __ModuleName__VectorStore(VectorStore):
"""Interface for vector store. """__ModuleName__ vector store.
Example: Example:
.. code-block:: python .. code-block:: python

View File

@ -12,25 +12,21 @@ license = "MIT"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
langchain-core = ">=0.0.12" langchain-core = "^0.1"
[tool.poetry.group.test] [tool.poetry.group.test]
optional = true optional = true
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
pytest = "^7.3.0" pytest = "^7.4.3"
freezegun = "^1.2.2" pytest-asyncio = "^0.23.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = {path = "../../core", develop = true} langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.codespell] [tool.poetry.group.codespell]
optional = true optional = true
[tool.poetry.group.codespell.dependencies] [tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0" codespell = "^2.2.6"
[tool.poetry.group.test_integration] [tool.poetry.group.test_integration]
optional = true optional = true
@ -41,10 +37,10 @@ optional = true
optional = true optional = true
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
ruff = "^0.1.5" ruff = "^0.1.8"
[tool.poetry.group.typing.dependencies] [tool.poetry.group.typing.dependencies]
mypy = "^0.991" mypy = "^1.7.1"
langchain-core = {path = "../../core", develop = true} langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.dev] [tool.poetry.group.dev]
@ -87,8 +83,6 @@ addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5
# Registering custom markers. # Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [ markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them", "compile: mark placeholder test used to compile integration tests without running them",
] ]
asyncio_mode = "auto" asyncio_mode = "auto"

View File

@ -5,9 +5,10 @@ set -eu
# Initialize a variable to keep track of errors # Initialize a variable to keep track of errors
errors=0 errors=0
# make sure not importing from langchain or langchain_experimental # make sure not importing from langchain, langchain_experimental, or langchain_community
git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors # Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then if [ "$errors" -gt 0 ]; then