mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	Titan Express model was not supported as a chat model because LangChain messages were not "translated" to a text prompt. Co-authored-by: Guillem Orellana Trullols <guillem.orellana_trullols@siemens.com>
		
			
				
	
	
		
			138 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			138 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Any, Dict, Iterator, List, Optional
 | 
						|
 | 
						|
from langchain_core.callbacks import (
 | 
						|
    CallbackManagerForLLMRun,
 | 
						|
)
 | 
						|
from langchain_core.language_models.chat_models import BaseChatModel
 | 
						|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
 | 
						|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
 | 
						|
from langchain_core.pydantic_v1 import Extra
 | 
						|
 | 
						|
from langchain_community.chat_models.anthropic import (
 | 
						|
    convert_messages_to_prompt_anthropic,
 | 
						|
)
 | 
						|
from langchain_community.chat_models.meta import convert_messages_to_prompt_llama
 | 
						|
from langchain_community.llms.bedrock import BedrockBase
 | 
						|
from langchain_community.utilities.anthropic import (
 | 
						|
    get_num_tokens_anthropic,
 | 
						|
    get_token_ids_anthropic,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
class ChatPromptAdapter:
 | 
						|
    """Adapter class to prepare the inputs from Langchain to prompt format
 | 
						|
    that Chat model expects.
 | 
						|
    """
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def convert_messages_to_prompt(
 | 
						|
        cls, provider: str, messages: List[BaseMessage]
 | 
						|
    ) -> str:
 | 
						|
        if provider == "anthropic":
 | 
						|
            prompt = convert_messages_to_prompt_anthropic(messages=messages)
 | 
						|
        elif provider == "meta":
 | 
						|
            prompt = convert_messages_to_prompt_llama(messages=messages)
 | 
						|
        elif provider == "amazon":
 | 
						|
            prompt = convert_messages_to_prompt_anthropic(
 | 
						|
                messages=messages,
 | 
						|
                human_prompt="\n\nUser:",
 | 
						|
                ai_prompt="\n\nBot:",
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise NotImplementedError(
 | 
						|
                f"Provider {provider} model does not support chat."
 | 
						|
            )
 | 
						|
        return prompt
 | 
						|
 | 
						|
 | 
						|
class BedrockChat(BaseChatModel, BedrockBase):
 | 
						|
    """A chat model that uses the Bedrock API."""
 | 
						|
 | 
						|
    @property
 | 
						|
    def _llm_type(self) -> str:
 | 
						|
        """Return type of chat model."""
 | 
						|
        return "amazon_bedrock_chat"
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def is_lc_serializable(cls) -> bool:
 | 
						|
        """Return whether this model can be serialized by Langchain."""
 | 
						|
        return True
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def get_lc_namespace(cls) -> List[str]:
 | 
						|
        """Get the namespace of the langchain object."""
 | 
						|
        return ["langchain", "chat_models", "bedrock"]
 | 
						|
 | 
						|
    @property
 | 
						|
    def lc_attributes(self) -> Dict[str, Any]:
 | 
						|
        attributes: Dict[str, Any] = {}
 | 
						|
 | 
						|
        if self.region_name:
 | 
						|
            attributes["region_name"] = self.region_name
 | 
						|
 | 
						|
        return attributes
 | 
						|
 | 
						|
    class Config:
 | 
						|
        """Configuration for this pydantic object."""
 | 
						|
 | 
						|
        extra = Extra.forbid
 | 
						|
 | 
						|
    def _stream(
 | 
						|
        self,
 | 
						|
        messages: List[BaseMessage],
 | 
						|
        stop: Optional[List[str]] = None,
 | 
						|
        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> Iterator[ChatGenerationChunk]:
 | 
						|
        provider = self._get_provider()
 | 
						|
        prompt = ChatPromptAdapter.convert_messages_to_prompt(
 | 
						|
            provider=provider, messages=messages
 | 
						|
        )
 | 
						|
 | 
						|
        for chunk in self._prepare_input_and_invoke_stream(
 | 
						|
            prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
 | 
						|
        ):
 | 
						|
            delta = chunk.text
 | 
						|
            yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
 | 
						|
 | 
						|
    def _generate(
 | 
						|
        self,
 | 
						|
        messages: List[BaseMessage],
 | 
						|
        stop: Optional[List[str]] = None,
 | 
						|
        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> ChatResult:
 | 
						|
        completion = ""
 | 
						|
 | 
						|
        if self.streaming:
 | 
						|
            for chunk in self._stream(messages, stop, run_manager, **kwargs):
 | 
						|
                completion += chunk.text
 | 
						|
        else:
 | 
						|
            provider = self._get_provider()
 | 
						|
            prompt = ChatPromptAdapter.convert_messages_to_prompt(
 | 
						|
                provider=provider, messages=messages
 | 
						|
            )
 | 
						|
 | 
						|
            params: Dict[str, Any] = {**kwargs}
 | 
						|
            if stop:
 | 
						|
                params["stop_sequences"] = stop
 | 
						|
 | 
						|
            completion = self._prepare_input_and_invoke(
 | 
						|
                prompt=prompt, stop=stop, run_manager=run_manager, **params
 | 
						|
            )
 | 
						|
 | 
						|
        message = AIMessage(content=completion)
 | 
						|
        return ChatResult(generations=[ChatGeneration(message=message)])
 | 
						|
 | 
						|
    def get_num_tokens(self, text: str) -> int:
 | 
						|
        if self._model_is_anthropic:
 | 
						|
            return get_num_tokens_anthropic(text)
 | 
						|
        else:
 | 
						|
            return super().get_num_tokens(text)
 | 
						|
 | 
						|
    def get_token_ids(self, text: str) -> List[int]:
 | 
						|
        if self._model_is_anthropic:
 | 
						|
            return get_token_ids_anthropic(text)
 | 
						|
        else:
 | 
						|
            return super().get_token_ids(text)
 |