mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 22:11:51 +00:00
core, partners: implement standard tracing params for LLMs (#25410)
This commit is contained in:
@@ -16,7 +16,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.language_models import BaseLLM, LangSmithParams
|
||||
from langchain_core.outputs import GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from ollama import AsyncClient, Client, Options
|
||||
@@ -155,6 +155,15 @@ class OllamaLLM(BaseLLM):
|
||||
"""Return type of LLM."""
|
||||
return "ollama-llm"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
if max_tokens := kwargs.get("num_predict", self.num_predict):
|
||||
params["ls_max_tokens"] = max_tokens
|
||||
return params
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def _set_clients(cls, values: dict) -> dict:
|
||||
"""Set clients to use for ollama."""
|
||||
|
Reference in New Issue
Block a user