tcvectordb.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import os
  2. from typing import Optional
  3. import pytest
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from requests.adapters import HTTPAdapter
  6. from tcvectordb import VectorDBClient
  7. from tcvectordb.model.database import Collection, Database
  8. from tcvectordb.model.document import Document, Filter
  9. from tcvectordb.model.enum import ReadConsistency
  10. from tcvectordb.model.index import Index
  11. from xinference_client.types import Embedding
  12. class MockTcvectordbClass:
  13. def mock_vector_db_client(
  14. self,
  15. url=None,
  16. username="",
  17. key="",
  18. read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
  19. timeout=5,
  20. adapter: HTTPAdapter = None,
  21. ):
  22. self._conn = None
  23. self._read_consistency = read_consistency
  24. def list_databases(self) -> list[Database]:
  25. return [
  26. Database(
  27. conn=self._conn,
  28. read_consistency=self._read_consistency,
  29. name="dify",
  30. )
  31. ]
  32. def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
  33. return []
  34. def drop_collection(self, name: str, timeout: Optional[float] = None):
  35. return {"code": 0, "msg": "operation success"}
  36. def create_collection(
  37. self,
  38. name: str,
  39. shard: int,
  40. replicas: int,
  41. description: str,
  42. index: Index,
  43. embedding: Embedding = None,
  44. timeout: float = None,
  45. ) -> Collection:
  46. return Collection(
  47. self,
  48. name,
  49. shard,
  50. replicas,
  51. description,
  52. index,
  53. embedding=embedding,
  54. read_consistency=self._read_consistency,
  55. timeout=timeout,
  56. )
  57. def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
  58. collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout)
  59. return collection
  60. def collection_upsert(
  61. self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs
  62. ):
  63. return {"code": 0, "msg": "operation success"}
  64. def collection_search(
  65. self,
  66. vectors: list[list[float]],
  67. filter: Filter = None,
  68. params=None,
  69. retrieve_vector: bool = False,
  70. limit: int = 10,
  71. output_fields: Optional[list[str]] = None,
  72. timeout: Optional[float] = None,
  73. ) -> list[list[dict]]:
  74. return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]]
  75. def collection_query(
  76. self,
  77. document_ids: Optional[list] = None,
  78. retrieve_vector: bool = False,
  79. limit: Optional[int] = None,
  80. offset: Optional[int] = None,
  81. filter: Optional[Filter] = None,
  82. output_fields: Optional[list[str]] = None,
  83. timeout: Optional[float] = None,
  84. ) -> list[dict]:
  85. return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
  86. def collection_delete(
  87. self,
  88. document_ids: list[str] = None,
  89. filter: Filter = None,
  90. timeout: float = None,
  91. ):
  92. return {"code": 0, "msg": "operation success"}
  93. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  94. @pytest.fixture
  95. def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
  96. if MOCK:
  97. monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
  98. monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases)
  99. monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection)
  100. monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections)
  101. monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection)
  102. monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection)
  103. monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert)
  104. monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search)
  105. monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query)
  106. monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete)
  107. yield
  108. if MOCK:
  109. monkeypatch.undo()