tools_manage_service.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. import json
  2. from flask import current_app
  3. from httpx import get
  4. from core.tools.entities.common_entities import I18nObject
  5. from core.tools.entities.tool_bundle import ApiBasedToolBundle
  6. from core.tools.entities.tool_entities import (
  7. ApiProviderAuthType,
  8. ApiProviderSchemaType,
  9. ToolCredentialsOption,
  10. ToolParameter,
  11. ToolProviderCredentials,
  12. )
  13. from core.tools.entities.user_entities import UserTool, UserToolProvider
  14. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  15. from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
  16. from core.tools.provider.tool_provider import ToolProviderController
  17. from core.tools.tool_manager import ToolManager
  18. from core.tools.utils.configuration import ToolConfigurationManager
  19. from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
  20. from core.tools.utils.parser import ApiBasedToolSchemaParser
  21. from extensions.ext_database import db
  22. from models.tools import ApiToolProvider, BuiltinToolProvider
  23. from services.model_provider_service import ModelProviderService
  24. class ToolManageService:
  25. @staticmethod
  26. def list_tool_providers(user_id: str, tenant_id: str):
  27. """
  28. list tool providers
  29. :return: the list of tool providers
  30. """
  31. result = [provider.to_dict() for provider in ToolManager.user_list_providers(
  32. user_id, tenant_id
  33. )]
  34. # add icon url prefix
  35. for provider in result:
  36. ToolManageService.repack_provider(provider)
  37. return result
  38. @staticmethod
  39. def repack_provider(provider: dict):
  40. """
  41. repack provider
  42. :param provider: the provider dict
  43. """
  44. url_prefix = (current_app.config.get("CONSOLE_API_URL")
  45. + "/console/api/workspaces/current/tool-provider/")
  46. if 'icon' in provider:
  47. if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value:
  48. provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon'
  49. elif provider['type'] == UserToolProvider.ProviderType.MODEL.value:
  50. provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon'
  51. elif provider['type'] == UserToolProvider.ProviderType.API.value:
  52. try:
  53. provider['icon'] = json.loads(provider['icon'])
  54. except:
  55. provider['icon'] = {
  56. "background": "#252525",
  57. "content": "\ud83d\ude01"
  58. }
  59. @staticmethod
  60. def list_builtin_tool_provider_tools(
  61. user_id: str, tenant_id: str, provider: str
  62. ):
  63. """
  64. list builtin tool provider tools
  65. """
  66. provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
  67. tools = provider_controller.get_tools()
  68. tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  69. # check if user has added the provider
  70. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  71. BuiltinToolProvider.tenant_id == tenant_id,
  72. BuiltinToolProvider.provider == provider,
  73. ).first()
  74. credentials = {}
  75. if builtin_provider is not None:
  76. # get credentials
  77. credentials = builtin_provider.credentials
  78. credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
  79. result = []
  80. for tool in tools:
  81. # fork tool runtime
  82. tool = tool.fork_tool_runtime(meta={
  83. 'credentials': credentials,
  84. 'tenant_id': tenant_id,
  85. })
  86. # get tool parameters
  87. parameters = tool.parameters or []
  88. # get tool runtime parameters
  89. runtime_parameters = tool.get_runtime_parameters()
  90. # override parameters
  91. current_parameters = parameters.copy()
  92. for runtime_parameter in runtime_parameters:
  93. found = False
  94. for index, parameter in enumerate(current_parameters):
  95. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  96. current_parameters[index] = runtime_parameter
  97. found = True
  98. break
  99. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  100. current_parameters.append(runtime_parameter)
  101. user_tool = UserTool(
  102. author=tool.identity.author,
  103. name=tool.identity.name,
  104. label=tool.identity.label,
  105. description=tool.description.human,
  106. parameters=current_parameters
  107. )
  108. result.append(user_tool)
  109. return json.loads(
  110. serialize_base_model_array(result)
  111. )
  112. @staticmethod
  113. def list_builtin_provider_credentials_schema(
  114. provider_name
  115. ):
  116. """
  117. list builtin provider credentials schema
  118. :return: the list of tool providers
  119. """
  120. provider = ToolManager.get_builtin_provider(provider_name)
  121. return [
  122. v.to_dict() for _, v in (provider.credentials_schema or {}).items()
  123. ]
  124. @staticmethod
  125. def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
  126. """
  127. parse api schema to tool bundle
  128. """
  129. try:
  130. warnings = {}
  131. try:
  132. tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
  133. except Exception as e:
  134. raise ValueError(f'invalid schema: {str(e)}')
  135. credentials_schema = [
  136. ToolProviderCredentials(
  137. name='auth_type',
  138. type=ToolProviderCredentials.CredentialsType.SELECT,
  139. required=True,
  140. default='none',
  141. options=[
  142. ToolCredentialsOption(value='none', label=I18nObject(
  143. en_US='None',
  144. zh_Hans='无'
  145. )),
  146. ToolCredentialsOption(value='api_key', label=I18nObject(
  147. en_US='Api Key',
  148. zh_Hans='Api Key'
  149. )),
  150. ],
  151. placeholder=I18nObject(
  152. en_US='Select auth type',
  153. zh_Hans='选择认证方式'
  154. )
  155. ),
  156. ToolProviderCredentials(
  157. name='api_key_header',
  158. type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
  159. required=False,
  160. placeholder=I18nObject(
  161. en_US='Enter api key header',
  162. zh_Hans='输入 api key header,如:X-API-KEY'
  163. ),
  164. default='api_key',
  165. help=I18nObject(
  166. en_US='HTTP header name for api key',
  167. zh_Hans='HTTP 头部字段名,用于传递 api key'
  168. )
  169. ),
  170. ToolProviderCredentials(
  171. name='api_key_value',
  172. type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
  173. required=False,
  174. placeholder=I18nObject(
  175. en_US='Enter api key',
  176. zh_Hans='输入 api key'
  177. ),
  178. default=''
  179. ),
  180. ]
  181. return json.loads(serialize_base_model_dict(
  182. {
  183. 'schema_type': schema_type,
  184. 'parameters_schema': tool_bundles,
  185. 'credentials_schema': credentials_schema,
  186. 'warning': warnings
  187. }
  188. ))
  189. except Exception as e:
  190. raise ValueError(f'invalid schema: {str(e)}')
  191. @staticmethod
  192. def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiBasedToolBundle]:
  193. """
  194. convert schema to tool bundles
  195. :return: the list of tool bundles, description
  196. """
  197. try:
  198. tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
  199. return tool_bundles
  200. except Exception as e:
  201. raise ValueError(f'invalid schema: {str(e)}')
  202. @staticmethod
  203. def create_api_tool_provider(
  204. user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
  205. schema_type: str, schema: str, privacy_policy: str
  206. ):
  207. """
  208. create api tool provider
  209. """
  210. if schema_type not in [member.value for member in ApiProviderSchemaType]:
  211. raise ValueError(f'invalid schema type {schema}')
  212. # check if the provider exists
  213. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  214. ApiToolProvider.tenant_id == tenant_id,
  215. ApiToolProvider.name == provider_name,
  216. ).first()
  217. if provider is not None:
  218. raise ValueError(f'provider {provider_name} already exists')
  219. # parse openapi to tool bundle
  220. extra_info = {}
  221. # extra info like description will be set here
  222. tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
  223. if len(tool_bundles) > 100:
  224. raise ValueError('the number of apis should be less than 100')
  225. # create db provider
  226. db_provider = ApiToolProvider(
  227. tenant_id=tenant_id,
  228. user_id=user_id,
  229. name=provider_name,
  230. icon=json.dumps(icon),
  231. schema=schema,
  232. description=extra_info.get('description', ''),
  233. schema_type_str=schema_type,
  234. tools_str=serialize_base_model_array(tool_bundles),
  235. credentials_str={},
  236. privacy_policy=privacy_policy
  237. )
  238. if 'auth_type' not in credentials:
  239. raise ValueError('auth_type is required')
  240. # get auth type, none or api key
  241. auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
  242. # create provider entity
  243. provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)
  244. # load tools into provider entity
  245. provider_controller.load_bundled_tools(tool_bundles)
  246. # encrypt credentials
  247. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  248. encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
  249. db_provider.credentials_str = json.dumps(encrypted_credentials)
  250. db.session.add(db_provider)
  251. db.session.commit()
  252. return { 'result': 'success' }
  253. @staticmethod
  254. def get_api_tool_provider_remote_schema(
  255. user_id: str, tenant_id: str, url: str
  256. ):
  257. """
  258. get api tool provider remote schema
  259. """
  260. headers = {
  261. "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
  262. "Accept": "*/*",
  263. }
  264. try:
  265. response = get(url, headers=headers, timeout=10)
  266. if response.status_code != 200:
  267. raise ValueError(f'Got status code {response.status_code}')
  268. schema = response.text
  269. # try to parse schema, avoid SSRF attack
  270. ToolManageService.parser_api_schema(schema)
  271. except Exception as e:
  272. raise ValueError('invalid schema, please check the url you provided')
  273. return {
  274. 'schema': schema
  275. }
  276. @staticmethod
  277. def list_api_tool_provider_tools(
  278. user_id: str, tenant_id: str, provider: str
  279. ):
  280. """
  281. list api tool provider tools
  282. """
  283. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  284. ApiToolProvider.tenant_id == tenant_id,
  285. ApiToolProvider.name == provider,
  286. ).first()
  287. if provider is None:
  288. raise ValueError(f'you have not added provider {provider}')
  289. return json.loads(
  290. serialize_base_model_array([
  291. UserTool(
  292. author=tool_bundle.author,
  293. name=tool_bundle.operation_id,
  294. label=I18nObject(
  295. en_US=tool_bundle.operation_id,
  296. zh_Hans=tool_bundle.operation_id
  297. ),
  298. description=I18nObject(
  299. en_US=tool_bundle.summary or '',
  300. zh_Hans=tool_bundle.summary or ''
  301. ),
  302. parameters=tool_bundle.parameters
  303. ) for tool_bundle in provider.tools
  304. ])
  305. )
  306. @staticmethod
  307. def update_builtin_tool_provider(
  308. user_id: str, tenant_id: str, provider_name: str, credentials: dict
  309. ):
  310. """
  311. update builtin tool provider
  312. """
  313. # get if the provider exists
  314. provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  315. BuiltinToolProvider.tenant_id == tenant_id,
  316. BuiltinToolProvider.provider == provider_name,
  317. ).first()
  318. try:
  319. # get provider
  320. provider_controller = ToolManager.get_builtin_provider(provider_name)
  321. if not provider_controller.need_credentials:
  322. raise ValueError(f'provider {provider_name} does not need credentials')
  323. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  324. # get original credentials if exists
  325. if provider is not None:
  326. original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
  327. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  328. # check if the credential has changed, save the original credential
  329. for name, value in credentials.items():
  330. if name in masked_credentials and value == masked_credentials[name]:
  331. credentials[name] = original_credentials[name]
  332. # validate credentials
  333. provider_controller.validate_credentials(credentials)
  334. # encrypt credentials
  335. credentials = tool_configuration.encrypt_tool_credentials(credentials)
  336. except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
  337. raise ValueError(str(e))
  338. if provider is None:
  339. # create provider
  340. provider = BuiltinToolProvider(
  341. tenant_id=tenant_id,
  342. user_id=user_id,
  343. provider=provider_name,
  344. encrypted_credentials=json.dumps(credentials),
  345. )
  346. db.session.add(provider)
  347. db.session.commit()
  348. else:
  349. provider.encrypted_credentials = json.dumps(credentials)
  350. db.session.add(provider)
  351. db.session.commit()
  352. # delete cache
  353. tool_configuration.delete_tool_credentials_cache()
  354. return { 'result': 'success' }
  355. @staticmethod
  356. def update_api_tool_provider(
  357. user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
  358. schema_type: str, schema: str, privacy_policy: str
  359. ):
  360. """
  361. update api tool provider
  362. """
  363. if schema_type not in [member.value for member in ApiProviderSchemaType]:
  364. raise ValueError(f'invalid schema type {schema}')
  365. # check if the provider exists
  366. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  367. ApiToolProvider.tenant_id == tenant_id,
  368. ApiToolProvider.name == original_provider,
  369. ).first()
  370. if provider is None:
  371. raise ValueError(f'api provider {provider_name} does not exists')
  372. # parse openapi to tool bundle
  373. extra_info = {}
  374. # extra info like description will be set here
  375. tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
  376. # update db provider
  377. provider.name = provider_name
  378. provider.icon = json.dumps(icon)
  379. provider.schema = schema
  380. provider.description = extra_info.get('description', '')
  381. provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
  382. provider.tools_str = serialize_base_model_array(tool_bundles)
  383. provider.privacy_policy = privacy_policy
  384. if 'auth_type' not in credentials:
  385. raise ValueError('auth_type is required')
  386. # get auth type, none or api key
  387. auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
  388. # create provider entity
  389. provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type)
  390. # load tools into provider entity
  391. provider_controller.load_bundled_tools(tool_bundles)
  392. # get original credentials if exists
  393. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  394. original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
  395. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  396. # check if the credential has changed, save the original credential
  397. for name, value in credentials.items():
  398. if name in masked_credentials and value == masked_credentials[name]:
  399. credentials[name] = original_credentials[name]
  400. credentials = tool_configuration.encrypt_tool_credentials(credentials)
  401. provider.credentials_str = json.dumps(credentials)
  402. db.session.add(provider)
  403. db.session.commit()
  404. # delete cache
  405. tool_configuration.delete_tool_credentials_cache()
  406. return { 'result': 'success' }
  407. @staticmethod
  408. def delete_builtin_tool_provider(
  409. user_id: str, tenant_id: str, provider_name: str
  410. ):
  411. """
  412. delete tool provider
  413. """
  414. provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  415. BuiltinToolProvider.tenant_id == tenant_id,
  416. BuiltinToolProvider.provider == provider_name,
  417. ).first()
  418. if provider is None:
  419. raise ValueError(f'you have not added provider {provider_name}')
  420. db.session.delete(provider)
  421. db.session.commit()
  422. # delete cache
  423. provider_controller = ToolManager.get_builtin_provider(provider_name)
  424. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  425. tool_configuration.delete_tool_credentials_cache()
  426. return { 'result': 'success' }
  427. @staticmethod
  428. def get_builtin_tool_provider_icon(
  429. provider: str
  430. ):
  431. """
  432. get tool provider icon and it's mimetype
  433. """
  434. icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
  435. with open(icon_path, 'rb') as f:
  436. icon_bytes = f.read()
  437. return icon_bytes, mime_type
  438. @staticmethod
  439. def get_model_tool_provider_icon(
  440. provider: str
  441. ):
  442. """
  443. get tool provider icon and it's mimetype
  444. """
  445. service = ModelProviderService()
  446. icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')
  447. if icon_bytes is None:
  448. raise ValueError(f'provider {provider} does not exists')
  449. return icon_bytes, mime_type
  450. @staticmethod
  451. def list_model_tool_provider_tools(
  452. user_id: str, tenant_id: str, provider: str
  453. ):
  454. """
  455. list model tool provider tools
  456. """
  457. provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)
  458. tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
  459. result = [
  460. UserTool(
  461. author=tool.identity.author,
  462. name=tool.identity.name,
  463. label=tool.identity.label,
  464. description=tool.description.human,
  465. parameters=tool.parameters or []
  466. ) for tool in tools
  467. ]
  468. return json.loads(
  469. serialize_base_model_array(result)
  470. )
  471. @staticmethod
  472. def delete_api_tool_provider(
  473. user_id: str, tenant_id: str, provider_name: str
  474. ):
  475. """
  476. delete tool provider
  477. """
  478. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  479. ApiToolProvider.tenant_id == tenant_id,
  480. ApiToolProvider.name == provider_name,
  481. ).first()
  482. if provider is None:
  483. raise ValueError(f'you have not added provider {provider_name}')
  484. db.session.delete(provider)
  485. db.session.commit()
  486. return { 'result': 'success' }
  487. @staticmethod
  488. def get_api_tool_provider(
  489. user_id: str, tenant_id: str, provider: str
  490. ):
  491. """
  492. get api tool provider
  493. """
  494. return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
  495. @staticmethod
  496. def test_api_tool_preview(
  497. tenant_id: str,
  498. provider_name: str,
  499. tool_name: str,
  500. credentials: dict,
  501. parameters: dict,
  502. schema_type: str,
  503. schema: str
  504. ):
  505. """
  506. test api tool before adding api tool provider
  507. """
  508. if schema_type not in [member.value for member in ApiProviderSchemaType]:
  509. raise ValueError(f'invalid schema type {schema_type}')
  510. try:
  511. tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
  512. except Exception as e:
  513. raise ValueError('invalid schema')
  514. # get tool bundle
  515. tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
  516. if tool_bundle is None:
  517. raise ValueError(f'invalid tool name {tool_name}')
  518. db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  519. ApiToolProvider.tenant_id == tenant_id,
  520. ApiToolProvider.name == provider_name,
  521. ).first()
  522. if not db_provider:
  523. # create a fake db provider
  524. db_provider = ApiToolProvider(
  525. tenant_id='', user_id='', name='', icon='',
  526. schema=schema,
  527. description='',
  528. schema_type_str=ApiProviderSchemaType.OPENAPI.value,
  529. tools_str=serialize_base_model_array(tool_bundles),
  530. credentials_str=json.dumps(credentials),
  531. )
  532. if 'auth_type' not in credentials:
  533. raise ValueError('auth_type is required')
  534. # get auth type, none or api key
  535. auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
  536. # create provider entity
  537. provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)
  538. # load tools into provider entity
  539. provider_controller.load_bundled_tools(tool_bundles)
  540. # decrypt credentials
  541. if db_provider.id:
  542. tool_configuration = ToolConfigurationManager(
  543. tenant_id=tenant_id,
  544. provider_controller=provider_controller
  545. )
  546. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  547. # check if the credential has changed, save the original credential
  548. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  549. for name, value in credentials.items():
  550. if name in masked_credentials and value == masked_credentials[name]:
  551. credentials[name] = decrypted_credentials[name]
  552. try:
  553. provider_controller.validate_credentials_format(credentials)
  554. # get tool
  555. tool = provider_controller.get_tool(tool_name)
  556. tool = tool.fork_tool_runtime(meta={
  557. 'credentials': credentials,
  558. 'tenant_id': tenant_id,
  559. })
  560. result = tool.validate_credentials(credentials, parameters)
  561. except Exception as e:
  562. return { 'error': str(e) }
  563. return { 'result': result or 'empty response' }