mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-25 12:44:04 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			167 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Hugging Face Chat Wrapper."""
 | |
| from typing import Any, List, Optional, Union
 | |
| 
 | |
| from langchain_core.callbacks.manager import (
 | |
|     AsyncCallbackManagerForLLMRun,
 | |
|     CallbackManagerForLLMRun,
 | |
| )
 | |
| from langchain_core.language_models.chat_models import BaseChatModel
 | |
| from langchain_core.messages import (
 | |
|     AIMessage,
 | |
|     BaseMessage,
 | |
|     HumanMessage,
 | |
|     SystemMessage,
 | |
| )
 | |
| from langchain_core.outputs import (
 | |
|     ChatGeneration,
 | |
|     ChatResult,
 | |
|     LLMResult,
 | |
| )
 | |
| 
 | |
| from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
 | |
| from langchain_community.llms.huggingface_hub import HuggingFaceHub
 | |
| from langchain_community.llms.huggingface_text_gen_inference import (
 | |
|     HuggingFaceTextGenInference,
 | |
| )
 | |
| 
 | |
| DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
 | |
| 
 | |
| 
 | |
| class ChatHuggingFace(BaseChatModel):
 | |
|     """
 | |
|     Wrapper for using Hugging Face LLM's as ChatModels.
 | |
| 
 | |
|     Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
 | |
|     and `HuggingFaceHub` LLMs.
 | |
| 
 | |
|     Upon instantiating this class, the model_id is resolved from the url
 | |
|     provided to the LLM, and the appropriate tokenizer is loaded from
 | |
|     the HuggingFace Hub.
 | |
| 
 | |
|     Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
 | |
|     """
 | |
| 
 | |
|     llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub]
 | |
|     system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
 | |
|     tokenizer: Any = None
 | |
|     model_id: Optional[str] = None
 | |
| 
 | |
|     def __init__(self, **kwargs: Any):
 | |
|         super().__init__(**kwargs)
 | |
| 
 | |
|         from transformers import AutoTokenizer
 | |
| 
 | |
|         self._resolve_model_id()
 | |
|         self.tokenizer = (
 | |
|             AutoTokenizer.from_pretrained(self.model_id)
 | |
|             if self.tokenizer is None
 | |
|             else self.tokenizer
 | |
|         )
 | |
| 
 | |
|     def _generate(
 | |
|         self,
 | |
|         messages: List[BaseMessage],
 | |
|         stop: Optional[List[str]] = None,
 | |
|         run_manager: Optional[CallbackManagerForLLMRun] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> ChatResult:
 | |
|         llm_input = self._to_chat_prompt(messages)
 | |
|         llm_result = self.llm._generate(
 | |
|             prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
 | |
|         )
 | |
|         return self._to_chat_result(llm_result)
 | |
| 
 | |
|     async def _agenerate(
 | |
|         self,
 | |
|         messages: List[BaseMessage],
 | |
|         stop: Optional[List[str]] = None,
 | |
|         run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> ChatResult:
 | |
|         llm_input = self._to_chat_prompt(messages)
 | |
|         llm_result = await self.llm._agenerate(
 | |
|             prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
 | |
|         )
 | |
|         return self._to_chat_result(llm_result)
 | |
| 
 | |
|     def _to_chat_prompt(
 | |
|         self,
 | |
|         messages: List[BaseMessage],
 | |
|     ) -> str:
 | |
|         """Convert a list of messages into a prompt format expected by wrapped LLM."""
 | |
|         if not messages:
 | |
|             raise ValueError("at least one HumanMessage must be provided")
 | |
| 
 | |
|         if not isinstance(messages[-1], HumanMessage):
 | |
|             raise ValueError("last message must be a HumanMessage")
 | |
| 
 | |
|         messages_dicts = [self._to_chatml_format(m) for m in messages]
 | |
| 
 | |
|         return self.tokenizer.apply_chat_template(
 | |
|             messages_dicts, tokenize=False, add_generation_prompt=True
 | |
|         )
 | |
| 
 | |
|     def _to_chatml_format(self, message: BaseMessage) -> dict:
 | |
|         """Convert LangChain message to ChatML format."""
 | |
| 
 | |
|         if isinstance(message, SystemMessage):
 | |
|             role = "system"
 | |
|         elif isinstance(message, AIMessage):
 | |
|             role = "assistant"
 | |
|         elif isinstance(message, HumanMessage):
 | |
|             role = "user"
 | |
|         else:
 | |
|             raise ValueError(f"Unknown message type: {type(message)}")
 | |
| 
 | |
|         return {"role": role, "content": message.content}
 | |
| 
 | |
|     @staticmethod
 | |
|     def _to_chat_result(llm_result: LLMResult) -> ChatResult:
 | |
|         chat_generations = []
 | |
| 
 | |
|         for g in llm_result.generations[0]:
 | |
|             chat_generation = ChatGeneration(
 | |
|                 message=AIMessage(content=g.text), generation_info=g.generation_info
 | |
|             )
 | |
|             chat_generations.append(chat_generation)
 | |
| 
 | |
|         return ChatResult(
 | |
|             generations=chat_generations, llm_output=llm_result.llm_output
 | |
|         )
 | |
| 
 | |
|     def _resolve_model_id(self) -> None:
 | |
|         """Resolve the model_id from the LLM's inference_server_url"""
 | |
| 
 | |
|         from huggingface_hub import list_inference_endpoints
 | |
| 
 | |
|         available_endpoints = list_inference_endpoints("*")
 | |
| 
 | |
|         if isinstance(self.llm, HuggingFaceTextGenInference):
 | |
|             endpoint_url = self.llm.inference_server_url
 | |
| 
 | |
|         elif isinstance(self.llm, HuggingFaceEndpoint):
 | |
|             endpoint_url = self.llm.endpoint_url
 | |
| 
 | |
|         elif isinstance(self.llm, HuggingFaceHub):
 | |
|             # no need to look up model_id for HuggingFaceHub LLM
 | |
|             self.model_id = self.llm.repo_id
 | |
|             return
 | |
| 
 | |
|         else:
 | |
|             raise ValueError(f"Unknown LLM type: {type(self.llm)}")
 | |
| 
 | |
|         for endpoint in available_endpoints:
 | |
|             if endpoint.url == endpoint_url:
 | |
|                 self.model_id = endpoint.repository
 | |
| 
 | |
|         if not self.model_id:
 | |
|             raise ValueError(
 | |
|                 "Failed to resolve model_id"
 | |
|                 f"Could not find model id for inference server provided: {endpoint_url}"
 | |
|                 "Make sure that your Hugging Face token has access to the endpoint."
 | |
|             )
 | |
| 
 | |
|     @property
 | |
|     def _llm_type(self) -> str:
 | |
|         return "huggingface-chat-wrapper"
 |