소스 검색

fix(core): Fix incorrect type hints. (#5427)

-LAN- 11 달 전
부모
커밋
23fa3dedc4

+ 4 - 2
api/core/extension/extensible.py

@@ -1,5 +1,5 @@
 import enum
 import enum
-import importlib
+import importlib.util
 import json
 import json
 import logging
 import logging
 import os
 import os
@@ -74,6 +74,8 @@ class Extensible:
                 # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
                 # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
                 py_path = os.path.join(subdir_path, extension_name + '.py')
                 py_path = os.path.join(subdir_path, extension_name + '.py')
                 spec = importlib.util.spec_from_file_location(extension_name, py_path)
                 spec = importlib.util.spec_from_file_location(extension_name, py_path)
+                if not spec or not spec.loader:
+                    raise Exception(f"Failed to load module {extension_name} from {py_path}")
                 mod = importlib.util.module_from_spec(spec)
                 mod = importlib.util.module_from_spec(spec)
                 spec.loader.exec_module(mod)
                 spec.loader.exec_module(mod)
 
 
@@ -108,6 +110,6 @@ class Extensible:
                     position=position
                     position=position
                 ))
                 ))
 
 
-        sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
+        sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
 
 
         return sorted_extensions
         return sorted_extensions

+ 10 - 11
api/core/helper/module_import_helper.py

@@ -5,11 +5,7 @@ from types import ModuleType
 from typing import AnyStr
 from typing import AnyStr
 
 
 
 
-def import_module_from_source(
-        module_name: str,
-        py_file_path: AnyStr,
-        use_lazy_loader: bool = False
-) -> ModuleType:
+def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
     """
     """
     Importing a module from the source file directly
     Importing a module from the source file directly
     """
     """
@@ -17,9 +13,13 @@ def import_module_from_source(
         existed_spec = importlib.util.find_spec(module_name)
         existed_spec = importlib.util.find_spec(module_name)
         if existed_spec:
         if existed_spec:
             spec = existed_spec
             spec = existed_spec
+            if not spec.loader:
+                raise Exception(f"Failed to load module {module_name} from {py_file_path}")
         else:
         else:
             # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
             # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
             spec = importlib.util.spec_from_file_location(module_name, py_file_path)
             spec = importlib.util.spec_from_file_location(module_name, py_file_path)
+            if not spec or not spec.loader:
+                raise Exception(f"Failed to load module {module_name} from {py_file_path}")
             if use_lazy_loader:
             if use_lazy_loader:
                 # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
                 # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
                 spec.loader = importlib.util.LazyLoader(spec.loader)
                 spec.loader = importlib.util.LazyLoader(spec.loader)
@@ -29,7 +29,7 @@ def import_module_from_source(
         spec.loader.exec_module(module)
         spec.loader.exec_module(module)
         return module
         return module
     except Exception as e:
     except Exception as e:
-        logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
+        logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
         raise e
         raise e
 
 
 
 
@@ -43,15 +43,14 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
 
 
 
 
 def load_single_subclass_from_source(
 def load_single_subclass_from_source(
-        module_name: str,
-        script_path: AnyStr,
-        parent_type: type,
-        use_lazy_loader: bool = False,
+    *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
 ) -> type:
 ) -> type:
     """
     """
     Load a single subclass from the source
     Load a single subclass from the source
     """
     """
-    module = import_module_from_source(module_name, script_path, use_lazy_loader)
+    module = import_module_from_source(
+        module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
+    )
     subclasses = get_subclasses_from_module(module, parent_type)
     subclasses = get_subclasses_from_module(module, parent_type)
     match len(subclasses):
     match len(subclasses):
         case 1:
         case 1:

+ 2 - 5
api/core/helper/position_helper.py

@@ -1,15 +1,12 @@
 import os
 import os
 from collections import OrderedDict
 from collections import OrderedDict
 from collections.abc import Callable
 from collections.abc import Callable
-from typing import Any, AnyStr
+from typing import Any
 
 
 from core.tools.utils.yaml_utils import load_yaml_file
 from core.tools.utils.yaml_utils import load_yaml_file
 
 
 
 
-def get_position_map(
-        folder_path: AnyStr,
-        file_name: str = '_position.yaml',
-) -> dict[str, int]:
+def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
     """
     """
     Get the mapping from name to index from a YAML file
     Get the mapping from name to index from a YAML file
     :param folder_path:
     :param folder_path:

+ 9 - 4
api/core/model_manager.py

@@ -1,6 +1,6 @@
 import logging
 import logging
 import os
 import os
-from collections.abc import Generator
+from collections.abc import Callable, Generator
 from typing import IO, Optional, Union, cast
 from typing import IO, Optional, Union, cast
 
 
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@@ -102,7 +102,7 @@ class ModelInstance:
 
 
     def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
     def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
                    tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
                    tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
-                   stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
+                   stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
             -> Union[LLMResult, Generator]:
             -> Union[LLMResult, Generator]:
         """
         """
         Invoke large language model
         Invoke large language model
@@ -291,7 +291,7 @@ class ModelInstance:
             streaming=streaming
             streaming=streaming
         )
         )
 
 
