Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
min_version,
optional_import,
pprint_edges,
safe_eval,
)

validate, _ = optional_import("jsonschema", name="validate")
Expand Down Expand Up @@ -161,7 +162,7 @@ def _get_fake_spatial_shape(shape: Sequence[str | int], p: int = 1, n: int = 1,
for c in _get_var_names(i):
if c not in ["p", "n"]:
raise ValueError(f"only support variables 'p' and 'n' so far, but got: {c}.")
ret.append(eval(i, {"p": p, "n": n}))
ret.append(safe_eval(i, {"p": p, "n": n}))
Comment thread
ericspod marked this conversation as resolved.
else:
raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.")
return tuple(ret)
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
torch_profiler_time_cpu_gpu,
torch_profiler_time_end_to_end,
)
from .safeeval import SAFE_TYPES, safe_eval
from .state_cacher import StateCacher
from .tf32 import detect_default_tf32, has_ampere_or_later
from .type_conversion import (
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _order_template(self, template: np.ndarray) -> np.ndarray:
else:
rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2])

sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths)
sequence = getattr(self, f"{self.ordering_type}_idx")(rows, columns, depths)

ordering = np.array([template[tuple(e)] for e in sequence])

Expand Down
73 changes: 73 additions & 0 deletions monai/utils/safeeval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import ast
from collections.abc import Mapping, Sequence
from typing import Any

__all__ = ["SAFE_TYPES", "safe_eval"]

# default set of safe AST node types
SAFE_TYPES: Sequence[type] = (
ast.Expression,
ast.Name,
ast.Load,
ast.Constant,
ast.BinOp,
ast.UnaryOp,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.FloorDiv,
ast.Pow,
ast.Mod,
ast.USub,
ast.UAdd,
)


def safe_eval(
expr: str,
globals_vars: Mapping[str, Any] | None = None,
locals_vars: Mapping[str, object] | None = None,
allowed_types: Sequence[type] = SAFE_TYPES,
) -> Any:
"""
Evaluate the Python expression `expr` using `eval`, but only if it is a safe expression in that its parsed AST
contains nodes whose types are given in `allowed_types`. This ensures unsafe node types are excluded, if these
are present in the AST a ValueError is raised. The default set of such types in `SAFE_TYPES` ensures only
expressions with constants and names can be evaluated, so excludes attribute access, indexing, and calls. Code
injection is infeasible through such expressions, so this is a safe and secure way of evaluating simple expressions.

Args:
expr: expression to evaluate, this will be stripped before parsing to avoid indentation complaints
globals_vars: global variable mapping, this will be treated as read-only for this function, unlike `eval`
locals_vars: local variable mapping
allowed_types: sequence of allowed AST types which can be found in `expr` when parsed

Raises:
ValueError: raised when any node in the AST parsed from `expr` has a type not in `allowed_types`

Returns:
The evaluated expression value, using `eval` with `globals_vars` and `locals_vars`
"""
parsed = ast.parse(expr.strip(), mode="eval")

# collect nodes in the AST which aren't permitted and unparse them for inclusion in the exception message
disallowed = [ast.unparse(n) for n in ast.walk(parsed) if not isinstance(n, tuple(allowed_types))]

if disallowed:
raise ValueError(f"Unsafe expression `{expr}` not evaluated, contains disallowed components: {disallowed}")

return eval(expr, dict(globals_vars) if globals_vars else None, locals_vars)
Comment thread
ericspod marked this conversation as resolved.
5 changes: 4 additions & 1 deletion tests/utils/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@


class TestModuleAlias(unittest.TestCase):
"""check that 'import monai.xx.file_name' returns a module"""
"""
Check that 'import monai.xx.file_name' returns a module. Note that this test will fail if a module has the same name
as a member of that module (or any other) which is imported in a `__init__.py` file.
"""

def test_files(self):
src_dir = os.path.dirname(TESTS_PATH)
Expand Down
58 changes: 58 additions & 0 deletions tests/utils/test_safe_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import ast
import unittest

from parameterized import parameterized

from monai.utils import safe_eval

GOOD_EXPRS = [
("1+2", None, None, 3),
(" 1 + 2 ", None, None, 3),
("1+2+x", {"x": 4}, None, 7),
("1+2+x", None, {"x": 4}, 7),
("1*2+x", {"x": 4}, None, 6),
("(1+2)*3", None, None, 9),
("foo+bar", {"foo": 1030}, {"bar": 204}, 1234),
]

BAD_EXPRS = [("foo()",), ("foo.bar",), ("foo[123]",), ("(1,2)",), ("[3,4]",), ("int.__class__.__init__.__globals__",)]


class TestSafeEval(unittest.TestCase):
@parameterized.expand(GOOD_EXPRS)
def test_good_exprs(self, expr, globals_vars, locals_vars, expected):
"""Test valid expressions with globals/locals evaluate to correct values."""
result = safe_eval(expr, globals_vars, locals_vars)
self.assertEqual(result, expected)

@parameterized.expand(BAD_EXPRS)
def test_bad_exprs(self, expr):
"""Test bad expressions correctly raise ValueError."""
with self.assertRaises(ValueError):
safe_eval(expr)

def test_allowed_types(self):
"""Test restricting the allowed list of types."""
allowed = [ast.Expression, ast.Constant, ast.BinOp, ast.Add]
result = safe_eval("1+2", allowed_types=allowed)
self.assertEqual(result, 3)

with self.assertRaises(ValueError):
safe_eval("1*2", allowed_types=allowed)


if __name__ == "__main__":
unittest.main()
Loading