Browse Source

feat: add WEAVIATE_BATCH_SIZE (#349)

John Wang 1 year ago
parent
commit
cd136fb293

+ 1 - 0
api/.env.example

@@ -72,6 +72,7 @@ VECTOR_STORE=weaviate
 WEAVIATE_ENDPOINT=http://localhost:8080
 WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
 WEAVIATE_GRPC_ENABLED=false
+WEAVIATE_BATCH_SIZE=100
 
 # Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
 QDRANT_URL=path:storage/qdrant

+ 2 - 0
api/config.py

@@ -43,6 +43,7 @@ DEFAULTS = {
     'SENTRY_TRACES_SAMPLE_RATE': 1.0,
     'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
     'WEAVIATE_GRPC_ENABLED': 'True',
+    'WEAVIATE_BATCH_SIZE': 100,
     'CELERY_BACKEND': 'database',
     'PDF_PREVIEW': 'True',
     'LOG_LEVEL': 'INFO',
@@ -138,6 +139,7 @@ class Config:
         self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
         self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
         self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
+        self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
 
         # qdrant settings
         self.QDRANT_URL = get_env('QDRANT_URL')

+ 2 - 1
api/core/vector_store/vector_store.py

@@ -27,7 +27,8 @@ class VectorStore:
             self._client = WeaviateVectorStoreClient(
                 endpoint=app.config['WEAVIATE_ENDPOINT'],
                 api_key=app.config['WEAVIATE_API_KEY'],
-                grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED']
+                grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'],
+                batch_size=app.config['WEAVIATE_BATCH_SIZE']
             )
         elif self._vector_store == 'qdrant':
             self._client = QdrantVectorStoreClient(

+ 4 - 4
api/core/vector_store/weaviate_vector_store_client.py

@@ -18,10 +18,10 @@ from llama_index.readers.weaviate.utils import (
 
 class WeaviateVectorStoreClient(BaseVectorStoreClient):
 
-    def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool):
-        self._client = self.init_from_config(endpoint, api_key, grpc_enabled)
+    def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
+        self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size)
 
-    def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool):
+    def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
         auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
 
         weaviate.connect.connection.has_grpc = grpc_enabled
@@ -36,7 +36,7 @@ class WeaviateVectorStoreClient(BaseVectorStoreClient):
         client.batch.configure(
             # `batch_size` takes an `int` value to enable auto-batching
             # (`None` is used for manual batching)
-            batch_size=100,
+            batch_size=batch_size,
             # dynamically update the `batch_size` based on import speed
             dynamic=True,
             # `timeout_retries` takes an `int` value to retry on time outs