-    def _round_robin_invoke(self, function: callable, *args, **kwargs):
+    def _round_robin_invoke(self, function: Callable, *args, **kwargs):
         """
         """
         Round-robin invoke
         Round-robin invoke
         :param function: function to invoke
         :param function: function to invoke
@@ -437,6 +437,7 @@ class LBModelManager:
 
 
         while True:
         while True:
             current_index = redis_client.incr(cache_key)
             current_index = redis_client.incr(cache_key)
+            current_index = cast(int, current_index)
             if current_index >= 10000000:
             if current_index >= 10000000:
                 current_index = 1
                 current_index = 1
                 redis_client.set(cache_key, current_index)
                 redis_client.set(cache_key, current_index)
@@ -499,7 +500,10 @@ class LBModelManager:
             config.id
             config.id
         )
         )
 
 
-        return redis_client.exists(cooldown_cache_key)
+
+        res = redis_client.exists(cooldown_cache_key)
+        res = cast(bool, res)
+        return res
 
 
     @classmethod
     @classmethod
     def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
     def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
@@ -528,4 +532,5 @@ class LBModelManager:
         if ttl == -2:
         if ttl == -2:
             return False, 0
             return False, 0
 
 
+        ttl = cast(int, ttl)
         return True, ttl
         return True, ttl

+ 5 - 4
api/core/model_runtime/entities/provider_entities.py

@@ -1,10 +1,11 @@
+from collections.abc import Sequence
 from enum import Enum
 from enum import Enum
 from typing import Optional
 from typing import Optional
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 
 
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
+from core.model_runtime.entities.model_entities import ModelType, ProviderModel
 
 
 
 
 class ConfigurateMethod(Enum):
 class ConfigurateMethod(Enum):
@@ -93,8 +94,8 @@ class SimpleProviderEntity(BaseModel):
     label: I18nObject
     label: I18nObject
     icon_small: Optional[I18nObject] = None
     icon_small: Optional[I18nObject] = None
     icon_large: Optional[I18nObject] = None
     icon_large: Optional[I18nObject] = None
-    supported_model_types: list[ModelType]
-    models: list[AIModelEntity] = []
+    supported_model_types: Sequence[ModelType]
+    models: list[ProviderModel] = []
 
 
 
 
 class ProviderHelpEntity(BaseModel):
 class ProviderHelpEntity(BaseModel):
@@ -116,7 +117,7 @@ class ProviderEntity(BaseModel):
     icon_large: Optional[I18nObject] = None
     icon_large: Optional[I18nObject] = None
     background: Optional[str] = None
     background: Optional[str] = None
     help: Optional[ProviderHelpEntity] = None
     help: Optional[ProviderHelpEntity] = None
-    supported_model_types: list[ModelType]
+    supported_model_types: Sequence[ModelType]
     configurate_methods: list[ConfigurateMethod]
     configurate_methods: list[ConfigurateMethod]
     models: list[ProviderModel] = []
     models: list[ProviderModel] = []
     provider_credential_schema: Optional[ProviderCredentialSchema] = None
     provider_credential_schema: Optional[ProviderCredentialSchema] = None

+ 33 - 19
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -1,6 +1,7 @@
 import decimal
 import decimal
 import os
 import os
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from collections.abc import Mapping
 from typing import Optional
 from typing import Optional
 
 
 from pydantic import ConfigDict
 from pydantic import ConfigDict
@@ -26,15 +27,16 @@ class AIModel(ABC):
     """
     """
     Base class for all models.
     Base class for all models.
     """
     """
+
     model_type: ModelType
     model_type: ModelType
-    model_schemas: list[AIModelEntity] = None
+    model_schemas: Optional[list[AIModelEntity]] = None
     started_at: float = 0
     started_at: float = 0
 
 
     # pydantic configs
     # pydantic configs
     model_config = ConfigDict(protected_namespaces=())
     model_config = ConfigDict(protected_namespaces=())
 
 
     @abstractmethod
     @abstractmethod
