tcvectordb.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. from typing import Optional
  3. import pytest
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from requests.adapters import HTTPAdapter
  6. from tcvectordb import VectorDBClient
  7. from tcvectordb.model.database import Collection, Database
  8. from tcvectordb.model.document import Document, Filter
  9. from tcvectordb.model.enum import ReadConsistency
  10. from tcvectordb.model.index import Index
  11. from xinference_client.types import Embedding
  12. class MockTcvectordbClass:
  13. def VectorDBClient(self, url=None, username='', key='',
  14. read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
  15. timeout=5,
  16. adapter: HTTPAdapter = None):
  17. self._conn = None
  18. self._read_consistency = read_consistency
  19. def list_databases(self) -> list[Database]:
  20. return [
  21. Database(
  22. conn=self._conn,
  23. read_consistency=self._read_consistency,
  24. name='dify',
  25. )]
  26. def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
  27. return []
  28. def drop_collection(self, name: str, timeout: Optional[float] = None):
  29. return {
  30. "code": 0,
  31. "msg": "operation success"
  32. }
  33. def create_collection(
  34. self,
  35. name: str,
  36. shard: int,
  37. replicas: int,
  38. description: str,
  39. index: Index,
  40. embedding: Embedding = None,
  41. timeout: float = None,
  42. ) -> Collection:
  43. return Collection(self, name, shard, replicas, description, index, embedding=embedding,
  44. read_consistency=self._read_consistency, timeout=timeout)
  45. def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
  46. collection = Collection(
  47. self,
  48. name,
  49. shard=1,
  50. replicas=2,
  51. description=name,
  52. timeout=timeout
  53. )
  54. return collection
  55. def collection_upsert(
  56. self,
  57. documents: list[Document],
  58. timeout: Optional[float] = None,
  59. build_index: bool = True,
  60. **kwargs
  61. ):
  62. return {
  63. "code": 0,
  64. "msg": "operation success"
  65. }
  66. def collection_search(
  67. self,
  68. vectors: list[list[float]],
  69. filter: Filter = None,
  70. params=None,
  71. retrieve_vector: bool = False,
  72. limit: int = 10,
  73. output_fields: Optional[list[str]] = None,
  74. timeout: Optional[float] = None,
  75. ) -> list[list[dict]]:
  76. return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]
  77. def collection_query(
  78. self,
  79. document_ids: Optional[list] = None,
  80. retrieve_vector: bool = False,
  81. limit: Optional[int] = None,
  82. offset: Optional[int] = None,
  83. filter: Optional[Filter] = None,
  84. output_fields: Optional[list[str]] = None,
  85. timeout: Optional[float] = None,
  86. ) -> list[dict]:
  87. return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]
  88. def collection_delete(
  89. self,
  90. document_ids: list[str] = None,
  91. filter: Filter = None,
  92. timeout: float = None,
  93. ):
  94. return {
  95. "code": 0,
  96. "msg": "operation success"
  97. }
  98. MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
  99. @pytest.fixture
  100. def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
  101. if MOCK:
  102. monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
  103. monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
  104. monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
  105. monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
  106. monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
  107. monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
  108. monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
  109. monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
  110. monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
  111. monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)
  112. yield
  113. if MOCK:
  114. monkeypatch.undo()