mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 23:41:46 +00:00
community: add standard chat model params to Ollama (#22446)
This commit is contained in:
@@ -6,7 +6,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@@ -69,6 +69,23 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return False
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="ollama",
|
||||
ls_model_name=self.model,
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_max_tokens := params.get("num_predict", self.num_predict):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None) or self.stop:
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
@deprecated("0.0.3", alternative="_convert_messages_to_ollama_messages")
|
||||
def _format_message_as_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
|
35
libs/community/tests/unit_tests/chat_models/test_ollama.py
Normal file
35
libs/community/tests/unit_tests/chat_models/test_ollama.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
||||
|
||||
def test_standard_params() -> None:
|
||||
class ExpectedParams(BaseModel):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
ls_model_type: Literal["chat"]
|
||||
ls_temperature: Optional[float]
|
||||
ls_max_tokens: Optional[int]
|
||||
ls_stop: Optional[List[str]]
|
||||
|
||||
model = ChatOllama(model="llama3")
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
assert ls_params["ls_model_name"] == "llama3"
|
||||
|
||||
# Test optional params
|
||||
model = ChatOllama(num_predict=10, stop=["test"], temperature=0.33)
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
assert ls_params["ls_max_tokens"] == 10
|
||||
assert ls_params["ls_stop"] == ["test"]
|
||||
assert ls_params["ls_temperature"] == 0.33
|
Reference in New Issue
Block a user