-    def validate_credentials(self, model: str, credentials: dict) -> None:
+    def validate_credentials(self, model: str, credentials: Mapping) -> None:
         """
         """
         Validate model credentials
         Validate model credentials
 
 
@@ -90,8 +92,8 @@ class AIModel(ABC):
 
 
         # get price info from predefined model schema
         # get price info from predefined model schema
         price_config: Optional[PriceConfig] = None
         price_config: Optional[PriceConfig] = None
-        if model_schema:
-            price_config: PriceConfig = model_schema.pricing
+        if model_schema and model_schema.pricing:
+            price_config = model_schema.pricing
 
 
         # get unit price
         # get unit price
         unit_price = None
         unit_price = None
@@ -103,13 +105,15 @@ class AIModel(ABC):
 
 
         if unit_price is None:
         if unit_price is None:
             return PriceInfo(
             return PriceInfo(
-                unit_price=decimal.Decimal('0.0'),
-                unit=decimal.Decimal('0.0'),
-                total_amount=decimal.Decimal('0.0'),
+                unit_price=decimal.Decimal("0.0"),
+                unit=decimal.Decimal("0.0"),
+                total_amount=decimal.Decimal("0.0"),
                 currency="USD",
                 currency="USD",
             )
             )
 
 
         # calculate total amount
         # calculate total amount
+        if not price_config:
+            raise ValueError(f"Price config not found for model {model}")
         total_amount = tokens * unit_price * price_config.unit
         total_amount = tokens * unit_price * price_config.unit
         total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
         total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
 
 
@@ -209,7 +213,7 @@ class AIModel(ABC):
 
 
         return model_schemas
         return model_schemas
 
 
-    def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
+    def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
         """
         """
         Get model schema by model name and credentials
         Get model schema by model name and credentials
 
 
@@ -231,7 +235,7 @@ class AIModel(ABC):
 
 
         return None
         return None
 
 
-    def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+    def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
         """
         """
         Get customizable model schema from credentials
         Get customizable model schema from credentials
 
 
@@ -240,8 +244,8 @@ class AIModel(ABC):
         :return: model schema
         :return: model schema
         """
         """
         return self._get_customizable_model_schema(model, credentials)
         return self._get_customizable_model_schema(model, credentials)
-    
-    def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+
+    def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
         """
         """
         Get customizable model schema and fill in the template
         Get customizable model schema and fill in the template
         """
         """
@@ -249,7 +253,7 @@ class AIModel(ABC):
 
 
         if not schema:
         if not schema:
             return None
             return None
-        
+
         # fill in the template
         # fill in the template
         new_parameter_rules = []
         new_parameter_rules = []
         for parameter_rule in schema.parameter_rules:
         for parameter_rule in schema.parameter_rules:
@@ -271,10 +275,20 @@ class AIModel(ABC):
                         parameter_rule.help = I18nObject(
                         parameter_rule.help = I18nObject(
                             en_US=default_parameter_rule['help']['en_US'],
                             en_US=default_parameter_rule['help']['en_US'],
                         )
                         )
-                    if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
-                        parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
-                    if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
-                        parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
+                    if (
+                        parameter_rule.help
+                        and not parameter_rule.help.en_US
+                        and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
+                    ):
+                        parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
+                    if (
+                        parameter_rule.help
+                        and not parameter_rule.help.zh_Hans
+                        and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
+                    ):
+                        parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
+                            "zh_Hans", default_parameter_rule["help"]["en_US"]
+                        )
                 except ValueError:
                 except ValueError:
                     pass
                     pass
 
 
@@ -284,7 +298,7 @@ class AIModel(ABC):
 
 
         return schema
         return schema
 
 
-    def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+    def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
         """
         """
         Get customizable model schema
         Get customizable model schema
 
 
@@ -304,7 +318,7 @@ class AIModel(ABC):
         default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
         default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
 
 
         if not default_parameter_rule:
         if not default_parameter_rule:
-            raise Exception(f'Invalid model parameter rule name {name}')
+            raise Exception(f"Invalid model parameter rule name {name}")
 
 
         return default_parameter_rule
         return default_parameter_rule
 
 
@@ -318,4 +332,4 @@ class AIModel(ABC):
         :param text: plain text of prompt. You need to convert the original message to plain text
         :param text: plain text of prompt. You need to convert the original message to plain text
         :return: number of tokens
         :return: number of tokens
         """
         """
