diff --git a/bot/extensions/wikipedia_extension.py b/bot/extensions/wikipedia_extension.py new file mode 100644 index 0000000..cda9043 --- /dev/null +++ b/bot/extensions/wikipedia_extension.py @@ -0,0 +1,80 @@ +import asyncio +import requests +from matrix import Extension, Context + +type WikiPayload = tuple[str, list[str], list[str], list[str]] + +_RESULT_LIMIT = 3 +_REQUEST_TIMEOUT = 5 +_USER_AGENT = "AdaBot/1.0 (https://github.com/Code-Society-Lab/ada)" +_API_URL = "https://en.wikipedia.org/w/api.php" + + +extension = Extension("wikipedia") + + +@extension.command( + usage="wiki ", + description="Search Wikipedia and display the top results.", +) +async def wiki(ctx: Context, *args: str) -> None: + if not args: + raise ValueError("Please provide a search query. Usage: !wiki ") + + query = " ".join(args).strip() + if len(query) > 300: + raise ValueError( + "Search query is too long. Please limit it to 300 characters or less." + ) + payload = await asyncio.to_thread(_search_wikipedia, query) + result_message = _format_results(query, payload) + await ctx.reply(result_message) + + +@wiki.error(exception=requests.RequestException) +async def wiki_unreachable(ctx: Context, error: requests.RequestException) -> None: + await ctx.reply("Sorry, something went wrong while contacting Wikipedia") + + +@wiki.error(exception=ValueError) +async def wiki_invalid(ctx: Context, error: ValueError) -> None: + await ctx.reply(str(error)) + + +def _search_wikipedia(query: str) -> WikiPayload: + params: dict[str, str | int] = { + "action": "opensearch", + "format": "json", + "namespace": 0, + "limit": _RESULT_LIMIT, + "search": query, + } + response = requests.get( + _API_URL, + params=params, + headers={"User-Agent": _USER_AGENT}, + timeout=_REQUEST_TIMEOUT, + ) + response.raise_for_status() + return response.json() + + +def _format_results(query: str, payload: WikiPayload) -> str: + if ( + not isinstance(payload, (list, tuple)) + or len(payload) < 4 + or not isinstance(payload[1], list) + or not isinstance(payload[3], list) + ): + raise ValueError("Unexpected response format from Wikipedia API.") + + titles, urls = payload[1], payload[3] + if not titles: + raise ValueError(f"No results found for '{query}'.") + + result_lines = [ + f"> **{i}.** [{title}](<{url}>)" + for i, (title, url) in enumerate(zip(titles, urls), start=1) + ] + header = f'#### Wikipedia results for "_{query}_"' + return f"{header}\n\n" + "\n".join(result_lines) diff --git a/pyproject.toml b/pyproject.toml index 106bc3c..20f1f92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ dependencies = [ "coloredlogs==15.0.1", "click==8.3.2", "sqlmodel-toolkit @ git+https://github.com/Code-Society-Lab/sqlmodel-toolkit.git", - "pelican-migration>=0.1.2a0" + "pelican-migration>=0.1.2a0", + "requests==2.33.1" ] [project.optional-dependencies] @@ -24,7 +25,8 @@ dev = [ "pytest-asyncio>=0.24", "flake8==7.3.0", "mypy==1.20.1", - "types-PyYAML==6.0.12.20260408" + "types-PyYAML==6.0.12.20260408", + "types-requests" ] [project.scripts] diff --git a/tests/extensions/test_wikipedia_extension.py b/tests/extensions/test_wikipedia_extension.py new file mode 100644 index 0000000..5540547 --- /dev/null +++ b/tests/extensions/test_wikipedia_extension.py @@ -0,0 +1,91 @@ +import pytest +import requests +from unittest.mock import patch, MagicMock + +from bot.extensions.wikipedia_extension import ( + _format_results, + _search_wikipedia, +) + + +def test_format_results_renders_header_and_lines() -> None: + payload = ("py", ["Python"], [""], ["https://en.wikipedia.org/wiki/Python"]) + result = _format_results("py", payload) + + assert '#### Wikipedia results for "_py_"' in result + assert "> **1.** [Python]()" in result + + +def test_format_results_raises_on_non_list_payload() -> None: + with pytest.raises(ValueError, match="Unexpected response format"): + _format_results("q", {}) # type: ignore[arg-type] + + +def test_format_results_raises_on_short_payload() -> None: + with pytest.raises(ValueError, match="Unexpected response format"): + _format_results("q", []) # type: ignore[arg-type] + + +def test_format_results_raises_on_wrong_inner_types() -> None: + with pytest.raises(ValueError, match="Unexpected response format"): + _format_results("q", ["query", 2, [""], ["https://test.link"]]) # type: ignore[arg-type] + + +def test_format_results_raises_on_no_results() -> None: + with pytest.raises(ValueError, match="No results found"): + _format_results("q", ("query", [], [""], ["https://test.link"])) + + +def test_search_wikipedia_calls_api_with_correct_params() -> None: + fake_payload = ["query", ["Title"], [""], ["https://test.link"]] + + fake_response = MagicMock() + fake_response.json.return_value = fake_payload + + with patch( + "bot.extensions.wikipedia_extension.requests.get", return_value=fake_response + ) as mock_get: + + result = _search_wikipedia("test query") + + mock_get.assert_called_once_with( + "https://en.wikipedia.org/w/api.php", + params={ + "action": "opensearch", + "format": "json", + "namespace": 0, + "limit": 3, + "search": "test query", + }, + headers={ + "User-Agent": "AdaBot/1.0 (https://github.com/Code-Society-Lab/ada)" + }, + timeout=5, + ) + assert result == fake_payload + + +def test_search_wikipedia_raises_on_http_error() -> None: + fake_response = MagicMock() + fake_response.raise_for_status.side_effect = requests.HTTPError("error") + + with patch( + "bot.extensions.wikipedia_extension.requests.get", return_value=fake_response + ): + with pytest.raises(requests.HTTPError, match="error"): + _search_wikipedia("python") + + +@pytest.mark.parametrize( + "exception", + [ + requests.ConnectionError("connection failed"), + requests.Timeout("request timed out"), + ], +) +def test_search_wikipedia_raises_on_network_error(exception) -> None: + with patch( + "bot.extensions.wikipedia_extension.requests.get", side_effect=exception + ): + with pytest.raises(type(exception), match=str(exception)): + _search_wikipedia("python")