mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-30 22:08:45 +00:00
Hybrid search
This commit is contained in:
parent
759767dc1b
commit
fbd298212f
30
.github/workflows/actions/install_dependencies/action.yml
vendored
Normal file
30
.github/workflows/actions/install_dependencies/action.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
name: "Install Dependencies"
|
||||
description: "Action to build the project dependencies from the main versions"
|
||||
inputs:
|
||||
python_version:
|
||||
required: true
|
||||
type: string
|
||||
default: "3.11.4"
|
||||
poetry_version:
|
||||
required: true
|
||||
type: string
|
||||
default: "1.5.1"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1
|
||||
with:
|
||||
version: ${{ inputs.poetry_version }}
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: false
|
||||
installer-parallel: true
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ inputs.python_version }}
|
||||
cache: "poetry"
|
||||
- name: Install Dependencies
|
||||
run: poetry install --extras "ui vector-stores-qdrant" --no-root
|
||||
shell: bash
|
||||
|
45
.github/workflows/docker.yml
vendored
Normal file
45
.github/workflows/docker.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
name: docker
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [ published ]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=sha
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile.external
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
21
.github/workflows/fern-check.yml
vendored
Normal file
21
.github/workflows/fern-check.yml
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
name: fern check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "fern/**"
|
||||
|
||||
jobs:
|
||||
fern-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Fern
|
||||
run: npm install -g fern-api
|
||||
|
||||
- name: Check Fern API is valid
|
||||
run: fern check
|
50
.github/workflows/preview-docs.yml
vendored
Normal file
50
.github/workflows/preview-docs.yml
vendored
Normal file
@ -0,0 +1,50 @@
|
||||
name: deploy preview docs
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "fern/**"
|
||||
|
||||
jobs:
|
||||
preview-docs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: refs/pull/${{ github.event.pull_request.number }}/merge
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "18"
|
||||
|
||||
- name: Install Fern
|
||||
run: npm install -g fern-api
|
||||
|
||||
- name: Generate Documentation Preview with Fern
|
||||
id: generate_docs
|
||||
env:
|
||||
FERN_TOKEN: ${{ secrets.FERN_TOKEN }}
|
||||
run: |
|
||||
output=$(fern generate --docs --preview --log-level debug)
|
||||
echo "$output"
|
||||
# Extract the URL
|
||||
preview_url=$(echo "$output" | grep -oP '(?<=Published docs to )https://[^\s]*')
|
||||
# Set the output for the step
|
||||
echo "::set-output name=preview_url::$preview_url"
|
||||
- name: Comment PR with URL using github-actions bot
|
||||
uses: actions/github-script@v4
|
||||
if: ${{ steps.generate_docs.outputs.preview_url }}
|
||||
with:
|
||||
script: |
|
||||
const preview_url = '${{ steps.generate_docs.outputs.preview_url }}';
|
||||
const issue_number = context.issue.number;
|
||||
github.issues.createComment({
|
||||
...context.repo,
|
||||
issue_number: issue_number,
|
||||
body: `Published docs preview URL: ${preview_url}`
|
||||
})
|
26
.github/workflows/publish-docs.yml
vendored
Normal file
26
.github/workflows/publish-docs.yml
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
name: publish docs
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "fern/**"
|
||||
|
||||
jobs:
|
||||
publish-docs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v3
|
||||
|
||||
- name: Download Fern
|
||||
run: npm install -g fern-api
|
||||
|
||||
- name: Generate and Publish Docs
|
||||
env:
|
||||
FERN_TOKEN: ${{ secrets.FERN_TOKEN }}
|
||||
run: fern generate --docs --log-level debug
|
19
.github/workflows/release-please.yml
vendored
Normal file
19
.github/workflows/release-please.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: release-please
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
release-please:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: google-github-actions/release-please-action@v3
|
||||
with:
|
||||
release-type: simple
|
||||
version-file: version.txt
|
30
.github/workflows/stale.yml
vendored
Normal file
30
.github/workflows/stale.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
||||
#
|
||||
# You can adjust the behavior by modifying this file.
|
||||
# For more information, see:
|
||||
# https://github.com/actions/stale
|
||||
name: Mark stale issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '42 5 * * *'
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v8
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
days-before-stale: 15
|
||||
stale-issue-message: 'Stale issue'
|
||||
stale-pr-message: 'Stale pull request'
|
||||
stale-issue-label: 'stale'
|
||||
stale-pr-label: 'stale'
|
||||
exempt-issue-labels: 'autorelease: pending'
|
||||
exempt-pr-labels: 'autorelease: pending'
|
67
.github/workflows/tests.yml
vendored
Normal file
67
.github/workflows/tests.yml
vendored
Normal file
@ -0,0 +1,67 @@
|
||||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref || github.ref }}
|
||||
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: ./.github/workflows/actions/install_dependencies
|
||||
|
||||
checks:
|
||||
needs: setup
|
||||
runs-on: ubuntu-latest
|
||||
name: ${{ matrix.quality-command }}
|
||||
strategy:
|
||||
matrix:
|
||||
quality-command:
|
||||
- black
|
||||
- ruff
|
||||
- mypy
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: ./.github/workflows/actions/install_dependencies
|
||||
- name: run ${{ matrix.quality-command }}
|
||||
run: make ${{ matrix.quality-command }}
|
||||
|
||||
test:
|
||||
needs: setup
|
||||
runs-on: ubuntu-latest
|
||||
name: test
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: ./.github/workflows/actions/install_dependencies
|
||||
- name: run test
|
||||
run: make test-coverage
|
||||
# Run even if make test fails for coverage reports
|
||||
# TODO: select a better xml results displayer
|
||||
- name: Archive test results coverage results
|
||||
uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: test_results
|
||||
path: tests-results.xml
|
||||
- name: Archive code coverage results
|
||||
uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: code-coverage-report
|
||||
path: htmlcov/
|
||||
|
||||
all_checks_passed:
|
||||
# Used to easily force requirements checks in GitHub
|
||||
needs:
|
||||
- checks
|
||||
- test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo "All checks passed"
|
@ -0,0 +1,32 @@
|
||||
"""Update is_enabled to false by default
|
||||
|
||||
Revision ID: 14281ff34686
|
||||
Revises: b7b896502e8e
|
||||
Create Date: 2024-03-18 16:33:43.133458
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '14281ff34686'
|
||||
down_revision: Union[str, None] = 'b7b896502e8e'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||
# ### end Alembic commands ###
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
pass
|
||||
# ### end Alembic commands ###
|
38
alembic/versions/b7b896502e8e_update.py
Normal file
38
alembic/versions/b7b896502e8e_update.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""update
|
||||
|
||||
Revision ID: b7b896502e8e
|
||||
Revises:
|
||||
Create Date: 2024-03-17 15:07:10.795935
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b7b896502e8e'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint('document_department_association_department_id_fkey', 'document_department_association', type_='foreignkey')
|
||||
op.drop_constraint('document_department_association_document_id_fkey', 'document_department_association', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'document_department_association', 'document', ['document_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE')
|
||||
op.create_foreign_key(None, 'document_department_association', 'departments', ['department_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE')
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
op.drop_constraint(None, 'document_department_association', type_='foreignkey')
|
||||
op.drop_constraint(None, 'document_department_association', type_='foreignkey')
|
||||
op.create_foreign_key('document_department_association_document_id_fkey', 'document_department_association', 'document', ['document_id'], ['id'])
|
||||
op.create_foreign_key('document_department_association_department_id_fkey', 'document_department_association', 'departments', ['department_id'], ['id'])
|
||||
# ### end Alembic commands ###
|
@ -11,7 +11,8 @@ from typing import Any
|
||||
|
||||
from llama_index.core.data_structs import IndexDict
|
||||
from llama_index.core.embeddings.utils import EmbedType
|
||||
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage, SimpleKeywordTableIndex
|
||||
from private_gpt.utils.vector_store import VectorStoreIndex1
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.ingestion import run_transformations
|
||||
from llama_index.core.schema import BaseNode, Document, TransformComponent
|
||||
@ -83,7 +84,7 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
||||
except ValueError:
|
||||
# There are no index in the storage context, creating a new one
|
||||
logger.info("Creating a new vector store index")
|
||||
index = VectorStoreIndex.from_documents(
|
||||
index = VectorStoreIndex1.from_documents(
|
||||
[],
|
||||
storage_context=self.storage_context,
|
||||
store_nodes_override=True, # Force store nodes in index and document stores
|
||||
@ -92,6 +93,17 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
||||
transformations=self.transformations,
|
||||
)
|
||||
index.storage_context.persist(persist_dir=local_data_path)
|
||||
|
||||
keyword_index = SimpleKeywordTableIndex.from_documents(
|
||||
[],
|
||||
storage_context=self.storage_context,
|
||||
store_nodes_override=True, # Force store nodes in index and document stores
|
||||
show_progress=self.show_progress,
|
||||
transformations=self.transformations,
|
||||
llm=
|
||||
)
|
||||
# Store the keyword index in the vector store
|
||||
index.keyword_index = keyword_index
|
||||
return index
|
||||
|
||||
def _save_index(self) -> None:
|
||||
|
55
private_gpt/components/vector_store/retriever.py
Normal file
55
private_gpt/components/vector_store/retriever.py
Normal file
@ -0,0 +1,55 @@
|
||||
|
||||
|
||||
# import QueryBundle
|
||||
from llama_index.core import QueryBundle
|
||||
|
||||
# import NodeWithScore
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
# Retrievers
|
||||
from llama_index.core.retrievers import (
|
||||
BaseRetriever,
|
||||
VectorIndexRetriever,
|
||||
KeywordTableSimpleRetriever,
|
||||
)
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
class CustomRetriever(BaseRetriever):
|
||||
"""Custom retriever that performs both semantic search and hybrid search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_retriever: VectorIndexRetriever,
|
||||
keyword_retriever: KeywordTableSimpleRetriever,
|
||||
mode: str = "AND",
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
|
||||
self._vector_retriever = vector_retriever
|
||||
self._keyword_retriever = keyword_retriever
|
||||
if mode not in ("AND", "OR"):
|
||||
raise ValueError("Invalid mode.")
|
||||
self._mode = mode
|
||||
super().__init__()
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
"""Retrieve nodes given query."""
|
||||
|
||||
vector_nodes = self._vector_retriever.retrieve(query_bundle)
|
||||
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
|
||||
|
||||
vector_ids = {n.node.node_id for n in vector_nodes}
|
||||
keyword_ids = {n.node.node_id for n in keyword_nodes}
|
||||
|
||||
combined_dict = {n.node.node_id: n for n in vector_nodes}
|
||||
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
|
||||
|
||||
if self._mode == "AND":
|
||||
retrieve_ids = vector_ids.intersection(keyword_ids)
|
||||
else:
|
||||
retrieve_ids = vector_ids.union(keyword_ids)
|
||||
|
||||
retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
|
||||
return retrieve_nodes
|
@ -3,12 +3,14 @@ import typing
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
|
||||
from llama_index.core.indices import SimpleKeywordTableIndex
|
||||
from llama_index.core.vector_stores.types import (
|
||||
FilterCondition,
|
||||
MetadataFilter,
|
||||
MetadataFilters,
|
||||
VectorStore,
|
||||
)
|
||||
from private_gpt.utils.vector_store import VectorStoreIndex1
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.paths import local_data_path
|
||||
@ -130,21 +132,43 @@ class VectorStoreComponent:
|
||||
|
||||
def get_retriever(
|
||||
self,
|
||||
index: VectorStoreIndex,
|
||||
index: VectorStoreIndex1,
|
||||
context_filter: ContextFilter | None = None,
|
||||
similarity_top_k: int = 2,
|
||||
) -> VectorIndexRetriever:
|
||||
# This way we support qdrant (using doc_ids) and the rest (using filters)
|
||||
return VectorIndexRetriever(
|
||||
|
||||
# from llama_index.core import get_response_synthesizer
|
||||
# from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from .retriever import CustomRetriever
|
||||
from llama_index.core.retrievers import (
|
||||
VectorIndexRetriever,
|
||||
KeywordTableSimpleRetriever,
|
||||
)
|
||||
|
||||
vector_retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=similarity_top_k,
|
||||
doc_ids=context_filter.docs_ids if context_filter else None,
|
||||
filters=(
|
||||
_doc_id_metadata_filter(context_filter)
|
||||
_doc_id_metadata_filter(context_filter)
|
||||
if self.settings.vectorstore.database != "qdrant"
|
||||
else None
|
||||
),
|
||||
)
|
||||
keyword_retriever = KeywordTableSimpleRetriever(index=index.keyword_index)
|
||||
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever)
|
||||
|
||||
# define response synthesizer
|
||||
# response_synthesizer = get_response_synthesizer()
|
||||
|
||||
# # assemble query engine
|
||||
# custom_query_engine = RetrieverQueryEngine(
|
||||
# retriever=custom_retriever,
|
||||
# response_synthesizer=response_synthesizer,
|
||||
# )
|
||||
|
||||
return custom_retriever
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self.vector_store.client, "close"):
|
||||
|
@ -5,7 +5,7 @@ from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
)
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.core.indices import VectorStoreIndex, SimpleKeywordTableIndex
|
||||
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from llama_index.core.postprocessor import (
|
||||
@ -99,6 +99,12 @@ class ChatService:
|
||||
embed_model=embedding_component.embedding_model,
|
||||
show_progress=True,
|
||||
)
|
||||
self.keyword_index = SimpleKeywordTableIndex.from_documents(
|
||||
vector_store_component.vector_store,
|
||||
storage_context=self.storage_context,
|
||||
embed_model=embedding_component.embedding_model,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
def _chat_engine(
|
||||
self,
|
||||
@ -110,6 +116,7 @@ class ChatService:
|
||||
if use_context:
|
||||
vector_index_retriever = self.vector_store_component.get_retriever(
|
||||
index=self.index,
|
||||
keyword_index=self.keyword_index,
|
||||
context_filter=context_filter,
|
||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||
)
|
||||
|
13
private_gpt/utils/vector_store.py
Normal file
13
private_gpt/utils/vector_store.py
Normal file
@ -0,0 +1,13 @@
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
from llama_index.core.indices import SimpleKeywordTableIndex
|
||||
|
||||
class VectorStoreIndex1(VectorStoreIndex):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.keyword_index = None
|
||||
|
||||
def set_keyword_index(self, keyword_index: SimpleKeywordTableIndex):
|
||||
self.keyword_index = keyword_index
|
||||
|
||||
def get_keyword_index(self) -> SimpleKeywordTableIndex:
|
||||
return self.keyword_index
|
Loading…
Reference in New Issue
Block a user