Adds ChatOllama (#9628)

@rlancemartin

---------

Co-authored-by: Adilkhan Sarsen <54854336+adolkhan@users.noreply.github.com>
Co-authored-by: Kim Minjong <make.dirty.code@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Lance Martin <lance@langchain.dev>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Jacob Lee
2023-08-23 13:02:26 -07:00
committed by GitHub
parent fa05e18278
commit 278ef0bdcf
6 changed files with 550 additions and 16 deletions

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema.messages import BaseMessage
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
@@ -14,6 +15,14 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
) -> None:
"""Run when LLM starts running."""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> None:
"""Run when LLM starts running."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)

View File

@@ -27,6 +27,7 @@ from langchain.chat_models.human import HumanInputChatModel
from langchain.chat_models.jinachat import JinaChat
from langchain.chat_models.litellm import ChatLiteLLM
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
from langchain.chat_models.ollama import ChatOllama
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.chat_models.vertexai import ChatVertexAI
@@ -39,6 +40,7 @@ __all__ = [
"ChatAnthropic",
"ChatGooglePalm",
"ChatMLflowAIGateway",
"ChatOllama",
"ChatVertexAI",
"JinaChat",
"HumanInputChatModel",

View File

@@ -32,7 +32,7 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
.. code-block:: python
import anthropic
from langchain.llms import Anthropic
from langchain.chat_models import ChatAnthropic
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
"""

View File

@@ -0,0 +1,122 @@
import json
from typing import Any, Iterator, List, Optional
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.ollama import _OllamaCommon
from langchain.schema import ChatResult
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGeneration, ChatGenerationChunk
def _stream_response_to_chat_generation_chunk(
stream_response: str,
) -> ChatGenerationChunk:
"""Convert a stream response to a generation chunk."""
parsed_response = json.loads(stream_response)
generation_info = parsed_response if parsed_response.get("done") is True else None
return ChatGenerationChunk(
message=AIMessageChunk(content=parsed_response.get("response", "")),
generation_info=generation_info,
)
class ChatOllama(BaseChatModel, _OllamaCommon):
"""Ollama locally runs large language models.
To use, follow the instructions at https://ollama.ai/.
Example:
.. code-block:: python
from langchain.chat_models import ChatOllama
ollama = ChatOllama(model="llama2")
"""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "ollama-chat"
@property
def lc_serializable(self) -> bool:
return True
def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]"
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<<SYS>> {message.content} <</SYS>>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
return "\n".join(
[self._format_message_as_text(message) for message in messages]
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to Ollama's generate endpoint.
Args:
messages: The list of base messages to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
Chat generations from the model
Example:
.. code-block:: python
response = ollama([
HumanMessage(content="Tell me about the history of AI")
])
"""
prompt = self._format_messages_as_text(messages)
final_chunk = super()._stream_with_aggregation(
prompt, stop=stop, run_manager=run_manager, verbose=self.verbose, **kwargs
)
chat_generation = ChatGeneration(
message=AIMessage(content=final_chunk.text),
generation_info=final_chunk.generation_info,
)
return ChatResult(generations=[chat_generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._format_messages_as_text(messages)
for stream_resp in self._create_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)

View File

@@ -144,9 +144,35 @@ class _OllamaCommon(BaseLanguageModel):
)
return response.iter_lines(decode_unicode=True)
def _stream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for stream_resp in self._create_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
class Ollama(BaseLLM, _OllamaCommon):
"""Ollama locally run large language models.
"""Ollama locally runs large language models.
To use, follow the instructions at https://ollama.ai/.
@@ -191,20 +217,13 @@ class Ollama(BaseLLM, _OllamaCommon):
# TODO: add caching here.
generations = []
for prompt in prompts:
final_chunk: Optional[GenerationChunk] = None
for stream_resp in self._create_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
final_chunk = super()._stream_with_aggregation(
prompt,
stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
generations.append([final_chunk])
return LLMResult(generations=generations)