Browse Source

feat: support elasticsearch vector database (#3558)

Co-authored-by: miendinh <miendinh@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
miendinh 8 months ago
parent
commit
f104b930cf

+ 2 - 1
.github/workflows/api-tests.yml

@@ -76,7 +76,7 @@ jobs:
       - name: Run Workflow
       - name: Run Workflow
         run: poetry run -C api bash dev/pytest/pytest_workflow.sh
         run: poetry run -C api bash dev/pytest/pytest_workflow.sh
 
 
-      - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
+      - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch)
         uses: hoverkraft-tech/compose-action@v2.0.0
         uses: hoverkraft-tech/compose-action@v2.0.0
         with:
         with:
           compose-file: |
           compose-file: |
@@ -90,5 +90,6 @@ jobs:
             pgvecto-rs
             pgvecto-rs
             pgvector
             pgvector
             chroma
             chroma
+            elasticsearch
       - name: Test Vector Stores
       - name: Test Vector Stores
         run: poetry run -C api bash dev/pytest/pytest_vdb.sh
         run: poetry run -C api bash dev/pytest/pytest_vdb.sh

+ 2 - 1
.github/workflows/expose_service_ports.sh

@@ -6,5 +6,6 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
 yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
 yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
 yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
 yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
 yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
 yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
+yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
 
 
-echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs."
+echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch"

+ 6 - 0
api/.env.example

@@ -130,6 +130,12 @@ TENCENT_VECTOR_DB_DATABASE=dify
 TENCENT_VECTOR_DB_SHARD=1
 TENCENT_VECTOR_DB_SHARD=1
 TENCENT_VECTOR_DB_REPLICAS=2
 TENCENT_VECTOR_DB_REPLICAS=2
 
 
+# ElasticSearch configuration
+ELASTICSEARCH_HOST=127.0.0.1
+ELASTICSEARCH_PORT=9200
+ELASTICSEARCH_USERNAME=elastic
+ELASTICSEARCH_PASSWORD=elastic
+
 # PGVECTO_RS configuration
 # PGVECTO_RS configuration
 PGVECTO_RS_HOST=localhost
 PGVECTO_RS_HOST=localhost
 PGVECTO_RS_PORT=5431
 PGVECTO_RS_PORT=5431

+ 8 - 0
api/commands.py

@@ -344,6 +344,14 @@ def migrate_knowledge_vector_database():
                         "vector_store": {"class_prefix": collection_name}
                         "vector_store": {"class_prefix": collection_name}
                     }
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
+                elif vector_type == VectorType.ELASTICSEARCH:
+                    dataset_id = dataset.id
+                    index_name = Dataset.gen_collection_name_by_id(dataset_id)
+                    index_struct_dict = {
+                        "type": 'elasticsearch',
+                        "vector_store": {"class_prefix": index_name}
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
                 else:
                 else:
                     raise ValueError(f"Vector store {vector_type} is not supported.")
                     raise ValueError(f"Vector store {vector_type} is not supported.")
 
 

+ 2 - 2
api/controllers/console/datasets/datasets.py

@@ -555,7 +555,7 @@ class DatasetRetrievalSettingApi(Resource):
                         RetrievalMethod.SEMANTIC_SEARCH.value
                         RetrievalMethod.SEMANTIC_SEARCH.value
                     ]
                     ]
                 }
                 }
-            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
+            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
                 return {
                 return {
                     'retrieval_method': [
                     'retrieval_method': [
                         RetrievalMethod.SEMANTIC_SEARCH.value,
                         RetrievalMethod.SEMANTIC_SEARCH.value,
@@ -579,7 +579,7 @@ class DatasetRetrievalSettingMockApi(Resource):
                         RetrievalMethod.SEMANTIC_SEARCH.value
                         RetrievalMethod.SEMANTIC_SEARCH.value
                     ]
                     ]
                 }
                 }