-        return GPT2Tokenizer.get_num_tokens(text)
+        return GPT2Tokenizer.get_num_tokens(text)

+ 17 - 14
api/core/model_runtime/model_providers/__base/large_language_model.py

@@ -3,7 +3,7 @@ import os
 import re
 import re
 import time
 import time
 from abc import abstractmethod
 from abc import abstractmethod
-from collections.abc import Generator
+from collections.abc import Generator, Mapping
 from typing import Optional, Union
 from typing import Optional, Union
 
 
 from pydantic import ConfigDict
 from pydantic import ConfigDict
@@ -43,7 +43,7 @@ class LargeLanguageModel(AIModel):
     def invoke(self, model: str, credentials: dict,
     def invoke(self, model: str, credentials: dict,
                prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
                prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
-               stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
+               stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
             -> Union[LLMResult, Generator]:
             -> Union[LLMResult, Generator]:
         """
         """
         Invoke large language model
         Invoke large language model
@@ -129,7 +129,7 @@ class LargeLanguageModel(AIModel):
                 user=user,
                 user=user,
                 callbacks=callbacks
                 callbacks=callbacks
             )
             )
-        else:
+        elif isinstance(result, LLMResult):
             self._trigger_after_invoke_callbacks(
             self._trigger_after_invoke_callbacks(
                 model=model,
                 model=model,
                 result=result,
                 result=result,
@@ -148,7 +148,7 @@ class LargeLanguageModel(AIModel):
     def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
     def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                            model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
                            model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
                            stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
                            stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
-                           callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
+                           callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
         """
         """
         Code block mode wrapper, ensure the response is a code block with output markdown quote
         Code block mode wrapper, ensure the response is a code block with output markdown quote
 
 
@@ -196,7 +196,7 @@ if you are not sure about the structure.
             # override the system message
             # override the system message
             prompt_messages[0] = SystemPromptMessage(
             prompt_messages[0] = SystemPromptMessage(
                 content=block_prompts
                 content=block_prompts
-                    .replace("{{instructions}}", prompt_messages[0].content)
+                    .replace("{{instructions}}", str(prompt_messages[0].content))
             )
             )
         else:
         else:
             # insert the system message
             # insert the system message
@@ -274,8 +274,9 @@ if you are not sure about the structure.
             else:
             else:
                 yield piece
                 yield piece
                 continue
                 continue
-            new_piece = ""
+            new_piece: str = ""
             for char in piece:
             for char in piece:
+                char = str(char)
                 if state == "normal":
                 if state == "normal":
                     if char == "`":
                     if char == "`":
                         state = "in_backticks"
                         state = "in_backticks"
@@ -340,7 +341,7 @@ if you are not sure about the structure.
             if state == "done":
             if state == "done":
                 continue
                 continue
 
 
-            new_piece = ""
+            new_piece: str = ""
             for char in piece:
             for char in piece:
                 if state == "search_start":
                 if state == "search_start":
                     if char == "`":
                     if char == "`":
@@ -365,7 +366,7 @@ if you are not sure about the structure.
                             # If backticks were counted but we're still collecting content, it was a false start
                             # If backticks were counted but we're still collecting content, it was a false start
                             new_piece += "`" * backtick_count
                             new_piece += "`" * backtick_count
                             backtick_count = 0
                             backtick_count = 0
-                        new_piece += char
+                        new_piece += str(char)
 
 
                 elif state == "done":
                 elif state == "done":
                     break
                     break
@@ -388,13 +389,14 @@ if you are not sure about the structure.
                                  prompt_messages: list[PromptMessage], model_parameters: dict,
                                  prompt_messages: list[PromptMessage], model_parameters: dict,
                                  tools: Optional[list[PromptMessageTool]] = None,
                                  tools: Optional[list[PromptMessageTool]] = None,
                                  stop: Optional[list[str]] = None, stream: bool = True,
                                  stop: Optional[list[str]] = None, stream: bool = True,
-                                 user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
+                                 user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
         """
         """
         Invoke result generator
         Invoke result generator
 
 
         :param result: result generator
         :param result: result generator
         :return: result generator
         :return: result generator
         """
         """
+        callbacks = callbacks or []
         prompt_message = AssistantPromptMessage(
         prompt_message = AssistantPromptMessage(
             content=""
             content=""
         )
         )
@@ -489,6 +491,7 @@ if you are not sure about the structure.
 
 
     def _llm_result_to_stream(self, result: LLMResult) -> Generator:
     def _llm_result_to_stream(self, result: LLMResult) -> Generator:
         """
         """
+from typing_extensions import deprecated
         Transform llm result to stream
         Transform llm result to stream
 
 
         :param result: llm result
         :param result: llm result
@@ -531,7 +534,7 @@ if you are not sure about the structure.
 
 
         return []
         return []
 
 
-    def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
+    def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
         """
         """
         Get model mode
         Get model mode
 
 
@@ -595,7 +598,7 @@ if you are not sure about the structure.
                                          prompt_messages: list[PromptMessage], model_parameters: dict,
                                          prompt_messages: list[PromptMessage], model_parameters: dict,
                                          tools: Optional[list[PromptMessageTool]] = None,
                                          tools: Optional[list[PromptMessageTool]] = None,
                                          stop: Optional[list[str]] = None, stream: bool = True,
                                          stop: Optional[list[str]] = None, stream: bool = True,
-                                         user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
+                                         user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
         """
         """
         Trigger before invoke callbacks
         Trigger before invoke callbacks
 
 
@@ -633,7 +636,7 @@ if you are not sure about the structure.
                                      prompt_messages: list[PromptMessage], model_parameters: dict,
                                      prompt_messages: list[PromptMessage], model_parameters: dict,
                                      tools: Optional[list[PromptMessageTool]] = None,
                                      tools: Optional[list[PromptMessageTool]] = None,
                                      stop: Optional[list[str]] = None, stream: bool = True,
                                      stop: Optional[list[str]] = None, stream: bool = True,
-                                     user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
+                                     user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
         """
         """
         Trigger new chunk callbacks
         Trigger new chunk callbacks
 
 
@@ -672,7 +675,7 @@ if you are not sure about the structure.
                                         prompt_messages: list[PromptMessage], model_parameters: dict,
                                         prompt_messages: list[PromptMessage], model_parameters: dict,
                                         tools: Optional[list[PromptMessageTool]] = None,
                                         tools: Optional[list[PromptMessageTool]] = None,
                                         stop: Optional[list[str]] = None, stream: bool = True,
                                         stop: Optional[list[str]] = None, stream: bool = True,
-                                        user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
+                                        user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
         """
         """
         Trigger after invoke callbacks
         Trigger after invoke callbacks
 
 
@@ -712,7 +715,7 @@ if you are not sure about the structure.
                                         prompt_messages: list[PromptMessage], model_parameters: dict,
                                         prompt_messages: list[PromptMessage], model_parameters: dict,
                                         tools: Optional[list[PromptMessageTool]] = None,
                                         tools: Optional[list[PromptMessageTool]] = None,
                                         stop: Optional[list[str]] = None, stream: bool = True,
                                         stop: Optional[list[str]] = None, stream: bool = True,
-                                        user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
+                                        user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
         """
         """
         Trigger invoke error callbacks
         Trigger invoke error callbacks
 
 

+ 18 - 11
api/core/model_runtime/model_providers/__base/model_provider.py

@@ -1,5 +1,6 @@
 import os
 import os
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from typing import Optional
 
 
 from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
 from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@@ -9,7 +10,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
 
 
 
 
 class ModelProvider(ABC):
 class ModelProvider(ABC):
-    provider_schema: ProviderEntity = None
+    provider_schema: Optional[ProviderEntity] = None
     model_instance_map: dict[str, AIModel] = {}
     model_instance_map: dict[str, AIModel] = {}
 
 
     @abstractmethod
     @abstractmethod
@@ -28,23 +29,23 @@ class ModelProvider(ABC):
     def get_provider_schema(self) -> ProviderEntity:
     def get_provider_schema(self) -> ProviderEntity:
         """
         """
         Get provider schema
         Get provider schema
-
+    
         :return: provider schema
         :return: provider schema
         """
         """
         if self.provider_schema:
         if self.provider_schema:
             return self.provider_schema
             return self.provider_schema
-
+    
         # get dirname of the current path
         # get dirname of the current path
         provider_name = self.__class__.__module__.split('.')[-1]
         provider_name = self.__class__.__module__.split('.')[-1]
 
 
         # get the path of the model_provider classes
         # get the path of the model_provider classes
         base_path = os.path.abspath(__file__)
         base_path = os.path.abspath(__file__)
         current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
         current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
-
+    
         # read provider schema from yaml file
         # read provider schema from yaml file
         yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
         yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
         yaml_data = load_yaml_file(yaml_path, ignore_error=True)
         yaml_data = load_yaml_file(yaml_path, ignore_error=True)
-
+    
         try:
         try:
             # yaml_data to entity
             # yaml_data to entity
             provider_schema = ProviderEntity(**yaml_data)
             provider_schema = ProviderEntity(**yaml_data)
@@ -53,7 +54,7 @@ class ModelProvider(ABC):
 
 
         # cache schema
         # cache schema
         self.provider_schema = provider_schema
         self.provider_schema = provider_schema
-
+    
         return provider_schema
         return provider_schema
 
 
     def models(self, model_type: ModelType) -> list[AIModelEntity]:
     def models(self, model_type: ModelType) -> list[AIModelEntity]:
@@ -84,7 +85,7 @@ class ModelProvider(ABC):
         :return:
         :return:
         """
         """
         # get dirname of the current path
         # get dirname of the current path
-        provider_name = self.__class__.__module__.split('.')[-1]
+        provider_name = self.__class__.__module__.split(".")[-1]
 
 
         if f"{provider_name}.{model_type.value}" in self.model_instance_map:
         if f"{provider_name}.{model_type.value}" in self.model_instance_map:
             return self.model_instance_map[f"{provider_name}.{model_type.value}"]
             return self.model_instance_map[f"{provider_name}.{model_type.value}"]
@@ -101,11 +102,17 @@ class ModelProvider(ABC):
         # Dynamic loading {model_type_name}.py file and find the subclass of AIModel
         # Dynamic loading {model_type_name}.py file and find the subclass of AIModel
         parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
         parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
         mod = import_module_from_source(
         mod = import_module_from_source(
-            f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
-        model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
-                                  get_subclasses_from_module(mod, AIModel)), None)
+            module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
+        )
+        model_class = next(
+            filter(
+                lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
+                get_subclasses_from_module(mod, AIModel),
+            ),
+            None,
+        )
         if not model_class:
         if not model_class:
-            raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
+            raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
 
 
         model_instance_map = model_class()
         model_instance_map = model_class()
         self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
         self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map

+ 54 - 26
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -1,5 +1,6 @@
 import logging
 import logging
 import os
 import os
+from collections.abc import Sequence
 from typing import Optional
 from typing import Optional
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
@@ -16,20 +17,21 @@ logger = logging.getLogger(__name__)
 
 
 
 
 class ModelProviderExtension(BaseModel):
 class ModelProviderExtension(BaseModel):
+    model_config = ConfigDict(arbitrary_types_allowed=True)
+
     provider_instance: ModelProvider
     provider_instance: ModelProvider
     name: str
     name: str
     position: Optional[int] = None
     position: Optional[int] = None
-    model_config = ConfigDict(arbitrary_types_allowed=True)
 
 
 
 
 class ModelProviderFactory:
 class ModelProviderFactory:
-    model_provider_extensions: dict[str, ModelProviderExtension] = None
+    model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None
 
 
     def __init__(self) -> None:
     def __init__(self) -> None:
         # for cache in memory
         # for cache in memory
         self.get_providers()
         self.get_providers()
 
 
-    def get_providers(self) -> list[ProviderEntity]:
+    def get_providers(self) -> Sequence[ProviderEntity]:
         """
         """
         Get all providers
         Get all providers
         :return: list of providers
         :return: list of providers
@@ -39,7 +41,7 @@ class ModelProviderFactory:
 
 
         # traverse all model_provider_extensions
         # traverse all model_provider_extensions
         providers = []
         providers = []
-        for name, model_provider_extension in model_provider_extensions.items():
+        for model_provider_extension in model_provider_extensions.values():
             # get model_provider instance
             # get model_provider instance
             model_provider_instance = model_provider_extension.provider_instance
             model_provider_instance = model_provider_extension.provider_instance
 
 
@@ -57,7 +59,7 @@ class ModelProviderFactory:
         # return providers
         # return providers
         return providers
         return providers
 
 
-    def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
+    def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict:
         """
         """
         Validate provider credentials
         Validate provider credentials
 
 
@@ -74,6 +76,9 @@ class ModelProviderFactory:
         # get provider_credential_schema and validate credentials according to the rules
         # get provider_credential_schema and validate credentials according to the rules
         provider_credential_schema = provider_schema.provider_credential_schema
         provider_credential_schema = provider_schema.provider_credential_schema
 
 
+        if not provider_credential_schema:
+            raise ValueError(f"Provider {provider} does not have provider_credential_schema")
+
         # validate provider credential schema
         # validate provider credential schema
         validator = ProviderCredentialSchemaValidator(provider_credential_schema)
         validator = ProviderCredentialSchemaValidator(provider_credential_schema)
         filtered_credentials = validator.validate_and_filter(credentials)
         filtered_credentials = validator.validate_and_filter(credentials)
@@ -83,8 +88,9 @@ class ModelProviderFactory:
 
 
         return filtered_credentials
         return filtered_credentials
 
 
-    def model_credentials_validate(self, provider: str, model_type: ModelType,
-                                   model: str, credentials: dict) -> dict:
+    def model_credentials_validate(
+        self, *, provider: str, model_type: ModelType, model: str, credentials: dict
+    ) -> dict:
         """
         """
         Validate model credentials
         Validate model credentials
 
 
@@ -103,6 +109,9 @@ class ModelProviderFactory:
         # get model_credential_schema and validate credentials according to the rules
         # get model_credential_schema and validate credentials according to the rules
         model_credential_schema = provider_schema.model_credential_schema
         model_credential_schema = provider_schema.model_credential_schema
 
 
+        if not model_credential_schema:
+            raise ValueError(f"Provider {provider} does not have model_credential_schema")
+
         # validate model credential schema
         # validate model credential schema
         validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
         validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
         filtered_credentials = validator.validate_and_filter(credentials)
         filtered_credentials = validator.validate_and_filter(credentials)
@@ -115,11 +124,13 @@ class ModelProviderFactory:
 
 
         return filtered_credentials
         return filtered_credentials
 
 
-    def get_models(self,
-                   provider: Optional[str] = None,
-                   model_type: Optional[ModelType] = None,
-                   provider_configs: Optional[list[ProviderConfig]] = None) \
-            -> list[SimpleProviderEntity]:
+    def get_models(
+        self,
+        *,
+        provider: Optional[str] = None,
+        model_type: Optional[ModelType] = None,
+        provider_configs: Optional[list[ProviderConfig]] = None,
+    ) -> list[SimpleProviderEntity]:
         """
         """
         Get all models for given model type
         Get all models for given model type
 
 
@@ -128,6 +139,8 @@ class ModelProviderFactory:
         :param provider_configs: list of provider configs
         :param provider_configs: list of provider configs
         :return: list of models
         :return: list of models
         """
         """
+        provider_configs = provider_configs or []
+
         # scan all providers
         # scan all providers
         model_provider_extensions = self._get_model_provider_map()
         model_provider_extensions = self._get_model_provider_map()
 
 
@@ -184,7 +197,7 @@ class ModelProviderFactory:
         # get the provider extension
         # get the provider extension
         model_provider_extension = model_provider_extensions.get(provider)
         model_provider_extension = model_provider_extensions.get(provider)
         if not model_provider_extension:
         if not model_provider_extension:
-            raise Exception(f'Invalid provider: {provider}')
+            raise Exception(f"Invalid provider: {provider}")
 
 
         # get the provider instance
         # get the provider instance
         model_provider_instance = model_provider_extension.provider_instance
         model_provider_instance = model_provider_extension.provider_instance
@@ -192,10 +205,22 @@ class ModelProviderFactory:
         return model_provider_instance
         return model_provider_instance
 
 
     def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
     def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
+        """
+        Retrieves the model provider map.
+
+        This method retrieves the model provider map, which is a dictionary containing the model provider names as keys
+        and instances of `ModelProviderExtension` as values. The model provider map is used to store information about
+        available model providers.
+
+        Returns:
+            A dictionary containing the model provider map.
+
+        Raises:
+            None.
+        """
         if self.model_provider_extensions:
         if self.model_provider_extensions:
             return self.model_provider_extensions
             return self.model_provider_extensions
 
 
-
         # get the path of current classes
         # get the path of current classes
         current_path = os.path.abspath(__file__)
         current_path = os.path.abspath(__file__)
         model_providers_path = os.path.dirname(current_path)
         model_providers_path = os.path.dirname(current_path)
@@ -204,8 +229,8 @@ class ModelProviderFactory:
         model_provider_dir_paths = [
         model_provider_dir_paths = [
             os.path.join(model_providers_path, model_provider_dir)
             os.path.join(model_providers_path, model_provider_dir)
             for model_provider_dir in os.listdir(model_providers_path)
             for model_provider_dir in os.listdir(model_providers_path)
-            if not model_provider_dir.startswith('__')
-               and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
+            if not model_provider_dir.startswith("__")
+            and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
         ]
         ]
 
 
         # get _position.yaml file path
         # get _position.yaml file path
@@ -219,30 +244,33 @@ class ModelProviderFactory:
 
 
             file_names = os.listdir(model_provider_dir_path)
             file_names = os.listdir(model_provider_dir_path)
 
 
-            if (model_provider_name + '.py') not in file_names:
+            if (model_provider_name + ".py") not in file_names:
                 logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
                 logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
                 continue
                 continue
 
 
             # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
             # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
-            py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
+            py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
             model_provider_class = load_single_subclass_from_source(
             model_provider_class = load_single_subclass_from_source(
-                module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
+                module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
                 script_path=py_path,
                 script_path=py_path,
-                parent_type=ModelProvider)
+                parent_type=ModelProvider,
+            )
 
 
             if not model_provider_class:
             if not model_provider_class:
                 logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
                 logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
                 continue
                 continue
 
 
-            if f'{model_provider_name}.yaml' not in file_names:
+            if f"{model_provider_name}.yaml" not in file_names:
                 logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
                 logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
                 continue
                 continue
 
 
-            model_providers.append(ModelProviderExtension(
-                name=model_provider_name,
-                provider_instance=model_provider_class(),
-                position=position_map.get(model_provider_name)
-            ))
+            model_providers.append(
+                ModelProviderExtension(
+                    name=model_provider_name,
+                    provider_instance=model_provider_class(),
+                    position=position_map.get(model_provider_name),
+                )
+            )
 
 
         sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
         sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
 
 

+ 12 - 20
api/core/model_runtime/model_providers/openai/_common.py

@@ -1,3 +1,5 @@
+from collections.abc import Mapping
+
 import openai
 import openai
 from httpx import Timeout
 from httpx import Timeout
 
 
@@ -12,7 +14,7 @@ from core.model_runtime.errors.invoke import (
 
 
 
 
 class _CommonOpenAI:
 class _CommonOpenAI:
-    def _to_credential_kwargs(self, credentials: dict) -> dict:
+    def _to_credential_kwargs(self, credentials: Mapping) -> dict:
         """
         """
         Transform credentials to kwargs for model instance
         Transform credentials to kwargs for model instance
 
 
@@ -25,9 +27,9 @@ class _CommonOpenAI:
             "max_retries": 1,
             "max_retries": 1,
         }
         }
 
 
