core, partners: implement standard tracing params for LLMs (#25410)

This commit is contained in:
ccurme
2024-08-16 13:18:09 -04:00
committed by GitHub
parent 9f0c76bf89
commit b83f1eb0d5
17 changed files with 298 additions and 36 deletions

View File

@@ -16,7 +16,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
from langchain_core.language_models import BaseLLM, LangSmithParams
from langchain_core.outputs import GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from ollama import AsyncClient, Client, Options
@@ -155,6 +155,15 @@ class OllamaLLM(BaseLLM):
"""Return type of LLM."""
return "ollama-llm"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
if max_tokens := kwargs.get("num_predict", self.num_predict):
params["ls_max_tokens"] = max_tokens
return params
@root_validator(pre=False, skip_on_failure=True)
def _set_clients(cls, values: dict) -> dict:
"""Set clients to use for ollama."""