123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859 |
- """Wrapper around the Milvus vector database."""
- from __future__ import annotations
- import logging
- from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence
- from uuid import uuid4
- import numpy as np
- from langchain.docstore.document import Document
- from langchain.embeddings.base import Embeddings
- from langchain.vectorstores.base import VectorStore
- from langchain.vectorstores.utils import maximal_marginal_relevance
- logger = logging.getLogger(__name__)
- DEFAULT_MILVUS_CONNECTION = {
- "host": "localhost",
- "port": "19530",
- "user": "",
- "password": "",
- "secure": False,
- }
- class Milvus(VectorStore):
- """Initialize wrapper around the milvus vector database.
- In order to use this you need to have `pymilvus` installed and a
- running Milvus
- See the following documentation for how to run a Milvus instance:
- https://milvus.io/docs/install_standalone-docker.md
- If looking for a hosted Milvus, take a look at this documentation:
- https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
- this project,
- IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
- Args:
- embedding_function (Embeddings): Function used to embed the text.
- collection_name (str): Which Milvus collection to use. Defaults to
- "LangChainCollection".
- connection_args (Optional[dict[str, any]]): The connection args used for
- this class comes in the form of a dict.
- consistency_level (str): The consistency level to use for a collection.
- Defaults to "Session".
- index_params (Optional[dict]): Which index params to use. Defaults to
- HNSW/AUTOINDEX depending on service.
- search_params (Optional[dict]): Which search params to use. Defaults to
- default of index.
- drop_old (Optional[bool]): Whether to drop the current collection. Defaults
- to False.
- The connection args used for this class comes in the form of a dict,
- here are a few of the options:
- address (str): The actual address of Milvus
- instance. Example address: "localhost:19530"
- uri (str): The uri of Milvus instance. Example uri:
- "http://randomwebsite:19530",
- "tcp:foobarsite:19530",
- "https://ok.s3.south.com:19530".
- host (str): The host of Milvus instance. Default at "localhost",
- PyMilvus will fill in the default host if only port is provided.
- port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
- will fill in the default port if only host is provided.
- user (str): Use which user to connect to Milvus instance. If user and
- password are provided, we will add related header in every RPC call.
- password (str): Required when user is provided. The password
- corresponding to the user.
- secure (bool): Default is false. If set to true, tls will be enabled.
- client_key_path (str): If use tls two-way authentication, need to
- write the client.key path.
- client_pem_path (str): If use tls two-way authentication, need to
- write the client.pem path.
- ca_pem_path (str): If use tls two-way authentication, need to write
- the ca.pem path.
- server_pem_path (str): If use tls one-way authentication, need to
- write the server.pem path.
- server_name (str): If use tls, need to write the common name.
- Example:
- .. code-block:: python
- from langchain import Milvus
- from langchain.embeddings import OpenAIEmbeddings
- embedding = OpenAIEmbeddings()
- # Connect to a milvus instance on localhost
- milvus_store = Milvus(
- embedding_function = Embeddings,
- collection_name = "LangChainCollection",
- drop_old = True,
- )
- Raises:
- ValueError: If the pymilvus python package is not installed.
- """
- def __init__(
- self,
- embedding_function: Embeddings,
- collection_name: str = "LangChainCollection",
- connection_args: Optional[dict[str, Any]] = None,
- consistency_level: str = "Session",
- index_params: Optional[dict] = None,
- search_params: Optional[dict] = None,
- drop_old: Optional[bool] = False,
- ):
- """Initialize the Milvus vector store."""
- try:
- from pymilvus import Collection, utility
- except ImportError:
- raise ValueError(
- "Could not import pymilvus python package. "
- "Please install it with `pip install pymilvus`."
- )
- # Default search params when one is not provided.
- self.default_search_params = {
- "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
- "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
- "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
- "HNSW": {"metric_type": "L2", "params": {"ef": 10}},
- "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
- "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
- "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
- "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
- "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
- "AUTOINDEX": {"metric_type": "L2", "params": {}},
- }
- self.embedding_func = embedding_function
- self.collection_name = collection_name
- self.index_params = index_params
- self.search_params = search_params
- self.consistency_level = consistency_level
- # In order for a collection to be compatible, pk needs to be auto'id and int
- self._primary_field = "id"
- # In order for compatibility, the text field will need to be called "text"
- self._text_field = "page_content"
- # In order for compatibility, the vector field needs to be called "vector"
- self._vector_field = "vectors"
- # In order for compatibility, the metadata field will need to be called "metadata"
- self._metadata_field = "metadata"
- self.fields: list[str] = []
- # Create the connection to the server
- if connection_args is None:
- connection_args = DEFAULT_MILVUS_CONNECTION
- self.alias = self._create_connection_alias(connection_args)
- self.col: Optional[Collection] = None
- # Grab the existing collection if it exists
- if utility.has_collection(self.collection_name, using=self.alias):
- self.col = Collection(
- self.collection_name,
- using=self.alias,
- )
- # If need to drop old, drop it
- if drop_old and isinstance(self.col, Collection):
- self.col.drop()
- self.col = None
- # Initialize the vector store
- self._init()
- @property
- def embeddings(self) -> Embeddings:
- return self.embedding_func
- def _create_connection_alias(self, connection_args: dict) -> str:
- """Create the connection to the Milvus server."""
- from pymilvus import MilvusException, connections
- # Grab the connection arguments that are used for checking existing connection
- host: str = connection_args.get("host", None)
- port: Union[str, int] = connection_args.get("port", None)
- address: str = connection_args.get("address", None)
- uri: str = connection_args.get("uri", None)
- user = connection_args.get("user", None)
- # Order of use is host/port, uri, address
- if host is not None and port is not None:
- given_address = str(host) + ":" + str(port)
- elif uri is not None:
- given_address = uri.split("https://")[1]
- elif address is not None:
- given_address = address
- else:
- given_address = None
- logger.debug("Missing standard address type for reuse atttempt")
- # User defaults to empty string when getting connection info
- if user is not None:
- tmp_user = user
- else:
- tmp_user = ""
- # If a valid address was given, then check if a connection exists
- if given_address is not None:
- for con in connections.list_connections():
- addr = connections.get_connection_addr(con[0])
- if (
- con[1]
- and ("address" in addr)
- and (addr["address"] == given_address)
- and ("user" in addr)
- and (addr["user"] == tmp_user)
- ):
- logger.debug("Using previous connection: %s", con[0])
- return con[0]
- # Generate a new connection if one doesn't exist
- alias = uuid4().hex
- try:
- connections.connect(alias=alias, **connection_args)
- logger.debug("Created new connection using: %s", alias)
- return alias
- except MilvusException as e:
- logger.error("Failed to create new connection using: %s", alias)
- raise e
- def _init(
- self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
- ) -> None:
- if embeddings is not None:
- self._create_collection(embeddings, metadatas)
- self._extract_fields()
- self._create_index()
- self._create_search_params()
- self._load()
- def _create_collection(
- self, embeddings: list, metadatas: Optional[list[dict]] = None
- ) -> None:
- from pymilvus import (
- Collection,
- CollectionSchema,
- DataType,
- FieldSchema,
- MilvusException,
- )
- from pymilvus.orm.types import infer_dtype_bydata
- # Determine embedding dim
- dim = len(embeddings[0])
- fields = []
- # Determine metadata schema
- # if metadatas:
- # # Create FieldSchema for each entry in metadata.
- # for key, value in metadatas[0].items():
- # # Infer the corresponding datatype of the metadata
- # dtype = infer_dtype_bydata(value)
- # # Datatype isn't compatible
- # if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
- # logger.error(
- # "Failure to create collection, unrecognized dtype for key: %s",
- # key,
- # )
- # raise ValueError(f"Unrecognized datatype for {key}.")
- # # Dataype is a string/varchar equivalent
- # elif dtype == DataType.VARCHAR:
- # fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
- # else:
- # fields.append(FieldSchema(key, dtype))
- if metadatas:
- fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
- # Create the text field
- fields.append(
- FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
- )
- # Create the primary key field
- fields.append(
- FieldSchema(
- self._primary_field, DataType.INT64, is_primary=True, auto_id=True
- )
- )
- # Create the vector field, supports binary or float vectors
- fields.append(
- FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
- )
- # Create the schema for the collection
- schema = CollectionSchema(fields)
- # Create the collection
- try:
- self.col = Collection(
- name=self.collection_name,
- schema=schema,
- consistency_level=self.consistency_level,
- using=self.alias,
- )
- except MilvusException as e:
- logger.error(
- "Failed to create collection: %s error: %s", self.collection_name, e
- )
- raise e
- def _extract_fields(self) -> None:
- """Grab the existing fields from the Collection"""
- from pymilvus import Collection
- if isinstance(self.col, Collection):
- schema = self.col.schema
- for x in schema.fields:
- self.fields.append(x.name)
- # Since primary field is auto-id, no need to track it
- self.fields.remove(self._primary_field)
- def _get_index(self) -> Optional[dict[str, Any]]:
- """Return the vector index information if it exists"""
- from pymilvus import Collection
- if isinstance(self.col, Collection):
- for x in self.col.indexes:
- if x.field_name == self._vector_field:
- return x.to_dict()
- return None
- def _create_index(self) -> None:
- """Create a index on the collection"""
- from pymilvus import Collection, MilvusException
- if isinstance(self.col, Collection) and self._get_index() is None:
- try:
- # If no index params, use a default HNSW based one
- if self.index_params is None:
- self.index_params = {
- "metric_type": "IP",
- "index_type": "HNSW",
- "params": {"M": 8, "efConstruction": 64},
- }
- try:
- self.col.create_index(
- self._vector_field,
- index_params=self.index_params,
- using=self.alias,
- )
- # If default did not work, most likely on Zilliz Cloud
- except MilvusException:
- # Use AUTOINDEX based index
- self.index_params = {
- "metric_type": "L2",
- "index_type": "AUTOINDEX",
- "params": {},
- }
- self.col.create_index(
- self._vector_field,
- index_params=self.index_params,
- using=self.alias,
- )
- logger.debug(
- "Successfully created an index on collection: %s",
- self.collection_name,
- )
- except MilvusException as e:
- logger.error(
- "Failed to create an index on collection: %s", self.collection_name
- )
- raise e
- def _create_search_params(self) -> None:
- """Generate search params based on the current index type"""
- from pymilvus import Collection
- if isinstance(self.col, Collection) and self.search_params is None:
- index = self._get_index()
- if index is not None:
- index_type: str = index["index_param"]["index_type"]
- metric_type: str = index["index_param"]["metric_type"]
- self.search_params = self.default_search_params[index_type]
- self.search_params["metric_type"] = metric_type
- def _load(self) -> None:
- """Load the collection if available."""
- from pymilvus import Collection
- if isinstance(self.col, Collection) and self._get_index() is not None:
- self.col.load()
- def add_texts(
- self,
- texts: Iterable[str],
- metadatas: Optional[List[dict]] = None,
- timeout: Optional[int] = None,
- batch_size: int = 1000,
- **kwargs: Any,
- ) -> List[str]:
- """Insert text data into Milvus.
- Inserting data when the collection has not be made yet will result
- in creating a new Collection. The data of the first entity decides
- the schema of the new collection, the dim is extracted from the first
- embedding and the columns are decided by the first metadata dict.
- Metada keys will need to be present for all inserted values. At
- the moment there is no None equivalent in Milvus.
- Args:
- texts (Iterable[str]): The texts to embed, it is assumed
- that they all fit in memory.
- metadatas (Optional[List[dict]]): Metadata dicts attached to each of
- the texts. Defaults to None.
- timeout (Optional[int]): Timeout for each batch insert. Defaults
- to None.
- batch_size (int, optional): Batch size to use for insertion.
- Defaults to 1000.
- Raises:
- MilvusException: Failure to add texts
- Returns:
- List[str]: The resulting keys for each inserted element.
- """
- from pymilvus import Collection, MilvusException
- texts = list(texts)
- try:
- embeddings = self.embedding_func.embed_documents(texts)
- except NotImplementedError:
- embeddings = [self.embedding_func.embed_query(x) for x in texts]
- if len(embeddings) == 0:
- logger.debug("Nothing to insert, skipping.")
- return []
- # If the collection hasn't been initialized yet, perform all steps to do so
- if not isinstance(self.col, Collection):
- self._init(embeddings, metadatas)
- # Dict to hold all insert columns
- insert_dict: dict[str, list] = {
- self._text_field: texts,
- self._vector_field: embeddings,
- }
- # Collect the metadata into the insert dict.
- # if metadatas is not None:
- # for d in metadatas:
- # for key, value in d.items():
- # if key in self.fields:
- # insert_dict.setdefault(key, []).append(value)
- if metadatas is not None:
- for d in metadatas:
- insert_dict.setdefault(self._metadata_field, []).append(d)
- # Total insert count
- vectors: list = insert_dict[self._vector_field]
- total_count = len(vectors)
- pks: list[str] = []
- assert isinstance(self.col, Collection)
- for i in range(0, total_count, batch_size):
- # Grab end index
- end = min(i + batch_size, total_count)
- # Convert dict to list of lists batch for insertion
- insert_list = [insert_dict[x][i:end] for x in self.fields]
- # Insert into the collection.
- try:
- res: Collection
- res = self.col.insert(insert_list, timeout=timeout, **kwargs)
- pks.extend(res.primary_keys)
- except MilvusException as e:
- logger.error(
- "Failed to insert batch starting at entity: %s/%s", i, total_count
- )
- raise e
- return pks
- def similarity_search(
- self,
- query: str,
- k: int = 4,
- param: Optional[dict] = None,
- expr: Optional[str] = None,
- timeout: Optional[int] = None,
- **kwargs: Any,
- ) -> List[Document]:
- """Perform a similarity search against the query string.
- Args:
- query (str): The text to search.
- k (int, optional): How many results to return. Defaults to 4.
- param (dict, optional): The search params for the index type.
- Defaults to None.
- expr (str, optional): Filtering expression. Defaults to None.
- timeout (int, optional): How long to wait before timeout error.
- Defaults to None.
- kwargs: Collection.search() keyword arguments.
- Returns:
- List[Document]: Document results for search.
- """
- if self.col is None:
- logger.debug("No existing collection to search.")
- return []
- res = self.similarity_search_with_score(
- query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
- )
- return [doc for doc, _ in res]
- def similarity_search_by_vector(
- self,
- embedding: List[float],
- k: int = 4,
- param: Optional[dict] = None,
- expr: Optional[str] = None,
- timeout: Optional[int] = None,
- **kwargs: Any,
- ) -> List[Document]:
- """Perform a similarity search against the query string.
- Args:
- embedding (List[float]): The embedding vector to search.
- k (int, optional): How many results to return. Defaults to 4.
- param (dict, optional): The search params for the index type.
- Defaults to None.
- expr (str, optional): Filtering expression. Defaults to None.
- timeout (int, optional): How long to wait before timeout error.
- Defaults to None.
- kwargs: Collection.search() keyword arguments.
- Returns:
- List[Document]: Document results for search.
- """
- if self.col is None:
- logger.debug("No existing collection to search.")
- return []
- res = self.similarity_search_with_score_by_vector(
- embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
- )
- return [doc for doc, _ in res]
- def similarity_search_with_score(
- self,
- query: str,
- k: int = 4,
- param: Optional[dict] = None,
- expr: Optional[str] = None,
- timeout: Optional[int] = None,
- **kwargs: Any,
- ) -> List[Tuple[Document, float]]:
- """Perform a search on a query string and return results with score.
- For more information about the search parameters, take a look at the pymilvus
- documentation found here:
- https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
- Args:
- query (str): The text being searched.
- k (int, optional): The amount of results to return. Defaults to 4.
- param (dict): The search params for the specified index.
- Defaults to None.
- expr (str, optional): Filtering expression. Defaults to None.
- timeout (int, optional): How long to wait before timeout error.
- Defaults to None.
- kwargs: Collection.search() keyword arguments.
- Returns:
- List[float], List[Tuple[Document, any, any]]:
- """
- if self.col is None:
- logger.debug("No existing collection to search.")
- return []
- # Embed the query text.
- embedding = self.embedding_func.embed_query(query)
- res = self.similarity_search_with_score_by_vector(
- embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
- )
- return res
- def _similarity_search_with_relevance_scores(
- self,
- query: str,
- k: int = 4,
- **kwargs: Any,
- ) -> List[Tuple[Document, float]]:
- """Return docs and relevance scores in the range [0, 1].
- 0 is dissimilar, 1 is most similar.
- Args:
- query: input text
- k: Number of Documents to return. Defaults to 4.
- **kwargs: kwargs to be passed to similarity search. Should include:
- score_threshold: Optional, a floating point value between 0 to 1 to
- filter the resulting set of retrieved docs
- Returns:
- List of Tuples of (doc, similarity_score)
- """
- return self.similarity_search_with_score(query, k, **kwargs)
- def similarity_search_with_score_by_vector(
- self,
- embedding: List[float],
- k: int = 4,
- param: Optional[dict] = None,
- expr: Optional[str] = None,
- timeout: Optional[int] = None,
- **kwargs: Any,
- ) -> List[Tuple[Document, float]]:
- """Perform a search on a query string and return results with score.
- For more information about the search parameters, take a look at the pymilvus
- documentation found here:
- https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
- Args:
- embedding (List[float]): The embedding vector being searched.
- k (int, optional): The amount of results to return. Defaults to 4.
- param (dict): The search params for the specified index.
- Defaults to None.
- expr (str, optional): Filtering expression. Defaults to None.
- timeout (int, optional): How long to wait before timeout error.
- Defaults to None.
- kwargs: Collection.search() keyword arguments.
- Returns:
- List[Tuple[Document, float]]: Result doc and score.
- """
- if self.col is None:
- logger.debug("No existing collection to search.")
- return []
- if param is None:
- param = self.search_params
- # Determine result metadata fields.
- output_fields = self.fields[:]
- output_fields.remove(self._vector_field)
- # Perform the search.
- res = self.col.search(
- data=[embedding],
- anns_field=self._vector_field,
- param=param,
- limit=k,
- expr=expr,
- output_fields=output_fields,
- timeout=timeout,
- **kwargs,
- )
- # Organize results.
- ret = []
- for result in res[0]:
- meta = {x: result.entity.get(x) for x in output_fields}
- doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
- pair = (doc, result.score)
- ret.append(pair)
- return ret
- def max_marginal_relevance_search(
- self,
- query: str,
- k: int = 4,
- fetch_k: int = 20,
- lambda_mult: float = 0.5,
- param: Optional[dict] = None,
- expr: Optional[str] = None,
- timeout: Optional[int] = None,
- **kwargs: Any,
- ) -> List[Document]:
- """Perform a search and return results that are reordered by MMR.
- Args:
- query (str): The text being searched.
- k (int, optional): How many results to give. Defaults to 4.
- fetch_k (int, optional): Total results to select k from.
- Defaults to 20.
- lambda_mult: Number between 0 and 1 that determines the degree
- of diversity among the results with 0 corresponding
- to maximum diversity and 1 to minimum diversity.
- Defaults to 0.5
- param (dict, optional): The search params for the specified index.
- Defaults to None.
- expr (str, optional): Filtering expression. Defaults to None.
- timeout (int, optional): How long to wait before timeout error.
- Defaults to None.
- kwargs: Collection.search() keyword arguments.
- Returns:
- List[Document]: Document results for search.
- """
- if self.col is None:
- logger.debug("No existing collection to search.")
- return []
- embedding = self.embedding_func.embed_query(query)
- return self.max_marginal_relevance_search_by_vector(
- embedding=embedding,
- k=k,
- fetch_k=fetch_k,
- lambda_mult=lambda_mult,
- param=param,
- expr=expr,
- timeout=timeout,
- **kwargs,
- )
- def max_marginal_relevance_search_by_vector(
- self,
- embedding: list[float],
- k: int = 4,
- fetch_k: int = 20,
- lambda_mult: float = 0.5,
- param: Optional[dict] = None,
- expr: Optional[str] = None,
- timeout: Optional[int] = None,
- **kwargs: Any,
- ) -> List[Document]:
- """Perform a search and return results that are reordered by MMR.
- Args:
- embedding (str): The embedding vector being searched.
- k (int, optional): How many results to give. Defaults to 4.
- fetch_k (int, optional): Total results to select k from.
- Defaults to 20.
- lambda_mult: Number between 0 and 1 that determines the degree
- of diversity among the results with 0 corresponding
- to maximum diversity and 1 to minimum diversity.
- Defaults to 0.5
- param (dict, optional): The search params for the specified index.
- Defaults to None.
- expr (str, optional): Filtering expression. Defaults to None.
- timeout (int, optional): How long to wait before timeout error.
- Defaults to None.
- kwargs: Collection.search() keyword arguments.
- Returns:
- List[Document]: Document results for search.
- """
- if self.col is None:
- logger.debug("No existing collection to search.")
- return []
- if param is None:
- param = self.search_params
- # Determine result metadata fields.
- output_fields = self.fields[:]
- output_fields.remove(self._vector_field)
- # Perform the search.
- res = self.col.search(
- data=[embedding],
- anns_field=self._vector_field,
- param=param,
- limit=fetch_k,
- expr=expr,
- output_fields=output_fields,
- timeout=timeout,
- **kwargs,
- )
- # Organize results.
- ids = []
- documents = []
- scores = []
- for result in res[0]:
- meta = {x: result.entity.get(x) for x in output_fields}
- doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
- documents.append(doc)
- scores.append(result.score)
- ids.append(result.id)
- vectors = self.col.query(
- expr=f"{self._primary_field} in {ids}",
- output_fields=[self._primary_field, self._vector_field],
- timeout=timeout,
- )
- # Reorganize the results from query to match search order.
- vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
- ordered_result_embeddings = [vectors[x] for x in ids]
- # Get the new order of results.
- new_ordering = maximal_marginal_relevance(
- np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
- )
- # Reorder the values and return.
- ret = []
- for x in new_ordering:
- # Function can return -1 index
- if x == -1:
- break
- else:
- ret.append(documents[x])
- return ret
- @classmethod
- def from_texts(
- cls,
- texts: List[str],
- embedding: Embeddings,
- metadatas: Optional[List[dict]] = None,
- collection_name: str = "LangChainCollection",
- connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
- consistency_level: str = "Session",
- index_params: Optional[dict] = None,
- search_params: Optional[dict] = None,
- drop_old: bool = False,
- batch_size: int = 100,
- ids: Optional[Sequence[str]] = None,
- **kwargs: Any,
- ) -> Milvus:
- """Create a Milvus collection, indexes it with HNSW, and insert data.
- Args:
- texts (List[str]): Text data.
- embedding (Embeddings): Embedding function.
- metadatas (Optional[List[dict]]): Metadata for each text if it exists.
- Defaults to None.
- collection_name (str, optional): Collection name to use. Defaults to
- "LangChainCollection".
- connection_args (dict[str, Any], optional): Connection args to use. Defaults
- to DEFAULT_MILVUS_CONNECTION.
- consistency_level (str, optional): Which consistency level to use. Defaults
- to "Session".
- index_params (Optional[dict], optional): Which index_params to use. Defaults
- to None.
- search_params (Optional[dict], optional): Which search params to use.
- Defaults to None.
- drop_old (Optional[bool], optional): Whether to drop the collection with
- that name if it exists. Defaults to False.
- batch_size:
- How many vectors upload per-request.
- Default: 100
- ids: Optional[Sequence[str]] = None,
- Returns:
- Milvus: Milvus Vector Store
- """
- vector_db = cls(
- embedding_function=embedding,
- collection_name=collection_name,
- connection_args=connection_args,
- consistency_level=consistency_level,
- index_params=index_params,
- search_params=search_params,
- drop_old=drop_old,
- **kwargs,
- )
- vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
- return vector_db
|