-        if credentials.get('openai_api_base'):
-            credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
-            credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
+        if credentials.get("openai_api_base"):
+            openai_api_base = credentials["openai_api_base"].rstrip("/")
+            credentials_kwargs["base_url"] = openai_api_base + "/v1"
 
 
         if 'openai_organization' in credentials:
         if 'openai_organization' in credentials:
             credentials_kwargs['organization'] = credentials['openai_organization']
             credentials_kwargs['organization'] = credentials['openai_organization']
@@ -45,24 +47,14 @@ class _CommonOpenAI:
         :return: Invoke error mapping
         :return: Invoke error mapping
         """
         """
         return {
         return {
-            InvokeConnectionError: [
-                openai.APIConnectionError,
-                openai.APITimeoutError
-            ],
-            InvokeServerUnavailableError: [
-                openai.InternalServerError
-            ],
-            InvokeRateLimitError: [
-                openai.RateLimitError
-            ],
-            InvokeAuthorizationError: [
-                openai.AuthenticationError,
-                openai.PermissionDeniedError
-            ],
+            InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
+            InvokeServerUnavailableError: [openai.InternalServerError],
+            InvokeRateLimitError: [openai.RateLimitError],
+            InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
             InvokeBadRequestError: [
             InvokeBadRequestError: [
                 openai.BadRequestError,
                 openai.BadRequestError,
                 openai.NotFoundError,
                 openai.NotFoundError,
                 openai.UnprocessableEntityError,
                 openai.UnprocessableEntityError,
-                openai.APIError
-            ]
+                openai.APIError,
+            ],
         }
         }

+ 2 - 1
api/core/model_runtime/model_providers/openai/openai.py

@@ -1,4 +1,5 @@
 import logging
 import logging
+from collections.abc import Mapping
 
 
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -9,7 +10,7 @@ logger = logging.getLogger(__name__)
 
 
 class OpenAIProvider(ModelProvider):
 class OpenAIProvider(ModelProvider):
 
 
-    def validate_provider_credentials(self, credentials: dict) -> None:
+    def validate_provider_credentials(self, credentials: Mapping) -> None:
         """
         """
         Validate provider credentials
         Validate provider credentials
         if validate failed, raise exception
         if validate failed, raise exception