|
@@ -0,0 +1,137 @@
|
|
|
+import json
|
|
|
+from copy import deepcopy
|
|
|
+from typing import Any, Union
|
|
|
+
|
|
|
+from pandas import DataFrame
|
|
|
+from yarl import URL
|
|
|
+
|
|
|
+from core.helper import ssrf_proxy
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
+from core.tools.errors import ToolProviderCredentialValidationError
|
|
|
+from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
+
|
|
|
+
|
|
|
+class NovitaAiModelQueryTool(BuiltinTool):
|
|
|
+ _model_query_endpoint = 'https://api.novita.ai/v3/model'
|
|
|
+
|
|
|
+ def _invoke(self,
|
|
|
+ user_id: str,
|
|
|
+ tool_parameters: dict[str, Any],
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+ """
|
|
|
+ invoke tools
|
|
|
+ """
|
|
|
+ if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
|
|
|
+ raise ToolProviderCredentialValidationError("Novita AI API Key is required.")
|
|
|
+
|
|
|
+ api_key = self.runtime.credentials.get('api_key')
|
|
|
+ headers = {
|
|
|
+ 'Content-Type': 'application/json',
|
|
|
+ 'Authorization': "Bearer " + api_key
|
|
|
+ }
|
|
|
+ params = self._process_parameters(tool_parameters)
|
|
|
+ result_type = params.get('result_type')
|
|
|
+ del params['result_type']
|
|
|
+
|
|
|
+ models_data = self._query_models(
|
|
|
+ models_data=[],
|
|
|
+ headers=headers,
|
|
|
+ params=params,
|
|
|
+ recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True
|
|
|
+ )
|
|
|
+
|
|
|
+ result_str = ''
|
|
|
+ if result_type == 'first sd_name':
|
|
|
+ result_str = models_data[0]['sd_name_in_api']
|
|
|
+ elif result_type == 'first name sd_name pair':
|
|
|
+ result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']})
|
|
|
+ elif result_type == 'sd_name array':
|
|
|
+ sd_name_array = [model['sd_name_in_api'] for model in models_data]
|
|
|
+ result_str = json.dumps(sd_name_array)
|
|
|
+ elif result_type == 'name array':
|
|
|
+ name_array = [model['name'] for model in models_data]
|
|
|
+ result_str = json.dumps(name_array)
|
|
|
+ elif result_type == 'name sd_name pair array':
|
|
|
+ name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']} for model in models_data]
|
|
|
+ result_str = json.dumps(name_sd_name_pair_array)
|
|
|
+ elif result_type == 'whole info array':
|
|
|
+ result_str = json.dumps(models_data)
|
|
|
+ else:
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+ return self.create_text_message(result_str)
|
|
|
+
|
|
|
+ def _query_models(self, models_data: list, headers: dict[str, Any],
|
|
|
+ params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list:
|
|
|
+ """
|
|
|
+ query models
|
|
|
+ """
|
|
|
+ inside_params = deepcopy(params)
|
|
|
+
|
|
|
+ if pagination_cursor != '':
|
|
|
+ inside_params['pagination.cursor'] = pagination_cursor
|
|
|
+
|
|
|
+ response = ssrf_proxy.get(
|
|
|
+ url=str(URL(self._model_query_endpoint)),
|
|
|
+ headers=headers,
|
|
|
+ params=params,
|
|
|
+ timeout=(10, 60)
|
|
|
+ )
|
|
|
+
|
|
|
+ res_data = response.json()
|
|
|
+
|
|
|
+ models_data.extend(res_data['models'])
|
|
|
+
|
|
|
+ res_data_len = len(res_data['models'])
|
|
|
+ if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False:
|
|
|
+ # deduplicate
|
|
|
+ df = DataFrame.from_dict(models_data)
|
|
|
+ df_unique = df.drop_duplicates(subset=['id'])
|
|
|
+ models_data = df_unique.to_dict('records')
|
|
|
+ return models_data
|
|
|
+
|
|
|
+ return self._query_models(
|
|
|
+ models_data=models_data,
|
|
|
+ headers=headers,
|
|
|
+ params=inside_params,
|
|
|
+ pagination_cursor=res_data['pagination']['next_cursor']
|
|
|
+ )
|
|
|
+
|
|
|
+ def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
|
|
+ """
|
|
|
+ process parameters
|
|
|
+ """
|
|
|
+ process_parameters = deepcopy(parameters)
|
|
|
+ res_parameters = {}
|
|
|
+
|
|
|
+ # delete none or empty
|
|
|
+ keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == '']
|
|
|
+ for k in keys_to_delete:
|
|
|
+ del process_parameters[k]
|
|
|
+
|
|
|
+ if 'query' in process_parameters and process_parameters.get('query') != 'unspecified':
|
|
|
+ res_parameters['filter.query'] = process_parameters['query']
|
|
|
+
|
|
|
+ if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified':
|
|
|
+ res_parameters['filter.visibility'] = process_parameters['visibility']
|
|
|
+
|
|
|
+ if 'source' in process_parameters and process_parameters.get('source') != 'unspecified':
|
|
|
+ res_parameters['filter.source'] = process_parameters['source']
|
|
|
+
|
|
|
+ if 'type' in process_parameters and process_parameters.get('type') != 'unspecified':
|
|
|
+ res_parameters['filter.types'] = process_parameters['type']
|
|
|
+
|
|
|
+ if 'is_sdxl' in process_parameters:
|
|
|
+ if process_parameters['is_sdxl'] == 'true':
|
|
|
+ res_parameters['filter.is_sdxl'] = True
|
|
|
+ elif process_parameters['is_sdxl'] == 'false':
|
|
|
+ res_parameters['filter.is_sdxl'] = False
|
|
|
+
|
|
|
+ res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name')
|
|
|
+
|
|
|
+ res_parameters['pagination.limit'] = 1 \
|
|
|
+ if res_parameters.get('result_type') == 'first sd_name' \
|
|
|
+ or res_parameters.get('result_type') == 'first name sd_name pair'\
|
|
|
+ else 100
|
|
|
+
|
|
|
+ return res_parameters
|