-            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
+            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
                 return {
                 return {
                     'retrieval_method': [
                     'retrieval_method': [
                         RetrievalMethod.SEMANTIC_SEARCH.value,
                         RetrievalMethod.SEMANTIC_SEARCH.value,

+ 0 - 0
api/core/rag/datasource/vdb/elasticsearch/__init__.py


+ 191 - 0
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

@@ -0,0 +1,191 @@
+import json
+from typing import Any
+
+import requests
+from elasticsearch import Elasticsearch
+from flask import current_app
+from pydantic import BaseModel, model_validator
+
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.models.document import Document
+from models.dataset import Dataset
+
+
+class ElasticSearchConfig(BaseModel):
+    host: str
+    port: str
+    username: str
+    password: str
+
+    @model_validator(mode='before')
+    def validate_config(cls, values: dict) -> dict:
+        if not values['host']:
+            raise ValueError("config HOST is required")
+        if not values['port']:
+            raise ValueError("config PORT is required")
+        if not values['username']:
+            raise ValueError("config USERNAME is required")
+        if not values['password']:
+            raise ValueError("config PASSWORD is required")
+        return values
+
+
+class ElasticSearchVector(BaseVector):
+    def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
+        super().__init__(index_name.lower())
+        self._client = self._init_client(config)
+        self._attributes = attributes
+
+    def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
+        try:
+            client = Elasticsearch(
+                hosts=f'{config.host}:{config.port}',
+                basic_auth=(config.username, config.password),
+                request_timeout=100000,
+                retry_on_timeout=True,
+                max_retries=10000,
+            )
+        except requests.exceptions.ConnectionError:
+            raise ConnectionError("Vector database connection error")
+
+        return client
+
+    def get_type(self) -> str:
+        return 'elasticsearch'
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        uuids = self._get_uuids(documents)
+        texts = [d.page_content for d in documents]
+        metadatas = [d.metadata for d in documents]
+
+        if not self._client.indices.exists(index=self._collection_name):
+            dim = len(embeddings[0])
+            mapping = {
+                "properties": {
+                    "text": {
+                        "type": "text"
+                    },
+                    "vector": {
+                        "type": "dense_vector",
+                        "index": True,
+                        "dims": dim,
+                        "similarity": "l2_norm"
+                    },
+                }
+            }
+            self._client.indices.create(index=self._collection_name, mappings=mapping)
+
+        added_ids = []
+        for i, text in enumerate(texts):
+            self._client.index(index=self._collection_name,
+                               id=uuids[i],
+                               document={
+                                   "text": text,
+                                   "vector": embeddings[i] if embeddings[i] else None,
+                                   "metadata": metadatas[i] if metadatas[i] else {},
+                               })
+            added_ids.append(uuids[i])
+
+        self._client.indices.refresh(index=self._collection_name)
+        return uuids
+
+    def text_exists(self, id: str) -> bool:
+        return self._client.exists(index=self._collection_name, id=id).__bool__()
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        for id in ids:
+            self._client.delete(index=self._collection_name, id=id)
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        query_str = {
+            'query': {
+                'match': {
+                    f'metadata.{key}': f'{value}'
+                }
+            }
+        }
+        results = self._client.search(index=self._collection_name, body=query_str)
+        ids = [hit['_id'] for hit in results['hits']['hits']]
+        if ids:
+            self.delete_by_ids(ids)
+
+    def delete(self) -> None:
+        self._client.indices.delete(index=self._collection_name)
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        query_str = {
+            "query": {
+                "script_score": {
+                    "query": {
+                        "match_all": {}
+                    },
+                    "script": {
+                        "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
+                        "params": {
+                            "query_vector": query_vector
+                        }
+                    }
+                }
+            }
+        }
+
+        results = self._client.search(index=self._collection_name, body=query_str)
+
+        docs_and_scores = []
+        for hit in results['hits']['hits']:
+            docs_and_scores.append(
+                (Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
+
+        docs = []
+        for doc, score in docs_and_scores:
+            score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
+            if score > score_threshold:
+                doc.metadata['score'] = score
+            docs.append(doc)
+
+        # Sort the documents by score in descending order
+        docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
+
+        return docs
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        query_str = {
+            "match": {
+                "text": query
+            }
+        }
+        results = self._client.search(index=self._collection_name, query=query_str)
+        docs = []
+        for hit in results['hits']['hits']:
+            docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
+
+        return docs
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        return self.add_texts(texts, embeddings, **kwargs)
+
+
+class ElasticSearchVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
+
+        config = current_app.config
+        return ElasticSearchVector(
+            index_name=collection_name,
+            config=ElasticSearchConfig(
+                host=config.get('ELASTICSEARCH_HOST'),
+                port=config.get('ELASTICSEARCH_PORT'),
+                username=config.get('ELASTICSEARCH_USERNAME'),
+                password=config.get('ELASTICSEARCH_PASSWORD'),
+            ),
+            attributes=[]
+        )

+ 3 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -71,6 +71,9 @@ class Vector:
             case VectorType.RELYT:
             case VectorType.RELYT:
                 from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
                 from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
                 return RelytVectorFactory
                 return RelytVectorFactory
+            case VectorType.ELASTICSEARCH:
+                from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
+                return ElasticSearchVectorFactory
             case VectorType.TIDB_VECTOR:
             case VectorType.TIDB_VECTOR:
                 from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
                 from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
                 return TiDBVectorFactory
                 return TiDBVectorFactory

+ 1 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -15,3 +15,4 @@ class VectorType(str, Enum):
     OPENSEARCH = 'opensearch'
     OPENSEARCH = 'opensearch'
     TENCENT = 'tencent'
     TENCENT = 'tencent'
     ORACLE = 'oracle'
     ORACLE = 'oracle'
+    ELASTICSEARCH = 'elasticsearch'

+ 39 - 1
api/poetry.lock

@@ -2101,6 +2101,44 @@ dev = ["mypy (>=1.11.0)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff
 lxml = ["lxml (>=5.2.2)"]
 lxml = ["lxml (>=5.2.2)"]
 
 
 [[package]]
 [[package]]
+name = "elastic-transport"
+version = "8.15.0"
+description = "Transport classes and utilities shared among Python Elastic client libraries"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "elastic_transport-8.15.0-py3-none-any.whl", hash = "sha256:d7080d1dada2b4eee69e7574f9c17a76b42f2895eff428e562f94b0360e158c0"},
+    {file = "elastic_transport-8.15.0.tar.gz", hash = "sha256:85d62558f9baafb0868c801233a59b235e61d7b4804c28c2fadaa866b6766233"},
+]
+
+[package.dependencies]
+certifi = "*"
+urllib3 = ">=1.26.2,<3"
+
+[package.extras]
+develop = ["aiohttp", "furo", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"]
+
+[[package]]
+name = "elasticsearch"
+version = "8.14.0"
+description = "Python client for Elasticsearch"
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "elasticsearch-8.14.0-py3-none-any.whl", hash = "sha256:cef8ef70a81af027f3da74a4f7d9296b390c636903088439087b8262a468c130"},
+    {file = "elasticsearch-8.14.0.tar.gz", hash = "sha256:aa2490029dd96f4015b333c1827aa21fd6c0a4d223b00dfb0fe933b8d09a511b"},
+]
+
+[package.dependencies]
+elastic-transport = ">=8.13,<9"
+
+[package.extras]
+async = ["aiohttp (>=3,<4)"]
+orjson = ["orjson (>=3)"]
+requests = ["requests (>=2.4.0,!=2.32.2,<3.0.0)"]
+vectorstore-mmr = ["numpy (>=1)", "simsimd (>=3)"]
+
+[[package]]
 name = "emoji"
 name = "emoji"
 version = "2.12.1"
 version = "2.12.1"
 description = "Emoji for Python"
 description = "Emoji for Python"
@@ -9546,4 +9584,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 [metadata]
 lock-version = "2.0"
 lock-version = "2.0"
 python-versions = ">=3.10,<3.13"
 python-versions = ">=3.10,<3.13"
-content-hash = "2b822039247a445f72e04e967aef84f841781e2789b70071acad022f36ba26a5"
+content-hash = "05dfa6b9bce9ed8ac21caf58eff1596f146080ab2ab6987924b189be673c22cf"

+ 1 - 0
api/pyproject.toml

@@ -181,6 +181,7 @@ zhipuai = "1.0.7"
 rank-bm25 = "~0.2.2"
 rank-bm25 = "~0.2.2"
 openpyxl = "^3.1.5"
 openpyxl = "^3.1.5"
 kaleido = "0.2.1"
 kaleido = "0.2.1"
+elasticsearch = "8.14.0"
 
 
 ############################################################
 ############################################################
 # Tool dependencies required by tool implementations
 # Tool dependencies required by tool implementations

+ 0 - 0
api/tests/integration_tests/vdb/elasticsearch/__init__.py


+ 25 - 0
api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py

@@ -0,0 +1,25 @@
+from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
+from tests.integration_tests.vdb.test_vector_store import (
+    AbstractVectorTest,
+    setup_mock_redis,
+)
+
+
+class ElasticSearchVectorTest(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
+        self.vector = ElasticSearchVector(
+            index_name=self.collection_name.lower(),
+            config=ElasticSearchConfig(
+                host='http://localhost',
+                port='9200',
+                username='elastic',
+                password='elastic'
+            ),
+            attributes=self.attributes
+        )
+
+
+def test_elasticsearch_vector(setup_mock_redis):
+    ElasticSearchVectorTest().run_all_tests()

+ 1 - 0
dev/pytest/pytest_vdb.sh

@@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \
   api/tests/integration_tests/vdb/pgvector \
   api/tests/integration_tests/vdb/pgvector \
   api/tests/integration_tests/vdb/qdrant \
   api/tests/integration_tests/vdb/qdrant \
   api/tests/integration_tests/vdb/weaviate \
   api/tests/integration_tests/vdb/weaviate \
+  api/tests/integration_tests/vdb/elasticsearch \
   api/tests/integration_tests/vdb/test_vector_store.py
   api/tests/integration_tests/vdb/test_vector_store.py

+ 10 - 0
docker-legacy/docker-compose.yaml

@@ -169,6 +169,11 @@ services:
       CHROMA_DATABASE: default_database
       CHROMA_DATABASE: default_database
       CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
       CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
       CHROMA_AUTH_CREDENTIALS: xxxxxx
       CHROMA_AUTH_CREDENTIALS: xxxxxx
+      # ElasticSearch Config
+      ELASTICSEARCH_HOST: 127.0.0.1
+      ELASTICSEARCH_PORT: 9200
+      ELASTICSEARCH_USERNAME: elastic
+      ELASTICSEARCH_PASSWORD: elastic
       # Mail configuration, support: resend, smtp
       # Mail configuration, support: resend, smtp
       MAIL_TYPE: ''
       MAIL_TYPE: ''
       # default send from email address, if not specified
       # default send from email address, if not specified
@@ -371,6 +376,11 @@ services:
       CHROMA_DATABASE: default_database
       CHROMA_DATABASE: default_database
       CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
       CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
       CHROMA_AUTH_CREDENTIALS: xxxxxx
       CHROMA_AUTH_CREDENTIALS: xxxxxx
+      # ElasticSearch Config
+      ELASTICSEARCH_HOST: 127.0.0.1
+      ELASTICSEARCH_PORT: 9200
+      ELASTICSEARCH_USERNAME: elastic
+      ELASTICSEARCH_PASSWORD: elastic
       # Notion import configuration, support public and internal
       # Notion import configuration, support public and internal
       NOTION_INTEGRATION_TYPE: public
       NOTION_INTEGRATION_TYPE: public
       NOTION_CLIENT_SECRET: you-client-secret
       NOTION_CLIENT_SECRET: you-client-secret

+ 25 - 0
docker/docker-compose.yaml

@@ -125,6 +125,10 @@ x-shared-env: &shared-api-worker-env
   CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database}
   CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database}
   CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider}
   CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider}
   CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-}
   CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-}
+  ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-127.0.0.1}
+  ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200}
+  ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic}
+  ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
   # AnalyticDB configuration
   # AnalyticDB configuration
   ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-}
   ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-}
   ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-}
   ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-}
@@ -595,6 +599,27 @@ services:
     ports:
     ports:
       - "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}"
       - "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}"
 
 
+  elasticsearch:
+    image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3
+    container_name: elasticsearch
+    profiles:
+      - elasticsearch
+    restart: always
+    environment:
+      - "ELASTIC_PASSWORD=${ELASTICSEARCH_USERNAME:-elastic}"
+      - "cluster.name=dify-es-cluster"
+      - "node.name=dify-es0"
+      - "discovery.type=single-node"
+      - "xpack.security.http.ssl.enabled=false"
+      - "xpack.license.self_generated.type=trial"
+    ports:
+      - "${ELASTICSEARCH_PORT:-9200}:${ELASTICSEARCH_PORT:-9200}"
+    healthcheck:
+      test: ["CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty"]
+      interval: 30s
+      timeout: 10s
+      retries: 50
+
   # unstructured .
   # unstructured .
   # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.)
   # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.)
   unstructured:
   unstructured:

File diff suppressed because it is too large
+ 0 - 9372
web/yarn.lock