Update code to use sentence-transformers through huggingfaceembeddings

This commit is contained in:
Iván Martínez 2023-05-17 00:32:41 +02:00
parent 8a5b2f453b
commit 23d24c88e9
3 changed files with 13 additions and 12 deletions

View File

@ -1,5 +1,5 @@
PERSIST_DIRECTORY=db PERSIST_DIRECTORY=db
LLAMA_EMBEDDINGS_MODEL=models/ggml-model-q4_0.bin
MODEL_TYPE=GPT4All MODEL_TYPE=GPT4All
MODEL_PATH=models/ggml-gpt4all-j-v1.3-groovy.bin MODEL_PATH=models/ggml-gpt4all-j-v1.3-groovy.bin
EMBEDDINGS_MODEL_NAME=all-MiniLM-L6-v2
MODEL_N_CTX=1000 MODEL_N_CTX=1000

View File

@ -6,7 +6,7 @@ from dotenv import load_dotenv
from langchain.document_loaders import TextLoader, PDFMinerLoader, CSVLoader from langchain.document_loaders import TextLoader, PDFMinerLoader, CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from langchain.embeddings import LlamaCppEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
from constants import CHROMA_SETTINGS from constants import CHROMA_SETTINGS
@ -38,22 +38,23 @@ def main():
# Load environment variables # Load environment variables
persist_directory = os.environ.get('PERSIST_DIRECTORY') persist_directory = os.environ.get('PERSIST_DIRECTORY')
source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents') source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
llama_embeddings_model = os.environ.get('LLAMA_EMBEDDINGS_MODEL') embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME')
model_n_ctx = os.environ.get('MODEL_N_CTX')
# Load documents and split in chunks # Load documents and split in chunks
print(f"Loading documents from {source_directory}") print(f"Loading documents from {source_directory}")
chunk_size = 500
chunk_overlap = 50
documents = load_documents(source_directory) documents = load_documents(source_directory)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
texts = text_splitter.split_documents(documents) texts = text_splitter.split_documents(documents)
print(f"Loaded {len(documents)} documents from {source_directory}") print(f"Loaded {len(documents)} documents from {source_directory}")
print(f"Split into {len(texts)} chunks of text (max. 500 tokens each)") print(f"Split into {len(texts)} chunks of text (max. {chunk_size} characters each)")
# Create embeddings # Create embeddings
llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx) embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
# Create and store locally vectorstore # Create and store locally vectorstore
db = Chroma.from_documents(texts, llama, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS) db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
db.persist() db.persist()
db = None db = None

View File

@ -1,6 +1,6 @@
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from langchain.embeddings import LlamaCppEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from langchain.llms import GPT4All, LlamaCpp from langchain.llms import GPT4All, LlamaCpp
@ -8,7 +8,7 @@ import os
load_dotenv() load_dotenv()
llama_embeddings_model = os.environ.get("LLAMA_EMBEDDINGS_MODEL") embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY') persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE') model_type = os.environ.get('MODEL_TYPE')
@ -18,8 +18,8 @@ model_n_ctx = os.environ.get('MODEL_N_CTX')
from constants import CHROMA_SETTINGS from constants import CHROMA_SETTINGS
def main(): def main():
llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx) embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=llama, client_settings=CHROMA_SETTINGS) db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever() retriever = db.as_retriever()
# Prepare the LLM # Prepare the LLM
callbacks = [StreamingStdOutCallbackHandler()] callbacks = [StreamingStdOutCallbackHandler()]