Browse Source

fix: type num of variable converted to str (#3758)

takatost 1 year ago
parent
commit
2ea8c73cd8

+ 11 - 3
api/core/app/apps/base_app_generator.py

@@ -23,20 +23,28 @@ class BaseAppGenerator:
             value = user_inputs[variable]
 
             if value:
-                if not isinstance(value, str):
+                if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
                     raise ValueError(f"{variable} in input form must be a string")
+                elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
+                    if '.' in value:
+                        value = float(value)
+                    else:
+                        value = int(value)
 
             if variable_config.type == VariableEntity.Type.SELECT:
                 options = variable_config.options if variable_config.options is not None else []
                 if value not in options:
                     raise ValueError(f"{variable} in input form must be one of the following: {options}")
-            else:
+            elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
                 if variable_config.max_length is not None:
                     max_length = variable_config.max_length
                     if len(value) > max_length:
                         raise ValueError(f'{variable} in input form must be less than {max_length} characters')
 
-            filtered_inputs[variable] = value.replace('\x00', '') if value else None
+            if value and isinstance(value, str):
+                filtered_inputs[variable] = value.replace('\x00', '')
+            else:
+                filtered_inputs[variable] = value if value else None
 
         return filtered_inputs
 

+ 1 - 1
api/core/app/entities/app_invoke_entities.py

@@ -72,7 +72,7 @@ class AppGenerateEntity(BaseModel):
     # app config
     app_config: AppConfig
 
-    inputs: dict[str, str]
+    inputs: dict[str, Any]
     files: list[FileVar] = []
     user_id: str
 

+ 2 - 0
api/core/prompt/advanced_prompt_transform.py

@@ -32,6 +32,8 @@ class AdvancedPromptTransform(PromptTransform):
                    memory_config: Optional[MemoryConfig],
                    memory: Optional[TokenBufferMemory],
                    model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
+        inputs = {key: str(value) for key, value in inputs.items()}
+
         prompt_messages = []
 
         model_mode = ModelMode.value_of(model_config.mode)

+ 2 - 0
api/core/prompt/simple_prompt_transform.py

@@ -55,6 +55,8 @@ class SimplePromptTransform(PromptTransform):
                    memory: Optional[TokenBufferMemory],
                    model_config: ModelConfigWithCredentialsEntity) -> \
             tuple[list[PromptMessage], Optional[list[str]]]:
+        inputs = {key: str(value) for key, value in inputs.items()}
+
         model_mode = ModelMode.value_of(model_config.mode)
         if model_mode == ModelMode.CHAT:
             prompt_messages, stops = self._get_chat_model_prompt_messages(

+ 1 - 43
api/core/workflow/nodes/start/start_node.py

@@ -1,6 +1,4 @@
-from typing import cast
 
-from core.app.app_config.entities import VariableEntity
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
@@ -19,12 +17,8 @@ class StartNode(BaseNode):
         :param variable_pool: variable pool
         :return:
         """
-        node_data = self.node_data
-        node_data = cast(self._node_data_cls, node_data)
-        variables = node_data.variables
-
         # Get cleaned inputs
-        cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
+        cleaned_inputs = variable_pool.user_inputs
 
         for var in variable_pool.system_variables:
             if var == SystemVariable.CONVERSATION:
@@ -38,42 +32,6 @@ class StartNode(BaseNode):
             outputs=cleaned_inputs
         )
 
-    def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
-        if user_inputs is None:
-            user_inputs = {}
-
-        filtered_inputs = {}
-
-        for variable_config in variables:
-            variable = variable_config.variable
-
-            if variable not in user_inputs or not user_inputs[variable]:
-                if variable_config.required:
-                    raise ValueError(f"Input form variable {variable} is required")
-                else:
-                    filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
-                    continue
-
-            value = user_inputs[variable]
-
-            if value:
-                if not isinstance(value, str):
-                    raise ValueError(f"{variable} in input form must be a string")
-
-            if variable_config.type == VariableEntity.Type.SELECT:
-                options = variable_config.options if variable_config.options is not None else []
-                if value not in options:
-                    raise ValueError(f"{variable} in input form must be one of the following: {options}")
-            else:
-                if variable_config.max_length is not None:
-                    max_length = variable_config.max_length
-                    if len(value) > max_length:
-                        raise ValueError(f'{variable} in input form must be less than {max_length} characters')
-
-            filtered_inputs[variable] = value.replace('\x00', '') if value else None
-
-        return filtered_inputs
-
     @classmethod
     def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
         """