Browse Source

feat: allow non-english wikipedias to be searched (#5371)

kurokobo 10 months ago
parent
commit
0e3113b7ce

+ 23 - 13
api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py

@@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
 
 WIKIPEDIA_MAX_QUERY_LENGTH = 300
 
+
 class WikipediaAPIWrapper:
     """Wrapper around WikipediaAPI.
 
@@ -25,7 +26,10 @@ class WikipediaAPIWrapper:
     def __init__(self, doc_content_chars_max: int = 4000):
         self.doc_content_chars_max = doc_content_chars_max
 
-    def run(self, query: str) -> str:
+    def run(self, query: str, lang: str = "") -> str:
+        if lang in wikipedia.languages().keys():
+            self.lang = lang
+
         wikipedia.set_lang(self.lang)
         wiki_client = wikipedia
 
@@ -53,6 +57,7 @@ class WikipediaAPIWrapper:
         ):
             return None
 
+
 class WikipediaQueryRun:
     """Tool that searches the Wikipedia API."""
 
@@ -71,26 +76,31 @@ class WikipediaQueryRun:
     def _run(
         self,
         query: str,
+        lang: str = "",
     ) -> str:
         """Use the Wikipedia tool."""
-        return self.api_wrapper.run(query)
+        return self.api_wrapper.run(query, lang)
+
+
 class WikiPediaSearchTool(BuiltinTool):
-    def _invoke(self, 
-                user_id: str, 
-               tool_parameters: dict[str, Any], 
-        ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+
+    def _invoke(
+        self,
+        user_id: str,
+        tool_parameters: dict[str, Any],
+    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         """
-            invoke tools
+        invoke tools
         """
-        query = tool_parameters.get('query', '')
+        query = tool_parameters.get("query", "")
+        lang = tool_parameters.get("language", "")
         if not query:
-            return self.create_text_message('Please input query')
-        
+            return self.create_text_message("Please input query")
+
         tool = WikipediaQueryRun(
             api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
         )
 
-        result = tool._run(query)
+        result = tool._run(query, lang)
 
-        return self.create_text_message(self.summary(user_id=user_id,content=result))
-    
+        return self.create_text_message(self.summary(user_id=user_id, content=result))

+ 74 - 1
api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.yaml

@@ -24,5 +24,78 @@ parameters:
       en_US: key words for searching
       zh_Hans: 查询关键词
       pt_BR: key words for searching
-    llm_description: key words for searching
+    llm_description: key words for searching, this should be in the language of "language" parameter
     form: llm
+  - name: language
+    type: string
+    required: true
+    label:
+      en_US: Language
+      zh_Hans: 语言
+    human_description:
+      en_US: The language of the Wikipedia to be searched
+      zh_Hans: 要搜索的维基百科语言
+    llm_description: >-
+      language of the wikipedia to be searched,
+      only "de" for German,
+      "en" for English,
+      "fr" for French,
+      "hi" for Hindi,
+      "ja" for Japanese,
+      "ko" for Korean,
+      "pl" for Polish,
+      "pt" for Portuguese,
+      "ro" for Romanian,
+      "uk" for Ukrainian,
+      "vi" for Vietnamese,
+      and "zh" for Chinese are supported
+    form: llm
+    options:
+      - value: de
+        label:
+          en_US: German
+          zh_Hans: 德语
+      - value: en
+        label:
+          en_US: English
+          zh_Hans: 英语
+      - value: fr
+        label:
+          en_US: French
+          zh_Hans: 法语
+      - value: hi
+        label:
+          en_US: Hindi
+          zh_Hans: 印地语
+      - value: ja
+        label:
+          en_US: Japanese
+          zh_Hans: 日语
+      - value: ko
+        label:
+          en_US: Korean
+          zh_Hans: 韩语
+      - value: pl
+        label:
+          en_US: Polish
+          zh_Hans: 波兰语
+      - value: pt
+        label:
+          en_US: Portuguese
+          zh_Hans: 葡萄牙语
+      - value: ro
+        label:
+          en_US: Romanian
+          zh_Hans: 罗马尼亚语
+      - value: uk
+        label:
+          en_US: Ukrainian
+          zh_Hans: 乌克兰语
+      - value: vi
+        label:
+          en_US: Vietnamese
+          zh_Hans: 越南语
+      - value: zh
+        label:
+          en_US: Chinese
+          zh_Hans: 中文