123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- import json
- import uuid
- from contextlib import contextmanager
- from typing import Any
- import psycopg2.extras # type: ignore
- import psycopg2.pool # type: ignore
- from pydantic import BaseModel, model_validator
- from configs import dify_config
- from core.rag.datasource.vdb.vector_base import BaseVector
- from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
- from core.rag.datasource.vdb.vector_type import VectorType
- from core.rag.embedding.embedding_base import Embeddings
- from core.rag.models.document import Document
- from extensions.ext_redis import redis_client
- from models.dataset import Dataset
- class OpenGaussConfig(BaseModel):
- host: str
- port: int
- user: str
- password: str
- database: str
- min_connection: int
- max_connection: int
- @model_validator(mode="before")
- @classmethod
- def validate_config(cls, values: dict) -> dict:
- if not values["host"]:
- raise ValueError("config OPENGAUSS_HOST is required")
- if not values["port"]:
- raise ValueError("config OPENGAUSS_PORT is required")
- if not values["user"]:
- raise ValueError("config OPENGAUSS_USER is required")
- if not values["password"]:
- raise ValueError("config OPENGAUSS_PASSWORD is required")
- if not values["database"]:
- raise ValueError("config OPENGAUSS_DATABASE is required")
- if not values["min_connection"]:
- raise ValueError("config OPENGAUSS_MIN_CONNECTION is required")
- if not values["max_connection"]:
- raise ValueError("config OPENGAUSS_MAX_CONNECTION is required")
- if values["min_connection"] > values["max_connection"]:
- raise ValueError("config OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION")
- return values
- SQL_CREATE_TABLE = """
- CREATE TABLE IF NOT EXISTS {table_name} (
- id UUID PRIMARY KEY,
- text TEXT NOT NULL,
- meta JSONB NOT NULL,
- embedding vector({dimension}) NOT NULL
- );
- """
- SQL_CREATE_INDEX = """
- CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
- USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
- """
- class OpenGauss(BaseVector):
- def __init__(self, collection_name: str, config: OpenGaussConfig):
- super().__init__(collection_name)
- self.pool = self._create_connection_pool(config)
- self.table_name = f"embedding_{collection_name}"
- def get_type(self) -> str:
- return VectorType.OPENGAUSS
- def _create_connection_pool(self, config: OpenGaussConfig):
- return psycopg2.pool.SimpleConnectionPool(
- config.min_connection,
- config.max_connection,
- host=config.host,
- port=config.port,
- user=config.user,
- password=config.password,
- database=config.database,
- )
- @contextmanager
- def _get_cursor(self):
- conn = self.pool.getconn()
- cur = conn.cursor()
- try:
- yield cur
- finally:
- cur.close()
- conn.commit()
- self.pool.putconn(conn)
- def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
- dimension = len(embeddings[0])
- self._create_collection(dimension)
- return self.add_texts(texts, embeddings)
- def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
- values = []
- pks = []
- for i, doc in enumerate(documents):
- if doc.metadata is not None:
- doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
- pks.append(doc_id)
- values.append(
- (
- doc_id,
- doc.page_content,
- json.dumps(doc.metadata),
- embeddings[i],
- )
- )
- with self._get_cursor() as cur:
- psycopg2.extras.execute_values(
- cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
- )
- return pks
- def text_exists(self, id: str) -> bool:
- with self._get_cursor() as cur:
- cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
- return cur.fetchone() is not None
- def get_by_ids(self, ids: list[str]) -> list[Document]:
- with self._get_cursor() as cur:
- cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
- docs = []
- for record in cur:
- docs.append(Document(page_content=record[1], metadata=record[0]))
- return docs
- def delete_by_ids(self, ids: list[str]) -> None:
- # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
- # Scenario 1: extract a document fails, resulting in a table not being created.
- # Then clicking the retry button triggers a delete operation on an empty list.
- if not ids:
- return
- with self._get_cursor() as cur:
- cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
- def delete_by_metadata_field(self, key: str, value: str) -> None:
- with self._get_cursor() as cur:
- cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
- def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
- """
- Search the nearest neighbors to a vector.
- :param query_vector: The input vector to search for similar items.
- :param top_k: The number of nearest neighbors to return, default is 5.
- :return: List of Documents that are nearest to the query vector.
- """
- top_k = kwargs.get("top_k", 4)
- if not isinstance(top_k, int) or top_k <= 0:
- raise ValueError("top_k must be a positive integer")
- with self._get_cursor() as cur:
- cur.execute(
- f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
- f" ORDER BY distance LIMIT {top_k}",
- (json.dumps(query_vector),),
- )
- docs = []
- score_threshold = float(kwargs.get("score_threshold") or 0.0)
- for record in cur:
- metadata, text, distance = record
- score = 1 - distance
- metadata["score"] = score
- if score > score_threshold:
- docs.append(Document(page_content=text, metadata=metadata))
- return docs
- def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
- top_k = kwargs.get("top_k", 5)
- if not isinstance(top_k, int) or top_k <= 0:
- raise ValueError("top_k must be a positive integer")
- with self._get_cursor() as cur:
- cur.execute(
- f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
- FROM {self.table_name}
- WHERE to_tsvector(text) @@ plainto_tsquery(%s)
- ORDER BY score DESC
- LIMIT {top_k}""",
- # f"'{query}'" is required in order to account for whitespace in query
- (f"'{query}'", f"'{query}'"),
- )
- docs = []
- for record in cur:
- metadata, text, score = record
- metadata["score"] = score
- docs.append(Document(page_content=text, metadata=metadata))
- return docs
- def delete(self) -> None:
- with self._get_cursor() as cur:
- cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
- def _create_collection(self, dimension: int):
- cache_key = f"vector_indexing_{self._collection_name}"
- lock_name = f"{cache_key}_lock"
- with redis_client.lock(lock_name, timeout=20):
- collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
- if redis_client.get(collection_exist_cache_key):
- return
- with self._get_cursor() as cur:
- cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
- if dimension <= 2000:
- cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
- redis_client.set(collection_exist_cache_key, 1, ex=3600)
- class OpenGaussFactory(AbstractVectorFactory):
- def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenGauss:
- if dataset.index_struct_dict:
- class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
- collection_name = class_prefix
- else:
- dataset_id = dataset.id
- collection_name = Dataset.gen_collection_name_by_id(dataset_id)
- dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENGAUSS, collection_name))
- return OpenGauss(
- collection_name=collection_name,
- config=OpenGaussConfig(
- host=dify_config.OPENGAUSS_HOST or "localhost",
- port=dify_config.OPENGAUSS_PORT,
- user=dify_config.OPENGAUSS_USER or "postgres",
- password=dify_config.OPENGAUSS_PASSWORD or "",
- database=dify_config.OPENGAUSS_DATABASE or "dify",
- min_connection=dify_config.OPENGAUSS_MIN_CONNECTION,
- max_connection=dify_config.OPENGAUSS_MAX_CONNECTION,
- ),
- )
|