mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 19:57:51 +00:00
community: Ollama - Parameter structure to follow official documentation (#16035)
## Feature - Follow parameter structure as per official documentation - top level parameters (e.g. model, system, template) will be passed as top level parameters - other parameters will be sent in options unless options is provided  ## Tests - Test if top level parameters handled properly - Test if parameters that are not top level parameters are handled as options - Test if options is provided, it will be passed as is
This commit is contained in:
parent
60d6a416e6
commit
86321a949f
@ -190,8 +190,9 @@ class _OllamaCommon(BaseLanguageModel):
|
|||||||
|
|
||||||
params = self._default_params
|
params = self._default_params
|
||||||
|
|
||||||
if "model" in kwargs:
|
for key in self._default_params:
|
||||||
params["model"] = kwargs["model"]
|
if key in kwargs:
|
||||||
|
params[key] = kwargs[key]
|
||||||
|
|
||||||
if "options" in kwargs:
|
if "options" in kwargs:
|
||||||
params["options"] = kwargs["options"]
|
params["options"] = kwargs["options"]
|
||||||
@ -199,7 +200,7 @@ class _OllamaCommon(BaseLanguageModel):
|
|||||||
params["options"] = {
|
params["options"] = {
|
||||||
**params["options"],
|
**params["options"],
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
**kwargs,
|
**{k: v for k, v in kwargs.items() if k not in self._default_params},
|
||||||
}
|
}
|
||||||
|
|
||||||
if payload.get("messages"):
|
if payload.get("messages"):
|
||||||
@ -253,8 +254,9 @@ class _OllamaCommon(BaseLanguageModel):
|
|||||||
|
|
||||||
params = self._default_params
|
params = self._default_params
|
||||||
|
|
||||||
if "model" in kwargs:
|
for key in self._default_params:
|
||||||
params["model"] = kwargs["model"]
|
if key in kwargs:
|
||||||
|
params[key] = kwargs[key]
|
||||||
|
|
||||||
if "options" in kwargs:
|
if "options" in kwargs:
|
||||||
params["options"] = kwargs["options"]
|
params["options"] = kwargs["options"]
|
||||||
@ -262,7 +264,7 @@ class _OllamaCommon(BaseLanguageModel):
|
|||||||
params["options"] = {
|
params["options"] = {
|
||||||
**params["options"],
|
**params["options"],
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
**kwargs,
|
**{k: v for k, v in kwargs.items() if k not in self._default_params},
|
||||||
}
|
}
|
||||||
|
|
||||||
if payload.get("messages"):
|
if payload.get("messages"):
|
||||||
|
@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|||||||
timeout=300,
|
timeout=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
def mockPost(url, headers, json, stream, timeout):
|
def mock_post(url, headers, json, stream, timeout):
|
||||||
assert url == "https://ollama-hostname:8000/api/generate/"
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
assert headers == {
|
assert headers == {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@ -44,7 +44,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|||||||
|
|
||||||
return mock_response_stream()
|
return mock_response_stream()
|
||||||
|
|
||||||
monkeypatch.setattr(requests, "post", mockPost)
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
|
||||||
llm("Test prompt")
|
llm("Test prompt")
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|||||||
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
|
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
|
||||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||||
|
|
||||||
def mockPost(url, headers, json, stream, timeout):
|
def mock_post(url, headers, json, stream, timeout):
|
||||||
assert url == "https://ollama-hostname:8000/api/generate/"
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
assert headers == {
|
assert headers == {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@ -63,6 +63,131 @@ def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
|
|||||||
|
|
||||||
return mock_response_stream()
|
return mock_response_stream()
|
||||||
|
|
||||||
monkeypatch.setattr(requests, "post", mockPost)
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
|
||||||
llm("Test prompt")
|
llm("Test prompt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
"""Test that top level params are sent to the endpoint as top level params"""
|
||||||
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||||
|
|
||||||
|
def mock_post(url, headers, json, stream, timeout):
|
||||||
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
|
assert headers == {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
assert json == {
|
||||||
|
"format": None,
|
||||||
|
"images": None,
|
||||||
|
"model": "test-model",
|
||||||
|
"options": {
|
||||||
|
"mirostat": None,
|
||||||
|
"mirostat_eta": None,
|
||||||
|
"mirostat_tau": None,
|
||||||
|
"num_ctx": None,
|
||||||
|
"num_gpu": None,
|
||||||
|
"num_thread": None,
|
||||||
|
"repeat_last_n": None,
|
||||||
|
"repeat_penalty": None,
|
||||||
|
"stop": [],
|
||||||
|
"temperature": None,
|
||||||
|
"tfs_z": None,
|
||||||
|
"top_k": None,
|
||||||
|
"top_p": None,
|
||||||
|
},
|
||||||
|
"prompt": "Test prompt",
|
||||||
|
"system": "Test system prompt",
|
||||||
|
"template": None,
|
||||||
|
}
|
||||||
|
assert stream is True
|
||||||
|
assert timeout == 300
|
||||||
|
|
||||||
|
return mock_response_stream()
|
||||||
|
|
||||||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
|
||||||
|
llm("Test prompt", model="test-model", system="Test system prompt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
"""
|
||||||
|
Test that params that are not top level params will be sent to the endpoint
|
||||||
|
as options
|
||||||
|
"""
|
||||||
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||||
|
|
||||||
|
def mock_post(url, headers, json, stream, timeout):
|
||||||
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
|
assert headers == {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
assert json == {
|
||||||
|
"format": None,
|
||||||
|
"images": None,
|
||||||
|
"model": "foo",
|
||||||
|
"options": {
|
||||||
|
"mirostat": None,
|
||||||
|
"mirostat_eta": None,
|
||||||
|
"mirostat_tau": None,
|
||||||
|
"num_ctx": None,
|
||||||
|
"num_gpu": None,
|
||||||
|
"num_thread": None,
|
||||||
|
"repeat_last_n": None,
|
||||||
|
"repeat_penalty": None,
|
||||||
|
"stop": [],
|
||||||
|
"temperature": 0.8,
|
||||||
|
"tfs_z": None,
|
||||||
|
"top_k": None,
|
||||||
|
"top_p": None,
|
||||||
|
"unknown": "Unknown parameter value",
|
||||||
|
},
|
||||||
|
"prompt": "Test prompt",
|
||||||
|
"system": None,
|
||||||
|
"template": None,
|
||||||
|
}
|
||||||
|
assert stream is True
|
||||||
|
assert timeout == 300
|
||||||
|
|
||||||
|
return mock_response_stream()
|
||||||
|
|
||||||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
|
||||||
|
llm("Test prompt", unknown="Unknown parameter value", temperature=0.8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
"""
|
||||||
|
Test that if options provided it will be sent to the endpoint as options,
|
||||||
|
ignoring other params that are not top level params.
|
||||||
|
"""
|
||||||
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||||
|
|
||||||
|
def mock_post(url, headers, json, stream, timeout):
|
||||||
|
assert url == "https://ollama-hostname:8000/api/generate/"
|
||||||
|
assert headers == {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
assert json == {
|
||||||
|
"format": None,
|
||||||
|
"images": None,
|
||||||
|
"model": "test-another-model",
|
||||||
|
"options": {"unknown_option": "Unknown option value"},
|
||||||
|
"prompt": "Test prompt",
|
||||||
|
"system": None,
|
||||||
|
"template": None,
|
||||||
|
}
|
||||||
|
assert stream is True
|
||||||
|
assert timeout == 300
|
||||||
|
|
||||||
|
return mock_response_stream()
|
||||||
|
|
||||||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
|
||||||
|
llm(
|
||||||
|
"Test prompt",
|
||||||
|
model="test-another-model",
|
||||||
|
options={"unknown_option": "Unknown option value"},
|
||||||
|
unknown="Unknown parameter value",
|
||||||
|
temperature=0.8,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user