Ver código fonte

fix: mypy issues

Yeuoly 3 meses atrás
pai
commit
f748d6c7c4
49 arquivos alterados com 157 adições e 133 exclusões
  1. 3 2
      api/core/app/apps/agent_chat/generate_response_converter.py
  2. 1 1
      api/core/app/apps/base_app_generate_response_converter.py
  3. 3 2
      api/core/app/apps/chat/generate_response_converter.py
  4. 3 2
      api/core/app/apps/completion/generate_response_converter.py
  5. 3 3
      api/core/app/apps/workflow/app_generator.py
  6. 3 2
      api/core/app/apps/workflow/generate_response_converter.py
  7. 5 5
      api/core/entities/provider_configuration.py
  8. 1 1
      api/core/file/upload_file_parser.py
  9. 8 8
      api/core/llm_generator/llm_generator.py
  10. 2 0
      api/core/model_runtime/model_providers/__base/large_language_model.py
  11. 6 6
      api/core/model_runtime/model_providers/model_provider_factory.py
  12. 3 3
      api/core/plugin/manager/base.py
  13. 1 1
      api/core/provider_manager.py
  14. 3 0
      api/core/rag/retrieval/dataset_retrieval.py
  15. 5 6
      api/core/rag/splitter/fixed_text_splitter.py
  16. 6 4
      api/core/rag/splitter/text_splitter.py
  17. 2 2
      api/core/tools/__base/tool.py
  18. 2 2
      api/core/tools/builtin_tool/provider.py
  19. 2 1
      api/core/tools/builtin_tool/providers/audio/audio.py
  20. 3 3
      api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py
  21. 1 1
      api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py
  22. 1 1
      api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py
  23. 1 1
      api/core/tools/builtin_tool/providers/webscraper/webscraper.py
  24. 1 1
      api/core/tools/custom_tool/provider.py
  25. 2 2
      api/core/tools/plugin_tool/provider.py
  26. 9 1
      api/core/tools/plugin_tool/tool.py
  27. 23 20
      api/core/tools/tool_manager.py
  28. 1 1
      api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
  29. 15 3
      api/core/tools/utils/dataset_retriever_tool.py
  30. 1 1
      api/core/tools/utils/message_transformer.py
  31. 1 1
      api/core/tools/utils/workflow_configuration_sync.py
  32. 2 3
      api/core/tools/workflow_as_tool/provider.py
  33. 4 3
      api/core/tools/workflow_as_tool/tool.py
  34. 3 3
      api/core/workflow/nodes/agent/agent_node.py
  35. 2 2
      api/core/workflow/nodes/llm/node.py
  36. 2 2
      api/core/workflow/nodes/tool/tool_node.py
  37. 0 2
      api/core/workflow/workflow_entry.py
  38. 1 1
      api/libs/helper.py
  39. 1 1
      api/libs/login.py
  40. 2 2
      api/models/account.py
  41. 1 1
      api/models/model.py
  42. 2 10
      api/models/tools.py
  43. 1 1
      api/services/agent_service.py
  44. 1 1
      api/services/entities/model_provider_entities.py
  45. 1 2
      api/services/plugin/plugin_migration.py
  46. 1 1
      api/services/tools/api_tools_manage_service.py
  47. 2 2
      api/services/tools/tools_transform_service.py
  48. 9 8
      api/services/tools/workflow_tools_manage_service.py
  49. 1 1
      api/tasks/batch_create_segment_to_index_task.py

+ 3 - 2
api/core/app/apps/agent_chat/generate_response_converter.py

@@ -3,6 +3,7 @@ from typing import cast
 
 from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
 from core.app.entities.task_entities import (
+    AppStreamResponse,
     ChatbotAppBlockingResponse,
     ChatbotAppStreamResponse,
     ErrorStreamResponse,
@@ -51,7 +52,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
 
     @classmethod
     def convert_stream_full_response(
-        cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+        cls, stream_response: Generator[AppStreamResponse, None, None]
     ) -> Generator[dict | str, None, None]:
         """
         Convert stream full response.
@@ -82,7 +83,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
 
     @classmethod
     def convert_stream_simple_response(
-        cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+        cls, stream_response: Generator[AppStreamResponse, None, None]
     ) -> Generator[dict | str, None, None]:
         """
         Convert stream simple response.

+ 1 - 1
api/core/app/apps/base_app_generate_response_converter.py

@@ -56,7 +56,7 @@ class AppGenerateResponseConverter(ABC):
     @abstractmethod
     def convert_stream_simple_response(
         cls, stream_response: Generator[AppStreamResponse, None, None]
-    ) -> Generator[str, None, None]:
+    ) -> Generator[dict | str, None, None]:
         raise NotImplementedError
 
     @classmethod

+ 3 - 2
api/core/app/apps/chat/generate_response_converter.py

