diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 62808a9498..99ad05c6a9 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -898,6 +898,67 @@ def _insert_complex_io_adapters( partitioned_module.recompile() +def _apply_dynamic_shape_bounds( + gm: torch.fx.GraphModule, + sample_arg_inputs: Sequence[Input], + sample_kwarg_inputs: dict[Any, Any], +) -> None: + """Propagate user Input min/max bounds into the FX shape_env. + + This lets explicit torch_tensorrt.Input bounds constrain exported programs + that otherwise carry broad Dim.DYNAMIC ranges. + """ + from torch.utils._sympy.value_ranges import ValueRanges + + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + + sample_by_name: dict[str, Input] = {} + for i, node in enumerate(placeholders): + if i < len(sample_arg_inputs): + inp = sample_arg_inputs[i] + if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC: + sample_by_name[node.target] = inp + + for name, inp in sample_kwarg_inputs.items(): + if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC: + sample_by_name[name] = inp + + if not sample_by_name: + return + + updated_syms: set = set() + for node in placeholders: + if node.target not in sample_by_name: + continue + + sample_input = sample_by_name[node.target] + fake_val = node.meta.get("val") + if not isinstance(fake_val, torch.Tensor): + continue + + min_shape = sample_input.shape["min_shape"] + max_shape = sample_input.shape["max_shape"] + + for d, dim in enumerate(fake_val.size()): + if not isinstance(dim, torch.SymInt) or d >= len(min_shape): + continue + + expr = dim.node.expr + if expr in updated_syms: + continue + + shape_env = dim.node.shape_env + if expr not in shape_env.var_to_range: + continue + + old_range = shape_env.var_to_range[expr] + lower = max(old_range.lower, min_shape[d]) + upper = min(old_range.upper, max_shape[d]) + shape_env.var_to_range[expr] = ValueRanges(lower=lower, upper=upper) + updated_syms.add(expr) + logger.debug("Updated shape_env range for %s: [%s, %s]", expr, lower, upper) + + @fn_supports_debugger # type: ignore[misc] def compile_module( gm: torch.fx.GraphModule, @@ -929,6 +990,8 @@ def compile_module( if sample_kwarg_inputs is None: sample_kwarg_inputs = {} + _apply_dynamic_shape_bounds(gm, sample_arg_inputs, sample_kwarg_inputs) + # Configure user compilation settings to converters. CONVERTERS.set_compilation_settings(settings) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 112b04c187..419ed46715 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -816,6 +816,7 @@ def aten_ops_rsqrt( ) +@dynamo_tensorrt_converter(operator.neg, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.neg.default, supports_dynamic_shapes=True) def aten_ops_neg( ctx: ConversionContext, @@ -2223,6 +2224,7 @@ def aten_ops_maximum( ) +@dynamo_tensorrt_converter(torch.sym_min, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.minimum.default, supports_dynamic_shapes=True) def aten_ops_minimum( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 7b770ab68b..7400b21318 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -12,6 +12,7 @@ from .complex_graph_rewrite import complex_graph_detection from .constant_folding import constant_fold +from .eliminate_sym_min_int64_max import eliminate_sym_min_int64_max from .force_causal_efficient_attention import force_causal_efficient_attention from .fuse_prims_broadcast import fuse_prims_broadcast from .pass_manager import DynamoPassManager @@ -23,6 +24,7 @@ from .replace_fused_rms_norm import replace_fused_rms_norm from .replace_max_pool_with_indices import replace_max_pool_with_indices from .rule_based_autocast import rule_based_autocast +from .normalize_negative_slice_stop import normalize_negative_slice_stop pre_lowering_pass_list = [ remove_detach, @@ -41,6 +43,8 @@ remove_num_users_is_0_nodes, complex_graph_detection, force_causal_efficient_attention, + eliminate_sym_min_int64_max, + normalize_negative_slice_stop, ] if not is_tegra_platform(): diff --git a/py/torch_tensorrt/dynamo/lowering/passes/eliminate_sym_min_int64_max.py b/py/torch_tensorrt/dynamo/lowering/passes/eliminate_sym_min_int64_max.py new file mode 100644 index 0000000000..b9a6fa97fc --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/eliminate_sym_min_int64_max.py @@ -0,0 +1,50 @@ +import sys + +import torch +from torch.fx import GraphModule, Node + +from .pass_utils import clean_up_graph_after_modifications + + +_INT64_MAX = 2**63 - 1 +_SYM_MIN = getattr(torch, "sym_min", None) + + +def _is_int64_max(x: object) -> bool: + return isinstance(x, int) and x in (sys.maxsize, _INT64_MAX) + + +def eliminate_sym_min_int64_max( + gm: GraphModule, settings: object = None +) -> GraphModule: + """Remove no-op sym_min nodes where one operand is INT64_MAX. + + torch.export may emit sym_min(sym, INT64_MAX) for an effectively unbounded + symbolic value. That expression is equivalent to sym, and leaving it in the + graph can produce runtime calls to torch.sym_min with Tensor inputs. + """ + if _SYM_MIN is None: + return gm + + modified = False + for node in list(gm.graph.nodes): + if ( + node.op != "call_function" + or node.target is not _SYM_MIN + or len(node.args) < 2 + ): + continue + + lhs, rhs = node.args[:2] + if _is_int64_max(rhs) and isinstance(lhs, Node): + passthrough = lhs + elif _is_int64_max(lhs) and isinstance(rhs, Node): + passthrough = rhs + else: + continue + + node.replace_all_uses_with(passthrough) + gm.graph.erase_node(node) + modified = True + + return clean_up_graph_after_modifications(gm) if modified else gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/normalize_negative_slice_stop.py b/py/torch_tensorrt/dynamo/lowering/passes/normalize_negative_slice_stop.py new file mode 100644 index 0000000000..f68584675a --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/normalize_negative_slice_stop.py @@ -0,0 +1,88 @@ +import operator +from typing import Optional + +import torch +from torch.fx import GraphModule, Node + +from .pass_utils import clean_up_graph_after_modifications + + +def _negative_symint_operand(x: object) -> Optional[object]: + # Return n for symbolic bounds represented as -n. The caller rewrites + # that bound to dim_size - n, matching Python's negative indexing rules. + if ( + isinstance(x, Node) + and x.op == "call_function" + and x.target in (operator.neg, torch.ops.aten.neg.default) + and len(x.args) == 1 + ): + return x.args[0] + return None + + +def _rank(x: Node) -> Optional[int]: + val = x.meta.get("val") + if isinstance(val, torch.Tensor): + return val.dim() + if hasattr(val, "shape"): + return len(val.shape) + return None + + +def normalize_negative_slice_stop( + gm: GraphModule, settings: object = None +) -> GraphModule: + """Normalize negative symbolic slice bounds to positive dim-relative bounds. + + Python slicing accepts negative bounds such as x[-n:] or x[:-n]. TensorRT + shape expressions need the equivalent positive bound, dim_size - n. + """ + modified = False + + for node in list(gm.graph.nodes): + if node.op != "call_function" or node.target != torch.ops.aten.slice.Tensor: + continue + + args = list(node.args) + if len(args) < 3: + continue + + input_node, dim = args[:2] + if not isinstance(input_node, Node) or not isinstance(dim, int): + continue + + rank = _rank(input_node) + if rank is not None: + # Match PyTorch dim normalization for negative dims. + dim = dim % rank + + rewritten = False + # aten.slice.Tensor can appear as (input, dim, start) or + # (input, dim, start, stop, ...). Normalize either symbolic bound. + for bound_index in (2, 3): + if len(args) <= bound_index: + continue + + bound = args[bound_index] + positive_offset = _negative_symint_operand(bound) + if positive_offset is None: + continue + + with gm.graph.inserting_before(node): + dim_size = gm.graph.call_function( + torch.ops.aten.sym_size.int, args=(input_node, dim) + ) + # A negative symbolic bound -n becomes dim_size - n. + normalized_bound = gm.graph.call_function( + operator.sub, args=(dim_size, positive_offset) + ) + + args[bound_index] = normalized_bound + rewritten = True + + if rewritten: + args[1] = dim + node.args = tuple(args) + modified = True + + return clean_up_graph_after_modifications(gm) if modified else gm diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index b54d05a6cb..51b31c8da1 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -1,3 +1,6 @@ +import operator +import sys + import torch import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests @@ -278,6 +281,102 @@ def forward(self, x: torch.Tensor): self.assertTrue(True) +class TestNormalizeNegativeSliceStop(TestCase): + def test_normalizes_negative_symbolic_start_bound(self): + from torch_tensorrt.dynamo.lowering.passes.normalize_negative_slice_stop import ( + normalize_negative_slice_stop, + ) + + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.empty(2, 5, 3) + n = graph.placeholder("n") + neg = graph.call_function(operator.neg, args=(n,)) + sliced = graph.call_function(torch.ops.aten.slice.Tensor, args=(x, -2, neg)) + graph.output(sliced) + + gm = torch.fx.GraphModule({}, graph) + gm = normalize_negative_slice_stop(gm) + + slice_node = next( + node + for node in gm.graph.nodes + if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor + ) + self.assertEqual(slice_node.args[1], 1) + + normalized_start = slice_node.args[2] + self.assertEqual(normalized_start.op, "call_function") + self.assertEqual(normalized_start.target, operator.sub) + + dim_size, offset = normalized_start.args + self.assertEqual(dim_size.target, torch.ops.aten.sym_size.int) + self.assertEqual(dim_size.args[0], x) + self.assertEqual(dim_size.args[1], 1) + self.assertEqual(offset, n) + + def test_normalizes_negative_symbolic_stop_bound(self): + from torch_tensorrt.dynamo.lowering.passes.normalize_negative_slice_stop import ( + normalize_negative_slice_stop, + ) + + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.empty(2, 5, 3) + n = graph.placeholder("n") + neg = graph.call_function(torch.ops.aten.neg.default, args=(n,)) + sliced = graph.call_function(torch.ops.aten.slice.Tensor, args=(x, 1, 0, neg)) + graph.output(sliced) + + gm = torch.fx.GraphModule({}, graph) + gm = normalize_negative_slice_stop(gm) + + slice_node = next( + node + for node in gm.graph.nodes + if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor + ) + + normalized_stop = slice_node.args[3] + self.assertEqual(normalized_stop.op, "call_function") + self.assertEqual(normalized_stop.target, operator.sub) + + dim_size, offset = normalized_stop.args + self.assertEqual(dim_size.target, torch.ops.aten.sym_size.int) + self.assertEqual(dim_size.args[0], x) + self.assertEqual(dim_size.args[1], 1) + self.assertEqual(offset, n) + + +class TestEliminateSymMinInt64Max(TestCase): + def test_eliminates_noop_sym_min_int64_max(self): + if not hasattr(torch, "sym_min"): + self.skipTest("torch.sym_min is not available") + + from torch_tensorrt.dynamo.lowering.passes.eliminate_sym_min_int64_max import ( + eliminate_sym_min_int64_max, + ) + + graph = torch.fx.Graph() + x = graph.placeholder("x") + rhs_int64_max = graph.call_function(torch.sym_min, args=(x, sys.maxsize)) + lhs_int64_max = graph.call_function(torch.sym_min, args=(2**63 - 1, x)) + graph.output((rhs_int64_max, lhs_int64_max)) + + gm = torch.fx.GraphModule({}, graph) + gm = eliminate_sym_min_int64_max(gm) + + self.assertFalse( + any( + node.op == "call_function" and node.target is torch.sym_min + for node in gm.graph.nodes + ) + ) + + output_node = next(node for node in gm.graph.nodes if node.op == "output") + self.assertEqual(output_node.args[0], (x, x)) + + class TestRewriteEfficientAttention(TestCase): def test_force_causal_efficient_attention(self): class RewriteEfficientAttention(torch.nn.Module):