opengauss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import json
  2. import uuid
  3. from contextlib import contextmanager
  4. from typing import Any
  5. import psycopg2.extras # type: ignore
  6. import psycopg2.pool # type: ignore
  7. from pydantic import BaseModel, model_validator
  8. from configs import dify_config
  9. from core.rag.datasource.vdb.vector_base import BaseVector
  10. from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
  11. from core.rag.datasource.vdb.vector_type import VectorType
  12. from core.rag.embedding.embedding_base import Embeddings
  13. from core.rag.models.document import Document
  14. from extensions.ext_redis import redis_client
  15. from models.dataset import Dataset
  16. class OpenGaussConfig(BaseModel):
  17. host: str
  18. port: int
  19. user: str
  20. password: str
  21. database: str
  22. min_connection: int
  23. max_connection: int
  24. enable_pq: bool = False # Enable PQ acceleration
  25. @model_validator(mode="before")
  26. @classmethod
  27. def validate_config(cls, values: dict) -> dict:
  28. if not values["host"]:
  29. raise ValueError("config OPENGAUSS_HOST is required")
  30. if not values["port"]:
  31. raise ValueError("config OPENGAUSS_PORT is required")
  32. if not values["user"]:
  33. raise ValueError("config OPENGAUSS_USER is required")
  34. if not values["password"]:
  35. raise ValueError("config OPENGAUSS_PASSWORD is required")
  36. if not values["database"]:
  37. raise ValueError("config OPENGAUSS_DATABASE is required")
  38. if not values["min_connection"]:
  39. raise ValueError("config OPENGAUSS_MIN_CONNECTION is required")
  40. if not values["max_connection"]:
  41. raise ValueError("config OPENGAUSS_MAX_CONNECTION is required")
  42. if values["min_connection"] > values["max_connection"]:
  43. raise ValueError("config OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION")
  44. return values
  45. SQL_CREATE_TABLE = """
  46. CREATE TABLE IF NOT EXISTS {table_name} (
  47. id UUID PRIMARY KEY,
  48. text TEXT NOT NULL,
  49. meta JSONB NOT NULL,
  50. embedding vector({dimension}) NOT NULL
  51. );
  52. """
  53. SQL_CREATE_INDEX_PQ = """
  54. CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
  55. USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64, enable_pq=on, pq_m={pq_m});
  56. """
  57. SQL_CREATE_INDEX = """
  58. CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
  59. USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
  60. """
  61. class OpenGauss(BaseVector):
  62. def __init__(self, collection_name: str, config: OpenGaussConfig):
  63. super().__init__(collection_name)
  64. self.pool = self._create_connection_pool(config)
  65. self.table_name = f"embedding_{collection_name}"
  66. self.pq_enabled = config.enable_pq
  67. def get_type(self) -> str:
  68. return VectorType.OPENGAUSS
  69. def _create_connection_pool(self, config: OpenGaussConfig):
  70. return psycopg2.pool.SimpleConnectionPool(
  71. config.min_connection,
  72. config.max_connection,
  73. host=config.host,
  74. port=config.port,
  75. user=config.user,
  76. password=config.password,
  77. database=config.database,
  78. )
  79. @contextmanager
  80. def _get_cursor(self):
  81. conn = self.pool.getconn()
  82. cur = conn.cursor()
  83. try:
  84. yield cur
  85. finally:
  86. cur.close()
  87. conn.commit()
  88. self.pool.putconn(conn)
  89. def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
  90. dimension = len(embeddings[0])
  91. self._create_collection(dimension)
  92. self.add_texts(texts, embeddings)
  93. self._create_index(dimension)
  94. def _create_index(self, dimension: int):
  95. index_cache_key = f"vector_index_{self._collection_name}"
  96. lock_name = f"{index_cache_key}_lock"
  97. with redis_client.lock(lock_name, timeout=60):
  98. index_exist_cache_key = f"vector_index_{self._collection_name}"
  99. if redis_client.get(index_exist_cache_key):
  100. return
  101. with self._get_cursor() as cur:
  102. if dimension <= 2000:
  103. if self.pq_enabled:
  104. cur.execute(SQL_CREATE_INDEX_PQ.format(table_name=self.table_name, pq_m=int(dimension / 4)))
  105. cur.execute("SET hnsw_earlystop_threshold = 320")
  106. if not self.pq_enabled:
  107. cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
  108. redis_client.set(index_exist_cache_key, 1, ex=3600)
  109. def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
  110. values = []
  111. pks = []
  112. for i, doc in enumerate(documents):
  113. if doc.metadata is not None:
  114. doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
  115. pks.append(doc_id)
  116. values.append(
  117. (
  118. doc_id,
  119. doc.page_content,
  120. json.dumps(doc.metadata),
  121. embeddings[i],
  122. )
  123. )
  124. with self._get_cursor() as cur:
  125. psycopg2.extras.execute_values(
  126. cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
  127. )
  128. return pks
  129. def text_exists(self, id: str) -> bool:
  130. with self._get_cursor() as cur:
  131. cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
  132. return cur.fetchone() is not None
  133. def get_by_ids(self, ids: list[str]) -> list[Document]:
  134. with self._get_cursor() as cur:
  135. cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
  136. docs = []
  137. for record in cur:
  138. docs.append(Document(page_content=record[1], metadata=record[0]))
  139. return docs
  140. def delete_by_ids(self, ids: list[str]) -> None:
  141. # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
  142. # Scenario 1: extract a document fails, resulting in a table not being created.
  143. # Then clicking the retry button triggers a delete operation on an empty list.
  144. if not ids:
  145. return
  146. with self._get_cursor() as cur:
  147. cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
  148. def delete_by_metadata_field(self, key: str, value: str) -> None:
  149. with self._get_cursor() as cur:
  150. cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
  151. def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
  152. """
  153. Search the nearest neighbors to a vector.
  154. :param query_vector: The input vector to search for similar items.
  155. :param top_k: The number of nearest neighbors to return, default is 5.
  156. :return: List of Documents that are nearest to the query vector.
  157. """
  158. top_k = kwargs.get("top_k", 4)
  159. if not isinstance(top_k, int) or top_k <= 0:
  160. raise ValueError("top_k must be a positive integer")
  161. with self._get_cursor() as cur:
  162. cur.execute(
  163. f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
  164. f" ORDER BY distance LIMIT {top_k}",
  165. (json.dumps(query_vector),),
  166. )
  167. docs = []
  168. score_threshold = float(kwargs.get("score_threshold") or 0.0)
  169. for record in cur:
  170. metadata, text, distance = record
  171. score = 1 - distance
  172. metadata["score"] = score
  173. if score > score_threshold:
  174. docs.append(Document(page_content=text, metadata=metadata))
  175. return docs
  176. def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
  177. top_k = kwargs.get("top_k", 5)
  178. if not isinstance(top_k, int) or top_k <= 0:
  179. raise ValueError("top_k must be a positive integer")
  180. with self._get_cursor() as cur:
  181. cur.execute(
  182. f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
  183. FROM {self.table_name}
  184. WHERE to_tsvector(text) @@ plainto_tsquery(%s)
  185. ORDER BY score DESC
  186. LIMIT {top_k}""",
  187. # f"'{query}'" is required in order to account for whitespace in query
  188. (f"'{query}'", f"'{query}'"),
  189. )
  190. docs = []
  191. for record in cur:
  192. metadata, text, score = record
  193. metadata["score"] = score
  194. docs.append(Document(page_content=text, metadata=metadata))
  195. return docs
  196. def delete(self) -> None:
  197. with self._get_cursor() as cur:
  198. cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
  199. def _create_collection(self, dimension: int):
  200. cache_key = f"vector_indexing_{self._collection_name}"
  201. lock_name = f"{cache_key}_lock"
  202. with redis_client.lock(lock_name, timeout=20):
  203. collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
  204. if redis_client.get(collection_exist_cache_key):
  205. return
  206. with self._get_cursor() as cur:
  207. cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
  208. redis_client.set(collection_exist_cache_key, 1, ex=3600)
  209. class OpenGaussFactory(AbstractVectorFactory):
  210. def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenGauss:
  211. if dataset.index_struct_dict:
  212. class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
  213. collection_name = class_prefix
  214. else:
  215. dataset_id = dataset.id
  216. collection_name = Dataset.gen_collection_name_by_id(dataset_id)
  217. dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENGAUSS, collection_name))
  218. return OpenGauss(
  219. collection_name=collection_name,
  220. config=OpenGaussConfig(
  221. host=dify_config.OPENGAUSS_HOST or "localhost",
  222. port=dify_config.OPENGAUSS_PORT,
  223. user=dify_config.OPENGAUSS_USER or "postgres",
  224. password=dify_config.OPENGAUSS_PASSWORD or "",
  225. database=dify_config.OPENGAUSS_DATABASE or "dify",
  226. min_connection=dify_config.OPENGAUSS_MIN_CONNECTION,
  227. max_connection=dify_config.OPENGAUSS_MAX_CONNECTION,
  228. enable_pq=dify_config.OPENGAUSS_ENABLE_PQ or False,
  229. ),
  230. )