Browse Source

fix: errors occrus during rebasing

Yeuoly 4 months ago
parent
commit
e231cf2c48

+ 2 - 2
api/core/model_manager.py

@@ -191,7 +191,7 @@ class ModelInstance:
 
         self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
         return cast(
-            int,
+            list[int],
             self._round_robin_invoke(
                 function=self.model_type_instance.get_num_tokens,
                 model=self.model,
@@ -240,7 +240,7 @@ class ModelInstance:
 
         self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
         return cast(
-            int,
+            list[int],
             self._round_robin_invoke(
                 function=self.model_type_instance.get_num_tokens,
                 model=self.model,

+ 8 - 6
api/core/provider_manager.py

@@ -1,7 +1,7 @@
 import json
 from collections import defaultdict
 from json import JSONDecodeError
-from typing import Optional, cast
+from typing import Any, Optional, cast
 
 from sqlalchemy.exc import IntegrityError
 
@@ -350,7 +350,7 @@ class ProviderManager:
         :param tenant_id: workspace id
         :return:
         """
-        providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid is True).all()
+        providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()  # noqa
 
         provider_name_to_provider_records_dict = defaultdict(list)
         for provider in providers:
@@ -369,7 +369,7 @@ class ProviderManager:
         # Get all provider model records of the workspace
         provider_models = (
             db.session.query(ProviderModel)
-            .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid is True)
+            .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)  # noqa
             .all()
         )
 
@@ -739,9 +739,9 @@ class ProviderManager:
 
                 if not cached_provider_credentials:
                     try:
-                        provider_credentials = json.loads(provider_record.encrypted_config)
+                        provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
                     except JSONDecodeError:
-                        provider_credentials = {}
+                        provider_credentials: dict[str, Any] = {}
 
                     # Get provider credential secret variables
                     provider_credential_secret_variables = self._extract_secret_variables(
@@ -758,7 +758,9 @@ class ProviderManager:
                         if variable in provider_credentials:
                             try:
                                 provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
-                                    provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa
+                                    provider_credentials.get(variable, ""),
+                                    self.decoding_rsa_key,
+                                    self.decoding_cipher_rsa,
                                 )
                             except ValueError:
                                 pass

+ 3 - 3
api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py

@@ -88,7 +88,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
             DocumentSegment.dataset_id.in_(self.dataset_ids),
             DocumentSegment.completed_at.isnot(None),
             DocumentSegment.status == "completed",
-            DocumentSegment.enabled is True,
+            DocumentSegment.enabled == True,  # noqa
             DocumentSegment.index_node_id.in_(index_node_ids),
         ).all()
 
@@ -109,8 +109,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
                     dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
                     document = Document.query.filter(
                         Document.id == segment.document_id,
-                        Document.enabled is True,
-                        Document.archived is False,
+                        Document.enabled == True,  # noqa
+                        Document.archived == False,  # noqa
                     ).first()
                     if dataset and document:
                         source = {

+ 1 - 2
api/services/entities/model_provider_entities.py

@@ -7,7 +7,6 @@ from configs import dify_config
 from core.entities.model_entities import (
     ModelWithProviderEntity,
     ProviderModelWithStatusEntity,
-    SimpleModelProviderEntity,
 )
 from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
 from core.model_runtime.entities.common_entities import I18nObject
@@ -162,7 +161,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
     Model with provider entity.
     """
 
-    provider: SimpleModelProviderEntity
+    provider: SimpleProviderEntityResponse
 
     def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
         dump_model = model.model_dump()