Browse Source

feat: Enable baiduvector intergration test (#9369)

ice yao 6 months ago
parent
commit
568d5c46ed

+ 41 - 30
api/tests/integration_tests/vdb/__mock/baiduvectordb.py

@@ -1,4 +1,5 @@
 import os
+from unittest.mock import MagicMock
 
 import pytest
 from _pytest.monkeypatch import MonkeyPatch
@@ -10,26 +11,31 @@ from pymochow.model.table import Table
 from requests.adapters import HTTPAdapter
 
 
+class AttrDict(dict):
+    def __getattr__(self, item):
+        return self.get(item)
+
+
 class MockBaiduVectorDBClass:
     def mock_vector_db_client(
         self,
         config=None,
         adapter: HTTPAdapter = None,
     ):
-        self._conn = None
-        self._config = None
+        self.conn = MagicMock()
+        self._config = MagicMock()
 
     def list_databases(self, config=None) -> list[Database]:
         return [
             Database(
-                conn=self._conn,
+                conn=self.conn,
                 database_name="dify",
                 config=self._config,
             )
         ]
 
     def create_database(self, database_name: str, config=None) -> Database:
-        return Database(conn=self._conn, database_name=database_name, config=config)
+        return Database(conn=self.conn, database_name=database_name, config=config)
 
     def list_table(self, config=None) -> list[Table]:
         return []
@@ -88,16 +94,18 @@ class MockBaiduVectorDBClass:
         read_consistency=ReadConsistency.EVENTUAL,
         config=None,
     ):
-        return {
-            "row": {
-                "id": "doc_id_001",
-                "vector": [0.23432432, 0.8923744, 0.89238432],
-                "text": "text",
-                "metadata": {"doc_id": "doc_id_001"},
-            },
-            "code": 0,
-            "msg": "Success",
-        }
+        return AttrDict(
+            {
+                "row": {
+                    "id": primary_key.get("id"),
+                    "vector": [0.23432432, 0.8923744, 0.89238432],
+                    "text": "text",
+                    "metadata": '{"doc_id": "doc_id_001"}',
+                },
+                "code": 0,
+                "msg": "Success",
+            }
+        )
 
     def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
         return {"code": 0, "msg": "Success"}
@@ -111,22 +119,24 @@ class MockBaiduVectorDBClass:
         read_consistency=ReadConsistency.EVENTUAL,
         config=None,
     ):
-        return {
-            "rows": [
-                {
-                    "row": {
-                        "id": "doc_id_001",
-                        "vector": [0.23432432, 0.8923744, 0.89238432],
-                        "text": "text",
-                        "metadata": {"doc_id": "doc_id_001"},
-                    },
-                    "distance": 0.1,
-                    "score": 0.5,
-                }
-            ],
-            "code": 0,
-            "msg": "Success",
-        }
+        return AttrDict(
+            {
+                "rows": [
+                    {
+                        "row": {
+                            "id": "doc_id_001",
+                            "vector": [0.23432432, 0.8923744, 0.89238432],
+                            "text": "text",
+                            "metadata": '{"doc_id": "doc_id_001"}',
+                        },
+                        "distance": 0.1,
+                        "score": 0.5,
+                    }
+                ],
+                "code": 0,
+                "msg": "Success",
+            }
+        )
 
 
 MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@@ -146,6 +156,7 @@ def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
         monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
         monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
         monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
+        monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query)
         monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
 
     yield

+ 0 - 3
api/tests/integration_tests/vdb/baidu/test_baidu.py

@@ -4,9 +4,6 @@ from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
 from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
 from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
 
-mock_client = MagicMock()
-mock_client.list_databases.return_value = [{"name": "test"}]
-
 
 class BaiduVectorTest(AbstractVectorTest):
     def __init__(self):

+ 2 - 1
dev/pytest/pytest_vdb.sh

@@ -8,4 +8,5 @@ pytest api/tests/integration_tests/vdb/chroma \
   api/tests/integration_tests/vdb/qdrant \
   api/tests/integration_tests/vdb/weaviate \
   api/tests/integration_tests/vdb/elasticsearch \
-  api/tests/integration_tests/vdb/vikingdb
+  api/tests/integration_tests/vdb/vikingdb \
+  api/tests/integration_tests/vdb/baidu