diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index ab02cd552e..1221460800 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -51,6 +51,7 @@ min_version, optional_import, pprint_edges, + safe_eval, ) validate, _ = optional_import("jsonschema", name="validate") @@ -161,7 +162,7 @@ def _get_fake_spatial_shape(shape: Sequence[str | int], p: int = 1, n: int = 1, for c in _get_var_names(i): if c not in ["p", "n"]: raise ValueError(f"only support variables 'p' and 'n' so far, but got: {c}.") - ret.append(eval(i, {"p": p, "n": n})) + ret.append(safe_eval(i, {"p": p, "n": n})) else: raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.") return tuple(ret) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 3efc9b5e7f..d1a705205c 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -137,6 +137,7 @@ torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end, ) +from .safeeval import SAFE_TYPES, safe_eval from .state_cacher import StateCacher from .tf32 import detect_default_tf32, has_ampere_or_later from .type_conversion import ( diff --git a/monai/utils/ordering.py b/monai/utils/ordering.py index 1be61f98ab..6daf5d4582 100644 --- a/monai/utils/ordering.py +++ b/monai/utils/ordering.py @@ -148,7 +148,7 @@ def _order_template(self, template: np.ndarray) -> np.ndarray: else: rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2]) - sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths) + sequence = getattr(self, f"{self.ordering_type}_idx")(rows, columns, depths) ordering = np.array([template[tuple(e)] for e in sequence]) diff --git a/monai/utils/safeeval.py b/monai/utils/safeeval.py new file mode 100644 index 0000000000..48fddfa07f --- /dev/null +++ b/monai/utils/safeeval.py @@ -0,0 +1,73 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import ast +from collections.abc import Mapping, Sequence +from typing import Any + +__all__ = ["SAFE_TYPES", "safe_eval"] + +# default set of safe AST node types +SAFE_TYPES: Sequence[type] = ( + ast.Expression, + ast.Name, + ast.Load, + ast.Constant, + ast.BinOp, + ast.UnaryOp, + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.FloorDiv, + ast.Pow, + ast.Mod, + ast.USub, + ast.UAdd, +) + + +def safe_eval( + expr: str, + globals_vars: Mapping[str, Any] | None = None, + locals_vars: Mapping[str, object] | None = None, + allowed_types: Sequence[type] = SAFE_TYPES, +) -> Any: + """ + Evaluate the Python expression `expr` using `eval`, but only if it is a safe expression in that its parsed AST + contains nodes whose types are given in `allowed_types`. This ensures unsafe node types are excluded, if these + are present in the AST a ValueError is raised. The default set of such types in `SAFE_TYPES` ensures only + expressions with constants and names can be evaluated, so excludes attribute access, indexing, and calls. Code + injection is infeasible through such expressions, so this is a safe and secure way of evaluating simple expressions. + + Args: + expr: expression to evaluate, this will be stripped before parsing to avoid indentation complaints + globals_vars: global variable mapping, this will be treated as read-only for this function, unlike `eval` + locals_vars: local variable mapping + allowed_types: sequence of allowed AST types which can be found in `expr` when parsed + + Raises: + ValueError: raised when any node in the AST parsed from `expr` has a type not in `allowed_types` + + Returns: + The evaluated expression value, using `eval` with `globals_vars` and `locals_vars` + """ + parsed = ast.parse(expr.strip(), mode="eval") + + # collect nodes in the AST which aren't permitted and unparse them for inclusion in the exception message + disallowed = [ast.unparse(n) for n in ast.walk(parsed) if not isinstance(n, tuple(allowed_types))] + + if disallowed: + raise ValueError(f"Unsafe expression `{expr}` not evaluated, contains disallowed components: {disallowed}") + + return eval(expr, dict(globals_vars) if globals_vars else None, locals_vars) diff --git a/tests/utils/test_alias.py b/tests/utils/test_alias.py index e7abff3d89..8ec1f8ae00 100644 --- a/tests/utils/test_alias.py +++ b/tests/utils/test_alias.py @@ -23,7 +23,10 @@ class TestModuleAlias(unittest.TestCase): - """check that 'import monai.xx.file_name' returns a module""" + """ + Check that 'import monai.xx.file_name' returns a module. Note that this test will fail if a module has the same name + as a member of that module (or any other) which is imported in a `__init__.py` file. + """ def test_files(self): src_dir = os.path.dirname(TESTS_PATH) diff --git a/tests/utils/test_safe_eval.py b/tests/utils/test_safe_eval.py new file mode 100644 index 0000000000..bf0b0891c4 --- /dev/null +++ b/tests/utils/test_safe_eval.py @@ -0,0 +1,58 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import ast +import unittest + +from parameterized import parameterized + +from monai.utils import safe_eval + +GOOD_EXPRS = [ + ("1+2", None, None, 3), + (" 1 + 2 ", None, None, 3), + ("1+2+x", {"x": 4}, None, 7), + ("1+2+x", None, {"x": 4}, 7), + ("1*2+x", {"x": 4}, None, 6), + ("(1+2)*3", None, None, 9), + ("foo+bar", {"foo": 1030}, {"bar": 204}, 1234), +] + +BAD_EXPRS = [("foo()",), ("foo.bar",), ("foo[123]",), ("(1,2)",), ("[3,4]",), ("int.__class__.__init__.__globals__",)] + + +class TestSafeEval(unittest.TestCase): + @parameterized.expand(GOOD_EXPRS) + def test_good_exprs(self, expr, globals_vars, locals_vars, expected): + """Test valid expressions with globals/locals evaluate to correct values.""" + result = safe_eval(expr, globals_vars, locals_vars) + self.assertEqual(result, expected) + + @parameterized.expand(BAD_EXPRS) + def test_bad_exprs(self, expr): + """Test bad expressions correctly raise ValueError.""" + with self.assertRaises(ValueError): + safe_eval(expr) + + def test_allowed_types(self): + """Test restricting the allowed list of types.""" + allowed = [ast.Expression, ast.Constant, ast.BinOp, ast.Add] + result = safe_eval("1+2", allowed_types=allowed) + self.assertEqual(result, 3) + + with self.assertRaises(ValueError): + safe_eval("1*2", allowed_types=allowed) + + +if __name__ == "__main__": + unittest.main()