소스 검색

fix: mypy linter

Yeuoly 5 달 전
부모
커밋
b7d168ac59

+ 2 - 2
api/controllers/console/auth/forgot_password.py

@@ -2,7 +2,7 @@ import base64
 import secrets
 
 from flask import request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse  # type: ignore
 from sqlalchemy import select
 from sqlalchemy.orm import Session
 
@@ -129,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
                 )
             except WorkSpaceNotAllowedCreateError:
                 pass
-            except AccountRegisterError as are:
+            except AccountRegisterError:
                 raise AccountInFreezeError()
 
         return {"result": "success"}

+ 1 - 1
api/controllers/console/auth/oauth.py

@@ -4,7 +4,7 @@ from typing import Optional
 
 import requests
 from flask import current_app, redirect, request
-from flask_restful import Resource
+from flask_restful import Resource  # type: ignore
 from sqlalchemy import select
 from sqlalchemy.orm import Session
 from werkzeug.exceptions import Unauthorized

+ 2 - 2
api/controllers/console/datasets/data_source.py

@@ -2,8 +2,8 @@ import datetime
 import json
 
 from flask import request
-from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_login import current_user  # type: ignore
+from flask_restful import Resource, marshal_with, reqparse  # type: ignore
 from sqlalchemy import select
 from sqlalchemy.orm import Session
 from werkzeug.exceptions import NotFound

+ 1 - 1
api/controllers/console/init_validate.py

@@ -1,7 +1,7 @@
 import os
 
 from flask import session
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse  # type: ignore
 from sqlalchemy import select
 from sqlalchemy.orm import Session
 

+ 1 - 1
api/controllers/console/workspace/__init__.py

@@ -1,6 +1,6 @@
 from functools import wraps
 
-from flask_login import current_user
+from flask_login import current_user  # type: ignore
 from sqlalchemy.orm import Session
 from werkzeug.exceptions import Forbidden
 

+ 2 - 2
api/controllers/console/workspace/agent_providers.py

@@ -1,5 +1,5 @@
-from flask_login import current_user
-from flask_restful import Resource
+from flask_login import current_user  # type: ignore
+from flask_restful import Resource  # type: ignore
 
 from controllers.console import api
 from controllers.console.wraps import account_initialization_required, setup_required

+ 2 - 2
api/controllers/console/workspace/endpoint.py

@@ -1,5 +1,5 @@
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user  # type: ignore
+from flask_restful import Resource, reqparse  # type: ignore
 from werkzeug.exceptions import Forbidden
 
 from controllers.console import api

+ 2 - 2
api/controllers/console/workspace/plugin.py

@@ -1,8 +1,8 @@
 import io
 
 from flask import request, send_file
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user  # type: ignore
+from flask_restful import Resource, reqparse  # type: ignore
 from werkzeug.exceptions import Forbidden
 
 from configs import dify_config

+ 1 - 1
api/controllers/files/upload.py

@@ -1,5 +1,5 @@
 from flask import request
-from flask_restful import Resource, marshal_with
+from flask_restful import Resource, marshal_with  # type: ignore
 from werkzeug.exceptions import Forbidden
 
 import services

+ 1 - 1
api/controllers/inner_api/plugin/plugin.py

@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restful import Resource  # type: ignore
 
 from controllers.console.wraps import setup_required
 from controllers.inner_api import api

+ 1 - 1
api/controllers/inner_api/plugin/wraps.py

@@ -3,7 +3,7 @@ from functools import wraps
 from typing import Optional
 
 from flask import request
-from flask_restful import reqparse
+from flask_restful import reqparse  # type: ignore
 from pydantic import BaseModel
 from sqlalchemy.orm import Session
 

+ 1 - 1
api/core/agent/cot_agent_runner.py

@@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                 callbacks=[],
             )
 
-            usage_dict = {}
+            usage_dict: dict[str, Optional[LLMUsage]] = {}
             react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
             scratchpad = AgentScratchpadUnit(
                 agent_response="",

+ 1 - 1
api/core/app/apps/advanced_chat/app_runner.py

@@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
                 workflow=workflow,
                 node_id=self.application_generate_entity.single_iteration_run.node_id,
-                user_inputs=self.application_generate_entity.single_iteration_run.inputs,
+                user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
             )
         else:
             inputs = self.application_generate_entity.inputs

+ 3 - 1
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -644,7 +644,9 @@ class AdvancedChatAppGenerateTaskPipeline:
 
                 yield self._message_end_to_stream_response()
             elif isinstance(event, QueueAgentLogEvent):
-                yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
+                yield self._workflow_cycle_manager._handle_agent_log(
+                    task_id=self._application_generate_entity.task_id, event=event
+                )
             else:
                 continue
 

+ 2 - 2
api/core/app/apps/base_app_generate_response_converter.py

@@ -1,7 +1,7 @@
 import logging
 from abc import ABC, abstractmethod
 from collections.abc import Generator
-from typing import Any, Union
+from typing import Any, Mapping, Union
 
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
     @classmethod
     def convert(
         cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
-    ) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
+    ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
         if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
             if isinstance(response, AppBlockingResponse):
                 return cls.convert_blocking_full_response(response)

+ 4 - 4
api/core/app/apps/completion/app_generator.py

@@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         args: Mapping[str, Any],
         invoke_from: InvokeFrom,
         streaming: Literal[True],
