From 0910b3d0f7610581e5c40f8b210d2b57e98840c1 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 16 Mar 2026 20:35:21 +0200 Subject: [PATCH 01/20] feat: implement PR commit message retrieval and validation in commit-check --- main.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 7491710..71ab225 100755 --- a/main.py +++ b/main.py @@ -36,6 +36,30 @@ def log_env_vars(): print(f"PR_COMMENTS = {PR_COMMENTS}\n") +def get_pr_commit_messages() -> list[str]: + """Get all commit messages for the current PR (pull_request event only). + + In a pull_request event, actions/checkout checks out a synthetic merge + commit (HEAD = merge of PR branch into base). HEAD^1 is the base branch + tip, HEAD^2 is the PR branch tip. So HEAD^1..HEAD^2 gives all PR commits. + """ + if os.getenv("GITHUB_EVENT_NAME", "") != "pull_request": + return [] + try: + result = subprocess.run( + ["git", "log", "--pretty=format:%B%x00", "HEAD^1..HEAD^2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + check=False, + ) + if result.returncode == 0 and result.stdout: + return [m.strip() for m in result.stdout.split("\x00") if m.strip()] + except Exception: + pass + return [] + + def run_commit_check() -> int: """Runs the commit-check command and logs the result.""" args = [ @@ -58,9 +82,39 @@ def run_commit_check() -> int: if value == "true" ] - command = ["commit-check"] + args - print(" ".join(command)) + total_rc = 0 with open("result.txt", "w") as result_file: + if MESSAGE == "true": + pr_messages = get_pr_commit_messages() + if pr_messages: + # In PR context: check each commit message individually to avoid + # only validating the synthetic merge commit at HEAD. + for msg in pr_messages: + result = subprocess.run( + ["commit-check", "--message"], + input=msg, + stdout=result_file, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + total_rc += result.returncode + + # Run non-message checks (branch, author) once + other_args = [a for a in args if a != "--message"] + if other_args: + command = ["commit-check"] + other_args + print(" ".join(command)) + result = subprocess.run( + command, stdout=result_file, stderr=subprocess.PIPE, check=False + ) + total_rc += result.returncode + + return total_rc + + # Non-PR context or message disabled: run all checks at once + command = ["commit-check"] + args + print(" ".join(command)) result = subprocess.run( command, stdout=result_file, stderr=subprocess.PIPE, check=False ) From 8a8c1f6422eb04506a46f04403b23f8016226b5a Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 16 Mar 2026 20:41:57 +0200 Subject: [PATCH 02/20] fix: correct variable names in subprocess calls for clarity --- main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 71ab225..846518c 100755 --- a/main.py +++ b/main.py @@ -105,20 +105,20 @@ def run_commit_check() -> int: if other_args: command = ["commit-check"] + other_args print(" ".join(command)) - result = subprocess.run( + other_result = subprocess.run( command, stdout=result_file, stderr=subprocess.PIPE, check=False ) - total_rc += result.returncode + total_rc += other_result.returncode return total_rc # Non-PR context or message disabled: run all checks at once command = ["commit-check"] + args print(" ".join(command)) - result = subprocess.run( + default_result = subprocess.run( command, stdout=result_file, stderr=subprocess.PIPE, check=False ) - return result.returncode + return default_result.returncode def read_result_file() -> str | None: From 235d3975180c298576cd1927b77d44c332554765 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 01:26:10 +0200 Subject: [PATCH 03/20] fix: get original commit message content in get_pr_commit_messages() (#190) --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 846518c..aa4bd88 100755 --- a/main.py +++ b/main.py @@ -54,7 +54,7 @@ def get_pr_commit_messages() -> list[str]: check=False, ) if result.returncode == 0 and result.stdout: - return [m.strip() for m in result.stdout.split("\x00") if m.strip()] + return [m.rstrip("\n") for m in result.stdout.split("\x00") if m.rstrip("\n")] except Exception: pass return [] From 44e900c44acb56f8fdd6eee1039f051d4b4419b2 Mon Sep 17 00:00:00 2001 From: Xianpeng Shen Date: Tue, 17 Mar 2026 01:38:30 +0200 Subject: [PATCH 04/20] chore: Update commit-check version to 2.4.3 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c1cf58e..2473c51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Install commit-check CLI # For details please see: https://github.com/commit-check/commit-check -commit-check==2.4.2 +commit-check==2.4.3 # Interact with the GitHub API. PyGithub==2.8.1 From cb56efe463648caabf3a64bff3599a9f8ae7d1e4 Mon Sep 17 00:00:00 2001 From: Xianpeng Shen Date: Tue, 17 Mar 2026 02:06:08 +0200 Subject: [PATCH 05/20] feat: Enable autofix for pull requests in pre-commit config --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 571d5ee..4eee43c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,6 @@ # https://pre-commit.com/ ci: + autofix_prs: true autofix_commit_msg: 'ci: auto fixes from pre-commit.com hooks' autoupdate_commit_msg: 'ci: pre-commit autoupdate' From 85f23c7b24cb0b7adf4a9e8872b0e3ccaa66de0f Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Tue, 17 Mar 2026 02:23:46 +0200 Subject: [PATCH 06/20] feat: refactor commit-check logic and add unit tests for new functionality --- main.py | 102 +++++----- main_test.py | 538 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 593 insertions(+), 47 deletions(-) create mode 100644 main_test.py diff --git a/main.py b/main.py index aa4bd88..bbcd7f8 100755 --- a/main.py +++ b/main.py @@ -60,28 +60,59 @@ def get_pr_commit_messages() -> list[str]: return [] -def run_commit_check() -> int: - """Runs the commit-check command and logs the result.""" - args = [ - "--message", - "--branch", - "--author-name", - "--author-email", - ] - args = [ - arg - for arg, value in zip( - args, - [ - MESSAGE, - BRANCH, - AUTHOR_NAME, - AUTHOR_EMAIL, - ], +def build_check_args( + message: str, branch: str, author_name: str, author_email: str +) -> list[str]: + """Maps 'true'/'false' flag values to CLI argument list.""" + flags = ["--message", "--branch", "--author-name", "--author-email"] + values = [message, branch, author_name, author_email] + return [flag for flag, value in zip(flags, values) if value == "true"] + + +def run_pr_message_checks(pr_messages: list[str], result_file) -> int: # type: ignore[type-arg] + """Checks each PR commit message individually via commit-check --message. + + Returns cumulative returncode across all messages. + """ + total_rc = 0 + for msg in pr_messages: + result = subprocess.run( + ["commit-check", "--message"], + input=msg, + stdout=result_file, + stderr=subprocess.PIPE, + text=True, + check=False, ) - if value == "true" - ] + total_rc += result.returncode + return total_rc + +def run_other_checks(args: list[str], result_file) -> int: # type: ignore[type-arg] + """Runs non-message checks (branch, author) once. Returns 0 if args is empty.""" + if not args: + return 0 + command = ["commit-check"] + args + print(" ".join(command)) + result = subprocess.run( + command, stdout=result_file, stderr=subprocess.PIPE, check=False + ) + return result.returncode + + +def run_default_checks(args: list[str], result_file) -> int: # type: ignore[type-arg] + """Runs all checks at once (non-PR context or message disabled).""" + command = ["commit-check"] + args + print(" ".join(command)) + result = subprocess.run( + command, stdout=result_file, stderr=subprocess.PIPE, check=False + ) + return result.returncode + + +def run_commit_check() -> int: + """Runs the commit-check command and logs the result.""" + args = build_check_args(MESSAGE, BRANCH, AUTHOR_NAME, AUTHOR_EMAIL) total_rc = 0 with open("result.txt", "w") as result_file: if MESSAGE == "true": @@ -89,36 +120,13 @@ def run_commit_check() -> int: if pr_messages: # In PR context: check each commit message individually to avoid # only validating the synthetic merge commit at HEAD. - for msg in pr_messages: - result = subprocess.run( - ["commit-check", "--message"], - input=msg, - stdout=result_file, - stderr=subprocess.PIPE, - text=True, - check=False, - ) - total_rc += result.returncode - - # Run non-message checks (branch, author) once + total_rc += run_pr_message_checks(pr_messages, result_file) other_args = [a for a in args if a != "--message"] - if other_args: - command = ["commit-check"] + other_args - print(" ".join(command)) - other_result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, check=False - ) - total_rc += other_result.returncode - + total_rc += run_other_checks(other_args, result_file) return total_rc - # Non-PR context or message disabled: run all checks at once - command = ["commit-check"] + args - print(" ".join(command)) - default_result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, check=False - ) - return default_result.returncode + total_rc += run_default_checks(args, result_file) + return total_rc def read_result_file() -> str | None: diff --git a/main_test.py b/main_test.py new file mode 100644 index 0000000..85a108c --- /dev/null +++ b/main_test.py @@ -0,0 +1,538 @@ +"""Unit tests for main.py""" +import io +import json +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +# GITHUB_STEP_SUMMARY is accessed via os.environ[] (not getenv) at import time, +# so we must set it before importing main. +os.environ.setdefault("GITHUB_STEP_SUMMARY", "/tmp/step_summary.txt") + +import main # noqa: E402 + + +class TestBuildCheckArgs(unittest.TestCase): + def test_all_true(self): + result = main.build_check_args("true", "true", "true", "true") + self.assertEqual(result, ["--message", "--branch", "--author-name", "--author-email"]) + + def test_all_false(self): + result = main.build_check_args("false", "false", "false", "false") + self.assertEqual(result, []) + + def test_message_only(self): + result = main.build_check_args("true", "false", "false", "false") + self.assertEqual(result, ["--message"]) + + def test_branch_only(self): + result = main.build_check_args("false", "true", "false", "false") + self.assertEqual(result, ["--branch"]) + + def test_author_name_and_email(self): + result = main.build_check_args("false", "false", "true", "true") + self.assertEqual(result, ["--author-name", "--author-email"]) + + def test_message_and_branch(self): + result = main.build_check_args("true", "true", "false", "false") + self.assertEqual(result, ["--message", "--branch"]) + + +class TestRunPrMessageChecks(unittest.TestCase): + def _make_file(self): + return io.StringIO() + + def test_single_message_pass(self): + mock_result = MagicMock() + mock_result.returncode = 0 + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + rc = main.run_pr_message_checks(["fix: something"], self._make_file()) + self.assertEqual(rc, 0) + mock_run.assert_called_once() + call_kwargs = mock_run.call_args + self.assertIn("--message", call_kwargs[0][0]) + self.assertEqual(call_kwargs[1]["input"], "fix: something") + + def test_single_message_fail(self): + mock_result = MagicMock() + mock_result.returncode = 1 + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_pr_message_checks(["bad commit"], self._make_file()) + self.assertEqual(rc, 1) + + def test_multiple_messages_partial_failure(self): + results = [MagicMock(returncode=0), MagicMock(returncode=1), MagicMock(returncode=0)] + with patch("main.subprocess.run", side_effect=results): + rc = main.run_pr_message_checks(["ok", "bad", "ok"], self._make_file()) + self.assertEqual(rc, 1) + + def test_multiple_messages_all_fail(self): + results = [MagicMock(returncode=1), MagicMock(returncode=1)] + with patch("main.subprocess.run", side_effect=results): + rc = main.run_pr_message_checks(["bad1", "bad2"], self._make_file()) + self.assertEqual(rc, 2) + + def test_empty_list(self): + with patch("main.subprocess.run") as mock_run: + rc = main.run_pr_message_checks([], self._make_file()) + self.assertEqual(rc, 0) + mock_run.assert_not_called() + + +class TestRunOtherChecks(unittest.TestCase): + def test_empty_args_returns_zero(self): + with patch("main.subprocess.run") as mock_run: + rc = main.run_other_checks([], io.StringIO()) + self.assertEqual(rc, 0) + mock_run.assert_not_called() + + def test_with_args_calls_subprocess(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + rc = main.run_other_checks(["--branch"], io.StringIO()) + self.assertEqual(rc, 0) + called_cmd = mock_run.call_args[0][0] + self.assertEqual(called_cmd, ["commit-check", "--branch"]) + + def test_with_args_returns_returncode(self): + mock_result = MagicMock(returncode=1) + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_other_checks(["--branch", "--author-name"], io.StringIO()) + self.assertEqual(rc, 1) + + def test_prints_command(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result): + with patch("builtins.print") as mock_print: + main.run_other_checks(["--branch"], io.StringIO()) + mock_print.assert_called_once_with("commit-check --branch") + + +class TestRunDefaultChecks(unittest.TestCase): + def test_rc_zero(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_default_checks(["--message", "--branch"], io.StringIO()) + self.assertEqual(rc, 0) + + def test_rc_one(self): + mock_result = MagicMock(returncode=1) + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_default_checks(["--message"], io.StringIO()) + self.assertEqual(rc, 1) + + def test_command_contains_all_args(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + main.run_default_checks(["--message", "--branch", "--author-name"], io.StringIO()) + called_cmd = mock_run.call_args[0][0] + self.assertEqual( + called_cmd, + ["commit-check", "--message", "--branch", "--author-name"], + ) + + def test_prints_command(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result): + with patch("builtins.print") as mock_print: + main.run_default_checks(["--branch"], io.StringIO()) + mock_print.assert_called_once_with("commit-check --branch") + + def test_empty_args(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + main.run_default_checks([], io.StringIO()) + called_cmd = mock_run.call_args[0][0] + self.assertEqual(called_cmd, ["commit-check"]) + + +class TestRunCommitCheck(unittest.TestCase): + def setUp(self): + # Ensure result.txt is written to a temp location + self._orig_dir = os.getcwd() + import tempfile + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_pr_path_calls_pr_message_checks(self): + with ( + patch("main.MESSAGE", "true"), + patch("main.BRANCH", "false"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.get_pr_commit_messages", return_value=["fix: something"]), + patch("main.run_pr_message_checks", return_value=0) as mock_pr, + patch("main.run_other_checks", return_value=0), + patch("main.run_default_checks") as mock_default, + ): + rc = main.run_commit_check() + mock_pr.assert_called_once() + mock_default.assert_not_called() + self.assertEqual(rc, 0) + + def test_pr_path_rc_accumulation(self): + with ( + patch("main.MESSAGE", "true"), + patch("main.BRANCH", "true"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.get_pr_commit_messages", return_value=["bad msg"]), + patch("main.run_pr_message_checks", return_value=2), + patch("main.run_other_checks", return_value=1), + ): + rc = main.run_commit_check() + self.assertEqual(rc, 3) + + def test_non_pr_path_uses_default_checks(self): + with ( + patch("main.MESSAGE", "true"), + patch("main.BRANCH", "false"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.get_pr_commit_messages", return_value=[]), + patch("main.run_pr_message_checks") as mock_pr, + patch("main.run_default_checks", return_value=0) as mock_default, + ): + rc = main.run_commit_check() + mock_pr.assert_not_called() + mock_default.assert_called_once() + self.assertEqual(rc, 0) + + def test_message_false_uses_default_checks(self): + with ( + patch("main.MESSAGE", "false"), + patch("main.BRANCH", "true"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.run_pr_message_checks") as mock_pr, + patch("main.run_default_checks", return_value=0) as mock_default, + ): + rc = main.run_commit_check() + mock_pr.assert_not_called() + mock_default.assert_called_once() + self.assertEqual(rc, 0) + + def test_result_txt_is_created(self): + with ( + patch("main.MESSAGE", "false"), + patch("main.BRANCH", "false"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.run_default_checks", return_value=0), + ): + main.run_commit_check() + self.assertTrue(os.path.exists(os.path.join(self._tmpdir, "result.txt"))) + + def test_other_args_excludes_message(self): + """When in PR path, run_other_checks must not receive --message.""" + captured_args = [] + + def fake_other_checks(args, result_file): + captured_args.extend(args) + return 0 + + with ( + patch("main.MESSAGE", "true"), + patch("main.BRANCH", "true"), + patch("main.AUTHOR_NAME", "false"), + patch("main.AUTHOR_EMAIL", "false"), + patch("main.get_pr_commit_messages", return_value=["fix: x"]), + patch("main.run_pr_message_checks", return_value=0), + patch("main.run_other_checks", side_effect=fake_other_checks), + ): + main.run_commit_check() + self.assertNotIn("--message", captured_args) + self.assertIn("--branch", captured_args) + + +class TestGetPrCommitMessages(unittest.TestCase): + def test_non_pr_event_returns_empty(self): + with patch.dict(os.environ, {"GITHUB_EVENT_NAME": "push"}): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + def test_pr_event_with_commits(self): + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "fix: first\n\x00feat: second\n\x00" + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.subprocess.run", return_value=mock_result), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, ["fix: first", "feat: second"]) + + def test_pr_event_empty_output(self): + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "" + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.subprocess.run", return_value=mock_result), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + def test_git_failure_returns_empty(self): + mock_result = MagicMock() + mock_result.returncode = 1 + mock_result.stdout = "" + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.subprocess.run", return_value=mock_result), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + def test_exception_returns_empty(self): + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.subprocess.run", side_effect=Exception("git not found")), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + +class TestReadResultFile(unittest.TestCase): + def setUp(self): + import tempfile + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + + def tearDown(self): + os.chdir(self._orig_dir) + + def _write_result(self, content: str): + with open("result.txt", "w") as f: + f.write(content) + + def test_empty_file_returns_none(self): + self._write_result("") + result = main.read_result_file() + self.assertIsNone(result) + + def test_file_with_content(self): + self._write_result("some output\n") + result = main.read_result_file() + self.assertEqual(result, "some output") + + def test_ansi_codes_are_stripped(self): + self._write_result("\x1B[31mError\x1B[0m: bad commit") + result = main.read_result_file() + self.assertEqual(result, "Error: bad commit") + + def test_trailing_whitespace_stripped(self): + self._write_result("output\n\n") + result = main.read_result_file() + self.assertEqual(result, "output") + + +class TestAddJobSummary(unittest.TestCase): + def setUp(self): + import tempfile + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + # Create an empty result.txt + open("result.txt", "w").close() + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_false_skips(self): + with patch("main.JOB_SUMMARY", "false"): + rc = main.add_job_summary() + self.assertEqual(rc, 0) + + def test_success_writes_success_title(self): + summary_path = os.path.join(self._tmpdir, "summary.txt") + with ( + patch("main.JOB_SUMMARY", "true"), + patch("main.GITHUB_STEP_SUMMARY", summary_path), + patch("main.read_result_file", return_value=None), + ): + rc = main.add_job_summary() + self.assertEqual(rc, 0) + with open(summary_path) as f: + content = f.read() + self.assertIn(main.SUCCESS_TITLE, content) + + def test_failure_writes_failure_title(self): + summary_path = os.path.join(self._tmpdir, "summary.txt") + with ( + patch("main.JOB_SUMMARY", "true"), + patch("main.GITHUB_STEP_SUMMARY", summary_path), + patch("main.read_result_file", return_value="bad commit message"), + ): + rc = main.add_job_summary() + self.assertEqual(rc, 1) + with open(summary_path) as f: + content = f.read() + self.assertIn(main.FAILURE_TITLE, content) + self.assertIn("bad commit message", content) + + +class TestIsForkPr(unittest.TestCase): + def test_no_event_path(self): + with patch.dict(os.environ, {}, clear=True): + # Remove GITHUB_EVENT_PATH if present + os.environ.pop("GITHUB_EVENT_PATH", None) + result = main.is_fork_pr() + self.assertFalse(result) + + def test_same_repo_not_fork(self): + import tempfile + event = { + "pull_request": { + "head": {"repo": {"full_name": "owner/repo"}}, + "base": {"repo": {"full_name": "owner/repo"}}, + } + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(event, f) + event_path = f.name + with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): + result = main.is_fork_pr() + self.assertFalse(result) + os.unlink(event_path) + + def test_different_repo_is_fork(self): + import tempfile + event = { + "pull_request": { + "head": {"repo": {"full_name": "fork-owner/repo"}}, + "base": {"repo": {"full_name": "owner/repo"}}, + } + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(event, f) + event_path = f.name + with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): + result = main.is_fork_pr() + self.assertTrue(result) + os.unlink(event_path) + + def test_json_parse_failure_returns_false(self): + import tempfile + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write("not valid json{{{") + event_path = f.name + with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): + result = main.is_fork_pr() + self.assertFalse(result) + os.unlink(event_path) + + +class TestLogErrorAndExit(unittest.TestCase): + def test_exits_with_specified_code(self): + with self.assertRaises(SystemExit) as ctx: + main.log_error_and_exit("# Title", None, 0) + self.assertEqual(ctx.exception.code, 0) + + def test_exits_with_nonzero_code(self): + with self.assertRaises(SystemExit) as ctx: + main.log_error_and_exit("# Title", None, 2) + self.assertEqual(ctx.exception.code, 2) + + def test_with_result_text_prints_error(self): + with ( + patch("builtins.print") as mock_print, + self.assertRaises(SystemExit), + ): + main.log_error_and_exit("# Failure", "bad commit", 1) + mock_print.assert_called_once() + printed = mock_print.call_args[0][0] + self.assertIn("::error::", printed) + self.assertIn("bad commit", printed) + + def test_without_result_text_no_print(self): + with ( + patch("builtins.print") as mock_print, + self.assertRaises(SystemExit), + ): + main.log_error_and_exit("# Failure", None, 1) + mock_print.assert_not_called() + + def test_empty_string_result_text_no_print(self): + with ( + patch("builtins.print") as mock_print, + self.assertRaises(SystemExit), + ): + main.log_error_and_exit("# Failure", "", 1) + mock_print.assert_not_called() + + +class TestMain(unittest.TestCase): + def setUp(self): + import tempfile + self._orig_dir = os.getcwd() + self._tmpdir = tempfile.mkdtemp() + os.chdir(self._tmpdir) + open("result.txt", "w").close() + + def tearDown(self): + os.chdir(self._orig_dir) + + def test_success_path(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=0), + patch("main.add_job_summary", return_value=0), + patch("main.add_pr_comments", return_value=0), + patch("main.DRY_RUN", "false"), + patch("main.read_result_file", return_value=None), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 0) + + def test_failure_path(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=1), + patch("main.add_job_summary", return_value=0), + patch("main.add_pr_comments", return_value=0), + patch("main.DRY_RUN", "false"), + patch("main.read_result_file", return_value="bad msg"), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 1) + + def test_dry_run_forces_zero(self): + with ( + patch("main.log_env_vars"), + patch("main.run_commit_check", return_value=1), + patch("main.add_job_summary", return_value=1), + patch("main.add_pr_comments", return_value=0), + patch("main.DRY_RUN", "true"), + patch("main.read_result_file", return_value=None), + self.assertRaises(SystemExit) as ctx, + ): + main.main() + self.assertEqual(ctx.exception.code, 0) + + def test_all_subfunctions_called(self): + with ( + patch("main.log_env_vars") as mock_log, + patch("main.run_commit_check", return_value=0) as mock_run, + patch("main.add_job_summary", return_value=0) as mock_summary, + patch("main.add_pr_comments", return_value=0) as mock_comments, + patch("main.DRY_RUN", "false"), + patch("main.read_result_file", return_value=None), + self.assertRaises(SystemExit), + ): + main.main() + mock_log.assert_called_once() + mock_run.assert_called_once() + mock_summary.assert_called_once() + mock_comments.assert_called_once() + + +if __name__ == "__main__": + unittest.main() From 8d8615e8035706434b60600e912c07f74482fa61 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Tue, 17 Mar 2026 02:24:30 +0200 Subject: [PATCH 07/20] fix: format list comprehensions for better readability in get_pr_commit_messages() and related tests --- main.py | 4 +++- main_test.py | 22 +++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index bbcd7f8..0bd15b7 100755 --- a/main.py +++ b/main.py @@ -54,7 +54,9 @@ def get_pr_commit_messages() -> list[str]: check=False, ) if result.returncode == 0 and result.stdout: - return [m.rstrip("\n") for m in result.stdout.split("\x00") if m.rstrip("\n")] + return [ + m.rstrip("\n") for m in result.stdout.split("\x00") if m.rstrip("\n") + ] except Exception: pass return [] diff --git a/main_test.py b/main_test.py index 85a108c..b2a7525 100644 --- a/main_test.py +++ b/main_test.py @@ -1,4 +1,5 @@ """Unit tests for main.py""" + import io import json import os @@ -16,7 +17,9 @@ class TestBuildCheckArgs(unittest.TestCase): def test_all_true(self): result = main.build_check_args("true", "true", "true", "true") - self.assertEqual(result, ["--message", "--branch", "--author-name", "--author-email"]) + self.assertEqual( + result, ["--message", "--branch", "--author-name", "--author-email"] + ) def test_all_false(self): result = main.build_check_args("false", "false", "false", "false") @@ -62,7 +65,11 @@ def test_single_message_fail(self): self.assertEqual(rc, 1) def test_multiple_messages_partial_failure(self): - results = [MagicMock(returncode=0), MagicMock(returncode=1), MagicMock(returncode=0)] + results = [ + MagicMock(returncode=0), + MagicMock(returncode=1), + MagicMock(returncode=0), + ] with patch("main.subprocess.run", side_effect=results): rc = main.run_pr_message_checks(["ok", "bad", "ok"], self._make_file()) self.assertEqual(rc, 1) @@ -125,7 +132,9 @@ def test_rc_one(self): def test_command_contains_all_args(self): mock_result = MagicMock(returncode=0) with patch("main.subprocess.run", return_value=mock_result) as mock_run: - main.run_default_checks(["--message", "--branch", "--author-name"], io.StringIO()) + main.run_default_checks( + ["--message", "--branch", "--author-name"], io.StringIO() + ) called_cmd = mock_run.call_args[0][0] self.assertEqual( called_cmd, @@ -152,6 +161,7 @@ def setUp(self): # Ensure result.txt is written to a temp location self._orig_dir = os.getcwd() import tempfile + self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) @@ -300,6 +310,7 @@ def test_exception_returns_empty(self): class TestReadResultFile(unittest.TestCase): def setUp(self): import tempfile + self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) @@ -335,6 +346,7 @@ def test_trailing_whitespace_stripped(self): class TestAddJobSummary(unittest.TestCase): def setUp(self): import tempfile + self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) @@ -387,6 +399,7 @@ def test_no_event_path(self): def test_same_repo_not_fork(self): import tempfile + event = { "pull_request": { "head": {"repo": {"full_name": "owner/repo"}}, @@ -403,6 +416,7 @@ def test_same_repo_not_fork(self): def test_different_repo_is_fork(self): import tempfile + event = { "pull_request": { "head": {"repo": {"full_name": "fork-owner/repo"}}, @@ -419,6 +433,7 @@ def test_different_repo_is_fork(self): def test_json_parse_failure_returns_false(self): import tempfile + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write("not valid json{{{") event_path = f.name @@ -470,6 +485,7 @@ def test_empty_string_result_text_no_print(self): class TestMain(unittest.TestCase): def setUp(self): import tempfile + self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) From 83fd88ba9d295e05f5ebf3dabaad0cd1fc90a905 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Tue, 17 Mar 2026 02:25:32 +0200 Subject: [PATCH 08/20] fix: add args to codespell hook to ignore specific words --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4eee43c..dad2476 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,3 +26,4 @@ repos: rev: v2.3.0 hooks: - id: codespell + args: [--ignore-words-list=assertin] From 39d41e3f105a3db17447c6c34f02b02ac0ced50e Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Tue, 17 Mar 2026 02:39:17 +0200 Subject: [PATCH 09/20] fix: update main_test.py --- main_test.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/main_test.py b/main_test.py index b2a7525..13670ec 100644 --- a/main_test.py +++ b/main_test.py @@ -319,7 +319,7 @@ def tearDown(self): os.chdir(self._orig_dir) def _write_result(self, content: str): - with open("result.txt", "w") as f: + with open("result.txt", "w", encoding="utf-8") as f: f.write(content) def test_empty_file_returns_none(self): @@ -351,7 +351,8 @@ def setUp(self): self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) # Create an empty result.txt - open("result.txt", "w").close() + with open("result.txt", "w", encoding="utf-8"): + pass def tearDown(self): os.chdir(self._orig_dir) @@ -370,7 +371,7 @@ def test_success_writes_success_title(self): ): rc = main.add_job_summary() self.assertEqual(rc, 0) - with open(summary_path) as f: + with open(summary_path, encoding="utf-8") as f: content = f.read() self.assertIn(main.SUCCESS_TITLE, content) @@ -383,7 +384,7 @@ def test_failure_writes_failure_title(self): ): rc = main.add_job_summary() self.assertEqual(rc, 1) - with open(summary_path) as f: + with open(summary_path, encoding="utf-8") as f: content = f.read() self.assertIn(main.FAILURE_TITLE, content) self.assertIn("bad commit message", content) @@ -489,7 +490,8 @@ def setUp(self): self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) - open("result.txt", "w").close() + with open("result.txt", "w", encoding="utf-8"): + pass def tearDown(self): os.chdir(self._orig_dir) From 68fd620620a011ed05c56791cad10a4732236da2 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 20 Apr 2026 09:52:47 +0300 Subject: [PATCH 10/20] fix: improve error handling and return codes in commit message checks --- .gitignore | 1 + main.py | 19 +++++++++++-------- main_test.py | 6 +++--- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 43f4b6f..2233054 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ venv/ .venv/ +__pycache__/ diff --git a/main.py b/main.py index 0bd15b7..2c3c75b 100755 --- a/main.py +++ b/main.py @@ -57,8 +57,11 @@ def get_pr_commit_messages() -> list[str]: return [ m.rstrip("\n") for m in result.stdout.split("\x00") if m.rstrip("\n") ] - except Exception: - pass + except Exception as e: + print( + f"::warning::Failed to retrieve PR commit messages: {e}", + file=sys.stderr, + ) return [] @@ -74,9 +77,9 @@ def build_check_args( def run_pr_message_checks(pr_messages: list[str], result_file) -> int: # type: ignore[type-arg] """Checks each PR commit message individually via commit-check --message. - Returns cumulative returncode across all messages. + Returns 1 if any message fails, 0 if all pass. """ - total_rc = 0 + has_failure = False for msg in pr_messages: result = subprocess.run( ["commit-check", "--message"], @@ -86,8 +89,8 @@ def run_pr_message_checks(pr_messages: list[str], result_file) -> int: # type: text=True, check=False, ) - total_rc += result.returncode - return total_rc + has_failure = has_failure or (result.returncode != 0) + return 1 if has_failure else 0 def run_other_checks(args: list[str], result_file) -> int: # type: ignore[type-arg] @@ -97,7 +100,7 @@ def run_other_checks(args: list[str], result_file) -> int: # type: ignore[type- command = ["commit-check"] + args print(" ".join(command)) result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, check=False + command, stdout=result_file, stderr=subprocess.PIPE, text=True, check=False ) return result.returncode @@ -107,7 +110,7 @@ def run_default_checks(args: list[str], result_file) -> int: # type: ignore[typ command = ["commit-check"] + args print(" ".join(command)) result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, check=False + command, stdout=result_file, stderr=subprocess.PIPE, text=True, check=False ) return result.returncode diff --git a/main_test.py b/main_test.py index 13670ec..98086e5 100644 --- a/main_test.py +++ b/main_test.py @@ -78,7 +78,7 @@ def test_multiple_messages_all_fail(self): results = [MagicMock(returncode=1), MagicMock(returncode=1)] with patch("main.subprocess.run", side_effect=results): rc = main.run_pr_message_checks(["bad1", "bad2"], self._make_file()) - self.assertEqual(rc, 2) + self.assertEqual(rc, 1) def test_empty_list(self): with patch("main.subprocess.run") as mock_run: @@ -191,11 +191,11 @@ def test_pr_path_rc_accumulation(self): patch("main.AUTHOR_NAME", "false"), patch("main.AUTHOR_EMAIL", "false"), patch("main.get_pr_commit_messages", return_value=["bad msg"]), - patch("main.run_pr_message_checks", return_value=2), + patch("main.run_pr_message_checks", return_value=1), patch("main.run_other_checks", return_value=1), ): rc = main.run_commit_check() - self.assertEqual(rc, 3) + self.assertEqual(rc, 2) def test_non_pr_path_uses_default_checks(self): with ( From dd6ac2c1efac57ceaa2d3db61a23794e6d657d6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Apr 2026 06:52:59 +0000 Subject: [PATCH 11/20] fix: auto fixes from pre-commit.com hooks --- main_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_test.py b/main_test.py index 98086e5..87b42f1 100644 --- a/main_test.py +++ b/main_test.py @@ -333,7 +333,7 @@ def test_file_with_content(self): self.assertEqual(result, "some output") def test_ansi_codes_are_stripped(self): - self._write_result("\x1B[31mError\x1B[0m: bad commit") + self._write_result("\x1b[31mError\x1b[0m: bad commit") result = main.read_result_file() self.assertEqual(result, "Error: bad commit") From 45c85fa1c82ef53bfec374ef9b121df138e69946 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 20 Apr 2026 19:11:08 +0300 Subject: [PATCH 12/20] refactor: clean up PR commit validation flow --- main.py | 216 ++++++++++++++--------- main_test.py | 473 +++++++++++++++++++++++---------------------------- 2 files changed, 351 insertions(+), 338 deletions(-) diff --git a/main.py b/main.py index 88525f1..74bf03f 100755 --- a/main.py +++ b/main.py @@ -1,14 +1,15 @@ #!/usr/bin/env python3 import json import os -import sys -import subprocess import re -from github import Github, Auth, GithubException # type: ignore +import subprocess +import sys +from typing import TextIO # Constants for message titles SUCCESS_TITLE = "# Commit-Check ✔️" FAILURE_TITLE = "# Commit-Check ❌" +COMMIT_MESSAGE_DELIMITER = "\x00" # Environment variables MESSAGE = os.getenv("MESSAGE", "false") @@ -19,9 +20,20 @@ JOB_SUMMARY = os.getenv("JOB_SUMMARY", "false") PR_COMMENTS = os.getenv("PR_COMMENTS", "false") GITHUB_STEP_SUMMARY = os.environ["GITHUB_STEP_SUMMARY"] -GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") -GITHUB_REPOSITORY = os.getenv("GITHUB_REPOSITORY") -GITHUB_REF = os.getenv("GITHUB_REF") + + +def env_flag(name: str, default: str = "false") -> bool: + """Read a GitHub Action boolean-style environment variable.""" + return os.getenv(name, default).lower() == "true" + + +MESSAGE_ENABLED = env_flag("MESSAGE") +BRANCH_ENABLED = env_flag("BRANCH") +AUTHOR_NAME_ENABLED = env_flag("AUTHOR_NAME") +AUTHOR_EMAIL_ENABLED = env_flag("AUTHOR_EMAIL") +DRY_RUN_ENABLED = env_flag("DRY_RUN") +JOB_SUMMARY_ENABLED = env_flag("JOB_SUMMARY") +PR_COMMENTS_ENABLED = env_flag("PR_COMMENTS") def log_env_vars(): @@ -35,27 +47,74 @@ def log_env_vars(): print(f"PR_COMMENTS = {PR_COMMENTS}\n") +def is_pr_event() -> bool: + """Return whether the workflow was triggered by a PR-style event.""" + return os.getenv("GITHUB_EVENT_NAME", "") in {"pull_request", "pull_request_target"} + + +def parse_commit_messages(output: str) -> list[str]: + """Split git log output into individual commit messages.""" + return [ + message.rstrip("\n") + for message in output.split(COMMIT_MESSAGE_DELIMITER) + if message.rstrip("\n") + ] + + +def get_messages_from_merge_ref() -> list[str]: + """Read PR commit messages from GitHub's synthetic merge commit.""" + result = subprocess.run( + ["git", "log", "--pretty=format:%B%x00", "--reverse", "HEAD^1..HEAD^2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + check=False, + ) + if result.returncode == 0 and result.stdout: + return parse_commit_messages(result.stdout) + return [] + + +def get_messages_from_head_ref(base_ref: str) -> list[str]: + """Read PR commit messages when the workflow checks out the head SHA.""" + result = subprocess.run( + [ + "git", + "log", + "--pretty=format:%B%x00", + "--reverse", + f"origin/{base_ref}..HEAD", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + check=False, + ) + if result.returncode == 0 and result.stdout: + return parse_commit_messages(result.stdout) + return [] + + def get_pr_commit_messages() -> list[str]: - """Get all commit messages for the current PR (pull_request event only). + """Get all commit messages for the current PR workflow. - In a pull_request event, actions/checkout checks out a synthetic merge + In pull_request-style workflows, actions/checkout checks out a synthetic merge commit (HEAD = merge of PR branch into base). HEAD^1 is the base branch tip, HEAD^2 is the PR branch tip. So HEAD^1..HEAD^2 gives all PR commits. + If the workflow explicitly checks out the PR head SHA instead, fall back to + diffing against origin/ when that ref is available locally. """ - if os.getenv("GITHUB_EVENT_NAME", "") != "pull_request": + if not is_pr_event(): return [] + try: - result = subprocess.run( - ["git", "log", "--pretty=format:%B%x00", "HEAD^1..HEAD^2"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - encoding="utf-8", - check=False, - ) - if result.returncode == 0 and result.stdout: - return [ - m.rstrip("\n") for m in result.stdout.split("\x00") if m.rstrip("\n") - ] + messages = get_messages_from_merge_ref() + if messages: + return messages + + base_ref = os.getenv("GITHUB_BASE_REF", "") + if base_ref: + return get_messages_from_head_ref(base_ref) except Exception as e: print( f"::warning::Failed to retrieve PR commit messages: {e}", @@ -64,73 +123,72 @@ def get_pr_commit_messages() -> list[str]: return [] -def build_check_args( - message: str, branch: str, author_name: str, author_email: str -) -> list[str]: - """Maps 'true'/'false' flag values to CLI argument list.""" - flags = ["--message", "--branch", "--author-name", "--author-email"] - values = [message, branch, author_name, author_email] - return [flag for flag, value in zip(flags, values) if value == "true"] +def run_check_command( + args: list[str], result_file: TextIO, input_text: str | None = None +) -> int: + """Run commit-check and write both stdout and stderr to the result file.""" + command = ["commit-check"] + args + print(" ".join(command)) + result = subprocess.run( + command, + input=input_text, + stdout=result_file, + stderr=subprocess.STDOUT, + text=True, + check=False, + ) + return result.returncode -def run_pr_message_checks(pr_messages: list[str], result_file) -> int: # type: ignore[type-arg] +def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: """Checks each PR commit message individually via commit-check --message. Returns 1 if any message fails, 0 if all pass. """ has_failure = False - for msg in pr_messages: - result = subprocess.run( - ["commit-check", "--message"], - input=msg, - stdout=result_file, - stderr=subprocess.PIPE, - text=True, - check=False, - ) - has_failure = has_failure or (result.returncode != 0) + total_messages = len(pr_messages) + for index, msg in enumerate(pr_messages, start=1): + subject = msg.splitlines()[0] if msg else "" + result_file.write(f"\n--- Commit {index}/{total_messages}: {subject}\n") + has_failure = run_check_command( + ["--message"], result_file, input_text=msg + ) != 0 or has_failure return 1 if has_failure else 0 -def run_other_checks(args: list[str], result_file) -> int: # type: ignore[type-arg] +def run_other_checks(args: list[str], result_file: TextIO) -> int: """Runs non-message checks (branch, author) once. Returns 0 if args is empty.""" if not args: return 0 - command = ["commit-check"] + args - print(" ".join(command)) - result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, text=True, check=False - ) - return result.returncode + return run_check_command(args, result_file) -def run_default_checks(args: list[str], result_file) -> int: # type: ignore[type-arg] - """Runs all checks at once (non-PR context or message disabled).""" - command = ["commit-check"] + args - print(" ".join(command)) - result = subprocess.run( - command, stdout=result_file, stderr=subprocess.PIPE, text=True, check=False - ) - return result.returncode +def build_check_args() -> list[str]: + """Map enabled validation switches to commit-check CLI arguments.""" + flags = [ + ("--message", MESSAGE_ENABLED), + ("--branch", BRANCH_ENABLED), + ("--author-name", AUTHOR_NAME_ENABLED), + ("--author-email", AUTHOR_EMAIL_ENABLED), + ] + return [flag for flag, enabled in flags if enabled] def run_commit_check() -> int: """Runs the commit-check command and logs the result.""" - args = build_check_args(MESSAGE, BRANCH, AUTHOR_NAME, AUTHOR_EMAIL) - total_rc = 0 + args = build_check_args() with open("result.txt", "w") as result_file: - if MESSAGE == "true": + if MESSAGE_ENABLED: pr_messages = get_pr_commit_messages() if pr_messages: # In PR context: check each commit message individually to avoid # only validating the synthetic merge commit at HEAD. - total_rc += run_pr_message_checks(pr_messages, result_file) + message_rc = run_pr_message_checks(pr_messages, result_file) other_args = [a for a in args if a != "--message"] - total_rc += run_other_checks(other_args, result_file) - return total_rc + other_rc = run_other_checks(other_args, result_file) + return 1 if message_rc or other_rc else 0 # Non-PR context or message disabled: run all checks at once - total_rc += run_default_checks(args, result_file) - return total_rc + return 1 if run_check_command(args, result_file) else 0 def read_result_file() -> str | None: @@ -144,21 +202,22 @@ def read_result_file() -> str | None: return None +def build_result_body(result_text: str | None) -> str: + """Create the human-readable result body used in summaries and PR comments.""" + if result_text is None: + return SUCCESS_TITLE + return f"{FAILURE_TITLE}\n```\n{result_text}\n```" + + def add_job_summary() -> int: """Adds the commit check result to the GitHub job summary.""" - if JOB_SUMMARY == "false": + if not JOB_SUMMARY_ENABLED: return 0 result_text = read_result_file() - summary_content = ( - SUCCESS_TITLE - if result_text is None - else f"{FAILURE_TITLE}\n```\n{result_text}\n```" - ) - with open(GITHUB_STEP_SUMMARY, "a") as summary_file: - summary_file.write(summary_content) + summary_file.write(build_result_body(result_text)) return 0 if result_text is None else 1 @@ -183,7 +242,7 @@ def is_fork_pr() -> bool: def add_pr_comments() -> int: """Posts the commit check result as a comment on the pull request.""" - if PR_COMMENTS == "false": + if not PR_COMMENTS_ENABLED: return 0 # Fork PRs triggered by the pull_request event receive a read-only token; @@ -199,6 +258,8 @@ def add_pr_comments() -> int: return 0 try: + from github import Auth, Github, GithubException # type: ignore + token = os.getenv("GITHUB_TOKEN") repo_name = os.getenv("GITHUB_REPOSITORY") pr_number = os.getenv("GITHUB_REF") @@ -214,15 +275,9 @@ def add_pr_comments() -> int: repo = g.get_repo(repo_name) pull_request = repo.get_issue(int(pr_number)) - # Prepare comment content result_text = read_result_file() - pr_comment_body = ( - SUCCESS_TITLE - if result_text is None - else f"{FAILURE_TITLE}\n```\n{result_text}\n```" - ) + pr_comment_body = build_result_body(result_text) - # Fetch all existing comments on the PR comments = pull_request.get_comments() matching_comments = [ c @@ -282,12 +337,9 @@ def main(): """Main function to run commit-check, add job summary and post PR comments.""" log_env_vars() - # Combine return codes - ret_code = run_commit_check() - ret_code += add_job_summary() - ret_code += add_pr_comments() + ret_code = max(run_commit_check(), add_job_summary(), add_pr_comments()) - if DRY_RUN == "true": + if DRY_RUN_ENABLED: ret_code = 0 result_text = read_result_file() diff --git a/main_test.py b/main_test.py index 87b42f1..43480b0 100644 --- a/main_test.py +++ b/main_test.py @@ -1,9 +1,8 @@ -"""Unit tests for main.py""" +"""Unit tests for main.py.""" import io import json import os -import sys import unittest from unittest.mock import MagicMock, patch @@ -14,55 +13,93 @@ import main # noqa: E402 +class TestEnvFlag(unittest.TestCase): + def test_true_value(self): + with patch.dict(os.environ, {"FEATURE_FLAG": "true"}): + self.assertTrue(main.env_flag("FEATURE_FLAG")) + + def test_false_value(self): + with patch.dict(os.environ, {"FEATURE_FLAG": "false"}): + self.assertFalse(main.env_flag("FEATURE_FLAG")) + + def test_missing_uses_default(self): + with patch.dict(os.environ, {}, clear=True): + self.assertTrue(main.env_flag("FEATURE_FLAG", default="true")) + + class TestBuildCheckArgs(unittest.TestCase): def test_all_true(self): - result = main.build_check_args("true", "true", "true", "true") + with ( + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", True), + patch("main.AUTHOR_EMAIL_ENABLED", True), + ): + result = main.build_check_args() self.assertEqual( result, ["--message", "--branch", "--author-name", "--author-email"] ) def test_all_false(self): - result = main.build_check_args("false", "false", "false", "false") + with ( + patch("main.MESSAGE_ENABLED", False), + patch("main.BRANCH_ENABLED", False), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + ): + result = main.build_check_args() self.assertEqual(result, []) - def test_message_only(self): - result = main.build_check_args("true", "false", "false", "false") - self.assertEqual(result, ["--message"]) + def test_message_and_branch(self): + with ( + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + ): + result = main.build_check_args() + self.assertEqual(result, ["--message", "--branch"]) + - def test_branch_only(self): - result = main.build_check_args("false", "true", "false", "false") - self.assertEqual(result, ["--branch"]) +class TestParseCommitMessages(unittest.TestCase): + def test_splits_messages_and_trims_trailing_newlines(self): + result = main.parse_commit_messages("fix: first\n\x00feat: second\n\n\x00") + self.assertEqual(result, ["fix: first", "feat: second"]) - def test_author_name_and_email(self): - result = main.build_check_args("false", "false", "true", "true") - self.assertEqual(result, ["--author-name", "--author-email"]) - def test_message_and_branch(self): - result = main.build_check_args("true", "true", "false", "false") - self.assertEqual(result, ["--message", "--branch"]) +class TestRunCheckCommand(unittest.TestCase): + def test_with_args_calls_subprocess(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + rc = main.run_check_command(["--branch"], io.StringIO()) + self.assertEqual(rc, 0) + self.assertEqual(mock_run.call_args[0][0], ["commit-check", "--branch"]) + def test_with_input_uses_text_mode(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result) as mock_run: + main.run_check_command(["--message"], io.StringIO(), input_text="fix: demo") + self.assertEqual(mock_run.call_args[1]["input"], "fix: demo") + self.assertTrue(mock_run.call_args[1]["text"]) + + def test_prints_command(self): + mock_result = MagicMock(returncode=0) + with patch("main.subprocess.run", return_value=mock_result): + with patch("builtins.print") as mock_print: + main.run_check_command(["--branch"], io.StringIO()) + mock_print.assert_called_once_with("commit-check --branch") -class TestRunPrMessageChecks(unittest.TestCase): - def _make_file(self): - return io.StringIO() +class TestRunPrMessageChecks(unittest.TestCase): def test_single_message_pass(self): - mock_result = MagicMock() - mock_result.returncode = 0 + mock_result = MagicMock(returncode=0) + result_file = io.StringIO() with patch("main.subprocess.run", return_value=mock_result) as mock_run: - rc = main.run_pr_message_checks(["fix: something"], self._make_file()) + rc = main.run_pr_message_checks(["fix: something"], result_file) self.assertEqual(rc, 0) - mock_run.assert_called_once() - call_kwargs = mock_run.call_args - self.assertIn("--message", call_kwargs[0][0]) - self.assertEqual(call_kwargs[1]["input"], "fix: something") - - def test_single_message_fail(self): - mock_result = MagicMock() - mock_result.returncode = 1 - with patch("main.subprocess.run", return_value=mock_result): - rc = main.run_pr_message_checks(["bad commit"], self._make_file()) - self.assertEqual(rc, 1) + self.assertEqual(mock_run.call_args[0][0], ["commit-check", "--message"]) + self.assertEqual(mock_run.call_args[1]["input"], "fix: something") + self.assertIn("--- Commit 1/1: fix: something", result_file.getvalue()) def test_multiple_messages_partial_failure(self): results = [ @@ -71,18 +108,12 @@ def test_multiple_messages_partial_failure(self): MagicMock(returncode=0), ] with patch("main.subprocess.run", side_effect=results): - rc = main.run_pr_message_checks(["ok", "bad", "ok"], self._make_file()) - self.assertEqual(rc, 1) - - def test_multiple_messages_all_fail(self): - results = [MagicMock(returncode=1), MagicMock(returncode=1)] - with patch("main.subprocess.run", side_effect=results): - rc = main.run_pr_message_checks(["bad1", "bad2"], self._make_file()) + rc = main.run_pr_message_checks(["ok", "bad", "ok"], io.StringIO()) self.assertEqual(rc, 1) def test_empty_list(self): with patch("main.subprocess.run") as mock_run: - rc = main.run_pr_message_checks([], self._make_file()) + rc = main.run_pr_message_checks([], io.StringIO()) self.assertEqual(rc, 0) mock_run.assert_not_called() @@ -94,71 +125,99 @@ def test_empty_args_returns_zero(self): self.assertEqual(rc, 0) mock_run.assert_not_called() - def test_with_args_calls_subprocess(self): - mock_result = MagicMock(returncode=0) - with patch("main.subprocess.run", return_value=mock_result) as mock_run: - rc = main.run_other_checks(["--branch"], io.StringIO()) - self.assertEqual(rc, 0) - called_cmd = mock_run.call_args[0][0] - self.assertEqual(called_cmd, ["commit-check", "--branch"]) - def test_with_args_returns_returncode(self): mock_result = MagicMock(returncode=1) with patch("main.subprocess.run", return_value=mock_result): rc = main.run_other_checks(["--branch", "--author-name"], io.StringIO()) self.assertEqual(rc, 1) - def test_prints_command(self): - mock_result = MagicMock(returncode=0) - with patch("main.subprocess.run", return_value=mock_result): - with patch("builtins.print") as mock_print: - main.run_other_checks(["--branch"], io.StringIO()) - mock_print.assert_called_once_with("commit-check --branch") +class TestGetPrCommitMessages(unittest.TestCase): + def test_non_pr_event_returns_empty(self): + with patch.dict(os.environ, {"GITHUB_EVENT_NAME": "push"}): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) -class TestRunDefaultChecks(unittest.TestCase): - def test_rc_zero(self): - mock_result = MagicMock(returncode=0) - with patch("main.subprocess.run", return_value=mock_result): - rc = main.run_default_checks(["--message", "--branch"], io.StringIO()) - self.assertEqual(rc, 0) + def test_merge_ref_is_preferred(self): + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch( + "main.get_messages_from_merge_ref", + return_value=["fix: first", "feat: second"], + ) as mock_merge, + patch("main.get_messages_from_head_ref") as mock_head, + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, ["fix: first", "feat: second"]) + mock_merge.assert_called_once() + mock_head.assert_not_called() - def test_rc_one(self): - mock_result = MagicMock(returncode=1) - with patch("main.subprocess.run", return_value=mock_result): - rc = main.run_default_checks(["--message"], io.StringIO()) - self.assertEqual(rc, 1) + def test_pull_request_target_is_supported(self): + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request_target"}), + patch("main.get_messages_from_merge_ref", return_value=["fix: first"]), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, ["fix: first"]) - def test_command_contains_all_args(self): - mock_result = MagicMock(returncode=0) + def test_falls_back_to_base_ref_when_merge_ref_is_unavailable(self): + with ( + patch.dict( + os.environ, + { + "GITHUB_EVENT_NAME": "pull_request", + "GITHUB_BASE_REF": "main", + }, + ), + patch("main.get_messages_from_merge_ref", return_value=[]), + patch( + "main.get_messages_from_head_ref", + return_value=["fix: first", "feat: second"], + ) as mock_head, + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, ["fix: first", "feat: second"]) + mock_head.assert_called_once_with("main") + + def test_exception_returns_empty(self): + with ( + patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), + patch("main.get_messages_from_merge_ref", side_effect=Exception("git failed")), + ): + result = main.get_pr_commit_messages() + self.assertEqual(result, []) + + +class TestGitMessageReaders(unittest.TestCase): + def test_get_messages_from_merge_ref(self): + mock_result = MagicMock(returncode=0, stdout="fix: first\n\x00feat: second\n\x00") with patch("main.subprocess.run", return_value=mock_result) as mock_run: - main.run_default_checks( - ["--message", "--branch", "--author-name"], io.StringIO() - ) - called_cmd = mock_run.call_args[0][0] + result = main.get_messages_from_merge_ref() + self.assertEqual(result, ["fix: first", "feat: second"]) self.assertEqual( - called_cmd, - ["commit-check", "--message", "--branch", "--author-name"], + mock_run.call_args[0][0], + ["git", "log", "--pretty=format:%B%x00", "--reverse", "HEAD^1..HEAD^2"], ) - def test_prints_command(self): - mock_result = MagicMock(returncode=0) - with patch("main.subprocess.run", return_value=mock_result): - with patch("builtins.print") as mock_print: - main.run_default_checks(["--branch"], io.StringIO()) - mock_print.assert_called_once_with("commit-check --branch") - - def test_empty_args(self): - mock_result = MagicMock(returncode=0) + def test_get_messages_from_head_ref(self): + mock_result = MagicMock(returncode=0, stdout="fix: first\n\x00") with patch("main.subprocess.run", return_value=mock_result) as mock_run: - main.run_default_checks([], io.StringIO()) - called_cmd = mock_run.call_args[0][0] - self.assertEqual(called_cmd, ["commit-check"]) + result = main.get_messages_from_head_ref("main") + self.assertEqual(result, ["fix: first"]) + self.assertEqual( + mock_run.call_args[0][0], + [ + "git", + "log", + "--pretty=format:%B%x00", + "--reverse", + "origin/main..HEAD", + ], + ) class TestRunCommitCheck(unittest.TestCase): def setUp(self): - # Ensure result.txt is written to a temp location self._orig_dir = os.getcwd() import tempfile @@ -170,75 +229,74 @@ def tearDown(self): def test_pr_path_calls_pr_message_checks(self): with ( - patch("main.MESSAGE", "true"), - patch("main.BRANCH", "false"), - patch("main.AUTHOR_NAME", "false"), - patch("main.AUTHOR_EMAIL", "false"), + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", False), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), patch("main.get_pr_commit_messages", return_value=["fix: something"]), patch("main.run_pr_message_checks", return_value=0) as mock_pr, patch("main.run_other_checks", return_value=0), - patch("main.run_default_checks") as mock_default, + patch("main.run_check_command") as mock_command, ): rc = main.run_commit_check() - mock_pr.assert_called_once() - mock_default.assert_not_called() self.assertEqual(rc, 0) + mock_pr.assert_called_once() + mock_command.assert_not_called() - def test_pr_path_rc_accumulation(self): + def test_pr_path_returns_nonzero_when_any_check_fails(self): with ( - patch("main.MESSAGE", "true"), - patch("main.BRANCH", "true"), - patch("main.AUTHOR_NAME", "false"), - patch("main.AUTHOR_EMAIL", "false"), + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), patch("main.get_pr_commit_messages", return_value=["bad msg"]), patch("main.run_pr_message_checks", return_value=1), patch("main.run_other_checks", return_value=1), ): rc = main.run_commit_check() - self.assertEqual(rc, 2) + self.assertEqual(rc, 1) - def test_non_pr_path_uses_default_checks(self): + def test_non_pr_path_uses_direct_command(self): with ( - patch("main.MESSAGE", "true"), - patch("main.BRANCH", "false"), - patch("main.AUTHOR_NAME", "false"), - patch("main.AUTHOR_EMAIL", "false"), + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", False), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), patch("main.get_pr_commit_messages", return_value=[]), patch("main.run_pr_message_checks") as mock_pr, - patch("main.run_default_checks", return_value=0) as mock_default, + patch("main.run_check_command", return_value=0) as mock_command, ): rc = main.run_commit_check() - mock_pr.assert_not_called() - mock_default.assert_called_once() self.assertEqual(rc, 0) + mock_pr.assert_not_called() + mock_command.assert_called_once() - def test_message_false_uses_default_checks(self): + def test_message_disabled_uses_direct_command(self): with ( - patch("main.MESSAGE", "false"), - patch("main.BRANCH", "true"), - patch("main.AUTHOR_NAME", "false"), - patch("main.AUTHOR_EMAIL", "false"), + patch("main.MESSAGE_ENABLED", False), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), patch("main.run_pr_message_checks") as mock_pr, - patch("main.run_default_checks", return_value=0) as mock_default, + patch("main.run_check_command", return_value=0) as mock_command, ): rc = main.run_commit_check() - mock_pr.assert_not_called() - mock_default.assert_called_once() self.assertEqual(rc, 0) + mock_pr.assert_not_called() + mock_command.assert_called_once() def test_result_txt_is_created(self): with ( - patch("main.MESSAGE", "false"), - patch("main.BRANCH", "false"), - patch("main.AUTHOR_NAME", "false"), - patch("main.AUTHOR_EMAIL", "false"), - patch("main.run_default_checks", return_value=0), + patch("main.MESSAGE_ENABLED", False), + patch("main.BRANCH_ENABLED", False), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), + patch("main.run_check_command", return_value=0), ): main.run_commit_check() self.assertTrue(os.path.exists(os.path.join(self._tmpdir, "result.txt"))) def test_other_args_excludes_message(self): - """When in PR path, run_other_checks must not receive --message.""" captured_args = [] def fake_other_checks(args, result_file): @@ -246,10 +304,10 @@ def fake_other_checks(args, result_file): return 0 with ( - patch("main.MESSAGE", "true"), - patch("main.BRANCH", "true"), - patch("main.AUTHOR_NAME", "false"), - patch("main.AUTHOR_EMAIL", "false"), + patch("main.MESSAGE_ENABLED", True), + patch("main.BRANCH_ENABLED", True), + patch("main.AUTHOR_NAME_ENABLED", False), + patch("main.AUTHOR_EMAIL_ENABLED", False), patch("main.get_pr_commit_messages", return_value=["fix: x"]), patch("main.run_pr_message_checks", return_value=0), patch("main.run_other_checks", side_effect=fake_other_checks), @@ -259,54 +317,6 @@ def fake_other_checks(args, result_file): self.assertIn("--branch", captured_args) -class TestGetPrCommitMessages(unittest.TestCase): - def test_non_pr_event_returns_empty(self): - with patch.dict(os.environ, {"GITHUB_EVENT_NAME": "push"}): - result = main.get_pr_commit_messages() - self.assertEqual(result, []) - - def test_pr_event_with_commits(self): - mock_result = MagicMock() - mock_result.returncode = 0 - mock_result.stdout = "fix: first\n\x00feat: second\n\x00" - with ( - patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), - patch("main.subprocess.run", return_value=mock_result), - ): - result = main.get_pr_commit_messages() - self.assertEqual(result, ["fix: first", "feat: second"]) - - def test_pr_event_empty_output(self): - mock_result = MagicMock() - mock_result.returncode = 0 - mock_result.stdout = "" - with ( - patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), - patch("main.subprocess.run", return_value=mock_result), - ): - result = main.get_pr_commit_messages() - self.assertEqual(result, []) - - def test_git_failure_returns_empty(self): - mock_result = MagicMock() - mock_result.returncode = 1 - mock_result.stdout = "" - with ( - patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), - patch("main.subprocess.run", return_value=mock_result), - ): - result = main.get_pr_commit_messages() - self.assertEqual(result, []) - - def test_exception_returns_empty(self): - with ( - patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), - patch("main.subprocess.run", side_effect=Exception("git not found")), - ): - result = main.get_pr_commit_messages() - self.assertEqual(result, []) - - class TestReadResultFile(unittest.TestCase): def setUp(self): import tempfile @@ -319,28 +329,30 @@ def tearDown(self): os.chdir(self._orig_dir) def _write_result(self, content: str): - with open("result.txt", "w", encoding="utf-8") as f: - f.write(content) + with open("result.txt", "w", encoding="utf-8") as file_obj: + file_obj.write(content) def test_empty_file_returns_none(self): self._write_result("") - result = main.read_result_file() - self.assertIsNone(result) + self.assertIsNone(main.read_result_file()) def test_file_with_content(self): self._write_result("some output\n") - result = main.read_result_file() - self.assertEqual(result, "some output") + self.assertEqual(main.read_result_file(), "some output") def test_ansi_codes_are_stripped(self): self._write_result("\x1b[31mError\x1b[0m: bad commit") - result = main.read_result_file() - self.assertEqual(result, "Error: bad commit") + self.assertEqual(main.read_result_file(), "Error: bad commit") + - def test_trailing_whitespace_stripped(self): - self._write_result("output\n\n") - result = main.read_result_file() - self.assertEqual(result, "output") +class TestBuildResultBody(unittest.TestCase): + def test_success_body(self): + self.assertEqual(main.build_result_body(None), main.SUCCESS_TITLE) + + def test_failure_body(self): + result = main.build_result_body("bad commit") + self.assertIn(main.FAILURE_TITLE, result) + self.assertIn("bad commit", result) class TestAddJobSummary(unittest.TestCase): @@ -350,7 +362,6 @@ def setUp(self): self._orig_dir = os.getcwd() self._tmpdir = tempfile.mkdtemp() os.chdir(self._tmpdir) - # Create an empty result.txt with open("result.txt", "w", encoding="utf-8"): pass @@ -358,34 +369,34 @@ def tearDown(self): os.chdir(self._orig_dir) def test_false_skips(self): - with patch("main.JOB_SUMMARY", "false"): + with patch("main.JOB_SUMMARY_ENABLED", False): rc = main.add_job_summary() self.assertEqual(rc, 0) def test_success_writes_success_title(self): summary_path = os.path.join(self._tmpdir, "summary.txt") with ( - patch("main.JOB_SUMMARY", "true"), + patch("main.JOB_SUMMARY_ENABLED", True), patch("main.GITHUB_STEP_SUMMARY", summary_path), patch("main.read_result_file", return_value=None), ): rc = main.add_job_summary() self.assertEqual(rc, 0) - with open(summary_path, encoding="utf-8") as f: - content = f.read() + with open(summary_path, encoding="utf-8") as file_obj: + content = file_obj.read() self.assertIn(main.SUCCESS_TITLE, content) def test_failure_writes_failure_title(self): summary_path = os.path.join(self._tmpdir, "summary.txt") with ( - patch("main.JOB_SUMMARY", "true"), + patch("main.JOB_SUMMARY_ENABLED", True), patch("main.GITHUB_STEP_SUMMARY", summary_path), patch("main.read_result_file", return_value="bad commit message"), ): rc = main.add_job_summary() self.assertEqual(rc, 1) - with open(summary_path, encoding="utf-8") as f: - content = f.read() + with open(summary_path, encoding="utf-8") as file_obj: + content = file_obj.read() self.assertIn(main.FAILURE_TITLE, content) self.assertIn("bad commit message", content) @@ -393,7 +404,6 @@ def test_failure_writes_failure_title(self): class TestIsForkPr(unittest.TestCase): def test_no_event_path(self): with patch.dict(os.environ, {}, clear=True): - # Remove GITHUB_EVENT_PATH if present os.environ.pop("GITHUB_EVENT_PATH", None) result = main.is_fork_pr() self.assertFalse(result) @@ -407,9 +417,9 @@ def test_same_repo_not_fork(self): "base": {"repo": {"full_name": "owner/repo"}}, } } - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - json.dump(event, f) - event_path = f.name + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file_obj: + json.dump(event, file_obj) + event_path = file_obj.name with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): result = main.is_fork_pr() self.assertFalse(result) @@ -424,25 +434,14 @@ def test_different_repo_is_fork(self): "base": {"repo": {"full_name": "owner/repo"}}, } } - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - json.dump(event, f) - event_path = f.name + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file_obj: + json.dump(event, file_obj) + event_path = file_obj.name with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): result = main.is_fork_pr() self.assertTrue(result) os.unlink(event_path) - def test_json_parse_failure_returns_false(self): - import tempfile - - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - f.write("not valid json{{{") - event_path = f.name - with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): - result = main.is_fork_pr() - self.assertFalse(result) - os.unlink(event_path) - class TestLogErrorAndExit(unittest.TestCase): def test_exits_with_specified_code(self): @@ -450,38 +449,16 @@ def test_exits_with_specified_code(self): main.log_error_and_exit("# Title", None, 0) self.assertEqual(ctx.exception.code, 0) - def test_exits_with_nonzero_code(self): - with self.assertRaises(SystemExit) as ctx: - main.log_error_and_exit("# Title", None, 2) - self.assertEqual(ctx.exception.code, 2) - def test_with_result_text_prints_error(self): with ( patch("builtins.print") as mock_print, self.assertRaises(SystemExit), ): main.log_error_and_exit("# Failure", "bad commit", 1) - mock_print.assert_called_once() printed = mock_print.call_args[0][0] self.assertIn("::error::", printed) self.assertIn("bad commit", printed) - def test_without_result_text_no_print(self): - with ( - patch("builtins.print") as mock_print, - self.assertRaises(SystemExit), - ): - main.log_error_and_exit("# Failure", None, 1) - mock_print.assert_not_called() - - def test_empty_string_result_text_no_print(self): - with ( - patch("builtins.print") as mock_print, - self.assertRaises(SystemExit), - ): - main.log_error_and_exit("# Failure", "", 1) - mock_print.assert_not_called() - class TestMain(unittest.TestCase): def setUp(self): @@ -502,20 +479,20 @@ def test_success_path(self): patch("main.run_commit_check", return_value=0), patch("main.add_job_summary", return_value=0), patch("main.add_pr_comments", return_value=0), - patch("main.DRY_RUN", "false"), + patch("main.DRY_RUN_ENABLED", False), patch("main.read_result_file", return_value=None), self.assertRaises(SystemExit) as ctx, ): main.main() self.assertEqual(ctx.exception.code, 0) - def test_failure_path(self): + def test_multiple_failures_still_exit_with_one(self): with ( patch("main.log_env_vars"), patch("main.run_commit_check", return_value=1), - patch("main.add_job_summary", return_value=0), - patch("main.add_pr_comments", return_value=0), - patch("main.DRY_RUN", "false"), + patch("main.add_job_summary", return_value=1), + patch("main.add_pr_comments", return_value=1), + patch("main.DRY_RUN_ENABLED", False), patch("main.read_result_file", return_value="bad msg"), self.assertRaises(SystemExit) as ctx, ): @@ -528,29 +505,13 @@ def test_dry_run_forces_zero(self): patch("main.run_commit_check", return_value=1), patch("main.add_job_summary", return_value=1), patch("main.add_pr_comments", return_value=0), - patch("main.DRY_RUN", "true"), + patch("main.DRY_RUN_ENABLED", True), patch("main.read_result_file", return_value=None), self.assertRaises(SystemExit) as ctx, ): main.main() self.assertEqual(ctx.exception.code, 0) - def test_all_subfunctions_called(self): - with ( - patch("main.log_env_vars") as mock_log, - patch("main.run_commit_check", return_value=0) as mock_run, - patch("main.add_job_summary", return_value=0) as mock_summary, - patch("main.add_pr_comments", return_value=0) as mock_comments, - patch("main.DRY_RUN", "false"), - patch("main.read_result_file", return_value=None), - self.assertRaises(SystemExit), - ): - main.main() - mock_log.assert_called_once() - mock_run.assert_called_once() - mock_summary.assert_called_once() - mock_comments.assert_called_once() - if __name__ == "__main__": unittest.main() From 54dfe610a3d40dba998ce07c065ff7cecfe0b99c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:11:55 +0000 Subject: [PATCH 13/20] chore: auto fixes from pre-commit.com hooks --- main.py | 7 ++++--- main_test.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 74bf03f..f5cc488 100755 --- a/main.py +++ b/main.py @@ -150,9 +150,10 @@ def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: for index, msg in enumerate(pr_messages, start=1): subject = msg.splitlines()[0] if msg else "" result_file.write(f"\n--- Commit {index}/{total_messages}: {subject}\n") - has_failure = run_check_command( - ["--message"], result_file, input_text=msg - ) != 0 or has_failure + has_failure = ( + run_check_command(["--message"], result_file, input_text=msg) != 0 + or has_failure + ) return 1 if has_failure else 0 diff --git a/main_test.py b/main_test.py index 43480b0..d3eb676 100644 --- a/main_test.py +++ b/main_test.py @@ -182,7 +182,9 @@ def test_falls_back_to_base_ref_when_merge_ref_is_unavailable(self): def test_exception_returns_empty(self): with ( patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), - patch("main.get_messages_from_merge_ref", side_effect=Exception("git failed")), + patch( + "main.get_messages_from_merge_ref", side_effect=Exception("git failed") + ), ): result = main.get_pr_commit_messages() self.assertEqual(result, []) @@ -190,7 +192,9 @@ def test_exception_returns_empty(self): class TestGitMessageReaders(unittest.TestCase): def test_get_messages_from_merge_ref(self): - mock_result = MagicMock(returncode=0, stdout="fix: first\n\x00feat: second\n\x00") + mock_result = MagicMock( + returncode=0, stdout="fix: first\n\x00feat: second\n\x00" + ) with patch("main.subprocess.run", return_value=mock_result) as mock_run: result = main.get_messages_from_merge_ref() self.assertEqual(result, ["fix: first", "feat: second"]) @@ -417,7 +421,9 @@ def test_same_repo_not_fork(self): "base": {"repo": {"full_name": "owner/repo"}}, } } - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file_obj: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as file_obj: json.dump(event, file_obj) event_path = file_obj.name with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): @@ -434,7 +440,9 @@ def test_different_repo_is_fork(self): "base": {"repo": {"full_name": "owner/repo"}}, } } - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file_obj: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as file_obj: json.dump(event, file_obj) event_path = file_obj.name with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): From 388b0b100636c87739f519bdc22b4b4272f31ca1 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 20 Apr 2026 19:11:08 +0300 Subject: [PATCH 14/20] refactor: clean up PR commit validation flow --- main.py | 26 +++++++++++++++++--------- main_test.py | 45 +++++++++++++++++++++++---------------------- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/main.py b/main.py index f5cc488..d5775a5 100755 --- a/main.py +++ b/main.py @@ -55,9 +55,9 @@ def is_pr_event() -> bool: def parse_commit_messages(output: str) -> list[str]: """Split git log output into individual commit messages.""" return [ - message.rstrip("\n") + message.strip("\n") for message in output.split(COMMIT_MESSAGE_DELIMITER) - if message.rstrip("\n") + if message.strip("\n") ] @@ -124,7 +124,10 @@ def get_pr_commit_messages() -> list[str]: def run_check_command( - args: list[str], result_file: TextIO, input_text: str | None = None + args: list[str], + result_file: TextIO, + input_text: str | None = None, + output_prefix: str | None = None, ) -> int: """Run commit-check and write both stdout and stderr to the result file.""" command = ["commit-check"] + args @@ -132,11 +135,15 @@ def run_check_command( result = subprocess.run( command, input=input_text, - stdout=result_file, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=False, ) + if result.stdout: + if output_prefix: + result_file.write(output_prefix) + result_file.write(result.stdout) return result.returncode @@ -149,11 +156,12 @@ def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: total_messages = len(pr_messages) for index, msg in enumerate(pr_messages, start=1): subject = msg.splitlines()[0] if msg else "" - result_file.write(f"\n--- Commit {index}/{total_messages}: {subject}\n") - has_failure = ( - run_check_command(["--message"], result_file, input_text=msg) != 0 - or has_failure - ) + has_failure = run_check_command( + ["--message"], + result_file, + input_text=msg, + output_prefix=f"\n--- Commit {index}/{total_messages}: {subject}\n", + ) != 0 or has_failure return 1 if has_failure else 0 diff --git a/main_test.py b/main_test.py index d3eb676..1fb5a24 100644 --- a/main_test.py +++ b/main_test.py @@ -62,28 +62,28 @@ def test_message_and_branch(self): class TestParseCommitMessages(unittest.TestCase): - def test_splits_messages_and_trims_trailing_newlines(self): - result = main.parse_commit_messages("fix: first\n\x00feat: second\n\n\x00") + def test_splits_messages_and_trims_surrounding_newlines(self): + result = main.parse_commit_messages("\nfix: first\n\x00\nfeat: second\n\n\x00") self.assertEqual(result, ["fix: first", "feat: second"]) class TestRunCheckCommand(unittest.TestCase): def test_with_args_calls_subprocess(self): - mock_result = MagicMock(returncode=0) + mock_result = MagicMock(returncode=0, stdout="") with patch("main.subprocess.run", return_value=mock_result) as mock_run: rc = main.run_check_command(["--branch"], io.StringIO()) self.assertEqual(rc, 0) self.assertEqual(mock_run.call_args[0][0], ["commit-check", "--branch"]) def test_with_input_uses_text_mode(self): - mock_result = MagicMock(returncode=0) + mock_result = MagicMock(returncode=0, stdout="") with patch("main.subprocess.run", return_value=mock_result) as mock_run: main.run_check_command(["--message"], io.StringIO(), input_text="fix: demo") self.assertEqual(mock_run.call_args[1]["input"], "fix: demo") self.assertTrue(mock_run.call_args[1]["text"]) def test_prints_command(self): - mock_result = MagicMock(returncode=0) + mock_result = MagicMock(returncode=0, stdout="") with patch("main.subprocess.run", return_value=mock_result): with patch("builtins.print") as mock_print: main.run_check_command(["--branch"], io.StringIO()) @@ -92,20 +92,29 @@ def test_prints_command(self): class TestRunPrMessageChecks(unittest.TestCase): def test_single_message_pass(self): - mock_result = MagicMock(returncode=0) + mock_result = MagicMock(returncode=0, stdout="") result_file = io.StringIO() with patch("main.subprocess.run", return_value=mock_result) as mock_run: rc = main.run_pr_message_checks(["fix: something"], result_file) self.assertEqual(rc, 0) self.assertEqual(mock_run.call_args[0][0], ["commit-check", "--message"]) self.assertEqual(mock_run.call_args[1]["input"], "fix: something") + self.assertEqual(result_file.getvalue(), "") + + def test_failed_message_writes_header_and_output(self): + mock_result = MagicMock(returncode=1, stdout="Commit rejected.\n") + result_file = io.StringIO() + with patch("main.subprocess.run", return_value=mock_result): + rc = main.run_pr_message_checks(["fix: something"], result_file) + self.assertEqual(rc, 1) self.assertIn("--- Commit 1/1: fix: something", result_file.getvalue()) + self.assertIn("Commit rejected.", result_file.getvalue()) def test_multiple_messages_partial_failure(self): results = [ - MagicMock(returncode=0), - MagicMock(returncode=1), - MagicMock(returncode=0), + MagicMock(returncode=0, stdout=""), + MagicMock(returncode=1, stdout="Commit rejected.\n"), + MagicMock(returncode=0, stdout=""), ] with patch("main.subprocess.run", side_effect=results): rc = main.run_pr_message_checks(["ok", "bad", "ok"], io.StringIO()) @@ -126,7 +135,7 @@ def test_empty_args_returns_zero(self): mock_run.assert_not_called() def test_with_args_returns_returncode(self): - mock_result = MagicMock(returncode=1) + mock_result = MagicMock(returncode=1, stdout="branch check failed\n") with patch("main.subprocess.run", return_value=mock_result): rc = main.run_other_checks(["--branch", "--author-name"], io.StringIO()) self.assertEqual(rc, 1) @@ -182,9 +191,7 @@ def test_falls_back_to_base_ref_when_merge_ref_is_unavailable(self): def test_exception_returns_empty(self): with ( patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), - patch( - "main.get_messages_from_merge_ref", side_effect=Exception("git failed") - ), + patch("main.get_messages_from_merge_ref", side_effect=Exception("git failed")), ): result = main.get_pr_commit_messages() self.assertEqual(result, []) @@ -192,9 +199,7 @@ def test_exception_returns_empty(self): class TestGitMessageReaders(unittest.TestCase): def test_get_messages_from_merge_ref(self): - mock_result = MagicMock( - returncode=0, stdout="fix: first\n\x00feat: second\n\x00" - ) + mock_result = MagicMock(returncode=0, stdout="fix: first\n\x00feat: second\n\x00") with patch("main.subprocess.run", return_value=mock_result) as mock_run: result = main.get_messages_from_merge_ref() self.assertEqual(result, ["fix: first", "feat: second"]) @@ -421,9 +426,7 @@ def test_same_repo_not_fork(self): "base": {"repo": {"full_name": "owner/repo"}}, } } - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as file_obj: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file_obj: json.dump(event, file_obj) event_path = file_obj.name with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): @@ -440,9 +443,7 @@ def test_different_repo_is_fork(self): "base": {"repo": {"full_name": "owner/repo"}}, } } - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as file_obj: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file_obj: json.dump(event, file_obj) event_path = file_obj.name with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): From 0cf3f7742e0ff17ba7fc59a85e736b1ca4df7857 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:21:17 +0000 Subject: [PATCH 15/20] chore: auto fixes from pre-commit.com hooks --- main.py | 16 ++++++++++------ main_test.py | 16 ++++++++++++---- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index d5775a5..c50d898 100755 --- a/main.py +++ b/main.py @@ -156,12 +156,16 @@ def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: total_messages = len(pr_messages) for index, msg in enumerate(pr_messages, start=1): subject = msg.splitlines()[0] if msg else "" - has_failure = run_check_command( - ["--message"], - result_file, - input_text=msg, - output_prefix=f"\n--- Commit {index}/{total_messages}: {subject}\n", - ) != 0 or has_failure + has_failure = ( + run_check_command( + ["--message"], + result_file, + input_text=msg, + output_prefix=f"\n--- Commit {index}/{total_messages}: {subject}\n", + ) + != 0 + or has_failure + ) return 1 if has_failure else 0 diff --git a/main_test.py b/main_test.py index 1fb5a24..658d226 100644 --- a/main_test.py +++ b/main_test.py @@ -191,7 +191,9 @@ def test_falls_back_to_base_ref_when_merge_ref_is_unavailable(self): def test_exception_returns_empty(self): with ( patch.dict(os.environ, {"GITHUB_EVENT_NAME": "pull_request"}), - patch("main.get_messages_from_merge_ref", side_effect=Exception("git failed")), + patch( + "main.get_messages_from_merge_ref", side_effect=Exception("git failed") + ), ): result = main.get_pr_commit_messages() self.assertEqual(result, []) @@ -199,7 +201,9 @@ def test_exception_returns_empty(self): class TestGitMessageReaders(unittest.TestCase): def test_get_messages_from_merge_ref(self): - mock_result = MagicMock(returncode=0, stdout="fix: first\n\x00feat: second\n\x00") + mock_result = MagicMock( + returncode=0, stdout="fix: first\n\x00feat: second\n\x00" + ) with patch("main.subprocess.run", return_value=mock_result) as mock_run: result = main.get_messages_from_merge_ref() self.assertEqual(result, ["fix: first", "feat: second"]) @@ -426,7 +430,9 @@ def test_same_repo_not_fork(self): "base": {"repo": {"full_name": "owner/repo"}}, } } - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file_obj: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as file_obj: json.dump(event, file_obj) event_path = file_obj.name with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): @@ -443,7 +449,9 @@ def test_different_repo_is_fork(self): "base": {"repo": {"full_name": "owner/repo"}}, } } - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as file_obj: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as file_obj: json.dump(event, file_obj) event_path = file_obj.name with patch.dict(os.environ, {"GITHUB_EVENT_PATH": event_path}): From b8961dba154ef642b63bbab49596b7f98ed6657a Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 20 Apr 2026 23:07:44 +0300 Subject: [PATCH 16/20] feat: enhance commit message checks with section separators and no-banner option --- main.py | 16 ++++++++++++++-- main_test.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index c50d898..628012d 100755 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ SUCCESS_TITLE = "# Commit-Check ✔️" FAILURE_TITLE = "# Commit-Check ❌" COMMIT_MESSAGE_DELIMITER = "\x00" +COMMIT_SECTION_SEPARATOR = "\n" + ("-" * 72) + "\n" # Environment variables MESSAGE = os.getenv("MESSAGE", "false") @@ -156,12 +157,23 @@ def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: total_messages = len(pr_messages) for index, msg in enumerate(pr_messages, start=1): subject = msg.splitlines()[0] if msg else "" + command_args = ["--message"] + if index > 1: + command_args.append("--no-banner") + + output_prefix = f"\n--- Commit {index}/{total_messages}: {subject}\n" + if index > 1: + output_prefix = ( + f"{COMMIT_SECTION_SEPARATOR}" + f"--- Commit {index}/{total_messages}: {subject}\n" + ) + has_failure = ( run_check_command( - ["--message"], + command_args, result_file, input_text=msg, - output_prefix=f"\n--- Commit {index}/{total_messages}: {subject}\n", + output_prefix=output_prefix, ) != 0 or has_failure diff --git a/main_test.py b/main_test.py index 658d226..a20a8cb 100644 --- a/main_test.py +++ b/main_test.py @@ -126,6 +126,39 @@ def test_empty_list(self): self.assertEqual(rc, 0) mock_run.assert_not_called() + def test_second_and_later_messages_use_no_banner(self): + results = [ + MagicMock(returncode=1, stdout="Commit rejected.\n"), + MagicMock(returncode=1, stdout="Type subject_imperative check failed\n"), + ] + with patch("main.subprocess.run", side_effect=results) as mock_run: + main.run_pr_message_checks(["bad first", "bad second"], io.StringIO()) + + self.assertEqual( + mock_run.call_args_list[0][0][0], ["commit-check", "--message"] + ) + self.assertEqual( + mock_run.call_args_list[1][0][0], + ["commit-check", "--message", "--no-banner"], + ) + + def test_second_message_prefix_uses_separator(self): + results = [ + MagicMock(returncode=1, stdout="Commit rejected.\n"), + MagicMock(returncode=1, stdout="Type subject_imperative check failed\n"), + ] + result_file = io.StringIO() + with patch("main.subprocess.run", side_effect=results): + main.run_pr_message_checks(["bad first", "bad second"], result_file) + + output = result_file.getvalue() + self.assertIn("\n--- Commit 1/2: bad first\nCommit rejected.\n", output) + self.assertIn( + f"{main.COMMIT_SECTION_SEPARATOR}--- Commit 2/2: bad second\n", + output, + ) + self.assertIn("Type subject_imperative check failed\n", output) + class TestRunOtherChecks(unittest.TestCase): def test_empty_args_returns_zero(self): From a37959385b04eff6cdb1de6d71a50f7347d543a9 Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 20 Apr 2026 23:10:55 +0300 Subject: [PATCH 17/20] fix: update commit message formatting to improve output clarity and consistency --- main.py | 29 +++++++++++++++-------------- main_test.py | 30 ++++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index 628012d..7a54927 100755 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ SUCCESS_TITLE = "# Commit-Check ✔️" FAILURE_TITLE = "# Commit-Check ❌" COMMIT_MESSAGE_DELIMITER = "\x00" -COMMIT_SECTION_SEPARATOR = "\n" + ("-" * 72) + "\n" +COMMIT_SECTION_SEPARATOR = "\n---\n" # Environment variables MESSAGE = os.getenv("MESSAGE", "false") @@ -144,7 +144,8 @@ def run_check_command( if result.stdout: if output_prefix: result_file.write(output_prefix) - result_file.write(result.stdout) + result_file.write(result.stdout.rstrip("\n")) + result_file.write("\n") return result.returncode @@ -154,30 +155,30 @@ def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: Returns 1 if any message fails, 0 if all pass. """ has_failure = False + emitted_failure_output = False total_messages = len(pr_messages) for index, msg in enumerate(pr_messages, start=1): subject = msg.splitlines()[0] if msg else "" command_args = ["--message"] - if index > 1: + if emitted_failure_output: command_args.append("--no-banner") - output_prefix = f"\n--- Commit {index}/{total_messages}: {subject}\n" - if index > 1: + output_prefix = f"--- Commit {index}/{total_messages}: {subject}\n" + if emitted_failure_output: output_prefix = ( f"{COMMIT_SECTION_SEPARATOR}" f"--- Commit {index}/{total_messages}: {subject}\n" ) - has_failure = ( - run_check_command( - command_args, - result_file, - input_text=msg, - output_prefix=output_prefix, - ) - != 0 - or has_failure + return_code = run_check_command( + command_args, + result_file, + input_text=msg, + output_prefix=output_prefix, ) + if return_code != 0: + has_failure = True + emitted_failure_output = True return 1 if has_failure else 0 diff --git a/main_test.py b/main_test.py index a20a8cb..fc15e49 100644 --- a/main_test.py +++ b/main_test.py @@ -126,38 +126,52 @@ def test_empty_list(self): self.assertEqual(rc, 0) mock_run.assert_not_called() - def test_second_and_later_messages_use_no_banner(self): + def test_first_failure_keeps_banner_and_later_failures_use_no_banner(self): results = [ + MagicMock(returncode=0, stdout=""), MagicMock(returncode=1, stdout="Commit rejected.\n"), MagicMock(returncode=1, stdout="Type subject_imperative check failed\n"), ] with patch("main.subprocess.run", side_effect=results) as mock_run: - main.run_pr_message_checks(["bad first", "bad second"], io.StringIO()) + main.run_pr_message_checks(["ok first", "bad second", "bad third"], io.StringIO()) self.assertEqual( mock_run.call_args_list[0][0][0], ["commit-check", "--message"] ) self.assertEqual( mock_run.call_args_list[1][0][0], + ["commit-check", "--message"], + ) + self.assertEqual( + mock_run.call_args_list[2][0][0], ["commit-check", "--message", "--no-banner"], ) - def test_second_message_prefix_uses_separator(self): + def test_later_failure_prefix_uses_short_separator_without_extra_blank_lines(self): results = [ + MagicMock(returncode=0, stdout=""), MagicMock(returncode=1, stdout="Commit rejected.\n"), - MagicMock(returncode=1, stdout="Type subject_imperative check failed\n"), + MagicMock( + returncode=1, + stdout=( + "Type subject_imperative check failed ==> bad third\n" + "Commit message should use imperative mood\n" + "Suggest: Use imperative mood\n\n" + ), + ), ] result_file = io.StringIO() with patch("main.subprocess.run", side_effect=results): - main.run_pr_message_checks(["bad first", "bad second"], result_file) + main.run_pr_message_checks(["ok first", "bad second", "bad third"], result_file) output = result_file.getvalue() - self.assertIn("\n--- Commit 1/2: bad first\nCommit rejected.\n", output) + self.assertIn("--- Commit 2/3: bad second\nCommit rejected.\n", output) self.assertIn( - f"{main.COMMIT_SECTION_SEPARATOR}--- Commit 2/2: bad second\n", + f"{main.COMMIT_SECTION_SEPARATOR}--- Commit 3/3: bad third\n", output, ) - self.assertIn("Type subject_imperative check failed\n", output) + self.assertNotIn("------------------------------------------------------------------------", output) + self.assertNotIn("\n\n\n", output) class TestRunOtherChecks(unittest.TestCase): From dbd630563bd15b8b1c089b60635f047a88829d51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Apr 2026 20:11:10 +0000 Subject: [PATCH 18/20] chore: auto fixes from pre-commit.com hooks --- main_test.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/main_test.py b/main_test.py index fc15e49..da19166 100644 --- a/main_test.py +++ b/main_test.py @@ -133,7 +133,9 @@ def test_first_failure_keeps_banner_and_later_failures_use_no_banner(self): MagicMock(returncode=1, stdout="Type subject_imperative check failed\n"), ] with patch("main.subprocess.run", side_effect=results) as mock_run: - main.run_pr_message_checks(["ok first", "bad second", "bad third"], io.StringIO()) + main.run_pr_message_checks( + ["ok first", "bad second", "bad third"], io.StringIO() + ) self.assertEqual( mock_run.call_args_list[0][0][0], ["commit-check", "--message"] @@ -162,7 +164,9 @@ def test_later_failure_prefix_uses_short_separator_without_extra_blank_lines(sel ] result_file = io.StringIO() with patch("main.subprocess.run", side_effect=results): - main.run_pr_message_checks(["ok first", "bad second", "bad third"], result_file) + main.run_pr_message_checks( + ["ok first", "bad second", "bad third"], result_file + ) output = result_file.getvalue() self.assertIn("--- Commit 2/3: bad second\nCommit rejected.\n", output) @@ -170,7 +174,10 @@ def test_later_failure_prefix_uses_short_separator_without_extra_blank_lines(sel f"{main.COMMIT_SECTION_SEPARATOR}--- Commit 3/3: bad third\n", output, ) - self.assertNotIn("------------------------------------------------------------------------", output) + self.assertNotIn( + "------------------------------------------------------------------------", + output, + ) self.assertNotIn("\n\n\n", output) From 94942c69f8aa989a346acb5f863bb1b12dc0ac8e Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 20 Apr 2026 23:16:42 +0300 Subject: [PATCH 19/20] fix: simplify output handling in PR message checks and update test assertions --- main.py | 11 ++--------- main_test.py | 7 +++---- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 7a54927..cb90e34 100755 --- a/main.py +++ b/main.py @@ -156,19 +156,12 @@ def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: """ has_failure = False emitted_failure_output = False - total_messages = len(pr_messages) - for index, msg in enumerate(pr_messages, start=1): - subject = msg.splitlines()[0] if msg else "" + for msg in pr_messages: command_args = ["--message"] if emitted_failure_output: command_args.append("--no-banner") - output_prefix = f"--- Commit {index}/{total_messages}: {subject}\n" - if emitted_failure_output: - output_prefix = ( - f"{COMMIT_SECTION_SEPARATOR}" - f"--- Commit {index}/{total_messages}: {subject}\n" - ) + output_prefix = COMMIT_SECTION_SEPARATOR if emitted_failure_output else None return_code = run_check_command( command_args, diff --git a/main_test.py b/main_test.py index da19166..a331831 100644 --- a/main_test.py +++ b/main_test.py @@ -101,13 +101,12 @@ def test_single_message_pass(self): self.assertEqual(mock_run.call_args[1]["input"], "fix: something") self.assertEqual(result_file.getvalue(), "") - def test_failed_message_writes_header_and_output(self): + def test_failed_message_writes_output(self): mock_result = MagicMock(returncode=1, stdout="Commit rejected.\n") result_file = io.StringIO() with patch("main.subprocess.run", return_value=mock_result): rc = main.run_pr_message_checks(["fix: something"], result_file) self.assertEqual(rc, 1) - self.assertIn("--- Commit 1/1: fix: something", result_file.getvalue()) self.assertIn("Commit rejected.", result_file.getvalue()) def test_multiple_messages_partial_failure(self): @@ -169,9 +168,9 @@ def test_later_failure_prefix_uses_short_separator_without_extra_blank_lines(sel ) output = result_file.getvalue() - self.assertIn("--- Commit 2/3: bad second\nCommit rejected.\n", output) + self.assertIn("Commit rejected.\n", output) self.assertIn( - f"{main.COMMIT_SECTION_SEPARATOR}--- Commit 3/3: bad third\n", + f"{main.COMMIT_SECTION_SEPARATOR}Type subject_imperative check failed ==> bad third\n", output, ) self.assertNotIn( From 7216a50b4ad740ab1f4877a7d6edfc1b8f51abcb Mon Sep 17 00:00:00 2001 From: shenxianpeng Date: Mon, 20 Apr 2026 23:20:44 +0300 Subject: [PATCH 20/20] fix: update output formatting in PR message checks to include commit indexing --- main.py | 8 ++++++-- main_test.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index cb90e34..7f0f1e2 100755 --- a/main.py +++ b/main.py @@ -156,12 +156,16 @@ def run_pr_message_checks(pr_messages: list[str], result_file: TextIO) -> int: """ has_failure = False emitted_failure_output = False - for msg in pr_messages: + total = len(pr_messages) + for index, msg in enumerate(pr_messages, start=1): command_args = ["--message"] if emitted_failure_output: command_args.append("--no-banner") - output_prefix = COMMIT_SECTION_SEPARATOR if emitted_failure_output else None + if emitted_failure_output: + output_prefix = f"\n--- Commit {index}/{total}:\n" + else: + output_prefix = None return_code = run_check_command( command_args, diff --git a/main_test.py b/main_test.py index a331831..3a5d04c 100644 --- a/main_test.py +++ b/main_test.py @@ -170,7 +170,7 @@ def test_later_failure_prefix_uses_short_separator_without_extra_blank_lines(sel output = result_file.getvalue() self.assertIn("Commit rejected.\n", output) self.assertIn( - f"{main.COMMIT_SECTION_SEPARATOR}Type subject_imperative check failed ==> bad third\n", + "\n--- Commit 3/3:\nType subject_imperative check failed ==> bad third\n", output, ) self.assertNotIn(