@@ -3,6 +3,7 @@ from typing import cast
 
 from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
 from core.app.entities.task_entities import (
+    AppStreamResponse,
     ChatbotAppBlockingResponse,
     ChatbotAppStreamResponse,
     ErrorStreamResponse,
@@ -51,7 +52,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
 
     @classmethod
     def convert_stream_full_response(
-        cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+        cls, stream_response: Generator[AppStreamResponse, None, None]
     ) -> Generator[dict | str, None, None]:
         """
         Convert stream full response.
@@ -82,7 +83,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
 
     @classmethod
     def convert_stream_simple_response(
-        cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+        cls, stream_response: Generator[AppStreamResponse, None, None]
     ) -> Generator[dict | str, None, None]:
         """
         Convert stream simple response.

+ 3 - 2
api/core/app/apps/completion/generate_response_converter.py

@@ -3,6 +3,7 @@ from typing import cast
 
 from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
 from core.app.entities.task_entities import (
+    AppStreamResponse,
     CompletionAppBlockingResponse,
     CompletionAppStreamResponse,
     ErrorStreamResponse,
@@ -50,7 +51,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
 
     @classmethod
     def convert_stream_full_response(
-        cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
+        cls, stream_response: Generator[AppStreamResponse, None, None]
     ) -> Generator[dict | str, None, None]:
         """
         Convert stream full response.
@@ -80,7 +81,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
 
     @classmethod
     def convert_stream_simple_response(
-        cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
+        cls, stream_response: Generator[AppStreamResponse, None, None]
     ) -> Generator[dict | str, None, None]:
         """
         Convert stream simple response.

+ 3 - 3
api/core/app/apps/workflow/app_generator.py

@@ -149,7 +149,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
         invoke_from: InvokeFrom,
         streaming: bool = True,
         workflow_thread_pool_id: Optional[str] = None,
-    ) -> Union[dict, Generator[str | dict, None, None]]:
+    ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
         """
         Generate App response.
 
@@ -200,9 +200,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
         workflow: Workflow,
         node_id: str,
         user: Account | EndUser,
-        args: dict,
+        args: Mapping[str, Any],
         streaming: bool = True,
-    ) -> dict[str, Any] | Generator[str | dict, Any, None]:
+    ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
         """
         Generate App response.
 

+ 3 - 2
api/core/app/apps/workflow/generate_response_converter.py

@@ -3,6 +3,7 @@ from typing import cast
 
 from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
 from core.app.entities.task_entities import (
+    AppStreamResponse,
     ErrorStreamResponse,
     NodeFinishStreamResponse,
     NodeStartStreamResponse,
@@ -35,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
 
     @classmethod
     def convert_stream_full_response(
-        cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
+        cls, stream_response: Generator[AppStreamResponse, None, None]
     ) -> Generator[dict | str, None, None]:
         """
         Convert stream full response.
@@ -64,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
 
     @classmethod
     def convert_stream_simple_response(
-        cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
+        cls, stream_response: Generator[AppStreamResponse, None, None]
     ) -> Generator[dict | str, None, None]:
         """
         Convert stream simple response.

+ 5 - 5
api/core/entities/provider_configuration.py

@@ -157,7 +157,7 @@ class ProviderConfiguration(BaseModel):
         """
         return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
 
-    def get_custom_credentials(self, obfuscated: bool = False):
+    def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
         """
         Get custom credentials.
 
@@ -741,11 +741,11 @@ class ProviderConfiguration(BaseModel):
         model_provider_factory = ModelProviderFactory(self.tenant_id)
         provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
 
-        model_types = []
+        model_types: list[ModelType] = []
         if model_type:
             model_types.append(model_type)
         else:
-            model_types = provider_schema.supported_model_types
+            model_types = list(provider_schema.supported_model_types)
 
         # Group model settings by model type and model
         model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
@@ -1065,11 +1065,11 @@ class ProviderConfigurations(BaseModel):
     def values(self) -> Iterator[ProviderConfiguration]:
         return iter(self.configurations.values())
 
-    def get(self, key, default=None):
+    def get(self, key, default=None) -> ProviderConfiguration | None:
         if "/" not in key:
             key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
 
-        return self.configurations.get(key, default)
+        return self.configurations.get(key, default)  # type: ignore
 
 
 class ProviderModelBundle(BaseModel):

+ 1 - 1
api/core/file/upload_file_parser.py

@@ -20,7 +20,7 @@ class UploadFileParser:
         if upload_file.extension not in IMAGE_EXTENSIONS:
             return None
 
-        if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
+        if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
             return cls.get_signed_temp_image_url(upload_file.id)
         else:
             # get image file base64

+ 8 - 8
api/core/llm_generator/llm_generator.py

@@ -48,7 +48,7 @@ class LLMGenerator:
             response = cast(
                 LLMResult,
                 model_instance.invoke_llm(
-                    prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
+                    prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
                 ),
             )
         answer = cast(str, response.message.content)
@@ -101,7 +101,7 @@ class LLMGenerator:
             response = cast(
                 LLMResult,
                 model_instance.invoke_llm(
-                    prompt_messages=prompt_messages,
+                    prompt_messages=list(prompt_messages),
                     model_parameters={"max_tokens": 256, "temperature": 0},
                     stream=False,
                 ),
@@ -110,7 +110,7 @@ class LLMGenerator:
             questions = output_parser.parse(cast(str, response.message.content))
         except InvokeError:
             questions = []
-        except Exception as e:
+        except Exception:
             logging.exception("Failed to generate suggested questions after answer")
             questions = []
 
@@ -150,7 +150,7 @@ class LLMGenerator:
                 response = cast(
                     LLMResult,
                     model_instance.invoke_llm(
-                        prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+                        prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
                     ),
                 )
 
@@ -200,7 +200,7 @@ class LLMGenerator:
                 prompt_content = cast(
                     LLMResult,
                     model_instance.invoke_llm(
-                        prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+                        prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
                     ),
                 )
             except InvokeError as e:
@@ -236,7 +236,7 @@ class LLMGenerator:
                 parameter_content = cast(
                     LLMResult,
                     model_instance.invoke_llm(
-                        prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
+                        prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
                     ),
                 )
                 rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
@@ -248,7 +248,7 @@ class LLMGenerator:
                 statement_content = cast(
                     LLMResult,
                     model_instance.invoke_llm(
-                        prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
+                        prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
                     ),
                 )
                 rule_config["opening_statement"] = cast(str, statement_content.message.content)
@@ -301,7 +301,7 @@ class LLMGenerator:
             response = cast(
                 LLMResult,
                 model_instance.invoke_llm(
-                    prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+                    prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
                 ),
             )
 

+ 2 - 0
api/core/model_runtime/model_providers/__base/large_language_model.py

@@ -84,6 +84,8 @@ class LargeLanguageModel(AIModel):
             callbacks=callbacks,
         )
 
+        result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
+
         try:
             plugin_model_manager = PluginModelManager()
             result = plugin_model_manager.invoke_llm(

+ 6 - 6
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -285,17 +285,17 @@ class ModelProviderFactory:
         }
 
         if model_type == ModelType.LLM:
-            return LargeLanguageModel(**init_params)
+            return LargeLanguageModel(**init_params)  # type: ignore
         elif model_type == ModelType.TEXT_EMBEDDING:
-            return TextEmbeddingModel(**init_params)
+            return TextEmbeddingModel(**init_params)  # type: ignore
         elif model_type == ModelType.RERANK:
-            return RerankModel(**init_params)
+            return RerankModel(**init_params)  # type: ignore
         elif model_type == ModelType.SPEECH2TEXT:
-            return Speech2TextModel(**init_params)
+            return Speech2TextModel(**init_params)  # type: ignore
         elif model_type == ModelType.MODERATION:
-            return ModerationModel(**init_params)
+            return ModerationModel(**init_params)  # type: ignore
         elif model_type == ModelType.TTS:
-            return TTSModel(**init_params)
+            return TTSModel(**init_params)  # type: ignore
 
     def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
         """

+ 3 - 3
api/core/plugin/manager/base.py

@@ -119,7 +119,7 @@ class BasePluginManager:
         Make a request to the plugin daemon inner API and return the response as a model.
         """
         response = self._request(method, path, headers, data, params, files)
-        return type(**response.json())
+        return type(**response.json())  # type: ignore
 
     def _request_with_plugin_daemon_response(
         self,
@@ -140,7 +140,7 @@ class BasePluginManager:
         if transformer:
             json_response = transformer(json_response)
 
-        rep = PluginDaemonBasicResponse[type](**json_response)
+        rep = PluginDaemonBasicResponse[type](**json_response)  # type: ignore
         if rep.code != 0:
             try:
                 error = PluginDaemonError(**json.loads(rep.message))
@@ -171,7 +171,7 @@ class BasePluginManager:
             line_data = None
             try:
                 line_data = json.loads(line)
-                rep = PluginDaemonBasicResponse[type](**line_data)
+                rep = PluginDaemonBasicResponse[type](**line_data)  # type: ignore
             except Exception:
                 # TODO modify this when line_data has code and message
                 if line_data and "error" in line_data:

+ 1 - 1
api/core/provider_manager.py

@@ -742,7 +742,7 @@ class ProviderManager:
                     try:
                         provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
                     except JSONDecodeError:
-                        provider_credentials: dict[str, Any] = {}
+                        provider_credentials = {}
 
                     # Get provider credential secret variables
                     provider_credential_secret_variables = self._extract_secret_variables(

+ 3 - 0
api/core/rag/retrieval/dataset_retrieval.py

@@ -601,6 +601,9 @@ class DatasetRetrieval:
         elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
             from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
 
+            if retrieve_config.reranking_model is None:
+                raise ValueError("Reranking model is required for multiple retrieval")
+
             tool = DatasetMultiRetrieverTool.from_dataset(
                 dataset_ids=[dataset.id for dataset in available_datasets],
                 tenant_id=tenant_id,

+ 5 - 6
api/core/rag/splitter/fixed_text_splitter.py

@@ -30,14 +30,14 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
         disallowed_special: Union[Literal["all"], Collection[str]] = "all",  # noqa: UP037
         **kwargs: Any,
     ):
-        def _token_encoder(text: str) -> int:
-            if not text:
-                return 0
+        def _token_encoder(texts: list[str]) -> list[int]:
+            if not texts:
+                return []
 
             if embedding_model_instance:
-                return embedding_model_instance.get_text_embedding_num_tokens(texts=[text])
+                return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
             else:
-                return GPT2Tokenizer.get_num_tokens(text)
+                return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
 
         if issubclass(cls, TokenTextSplitter):
             extra_kwargs = {
@@ -96,7 +96,6 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
         _good_splits_lengths = []  # cache the lengths of the splits
         s_lens = self._length_function(splits)
         for s, s_len in zip(splits, s_lens):
-            s_len = self._length_function(s)
             if s_len < self._chunk_size:
                 _good_splits.append(s)
                 _good_splits_lengths.append(s_len)

+ 6 - 4
api/core/rag/splitter/text_splitter.py

@@ -106,7 +106,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
     def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
         # We now want to combine these smaller pieces into medium size
         # chunks to send to the LLM.
-        separator_len = self._length_function(separator)
+        separator_len = self._length_function([separator])[0]
 
         docs = []
         current_doc: list[str] = []
@@ -129,7 +129,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
                     while total > self._chunk_overlap or (
                         total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
                     ):
-                        total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
+                        total -= self._length_function([current_doc[0]])[0] + (
+                            separator_len if len(current_doc) > 1 else 0
+                        )
                         current_doc = current_doc[1:]
             current_doc.append(d)
             total += _len + (separator_len if len(current_doc) > 1 else 0)
@@ -155,7 +157,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
             raise ValueError(
                 "Could not import transformers python package. Please install it with `pip install transformers`."
             )
-        return cls(length_function=_huggingface_tokenizer_length, **kwargs)
+        return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
 
     @classmethod
     def from_tiktoken_encoder(
@@ -199,7 +201,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
             }
             kwargs = {**kwargs, **extra_kwargs}
 
-        return cls(length_function=_tiktoken_encoder, **kwargs)
+        return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
 
     def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
         """Transform sequence of documents by splitting them."""

+ 2 - 2
api/core/tools/__base/tool.py

@@ -71,13 +71,13 @@ class Tool(ABC):
 
         if isinstance(result, ToolInvokeMessage):
 
-            def single_generator():
+            def single_generator() -> Generator[ToolInvokeMessage, None, None]:
                 yield result
 
             return single_generator()
         elif isinstance(result, list):
 
-            def generator():
+            def generator() -> Generator[ToolInvokeMessage, None, None]:
                 yield from result
 
             return generator()

+ 2 - 2
api/core/tools/builtin_tool/provider.py

@@ -109,11 +109,11 @@ class BuiltinToolProviderController(ToolProviderController):
         """
         return self._get_builtin_tools()
 
-    def get_tool(self, tool_name: str) -> BuiltinTool | None:
+    def get_tool(self, tool_name: str) -> BuiltinTool | None:  # type: ignore
         """
         returns the tool that the provider can provide
         """
-        return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
+        return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)  # type: ignore
 
     @property
     def need_credentials(self) -> bool:

+ 2 - 1
api/core/tools/builtin_tool/providers/audio/audio.py

@@ -1,6 +1,7 @@
+from typing import Any
 from core.tools.builtin_tool.provider import BuiltinToolProviderController
 
 
 class AudioToolProvider(BuiltinToolProviderController):
-    def _validate_credentials(self, credentials: dict) -> None:
+    def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
         pass

+ 3 - 3
api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py

@@ -27,7 +27,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
             timezone = None
         time_format = "%Y-%m-%d %H:%M:%S"
 
-        timestamp = self.localtime_to_timestamp(localtime, time_format, timezone)
+        timestamp = self.localtime_to_timestamp(localtime, time_format, timezone)  # type: ignore
         if not timestamp:
             yield self.create_text_message(f"Invalid localtime: {localtime}")
             return
@@ -42,8 +42,8 @@ class LocaltimeToTimestampTool(BuiltinTool):
             if isinstance(local_tz, str):
                 local_tz = pytz.timezone(local_tz)
             local_time = datetime.strptime(localtime, time_format)
-            localtime = local_tz.localize(local_time)
-            timestamp = int(localtime.timestamp())
+            localtime = local_tz.localize(local_time)  # type: ignore
+            timestamp = int(localtime.timestamp())  # type: ignore
             return timestamp
         except Exception as e:
             raise ToolInvokeError(str(e))

+ 1 - 1
api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py

@@ -21,7 +21,7 @@ class TimestampToLocaltimeTool(BuiltinTool):
         """
         Convert timestamp to localtime
         """
-        timestamp = tool_parameters.get("timestamp")
+        timestamp: int = tool_parameters.get("timestamp", 0)
         timezone = tool_parameters.get("timezone", "Asia/Shanghai")
         if not timezone:
             timezone = None

+ 1 - 1
api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py

@@ -24,7 +24,7 @@ class TimezoneConversionTool(BuiltinTool):
         current_time = tool_parameters.get("current_time")
         current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai")
         target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo")
-        target_time = self.timezone_convert(current_time, current_timezone, target_timezone)
+        target_time = self.timezone_convert(current_time, current_timezone, target_timezone)  # type: ignore
         if not target_time:
             yield self.create_text_message(
                 f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"

+ 1 - 1
api/core/tools/builtin_tool/providers/webscraper/webscraper.py

@@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
 
 
 class WebscraperProvider(BuiltinToolProviderController):
-    def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+    def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
         pass

+ 1 - 1
api/core/tools/custom_tool/provider.py

@@ -31,7 +31,7 @@ class ApiToolProviderController(ToolProviderController):
         self.tools = []
 
     @classmethod
-    def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
+    def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
         credentials_schema = [
             ProviderConfig(
                 name="auth_type",

+ 2 - 2
api/core/tools/plugin_tool/provider.py

@@ -44,7 +44,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
         ):
             raise ToolProviderCredentialValidationError("Invalid credentials")
 
-    def get_tool(self, tool_name: str) -> PluginTool:
+    def get_tool(self, tool_name: str) -> PluginTool:  # type: ignore
         """
         return tool with given name
         """
@@ -61,7 +61,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
             plugin_unique_identifier=self.plugin_unique_identifier,
         )
 
-    def get_tools(self) -> list[PluginTool]:
+    def get_tools(self) -> list[PluginTool]:  # type: ignore
         """
         get all tools
         """

+ 9 - 1
api/core/tools/plugin_tool/tool.py

@@ -59,7 +59,12 @@ class PluginTool(Tool):
             plugin_unique_identifier=self.plugin_unique_identifier,
         )
 
-    def get_runtime_parameters(self) -> list[ToolParameter]:
+    def get_runtime_parameters(
+        self,
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> list[ToolParameter]:
         """
         get the runtime parameters
         """
@@ -76,6 +81,9 @@ class PluginTool(Tool):
             provider=self.entity.identity.provider,
             tool=self.entity.identity.name,
             credentials=self.runtime.credentials,
+            conversation_id=conversation_id,
+            app_id=app_id,
+            message_id=message_id,
         )
 
         return self.runtime_parameters

+ 23 - 20
api/core/tools/tool_manager.py

@@ -4,7 +4,7 @@ import mimetypes
 from collections.abc import Generator
 from os import listdir, path
 from threading import Lock
-from typing import TYPE_CHECKING, Any, Optional, Union, cast
+from typing import TYPE_CHECKING, Any, Union, cast
 
 from yarl import URL
 
@@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
 
 class ToolManager:
     _builtin_provider_lock = Lock()
-    _hardcoded_providers = {}
+    _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
     _builtin_providers_loaded = False
     _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
 
@@ -203,7 +203,7 @@ class ToolManager:
                 if builtin_provider is None:
                     raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
             else:
-                builtin_provider: BuiltinToolProvider | None = (
+                builtin_provider = (
                     db.session.query(BuiltinToolProvider)
                     .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
                     .first()
@@ -270,9 +270,7 @@ class ToolManager:
                 raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
 
             controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
-            controller_tools: Optional[list[Tool]] = controller.get_tools(
-                user_id="", tenant_id=workflow_provider.tenant_id
-            )
+            controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
             if controller_tools is None or len(controller_tools) == 0:
                 raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
 
@@ -747,18 +745,21 @@ class ToolManager:
         # add tool labels
         labels = ToolLabelManager.get_tool_labels(controller)
 
-        return jsonable_encoder(
-            {
-                "schema_type": provider_obj.schema_type,
-                "schema": provider_obj.schema,
-                "tools": provider_obj.tools,
-                "icon": icon,
-                "description": provider_obj.description,
-                "credentials": masked_credentials,
-                "privacy_policy": provider_obj.privacy_policy,
-                "custom_disclaimer": provider_obj.custom_disclaimer,
-                "labels": labels,
-            }
+        return cast(
+            dict,
+            jsonable_encoder(
+                {
+                    "schema_type": provider_obj.schema_type,
+                    "schema": provider_obj.schema,
+                    "tools": provider_obj.tools,
+                    "icon": icon,
+                    "description": provider_obj.description,
+                    "credentials": masked_credentials,
+                    "privacy_policy": provider_obj.privacy_policy,
+                    "custom_disclaimer": provider_obj.custom_disclaimer,
+                    "labels": labels,
+                }
+            ),
         )
 
     @classmethod
@@ -795,7 +796,8 @@ class ToolManager:
             if workflow_provider is None:
                 raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
 
-            return json.loads(workflow_provider.icon)
+            icon: dict = json.loads(workflow_provider.icon)
+            return icon
         except Exception:
             return {"background": "#252525", "content": "\ud83d\ude01"}
 
@@ -811,7 +813,8 @@ class ToolManager:
             if api_provider is None:
                 raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
 
-            return json.loads(api_provider.icon)
+            icon: dict = json.loads(api_provider.icon)
+            return icon
         except Exception:
             return {"background": "#252525", "content": "\ud83d\ude01"}
 

+ 1 - 1
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.models.document import Document as RetrievalDocument
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
-from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
+from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 from services.external_knowledge_service import ExternalDatasetService

+ 15 - 3
api/core/tools/utils/dataset_retriever_tool.py

@@ -1,5 +1,5 @@
 from collections.abc import Generator
-from typing import Any
+from typing import Any, Optional
 
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import InvokeFrom
@@ -83,7 +83,12 @@ class DatasetRetrieverTool(Tool):
 
         return tools
 
-    def get_runtime_parameters(self) -> list[ToolParameter]:
+    def get_runtime_parameters(
+        self,
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> list[ToolParameter]:
         return [
             ToolParameter(
                 name="query",
@@ -101,7 +106,14 @@ class DatasetRetrieverTool(Tool):
     def tool_provider_type(self) -> ToolProviderType:
         return ToolProviderType.DATASET_RETRIEVAL
 
-    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
+    def _invoke(
+        self,
+        user_id: str,
+        tool_parameters: dict[str, Any],
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> Generator[ToolInvokeMessage, None, None]:
         """
         invoke dataset retriever tool
         """

+ 1 - 1
api/core/tools/utils/message_transformer.py

@@ -91,7 +91,7 @@ class ToolFileMessageTransformer:
                     )
             elif message.type == ToolInvokeMessage.MessageType.FILE:
                 meta = message.meta or {}
-                file = meta.get("file")
+                file = meta.get("file", None)
                 if isinstance(file, File):
                     if file.transfer_method == FileTransferMethod.TOOL_FILE:
                         assert file.related_id is not None

+ 1 - 1
api/core/tools/utils/workflow_configuration_sync.py

@@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
     @classmethod
     def check_is_synced(
         cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
-    ) -> bool:
+    ):
         """
         check is synced
 

+ 2 - 3
api/core/tools/workflow_as_tool/provider.py

@@ -6,7 +6,6 @@ from pydantic import Field
 from core.app.app_config.entities import VariableEntity, VariableEntityType
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
 from core.plugin.entities.parameters import PluginParameterOption
-from core.tools.__base.tool import Tool
 from core.tools.__base.tool_provider import ToolProviderController
 from core.tools.__base.tool_runtime import ToolRuntime
 from core.tools.entities.common_entities import I18nObject
@@ -101,7 +100,7 @@ class WorkflowToolProviderController(ToolProviderController):
         variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
 
         def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
-            return next(filter(lambda x: x.variable == variable_name, variables), None)
+            return next(filter(lambda x: x.variable == variable_name, variables), None)  # type: ignore
 
         user = db_provider.user
 
@@ -212,7 +211,7 @@ class WorkflowToolProviderController(ToolProviderController):
 
         return self.tools
 
-    def get_tool(self, tool_name: str) -> Optional[Tool]:
+    def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:  # type: ignore
         """
         get tool by name
 

+ 4 - 3
api/core/tools/workflow_as_tool/tool.py

@@ -106,9 +106,9 @@ class WorkflowTool(Tool):
         if outputs is None:
             outputs = {}
         else:
-            outputs, files = self._extract_files(outputs)
+            outputs, files = self._extract_files(outputs)  # type: ignore
             for file in files:
-                yield self.create_file_message(file)
+                yield self.create_file_message(file)  # type: ignore
 
         yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
         yield self.create_json_message(outputs)
@@ -217,7 +217,7 @@ class WorkflowTool(Tool):
         :param result: the result
         :return: the result, files
         """
-        files = []
+        files: list[File] = []
         result = {}
         for key, value in outputs.items():
             if isinstance(value, list):
@@ -238,4 +238,5 @@ class WorkflowTool(Tool):
                 files.append(file)
 
             result[key] = value
+
         return result, files

+ 3 - 3
api/core/workflow/nodes/agent/agent_node.py

@@ -27,7 +27,7 @@ class AgentNode(ToolNode):
     Agent Node
     """
 
-    _node_data_cls = AgentNodeData
+    _node_data_cls = AgentNodeData  # type: ignore
     _node_type = NodeType.AGENT
 
     def _run(self) -> Generator:
@@ -125,7 +125,7 @@ class AgentNode(ToolNode):
         """
         agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
 
-        result = {}
+        result: dict[str, Any] = {}
         for parameter_name in node_data.agent_parameters:
             parameter = agent_parameters_dictionary.get(parameter_name)
             if not parameter:
@@ -214,7 +214,7 @@ class AgentNode(ToolNode):
         :return:
         """
         node_data = cast(AgentNodeData, node_data)
-        result = {}
+        result: dict[str, Any] = {}
         for parameter_name in node_data.agent_parameters:
             input = node_data.agent_parameters[parameter_name]
             if input.type == "mixed":

+ 2 - 2
api/core/workflow/nodes/llm/node.py

@@ -233,9 +233,9 @@ class LLMNode(BaseNode[LLMNodeData]):
         db.session.close()
 
         invoke_result = model_instance.invoke_llm(
-            prompt_messages=prompt_messages,
+            prompt_messages=list(prompt_messages),
             model_parameters=node_data_model.completion_params,
-            stop=stop,
+            stop=list(stop or []),
             stream=True,
             user=self.user_id,
         )

+ 2 - 2
api/core/workflow/nodes/tool/tool_node.py

@@ -1,5 +1,5 @@
 from collections.abc import Generator, Mapping, Sequence
-from typing import Any, Optional, cast
+from typing import Any, cast
 
 from sqlalchemy import select
 from sqlalchemy.orm import Session
@@ -197,7 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
         json: list[dict] = []
 
         agent_logs: list[AgentLogEvent] = []
-        agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
+        agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
 
         variables: dict[str, Any] = {}
 

+ 0 - 2
api/core/workflow/workflow_entry.py

@@ -284,8 +284,6 @@ class WorkflowEntry:
                 user_inputs=user_inputs,
                 variable_pool=variable_pool,
                 tenant_id=tenant_id,
-                node_type=node_type,
-                node_data=node_instance.node_data,
             )
 
             # run node

+ 1 - 1
api/libs/helper.py

@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
 from zoneinfo import available_timezones
 
 from flask import Response, stream_with_context
-from flask_restful import fields
+from flask_restful import fields  # type: ignore
 
 from configs import dify_config
 from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

+ 1 - 1
api/libs/login.py

@@ -102,6 +102,6 @@ def _get_user() -> EndUser | Account | None:
         if "_login_user" not in g:
             current_app.login_manager._load_user()  # type: ignore
 
-        return g._login_user
+        return g._login_user  # type: ignore
 
     return None

+ 2 - 2
api/models/account.py

@@ -1,7 +1,7 @@
 import enum
 import json
 
-from flask_login import UserMixin
+from flask_login import UserMixin  # type: ignore
 from sqlalchemy import func
 from sqlalchemy.orm import Mapped, mapped_column
 
@@ -56,7 +56,7 @@ class Account(UserMixin, Base):
         if ta:
             tenant.current_role = ta.role
         else:
-            tenant = None
+            tenant = None  # type: ignore
 
         self._current_tenant = tenant
 

+ 1 - 1
api/models/model.py

@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast
 
 import sqlalchemy as sa
 from flask import request
-from flask_login import UserMixin
+from flask_login import UserMixin  # type: ignore
 from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
 from sqlalchemy.orm import Mapped, Session, mapped_column
 

+ 2 - 10
api/models/tools.py

@@ -1,6 +1,6 @@
 import json
 from datetime import datetime
-from typing import Any, Optional
+from typing import Any, Optional, cast
 
 import sqlalchemy as sa
 from deprecated import deprecated
@@ -48,7 +48,7 @@ class BuiltinToolProvider(Base):
 
     @property
     def credentials(self) -> dict:
-        return json.loads(self.encrypted_credentials)
+        return cast(dict, json.loads(self.encrypted_credentials))
 
 
 class ApiToolProvider(Base):
@@ -302,13 +302,9 @@ class DeprecatedPublishedAppTool(Base):
         db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
     )
 
-    # id of the tool provider
-    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     # id of the app
     app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
     # who published this tool
-    user_id = db.Column(StringUUID, nullable=False)
-    # description of the tool, stored in i18n format, for human
     description = db.Column(db.Text, nullable=False)
     # llm_description of the tool, for LLM
     llm_description = db.Column(db.Text, nullable=False)
@@ -328,10 +324,6 @@ class DeprecatedPublishedAppTool(Base):
     def description_i18n(self) -> I18nObject:
         return I18nObject(**json.loads(self.description))
 
-    @property
-    def app(self) -> App:
-        return db.session.query(App).filter(App.id == self.app_id).first()
-
     id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
     tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)

+ 1 - 1
api/services/agent_service.py

@@ -23,7 +23,7 @@ class AgentService:
         contexts.plugin_tool_providers.set({})
         contexts.plugin_tool_providers_lock.set(threading.Lock())
 
-        conversation: Conversation = (
+        conversation: Conversation | None = (
             db.session.query(Conversation)
             .filter(
                 Conversation.id == conversation_id,

+ 1 - 1
api/services/entities/model_provider_entities.py

@@ -156,7 +156,7 @@ class DefaultModelResponse(BaseModel):
     model_config = ConfigDict(protected_namespaces=())
 
 
-class ModelWithProviderEntityResponse(ModelWithProviderEntity):
+class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
     """
     Model with provider entity.
     """

+ 1 - 2
api/services/plugin/plugin_migration.py

@@ -173,9 +173,8 @@ class PluginMigration:
         """
         Extract model tables.
 
-        NOTE: rename google to gemini
         """
-        models = []
+        models: list[str] = []
         table_pairs = [
             ("providers", "provider_name"),
             ("provider_models", "provider_name"),

+ 1 - 1
api/services/tools/api_tools_manage_service.py

@@ -439,7 +439,7 @@ class ApiToolManageService:
                     tenant_id=tenant_id,
                 )
             )
-            result = runtime_tool.validate_credentials(credentials, parameters)
+            result = tool.validate_credentials(credentials, parameters)
         except Exception as e:
             return {"error": str(e)}
 

+ 2 - 2
api/services/tools/tools_transform_service.py

@@ -1,6 +1,6 @@
 import json
 import logging
-from typing import Optional, Union
+from typing import Optional, Union, cast
 
 from yarl import URL
 
@@ -44,7 +44,7 @@ class ToolTransformService:
         elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
             try:
                 if isinstance(icon, str):
-                    return json.loads(icon)
+                    return cast(dict, json.loads(icon))
                 return icon
             except Exception:
                 return {"background": "#252525", "content": "\ud83d\ude01"}

+ 9 - 8
api/services/tools/workflow_tools_manage_service.py

@@ -1,7 +1,7 @@
 import json
-from collections.abc import Mapping, Sequence
+from collections.abc import Mapping
 from datetime import datetime
-from typing import Any, Optional
+from typing import Any
 
 from sqlalchemy import or_
 
@@ -11,6 +11,7 @@ from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntit
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
 from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
+from core.tools.workflow_as_tool.tool import WorkflowTool
 from extensions.ext_database import db
 from models.model import App
 from models.tools import WorkflowToolProvider
@@ -187,7 +188,7 @@ class WorkflowToolManageService:
         """
         db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
 
-        tools: Sequence[WorkflowToolProviderController] = []
+        tools: list[WorkflowToolProviderController] = []
         for provider in db_tools:
             try:
                 tools.append(ToolTransformService.workflow_provider_to_controller(provider))
@@ -264,7 +265,7 @@ class WorkflowToolManageService:
         return cls._get_workflow_tool(tenant_id, db_tool)
 
     @classmethod
-    def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
+    def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
         """
         Get a workflow tool.
         :db_tool: the database tool
@@ -285,8 +286,8 @@ class WorkflowToolManageService:
             raise ValueError("Workflow not found")
 
         tool = ToolTransformService.workflow_provider_to_controller(db_tool)
-        to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(tenant_id)
-        if to_user_tool is None or len(to_user_tool) == 0:
+        workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
+        if len(workflow_tools) == 0:
             raise ValueError(f"Tool {db_tool.id} not found")
 
         return {
@@ -325,8 +326,8 @@ class WorkflowToolManageService:
             raise ValueError(f"Tool {workflow_tool_id} not found")
 
         tool = ToolTransformService.workflow_provider_to_controller(db_tool)
-        to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(user_id, tenant_id)
-        if to_user_tool is None or len(to_user_tool) == 0:
+        workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
+        if len(workflow_tools) == 0:
             raise ValueError(f"Tool {workflow_tool_id} not found")
 
         return [

+ 1 - 1
api/tasks/batch_create_segment_to_index_task.py

@@ -67,7 +67,7 @@ def batch_create_segment_to_index_task(
         for segment, tokens in zip(content, tokens_list):
             content = segment["content"]
             doc_id = str(uuid.uuid4())
-            segment_hash = helper.generate_text_hash(content)
+            segment_hash = helper.generate_text_hash(content)  # type: ignore
             max_position = (
                 db.session.query(func.max(DocumentSegment.position))
                 .filter(DocumentSegment.document_id == dataset_document.id)