-    ) -> Generator[str, None, None]: ...
+    ) -> Generator[str | Mapping[str, Any], None, None]: ...
 
     @overload
     def generate(
@@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         args: Mapping[str, Any],
         invoke_from: InvokeFrom,
         streaming: bool = False,
-    ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
+    ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
 
     def generate(
         self,
@@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         args: Mapping[str, Any],
         invoke_from: InvokeFrom,
         streaming: bool = True,
-    ) -> Union[Mapping[str, Any], Generator[str, None, None]]:
+    ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
         """
         Generate App response.
 
@@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         user: Union[Account, EndUser],
         invoke_from: InvokeFrom,
         stream: bool = True,
-    ) -> Union[Mapping, Generator[str, None, None]]:
+    ) -> Union[Mapping, Generator[Mapping | str, None, None]]:
         """
         Generate App response.
 

+ 3 - 1
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -531,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline:
                     delta_text, from_variable_selector=event.from_variable_selector
                 )
             elif isinstance(event, QueueAgentLogEvent):
-                yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
+                yield self._workflow_cycle_manager._handle_agent_log(
+                    task_id=self._application_generate_entity.task_id, event=event
+                )
             else:
                 continue
 

+ 2 - 2
api/core/model_manager.py

@@ -178,7 +178,7 @@ class ModelInstance:
 
     def get_llm_num_tokens(
         self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
-    ) -> list[int]:
+    ) -> int:
         """
         Get number of tokens for llm
 
@@ -191,7 +191,7 @@ class ModelInstance:
 
         self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
         return cast(
-            list[int],
+            int,
             self._round_robin_invoke(
                 function=self.model_type_instance.get_num_tokens,
                 model=self.model,

+ 2 - 2
api/core/plugin/backwards_invocation/app.py

@@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
         stream: bool,
         inputs: Mapping,
         files: list[dict],
-    ):
+    ) -> Generator[Mapping | str, None, None] | Mapping:
         """
         invoke workflow app
         """
@@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
         stream: bool,
         inputs: Mapping,
         files: list[dict],
-    ):
+    ) -> Generator[Mapping | str, None, None] | Mapping:
         """
         invoke completion app
         """

+ 8 - 8
api/core/plugin/backwards_invocation/model.py

@@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
             return summary.message.content
 
         lines = content.split("\n")
-        new_lines = []
+        new_lines: list[str] = []
         # split long line into multiple lines
         for i in range(len(lines)):
             line = lines[i]
@@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
 
         # merge lines into messages with max tokens
         messages: list[str] = []
-        for i in new_lines:
+        for i in new_lines:  # type: ignore
             if len(messages) == 0:
-                messages.append(i)
+                messages.append(i)  # type: ignore
             else:
-                if len(messages[-1]) + len(i) < max_tokens * 0.5:
-                    messages[-1] += i
-                if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
-                    messages.append(i)
+                if len(messages[-1]) + len(i) < max_tokens * 0.5:  # type: ignore
+                    messages[-1] += i  # type: ignore
+                if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:  # type: ignore
+                    messages.append(i)  # type: ignore
                 else:
-                    messages[-1] += i
+                    messages[-1] += i  # type: ignore
 
         summaries = []
         for i in range(len(messages)):

+ 1 - 1
api/core/plugin/manager/base.py

@@ -103,7 +103,7 @@ class BasePluginManager:
         Make a stream request to the plugin daemon inner API and yield the response as a model.
         """
         for line in self._stream_request(method, path, params, headers, data, files):
-            yield type(**json.loads(line))
+            yield type(**json.loads(line))  # type: ignore
 
     def _request_with_model(
         self,

+ 6 - 1
api/core/tools/builtin_tool/providers/audio/tools/asr.py

@@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
                 items.append((provider, model.model))
         return items
 
-    def get_runtime_parameters(self) -> list[ToolParameter]:
+    def get_runtime_parameters(
+        self,
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> list[ToolParameter]:
         parameters = []
 
         options = []

+ 6 - 1
api/core/tools/builtin_tool/providers/audio/tools/tts.py

@@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
                 items.append((provider, model.model, voices))
         return items
 
-    def get_runtime_parameters(self) -> list[ToolParameter]:
+    def get_runtime_parameters(
+        self,
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> list[ToolParameter]:
         parameters = []
 
         options = []

+ 1 - 1
api/core/tools/entities/tool_entities.py

@@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
 
         @field_validator("variable_name", mode="before")
         @classmethod
-        def transform_variable_name(cls, value) -> str:
+        def transform_variable_name(cls, value: str) -> str:
             """
             The variable name must be a string.
             """

+ 2 - 2
api/libs/helper.py

@@ -9,7 +9,7 @@ import uuid
 from collections.abc import Generator
 from datetime import datetime
 from hashlib import sha256
-from typing import TYPE_CHECKING, Any, Optional, Union, cast
+from typing import TYPE_CHECKING, Any, Mapping, Optional, Union, cast
 from zoneinfo import available_timezones
 
 from flask import Response, stream_with_context
@@ -182,7 +182,7 @@ def generate_text_hash(text: str) -> str:
     return sha256(hash_text.encode()).hexdigest()
 
 
-def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
+def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
     if isinstance(response, dict):
         return Response(response=json.dumps(response), status=200, mimetype="application/json")
     else:

+ 3 - 0
api/services/account_service.py

@@ -900,6 +900,9 @@ class RegisterService:
     def invite_new_member(
         cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
     ) -> str:
+        if not inviter:
+            raise ValueError("Inviter is required")
+
         """Invite new member"""
         with Session(db.engine) as session:
             account = session.query(Account).filter_by(email=email).first()

+ 1 - 1
api/services/workflow_service.py

@@ -298,7 +298,7 @@ class WorkflowService:
         start_at: float,
         tenant_id: str,
         node_id: str,
-    ):
+    ) -> WorkflowNodeExecution:
         """
         Handle node run result