Parcourir la source

Add Oracle23ai as a vector datasource (#5342)

Co-authored-by: walter from vm <walter.jin@oracle.com>
tmuife il y a 10 mois
Parent
commit
6a09409ec9

+ 4 - 0
.github/workflows/api-tests.yml

@@ -68,6 +68,7 @@ jobs:
             docker/docker-compose.pgvecto-rs.yaml
             docker/docker-compose.pgvector.yaml
             docker/docker-compose.chroma.yaml
+            docker/docker-compose.oracle.yaml
           services: |
             weaviate
             qdrant
@@ -77,6 +78,7 @@ jobs:
             pgvecto-rs
             pgvector
             chroma
+            oracle
 
       - name: Test Vector Stores
         run: dev/pytest/pytest_vdb.sh
@@ -145,6 +147,7 @@ jobs:
             docker/docker-compose.pgvecto-rs.yaml
             docker/docker-compose.pgvector.yaml
             docker/docker-compose.chroma.yaml
+            docker/docker-compose.oracle.yaml
           services: |
             weaviate
             qdrant
@@ -154,6 +157,7 @@ jobs:
             pgvecto-rs
             pgvector
             chroma
+            oracle
 
       - name: Test Vector Stores
         run: poetry run -C api bash dev/pytest/pytest_vdb.sh

+ 2 - 0
api/configs/middleware/__init__.py

@@ -6,6 +6,7 @@ from configs.middleware.redis_configs import RedisConfigs
 from configs.middleware.vdb.chroma_configs import ChromaConfigs
 from configs.middleware.vdb.milvus_configs import MilvusConfigs
 from configs.middleware.vdb.opensearch_configs import OpenSearchConfigs
+from configs.middleware.vdb.oracle_configs import OracleConfigs
 from configs.middleware.vdb.pgvector_configs import PGVectorConfigs
 from configs.middleware.vdb.pgvectors_configs import PGVectoRSConfigs
 from configs.middleware.vdb.qdrant_configs import QdrantConfigs
@@ -61,5 +62,6 @@ class MiddlewareConfigs(
     TencentVectorDBConfigs,
     TiDBVectorConfigs,
     WeaviateConfigs,
+    OracleConfigs,
 ):
     pass

+ 34 - 0
api/configs/middleware/vdb/oracle_configs.py

@@ -0,0 +1,34 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field, PositiveInt
+
+
+class OracleConfigs(BaseModel):
+    """
+    ORACLE configs
+    """
+
+    ORACLE_HOST: Optional[str] = Field(
+        description='ORACLE host',
+        default=None,
+    )
+
+    ORACLE_PORT: Optional[PositiveInt] = Field(
+        description='ORACLE port',
+        default=None,
+    )
+
+    ORACLE_USER: Optional[str] = Field(
+        description='ORACLE user',
+        default=None,
+    )
+
+    ORACLE_PASSWORD: Optional[str] = Field(
+        description='ORACLE password',
+        default=None,
+    )
+
+    ORACLE_DATABASE: Optional[str] = Field(
+        description='ORACLE database',
+        default=None,
+    )

+ 2 - 2
api/controllers/console/datasets/datasets.py

@@ -498,7 +498,7 @@ class DatasetRetrievalSettingApi(Resource):
     def get(self):
         vector_type = current_app.config['VECTOR_STORE']
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
                 return {
                     'retrieval_method': [
                         RetrievalMethod.SEMANTIC_SEARCH
@@ -522,7 +522,7 @@ class DatasetRetrievalSettingMockApi(Resource):
     @account_initialization_required
     def get(self, vector_type):
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCEN:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
                 return {
                     'retrieval_method': [
                         RetrievalMethod.SEMANTIC_SEARCH

+ 0 - 0
api/core/rag/datasource/vdb/oracle/__init__.py


+ 239 - 0
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -0,0 +1,239 @@
+import array
+import json
+import uuid
+from contextlib import contextmanager
+from typing import Any
+
+import numpy
+import oracledb
+from flask import current_app
+from pydantic import BaseModel, model_validator
+
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+oracledb.defaults.fetch_lobs = False
+
+
+class OracleVectorConfig(BaseModel):
+    host: str
+    port: int
+    user: str
+    password: str
+    database: str
+
+    @model_validator(mode='before')
+    def validate_config(cls, values: dict) -> dict:
+        if not values["host"]:
+            raise ValueError("config ORACLE_HOST is required")
+        if not values["port"]:
+            raise ValueError("config ORACLE_PORT is required")
+        if not values["user"]:
+            raise ValueError("config ORACLE_USER is required")
+        if not values["password"]:
+            raise ValueError("config ORACLE_PASSWORD is required")
+        if not values["database"]:
+            raise ValueError("config ORACLE_DB is required")
+        return values
+
+
+SQL_CREATE_TABLE = """
+CREATE TABLE IF NOT EXISTS {table_name} (
+    id varchar2(100) 
+    ,text CLOB NOT NULL
+    ,meta JSON
+    ,embedding vector NOT NULL
+) 
+"""
+
+
+class OracleVector(BaseVector):
+    def __init__(self, collection_name: str, config: OracleVectorConfig):
+        super().__init__(collection_name)
+        self.pool = self._create_connection_pool(config)
+        self.table_name = f"embedding_{collection_name}"
+
+    def get_type(self) -> str:
+        return VectorType.ORACLE
+
+    def numpy_converter_in(self, value):
+        if value.dtype == numpy.float64:
+            dtype = "d"
+        elif value.dtype == numpy.float32:
+            dtype = "f"
+        else:
+            dtype = "b"
+        return array.array(dtype, value)
+
+    def input_type_handler(self, cursor, value, arraysize):
+        if isinstance(value, numpy.ndarray):
+            return cursor.var(
+                oracledb.DB_TYPE_VECTOR,
+                arraysize=arraysize,
+                inconverter=self.numpy_converter_in,
+            )
+
+    def numpy_converter_out(self, value):
+        if value.typecode == "b":
+            dtype = numpy.int8
+        elif value.typecode == "f":
+            dtype = numpy.float32
+        else:
+            dtype = numpy.float64
+        return numpy.array(value, copy=False, dtype=dtype)
+
+    def output_type_handler(self, cursor, metadata):
+        if metadata.type_code is oracledb.DB_TYPE_VECTOR:
+            return cursor.var(
+                metadata.type_code,
+                arraysize=cursor.arraysize,
+                outconverter=self.numpy_converter_out,
+            )
+    def _create_connection_pool(self, config: OracleVectorConfig):
+        return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1)
+
+
+    @contextmanager
+    def _get_cursor(self):
+        conn = self.pool.acquire()
+        conn.inputtypehandler = self.input_type_handler
+        conn.outputtypehandler = self.output_type_handler
+        cur = conn.cursor()
+        try:
+            yield cur
+        finally:
+            cur.close()
+            conn.commit()
+            conn.close()
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        dimension = len(embeddings[0])
+        self._create_collection(dimension)
+        return self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        values = []
+        pks = []
+        for i, doc in enumerate(documents):
+            doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+            pks.append(doc_id)
+            values.append(
+                (
+                    doc_id,
+                    doc.page_content,
+                    json.dumps(doc.metadata),
+                    #array.array("f", embeddings[i]),
+                    numpy.array(embeddings[i]),
+                )
+            )
+        #print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
+        with self._get_cursor() as cur:
+            cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values)
+        return pks
+
+    def text_exists(self, id: str) -> bool:
+        with self._get_cursor() as cur:
+            cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
+            return cur.fetchone() is not None
+
+    def get_by_ids(self, ids: list[str]) -> list[Document]:
+        with self._get_cursor() as cur:
+            cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+            docs = []
+            for record in cur:
+                docs.append(Document(page_content=record[1], metadata=record[0]))
+        return docs
+    #def get_ids_by_metadata_field(self, key: str, value: str):
+    #    with self._get_cursor() as cur:
+    #        cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
+    #        idss = []
+    #        for record in cur:
+    #            idss.append(record[0])
+    #    return idss
+
+    #def delete_by_document_id(self, document_id: str):
+    #    ids = self.get_ids_by_metadata_field('doc_id', document_id)
+    #    if len(ids)>0:
+    #        with self._get_cursor() as cur:
+    #            cur.execute(f"delete FROM {self.table_name} d WHERE d.meta.doc_id in '%s'" % ("','".join(ids),))
+
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        """
+        Search the nearest neighbors to a vector.
+
+        :param query_vector: The input vector to search for similar items.
+        :param top_k: The number of nearest neighbors to return, default is 5.
+        :return: List of Documents that are nearest to the query vector.
+        """
+        top_k = kwargs.get("top_k", 5)
+        with self._get_cursor() as cur:
+            cur.execute(
+                f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)]
+            )
+            docs = []
+            score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
+            for record in cur:
+                metadata, text, distance = record
+                score = 1 - distance
+                metadata["score"] = score
+                if score > score_threshold:
+                    docs.append(Document(page_content=text, metadata=metadata))
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # do not support bm25 search
+        return []
+
+    def delete(self) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
+
+    def _create_collection(self, dimension: int):
+        cache_key = f"vector_indexing_{self._collection_name}"
+        lock_name = f"{cache_key}_lock"
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+            if redis_client.get(collection_exist_cache_key):
+                return
+
+            with self._get_cursor() as cur:
+                cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+
+class OracleVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OracleVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+            collection_name = class_prefix
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
+
+        config = current_app.config
+        return OracleVector(
+            collection_name=collection_name,
+            config=OracleVectorConfig(
+                host=config.get("ORACLE_HOST"),
+                port=config.get("ORACLE_PORT"),
+                user=config.get("ORACLE_USER"),
+                password=config.get("ORACLE_PASSWORD"),
+                database=config.get("ORACLE_DATABASE"),
+            ),
+        )

+ 3 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -78,6 +78,9 @@ class Vector:
             case VectorType.TENCENT:
                 from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
                 return TencentVectorFactory
+            case VectorType.ORACLE:
+                from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
+                return OracleVectorFactory
             case VectorType.OPENSEARCH:
                 from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
                 return OpenSearchVectorFactory

+ 1 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -12,3 +12,4 @@ class VectorType(str, Enum):
     WEAVIATE = 'weaviate'
     OPENSEARCH = 'opensearch'
     TENCENT = 'tencent'
+    ORACLE = 'oracle'

Fichier diff supprimé car celui-ci est trop grand
+ 346 - 304
api/poetry.lock


+ 1 - 0
api/pyproject.toml

@@ -187,6 +187,7 @@ tenacity = "~8.3.0"
 cos-python-sdk-v5 = "1.9.30"
 novita-client = "^0.5.6"
 opensearch-py = "2.4.0"
+oracledb = "~2.2.1"
 
 [tool.poetry.group.dev]
 optional = true

+ 2 - 1
api/requirements.txt

@@ -92,4 +92,5 @@ chromadb~=0.5.1
 novita_client~=0.5.6
 tenacity~=8.3.0
 opensearch-py==2.4.0
-cos-python-sdk-v5==1.9.30
+cos-python-sdk-v5==1.9.30
+oracledb~=2.2.1

+ 0 - 0
api/tests/integration_tests/vdb/oracle/__init__.py


+ 30 - 0
api/tests/integration_tests/vdb/oracle/test_oraclevector.py

@@ -0,0 +1,30 @@
+from core.rag.datasource.vdb.oracle.oraclevector import OracleVector, OracleVectorConfig
+from core.rag.models.document import Document
+from tests.integration_tests.vdb.test_vector_store import (
+    AbstractVectorTest,
+    get_example_text,
+    setup_mock_redis,
+)
+
+
+class OracleVectorTest(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.vector = OracleVector(
+            collection_name=self.collection_name,
+            config=OracleVectorConfig(
+                host="localhost",
+                port=1521,
+                user="dify",
+                password="dify",
+                database="FREEPDB1",
+            ),
+        )
+
+    def search_by_full_text(self):
+        hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
+
+
+def test_oraclevector(setup_mock_redis):
+    OracleVectorTest().run_all_tests()

+ 18 - 0
docker/docker-compose.oracle.yaml

@@ -0,0 +1,18 @@
+version: '3'
+services:
+  # oracle 23 ai vector store.
+  oracle:
+    image: container-registry.oracle.com/database/free:latest
+    restart: always
+    ports:
+      - 1521:1521
+    volumes:
+      - type: volume
+        source: oradata_vector
+        target: /opt/oracle/oradata
+      - ./startupscripts:/opt/oracle/scripts/startup
+    environment:
+      - ORACLE_PWD=Dify123456
+      - ORACLE_CHARACTERSET=AL32UTF8
+volumes:
+  oradata_vector:

+ 31 - 0
docker/docker-compose.yaml

@@ -154,6 +154,12 @@ services:
       TIDB_VECTOR_USER: xxx.root
       TIDB_VECTOR_PASSWORD: xxxxxx
       TIDB_VECTOR_DATABASE: dify
+      # oracle configurations
+      ORACLE_HOST: oracle
+      ORACLE_PORT: 1521
+      ORACLE_USER: dify
+      ORACLE_PASSWORD: dify
+      ORACLE_DATABASE: FREEPDB1
       # Chroma configuration
       CHROMA_HOST: 127.0.0.1
       CHROMA_PORT: 8000
@@ -350,6 +356,12 @@ services:
       TIDB_VECTOR_USER: xxx.root
       TIDB_VECTOR_PASSWORD: xxxxxx
       TIDB_VECTOR_DATABASE: dify
+      # oracle configurations
+      ORACLE_HOST: oracle
+      ORACLE_PORT: 1521
+      ORACLE_USER: dify
+      ORACLE_PASSWORD: dify
+      ORACLE_DATABASE: FREEPDB1
       # Chroma configuration
       CHROMA_HOST: 127.0.0.1
       CHROMA_PORT: 8000
@@ -530,6 +542,22 @@ services:
   #     timeout: 3s
   #     retries: 30
 
+  # The oracle vector database.
+  # Uncomment to use oracle23ai as vector store. Also need to Uncomment volumes block
+  # oracle:
+  #   image: container-registry.oracle.com/database/free:latest
+  #   restart: always
+  #   ports:
+  #     - 1521:1521
+  #   volumes:
+  #     - type: volume
+  #       source: oradata
+  #       target: /opt/oracle/oradata
+  #     - ./startupscripts:/opt/oracle/scripts/startup
+  #   environment:
+  #     - ORACLE_PWD=Dify123456
+  #     - ORACLE_CHARACTERSET=AL32UTF8
+
 
   # The nginx reverse proxy.
   # used for reverse proxying the API service and Web service.
@@ -555,3 +583,6 @@ networks:
   ssrf_proxy_network:
     driver: bridge
     internal: true
+
+#volumes:
+#  oradata:

+ 5 - 0
docker/startupscripts/create_user.sql

@@ -0,0 +1,5 @@
+show pdbs;
+ALTER SYSTEM SET PROCESSES=500 SCOPE=SPFILE; 
+alter session set container= freepdb1;
+create user dify identified by dify DEFAULT TABLESPACE users quota unlimited on users;
+grant DB_DEVELOPER_ROLE to dify;