diff --git a/main.py b/main.py index ca1ccb4..7f0f1e2 100755 --- a/main.py +++ b/main.py @@ -1,14 +1,16 @@ #!/usr/bin/env python3 import json import os -import sys -import subprocess import re -from github import Github, Auth, GithubException # type: ignore +import subprocess +import sys +from typing import TextIO # Constants for message titles SUCCESS_TITLE = "# Commit-Check ✔️" FAILURE_TITLE = "# Commit-Check ❌" +COMMIT_MESSAGE_DELIMITER = "\x00" +COMMIT_SECTION_SEPARATOR = "\n---\n" # Environment variables MESSAGE = os.getenv("MESSAGE", "false") @@ -19,9 +21,20 @@ JOB_SUMMARY = os.getenv("JOB_SUMMARY", "false") PR_COMMENTS = os.getenv("PR_COMMENTS", "false") GITHUB_STEP_SUMMARY = os.environ["GITHUB_STEP_SUMMARY"] -GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") -GITHUB_REPOSITORY = os.getenv("GITHUB_REPOSITORY") -GITHUB_REF = os.getenv("GITHUB_REF") + + +def env_flag(name: str, default: str = "false") -> bool: + """Read a GitHub Action boolean-style environment variable.""" + return os.getenv(name, default).lower() == "true" + + +MESSAGE_ENABLED = env_flag("MESSAGE") +BRANCH_ENABLED = env_flag("BRANCH") +AUTHOR_NAME_ENABLED = env_flag("AUTHOR_NAME") +AUTHOR_EMAIL_ENABLED = env_flag("AUTHOR_EMAIL") +DRY_RUN_ENABLED = env_flag("DRY_RUN") +JOB_SUMMARY_ENABLED = env_flag("JOB_SUMMARY") +PR_COMMENTS_ENABLED = env_flag("PR_COMMENTS") def log_env_vars(): @@ -35,35 +48,170 @@ def log_env_vars(): print(f"PR_COMMENTS = {PR_COMMENTS}\n") -def run_commit_check() -> int: - """Runs the commit-check command and logs the result.""" - args = [ - "--message", - "--branch", - "--author-name", - "--author-email", +def is_pr_event() -> bool: + """Return whether the workflow was triggered by a PR-style event.""" + return os.getenv("GITHUB_EVENT_NAME", "") in {"pull_request", "pull_request_target"} + + +def parse_commit_messages(output: str) -> list[str]: + """Split git log output into individual commit messages.""" + return [ + message.strip("\n") + for message in output.split(COMMIT_MESSAGE_DELIMITER) + if message.strip("\n") ] - args = [ - arg - for arg, value in zip( - args, - [ - MESSAGE, - BRANCH, - AUTHOR_NAME, - AUTHOR_EMAIL, - ], + + +def get_messages_from_merge_ref() -> list[str]: + """Read PR commit messages from GitHub's synthetic merge commit.""" + result = subprocess.run( + ["git", "log", "--pretty=format:%B%x00", "--reverse", "HEAD^1..HEAD^2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + check=False, + ) + if result.returncode == 0 and result.stdout: + return parse_commit_messages(result.stdout) + return [] + + +def get_messages_from_head_ref(base_ref: str) -> list[str]: + """Read PR commit messages when the workflow checks out the head SHA.""" + result = subprocess.run( + [ + "git", + "log", + "--pretty=format:%B%x00", + "--reverse", + f"origin/{base_ref}..HEAD", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + check=False, + ) + if result.returncode == 0 and result.stdout: + return parse_commit_messages(result.stdout) + return [] + + +def get_pr_commit_messages() -> list[str]: + """Get all commit messages for the current PR workflow. + + In pull_request-style workflows, actions/checkout checks out a synthetic merge + commit (HEAD = merge of PR branch into base). HEAD^1 is the base branch + tip, HEAD^2 is the PR branch tip. So HEAD^1..HEAD^2 gives all PR commits. + If the workflow explicitly checks out the PR head SHA instead, fall back to + diffing against origin/ when that ref is available locally. + """ + if not is_pr_event(): + return [] + + try: + messages = get_messages_from_merge_ref() + if messages: + return messages + + base_ref = os.getenv("GITHUB_BASE_REF", "") + if base_ref: + return get_messages_from_head_ref(base_ref) + except Exception as e: + print( + f"::warning::Failed to retrieve PR commit messages: {e}", + file=sys.stderr, ) - if value == "true" - ] + return [] + +def run_check_command( + args: list[str], + result_file: TextIO, + input_text: str | None = None, + output_prefix: str | None = None, +) -> int: + """Run commit-check and write both stdout and stderr to the result file.""" command = ["commit-check"] + args print(" ".join(command)) - with open("result.txt", "w") as result_file: - result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, check=False + result = subprocess.run( + command, + input=input_text, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + check=False, + ) + if result.stdout: + if output_prefix: + result_file.write(output_prefix) + result_file.write(result.stdout.rstrip("\n")) + result_file.write("\n") + return result.returncode + + +def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: + """Checks each PR commit message individually via commit-check --message. + + Returns 1 if any message fails, 0 if all pass. + """ + has_failure = False + emitted_failure_output = False + total = len(pr_messages) + for index, msg in enumerate(pr_messages, start=1): + command_args = ["--message"] + if emitted_failure_output: + command_args.append("--no-banner") + + if emitted_failure_output: + output_prefix = f"\n--- Commit {index}/{total}:\n" + else: + output_prefix = None + + return_code = run_check_command( + command_args, + result_file, + input_text=msg, + output_prefix=output_prefix, ) - return result.returncode + if return_code != 0: + has_failure = True + emitted_failure_output = True + return 1 if has_failure else 0 + + +def run_other_checks(args: list[str], result_file: TextIO) -> int: + """Runs non-message checks (branch, author) once. Returns 0 if args is empty.""" + if not args: + return 0 + return run_check_command(args, result_file) + + +def build_check_args() -> list[str]: + """Map enabled validation switches to commit-check CLI arguments.""" + flags = [ + ("--message", MESSAGE_ENABLED), + ("--branch", BRANCH_ENABLED), + ("--author-name", AUTHOR_NAME_ENABLED), + ("--author-email", AUTHOR_EMAIL_ENABLED), + ] + return [flag for flag, enabled in flags if enabled] + + +def run_commit_check() -> int: + """Runs the commit-check command and logs the result.""" + args = build_check_args() + with open("result.txt", "w") as result_file: + if MESSAGE_ENABLED: + pr_messages = get_pr_commit_messages() + if pr_messages: + # In PR context: check each commit message individually to avoid + # only validating the synthetic merge commit at HEAD. + message_rc = run_pr_message_checks(pr_messages, result_file) + other_args = [a for a in args if a != "--message"] + other_rc = run_other_checks(other_args, result_file) + return 1 if message_rc or other_rc else 0 + # Non-PR context or message disabled: run all checks at once + return 1 if run_check_command(args, result_file) else 0 def read_result_file() -> str | None: @@ -77,21 +225,22 @@ def read_result_file() -> str | None: return None +def build_result_body(result_text: str | None) -> str: + """Create the human-readable result body used in summaries and PR comments.""" + if result_text is None: + return SUCCESS_TITLE + return f"{FAILURE_TITLE}\n```\n{result_text}\n```" + + def add_job_summary() -> int: """Adds the commit check result to the GitHub job summary.""" - if JOB_SUMMARY == "false": + if not JOB_SUMMARY_ENABLED: return 0 result_text = read_result_file() - summary_content = ( - SUCCESS_TITLE - if result_text is None - else f"{FAILURE_TITLE}\n```\n{result_text}\n```" - ) - with open(GITHUB_STEP_SUMMARY, "a") as summary_file: - summary_file.write(summary_content) + summary_file.write(build_result_body(result_text)) return 0 if result_text is None else 1 @@ -116,7 +265,7 @@ def is_fork_pr() -> bool: def add_pr_comments() -> int: """Posts the commit check result as a comment on the pull request.""" - if PR_COMMENTS == "false": + if not PR_COMMENTS_ENABLED: return 0 # Fork PRs triggered by the pull_request event receive a read-only token; @@ -132,6 +281,8 @@ def add_pr_comments() -> int: return 0 try: + from github import Auth, Github, GithubException # type: ignore + token = os.getenv("GITHUB_TOKEN") repo_name = os.getenv("GITHUB_REPOSITORY") pr_number = os.getenv("GITHUB_REF") @@ -147,15 +298,9 @@ def add_pr_comments() -> int: repo = g.get_repo(repo_name) pull_request = repo.get_issue(int(pr_number)) - # Prepare comment content result_text = read_result_file() - pr_comment_body = ( - SUCCESS_TITLE - if result_text is None - else f"{FAILURE_TITLE}\n```\n{result_text}\n```" - ) + pr_comment_body = build_result_body(result_text) - # Fetch all existing comments on the PR comments = pull_request.get_comments() matching_comments = [ c @@ -215,12 +360,9 @@ def main(): """Main function to run commit-check, add job summary and post PR comments.""" log_env_vars() - # Combine return codes - ret_code = run_commit_check() - ret_code += add_job_summary() - ret_code += add_pr_comments() + ret_code = max(run_commit_check(), add_job_summary(), add_pr_comments()) - if DRY_RUN == "true": + if DRY_RUN_ENABLED: ret_code = 0 result_text = read_result_file() diff --git a/main_test.py b/main_test.py new file mode 100644 index 0000000..3a5d04c --- /dev/null +++ b/main_test.py @@ -0,0 +1,587 @@ +"""Unit tests for main.py.""" + +import io +import json +import os +import unittest +from unittest.mock import MagicMock, patch + +# GITHUB_STEP_SUMMARY is accessed via os.environ[] (not getenv) at import time, +# so we must set it before importing main. +os.environ.setdefault("GITHUB_STEP_SUMMARY", "/tmp/step_summary.txt") + +import main # noqa: E402 + + +class TestEnvFlag(unittest.TestCase): + def test_true_value(self): + with patch.dict(os.environ, {"FEATURE_FLAG": "true"}): + self.assertTrue(main.env_flag("FEATURE_FLAG")) + + def test_false_value(self): + with patch.dict(os.environ, {"FEATURE_FLAG": "false"}): + self.assertFalse(main.env_flag("FEATURE_FLAG")) + + def test_missing_uses_default(self): + with patch.dict(os.environ, {}, clear=True): + self.assertTrue(main.env_flag("FEATURE_FLAG", default="true")) + + +class TestBuildCheckArgs(unittest.TestCase): + def test_all_true(self): + with ( + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", True), + patch("main.AUTHOR_EMAIL_ENABLED", True), + ): + result = main.build_check_args() + self.assertEqual( + result, ["--message", "--branch", "--author-name", "--author-email"] + ) + + def test_all_false(self): + with ( + patch("main.MESSAGE_ENABLED", False), + patch("main.BRANCH_ENABLED", False), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + ): + result = main.build_check_args() + self.assertEqual(result, []) + + def test_message_and_branch(self): + with ( + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + ): + result = main.build_check_args() + self.assertEqual(result, ["--message", "--branch"]) + + +class TestParseCommitMessages(unittest.TestCase): + def test_splits_messages_and_trims_surrounding_newlines(self): + result = main.parse_commit_messages("\nfix: first\n\x00\nfeat: second\n\n\x00") + self.assertEqual(result, ["fix: first", "feat: second"]) + + +class TestRunCheckCommand(unittest.TestCase): + def test_with_args_calls_subprocess(self): + mock_result = MagicMock(returncode=0, stdout="") + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + rc = main.run_check_command(["--branch"], io.StringIO()) + self.assertEqual(rc, 0) + self.assertEqual(mock_run.call_args[0][0], ["commit-check", "--branch"]) + + def test_with_input_uses_text_mode(self): + mock_result = MagicMock(returncode=0, stdout="") + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + main.run_check_command(["--message"], io.StringIO(), input_text="fix: demo") + self.assertEqual(mock_run.call_args[1]["input"], "fix: demo") + self.assertTrue(mock_run.call_args[1]["text"]) + + def test_prints_command(self): + mock_result = MagicMock(returncode=0, stdout="") + with patch("main.subprocess.run", return_value=mock_result): + with patch("builtins.print") as mock_print: + main.run_check_command(["--branch"], io.StringIO()) + mock_print.assert_called_once_with("commit-check --branch") + + +class TestRunPrMessageChecks(unittest.TestCase): + def test_single_message_pass(self): + mock_result = MagicMock(returncode=0, stdout="") + result_file = io.StringIO() + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + rc = main.run_pr_message_checks(["fix: something"], result_file) + self.assertEqual(rc, 0) + self.assertEqual(mock_run.call_args[0][0], ["commit-check", "--message"]) + self.assertEqual(mock_run.call_args[1]["input"], "fix: something") + self.assertEqual(result_file.getvalue(), "") + + def test_failed_message_writes_output(self): + mock_result = MagicMock(returncode=1, stdout="Commit rejected.\n") + result_file = io.StringIO() + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_pr_message_checks(["fix: something"], result_file) + self.assertEqual(rc, 1) + self.assertIn("Commit rejected.", result_file.getvalue()) + + def test_multiple_messages_partial_failure(self): + results = [ + MagicMock(returncode=0, stdout=""), + MagicMock(returncode=1, stdout="Commit rejected.\n"), + MagicMock(returncode=0, stdout=""), + ] + with patch("main.subprocess.run", side_effect=results): + rc = main.run_pr_message_checks(["ok", "bad", "ok"], io.StringIO()) + self.assertEqual(rc, 1) + + def test_empty_list(self): + with patch("main.subprocess.run") as mock_run: + rc = main.run_pr_message_checks([], io.StringIO()) + self.assertEqual(rc, 0) + mock_run.assert_not_called() + + def test_first_failure_keeps_banner_and_later_failures_use_no_banner(self): + results = [ + MagicMock(returncode=0, stdout=""), + MagicMock(returncode=1, stdout="Commit rejected.\n"), + MagicMock(returncode=1, stdout="Type subject_imperative check failed\n"), + ] + with patch("main.subprocess.run", side_effect=results) as mock_run: + main.run_pr_message_checks( + ["ok first", "bad second", "bad third"], io.StringIO() + ) + + self.assertEqual( + mock_run.call_args_list[0][0][0], ["commit-check", "--message"] + ) + self.assertEqual( + mock_run.call_args_list[1][0][0], + ["commit-check", "--message"], + ) + self.assertEqual( + mock_run.call_args_list[2][0][0], + ["commit-check", "--message", "--no-banner"], + ) + + def test_later_failure_prefix_uses_short_separator_without_extra_blank_lines(self): + results = [ + MagicMock(returncode=0, stdout=""), + MagicMock(returncode=1, stdout="Commit rejected.\n"), + MagicMock( + returncode=1, + stdout=( + "Type subject_imperative check failed ==> bad third\n" + "Commit message should use imperative mood\n" + "Suggest: Use imperative mood\n\n" + ), + ), + ] + result_file = io.StringIO() + with patch("main.subprocess.run", side_effect=results): + main.run_pr_message_checks( + ["ok first", "bad second", "bad third"], result_file + ) + + output = result_file.getvalue() + self.assertIn("Commit rejected.\n", output) + self.assertIn( + "\n--- Commit 3/3:\nType subject_imperative check failed ==> bad third\n", + output, + ) + self.assertNotIn( + "------------------------------------------------------------------------", + output, + ) + self.assertNotIn("\n\n\n", output) + + +class TestRunOtherChecks(unittest.TestCase): + def test_empty_args_returns_zero(self): + with patch("main.subprocess.run") as mock_run: + rc = main.run_other_checks([], io.StringIO()) + self.assertEqual(rc, 0) + mock_run.assert_not_called() + + def test_with_args_returns_returncode(self): + mock_result = MagicMock(returncode=1, stdout="branch check failed\n") + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_other_checks(["--branch", "--author-name"], io.StringIO()) + self.assertEqual(rc, 1) + + +class TestGetPrCommitMessages(unittest.TestCase): + def test_non_pr_event_returns_empty(self): + with patch.dict(os.environ, {"GITHUB_EVENT_NAME": "push"}): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + def test_merge_ref_is_preferred(self): + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch( + "main.get_messages_from_merge_ref", + return_value=["fix: first", "feat: second"], + ) as mock_merge, + patch("main.get_messages_from_head_ref") as mock_head, + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, ["fix: first", "feat: second"]) + mock_merge.assert_called_once() + mock_head.assert_not_called() + + def test_pull_request_target_is_supported(self): + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request_target"}), + patch("main.get_messages_from_merge_ref", return_value=["fix: first"]), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, ["fix: first"]) + + def test_falls_back_to_base_ref_when_merge_ref_is_unavailable(self): + with ( + patch.dict( + os.environ, + { + "GITHUB_EVENT_NAME": "pull_request", + "GITHUB_BASE_REF": "main", + }, + ), + patch("main.get_messages_from_merge_ref", return_value=[]), + patch( + "main.get_messages_from_head_ref", + return_value=["fix: first", "feat: second"], + ) as mock_head, + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, ["fix: first", "feat: second"]) + mock_head.assert_called_once_with("main") + + def test_exception_returns_empty(self): + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch( + "main.get_messages_from_merge_ref", side_effect=Exception("git failed") + ), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + +class TestGitMessageReaders(unittest.TestCase): + def test_get_messages_from_merge_ref(self): + mock_result = MagicMock( + returncode=0, stdout="fix: first\n\x00feat: second\n\x00" + ) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + result = main.get_messages_from_merge_ref() + self.assertEqual(result, ["fix: first", "feat: second"]) + self.assertEqual( + mock_run.call_args[0][0], + ["git", "log", "--pretty=format:%B%x00", "--reverse", "HEAD^1..HEAD^2"], + ) + + def test_get_messages_from_head_ref(self): + mock_result = MagicMock(returncode=0, stdout="fix: first\n\x00") + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + result = main.get_messages_from_head_ref("main") + self.assertEqual(result, ["fix: first"]) + self.assertEqual( + mock_run.call_args[0][0], + [ + "git", + "log", + "--pretty=format:%B%x00", + "--reverse", + "origin/main..HEAD", + ], + ) + + +class TestRunCommitCheck(unittest.TestCase): + def setUp(self): + self._orig_dir = os.getcwd() + import tempfile + + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_pr_path_calls_pr_message_checks(self): + with ( + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", False), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + patch("main.get_pr_commit_messages", return_value=["fix: something"]), + patch("main.run_pr_message_checks", return_value=0) as mock_pr, + patch("main.run_other_checks", return_value=0), + patch("main.run_check_command") as mock_command, + ): + rc = main.run_commit_check() + self.assertEqual(rc, 0) + mock_pr.assert_called_once() + mock_command.assert_not_called() + + def test_pr_path_returns_nonzero_when_any_check_fails(self): + with ( + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + patch("main.get_pr_commit_messages", return_value=["bad msg"]), + patch("main.run_pr_message_checks", return_value=1), + patch("main.run_other_checks", return_value=1), + ): + rc = main.run_commit_check() + self.assertEqual(rc, 1) + + def test_non_pr_path_uses_direct_command(self): + with ( + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", False), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + patch("main.get_pr_commit_messages", return_value=[]), + patch("main.run_pr_message_checks") as mock_pr, + patch("main.run_check_command", return_value=0) as mock_command, + ): + rc = main.run_commit_check() + self.assertEqual(rc, 0) + mock_pr.assert_not_called() + mock_command.assert_called_once() + + def test_message_disabled_uses_direct_command(self): + with ( + patch("main.MESSAGE_ENABLED", False), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + patch("main.run_pr_message_checks") as mock_pr, + patch("main.run_check_command", return_value=0) as mock_command, + ): + rc = main.run_commit_check() + self.assertEqual(rc, 0) + mock_pr.assert_not_called() + mock_command.assert_called_once() + + def test_result_txt_is_created(self): + with ( + patch("main.MESSAGE_ENABLED", False), + patch("main.BRANCH_ENABLED", False), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + patch("main.run_check_command", return_value=0), + ): + main.run_commit_check() + self.assertTrue(os.path.exists(os.path.join(self._tmpdir, "result.txt"))) + + def test_other_args_excludes_message(self): + captured_args = [] + + def fake_other_checks(args, result_file): + captured_args.extend(args) + return 0 + + with ( + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + patch("main.get_pr_commit_messages", return_value=["fix: x"]), + patch("main.run_pr_message_checks", return_value=0), + patch("main.run_other_checks", side_effect=fake_other_checks), + ): + main.run_commit_check() + self.assertNotIn("--message", captured_args) + self.assertIn("--branch", captured_args) + + +class TestReadResultFile(unittest.TestCase): + def setUp(self): + import tempfile + + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + + def tearDown(self): + os.chdir(self._orig_dir) + + def _write_result(self, content: str): + with open("result.txt", "w", encoding="utf-8") as file_obj: + file_obj.write(content) + + def test_empty_file_returns_none(self): + self._write_result("") + self.assertIsNone(main.read_result_file()) + + def test_file_with_content(self): + self._write_result("some output\n") + self.assertEqual(main.read_result_file(), "some output") + + def test_ansi_codes_are_stripped(self): + self._write_result("\x1b[31mError\x1b[0m: bad commit") + self.assertEqual(main.read_result_file(), "Error: bad commit") + + +class TestBuildResultBody(unittest.TestCase): + def test_success_body(self): + self.assertEqual(main.build_result_body(None), main.SUCCESS_TITLE) + + def test_failure_body(self): + result = main.build_result_body("bad commit") + self.assertIn(main.FAILURE_TITLE, result) + self.assertIn("bad commit", result) + + +class TestAddJobSummary(unittest.TestCase): + def setUp(self): + import tempfile + + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + with open("result.txt", "w", encoding="utf-8"): + pass + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_false_skips(self): + with patch("main.JOB_SUMMARY_ENABLED", False): + rc = main.add_job_summary() + self.assertEqual(rc, 0) + + def test_success_writes_success_title(self): + summary_path = os.path.join(self._tmpdir, "summary.txt") + with ( + patch("main.JOB_SUMMARY_ENABLED", True), + patch("main.GITHUB_STEP_SUMMARY", summary_path), + patch("main.read_result_file", return_value=None), + ): + rc = main.add_job_summary() + self.assertEqual(rc, 0) + with open(summary_path, encoding="utf-8") as file_obj: + content = file_obj.read() + self.assertIn(main.SUCCESS_TITLE, content) + + def test_failure_writes_failure_title(self): + summary_path = os.path.join(self._tmpdir, "summary.txt") + with ( + patch("main.JOB_SUMMARY_ENABLED", True), + patch("main.GITHUB_STEP_SUMMARY", summary_path), + patch("main.read_result_file", return_value="bad commit message"), + ): + rc = main.add_job_summary() + self.assertEqual(rc, 1) + with open(summary_path, encoding="utf-8") as file_obj: + content = file_obj.read() + self.assertIn(main.FAILURE_TITLE, content) + self.assertIn("bad commit message", content) + + +class TestIsForkPr(unittest.TestCase): + def test_no_event_path(self): + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("GITHUB_EVENT_PATH", None) + result = main.is_fork_pr() + self.assertFalse(result) + + def test_same_repo_not_fork(self): + import tempfile + + event = { + "pull_request": { + "head": {"repo": {"full_name": "owner/repo"}}, + "base": {"repo": {"full_name": "owner/repo"}}, + } + } + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as file_obj: + json.dump(event, file_obj) + event_path = file_obj.name + with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): + result = main.is_fork_pr() + self.assertFalse(result) + os.unlink(event_path) + + def test_different_repo_is_fork(self): + import tempfile + + event = { + "pull_request": { + "head": {"repo": {"full_name": "fork-owner/repo"}}, + "base": {"repo": {"full_name": "owner/repo"}}, + } + } + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as file_obj: + json.dump(event, file_obj) + event_path = file_obj.name + with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): + result = main.is_fork_pr() + self.assertTrue(result) + os.unlink(event_path) + + +class TestLogErrorAndExit(unittest.TestCase): + def test_exits_with_specified_code(self): + with self.assertRaises(SystemExit) as ctx: + main.log_error_and_exit("# Title", None, 0) + self.assertEqual(ctx.exception.code, 0) + + def test_with_result_text_prints_error(self): + with ( + patch("builtins.print") as mock_print, + self.assertRaises(SystemExit), + ): + main.log_error_and_exit("# Failure", "bad commit", 1) + printed = mock_print.call_args[0][0] + self.assertIn("::error::", printed) + self.assertIn("bad commit", printed) + + +class TestMain(unittest.TestCase): + def setUp(self): + import tempfile + + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + with open("result.txt", "w", encoding="utf-8"): + pass + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_success_path(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=0), + patch("main.add_job_summary", return_value=0), + patch("main.add_pr_comments", return_value=0), + patch("main.DRY_RUN_ENABLED", False), + patch("main.read_result_file", return_value=None), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 0) + + def test_multiple_failures_still_exit_with_one(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=1), + patch("main.add_job_summary", return_value=1), + patch("main.add_pr_comments", return_value=1), + patch("main.DRY_RUN_ENABLED", False), + patch("main.read_result_file", return_value="bad msg"), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 1) + + def test_dry_run_forces_zero(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=1), + patch("main.add_job_summary", return_value=1), + patch("main.add_pr_comments", return_value=0), + patch("main.DRY_RUN_ENABLED", True), + patch("main.read_result_file", return_value=None), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 0) + + +if __name__ == "__main__": + unittest.main()