|
@@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
|
|
|
WIKIPEDIA_MAX_QUERY_LENGTH = 300
|
|
|
|
|
|
+
|
|
|
class WikipediaAPIWrapper:
|
|
|
"""Wrapper around WikipediaAPI.
|
|
|
|
|
@@ -25,7 +26,10 @@ class WikipediaAPIWrapper:
|
|
|
def __init__(self, doc_content_chars_max: int = 4000):
|
|
|
self.doc_content_chars_max = doc_content_chars_max
|
|
|
|
|
|
- def run(self, query: str) -> str:
|
|
|
+ def run(self, query: str, lang: str = "") -> str:
|
|
|
+ if lang in wikipedia.languages().keys():
|
|
|
+ self.lang = lang
|
|
|
+
|
|
|
wikipedia.set_lang(self.lang)
|
|
|
wiki_client = wikipedia
|
|
|
|
|
@@ -53,6 +57,7 @@ class WikipediaAPIWrapper:
|
|
|
):
|
|
|
return None
|
|
|
|
|
|
+
|
|
|
class WikipediaQueryRun:
|
|
|
"""Tool that searches the Wikipedia API."""
|
|
|
|
|
@@ -71,26 +76,31 @@ class WikipediaQueryRun:
|
|
|
def _run(
|
|
|
self,
|
|
|
query: str,
|
|
|
+ lang: str = "",
|
|
|
) -> str:
|
|
|
"""Use the Wikipedia tool."""
|
|
|
- return self.api_wrapper.run(query)
|
|
|
+ return self.api_wrapper.run(query, lang)
|
|
|
+
|
|
|
+
|
|
|
class WikiPediaSearchTool(BuiltinTool):
|
|
|
- def _invoke(self,
|
|
|
- user_id: str,
|
|
|
- tool_parameters: dict[str, Any],
|
|
|
- ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+
|
|
|
+ def _invoke(
|
|
|
+ self,
|
|
|
+ user_id: str,
|
|
|
+ tool_parameters: dict[str, Any],
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
"""
|
|
|
- invoke tools
|
|
|
+ invoke tools
|
|
|
"""
|
|
|
- query = tool_parameters.get('query', '')
|
|
|
+ query = tool_parameters.get("query", "")
|
|
|
+ lang = tool_parameters.get("language", "")
|
|
|
if not query:
|
|
|
- return self.create_text_message('Please input query')
|
|
|
-
|
|
|
+ return self.create_text_message("Please input query")
|
|
|
+
|
|
|
tool = WikipediaQueryRun(
|
|
|
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
|
|
)
|
|
|
|
|
|
- result = tool._run(query)
|
|
|
+ result = tool._run(query, lang)
|
|
|
|
|
|
- return self.create_text_message(self.summary(user_id=user_id,content=result))
|
|
|
-
|
|
|
+ return self.create_text_message(self.summary(user_id=user_id, content=result))
|