Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 3 additions & 3 deletions docs/arch/fusion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ via ``cls.func_name``:
@R.function
def main(x: R.Tensor((10, 20), "float32")):
with R.dataflow():
lv0 = R.call_tir(add, (x, const_1), out_sinfo=R.Tensor((10, 20), "float32"))
lv1 = R.call_tir(exp, (lv0,), out_sinfo=R.Tensor((10, 20), "float32"))
gv = R.call_tir(squeeze, (lv1,), out_sinfo=R.Tensor((10, 20), "float32"))
lv0 = R.call_tir(add, (x, const_1), out_ty=R.Tensor((10, 20), "float32"))
lv1 = R.call_tir(exp, (lv0,), out_ty=R.Tensor((10, 20), "float32"))
gv = R.call_tir(squeeze, (lv1,), out_ty=R.Tensor((10, 20), "float32"))
R.output(gv)
return gv

Expand Down
24 changes: 12 additions & 12 deletions docs/arch/tvmscript.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ These can be composed:
y: R.Tensor((128,), "float32")) -> R.Tensor((128,), "float32"):
with R.dataflow():
out = R.call_tir(cls.add_kernel, (x, y),
out_sinfo=R.Tensor((128,), "float32"))
out_ty=R.Tensor((128,), "float32"))
R.output(out)
return out

Expand Down Expand Up @@ -268,18 +268,18 @@ The Relax builder (``ir_builder/relax/ir.py``) provides:
**Emit**:

- ``R.emit(value)`` → emit a binding, returns a ``Var``
- ``R.emit_match_cast(value, struct_info)`` → emit with type assertion
- ``R.emit_match_cast(value, ty)`` → emit with type assertion

**Type annotations**:

- ``R.Tensor(shape, dtype)`` — tensor struct info
- ``R.Tuple(*fields)`` — tuple struct info
- ``R.Shape(values)`` — shape struct info
- ``R.Object()`` — opaque object struct info
- ``R.Tensor(shape, dtype)`` — tensor type
- ``R.Tuple(*fields)`` — tuple type
- ``R.Shape(values)`` — shape type
- ``R.Object()`` — opaque object type

**Calling conventions**:

- ``R.call_tir(func, args, out_sinfo)`` — call a TIR function
- ``R.call_tir(func, args, out_ty)`` — call a TIR function
- ``R.call_packed(name, *args)`` — call a PackedFunc
- ``R.call_dps_packed(func, *args)`` — call using destination-passing style

Expand Down Expand Up @@ -326,7 +326,7 @@ It maintains:
Each IR dialect registers its own converters:

- ``src/script/printer/tirx/`` — converts PrimFunc, Buffer, SBlock, loops, expressions.
- ``src/script/printer/relax/`` — converts relax.Function, bindings, struct info, operators.
- ``src/script/printer/relax/`` — converts relax.Function, bindings, types, operators.
- ``src/script/printer/ir/`` — converts IRModule, shared types.

The final step calls ``DocToPythonScript()`` (``src/script/printer/doc_printer/python_doc_printer.cc``)
Expand Down Expand Up @@ -491,8 +491,8 @@ Function definition
# function body
return result

- ``R.Tensor(shape, dtype)`` — tensor type annotation (struct info).
- ``R.Tuple(...)``, ``R.Shape(...)``, ``R.Object()`` — other struct info types.
- ``R.Tensor(shape, dtype)`` — tensor type annotation.
- ``R.Tuple(...)``, ``R.Shape(...)``, ``R.Object()`` — other Relax type annotations.
- ``R.function(private=True)`` — marks the function as module-private.
- ``R.function(pure=False)`` — marks the function as having side effects.

Expand All @@ -514,10 +514,10 @@ Calling TIR functions

.. code-block:: python

out = R.call_tir(cls.my_kernel, (x, y), out_sinfo=R.Tensor((128,), "float32"))
out = R.call_tir(cls.my_kernel, (x, y), out_ty=R.Tensor((128,), "float32"))

- ``cls.my_kernel`` — references a TIR ``PrimFunc`` in the same module.
- ``out_sinfo`` — the struct info (shape and dtype) of the output tensor.
- ``out_ty`` — the type (shape and dtype) of the output tensor.

Control flow
~~~~~~~~~~~~
Expand Down
14 changes: 7 additions & 7 deletions docs/deep_dive/relax/dpl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ Any pattern can be further narrowed by attaching constraints:
- ``.has_dtype(dtype)`` -- the matched expression must have the given data type.
- ``.has_shape(shape)`` -- the matched expression must have the given shape.
- ``.has_attr(attrs)`` -- the matched call must carry the given attributes.
- ``.has_struct_info(struct_info)`` -- the matched expression must have the given struct info.
- ``.has_ty(ty)`` -- the matched expression must have the given type.

.. code:: python

Expand Down Expand Up @@ -286,7 +286,7 @@ The callback receives *variables* rather than expressions:
...

- ``matchings[pat]`` returns the **bound variable** (``Var``) whose right-hand
side matched ``pat``. The ``Var`` itself carries ``struct_info`` and can be
side matched ``pat``. The ``Var`` itself carries ``ty`` and can be
used directly in new expressions.
- ``bindings`` maps each ``Var`` to its bound ``Expr`` (the right-hand side),
useful when you need to inspect the original expression.
Expand All @@ -311,7 +311,7 @@ The callback receives *variables* rather than expressions:
W1 = matchings[w1]
W2 = matchings[w2]
W3 = matchings[w3]
width = W1.struct_info.shape[1]
width = W1.ty.shape[1]

concat_w = R.concat([W1, W2, W3], axis=1)
merged = R.matmul(inp, concat_w)
Expand Down Expand Up @@ -361,7 +361,7 @@ object that can be applied directly.
"my_fast_add",
A,
B,
sinfo_args=R.Tensor([16], "float32"),
ty_args=R.Tensor([16], "float32"),
)
return C

Expand Down Expand Up @@ -440,7 +440,7 @@ structurally (dtype restrictions, shape compatibility, attribute values, etc.):
def my_check_fn(ctx: PatternCheckContext) -> bool:
matmul_expr = ctx.annotated_expr["matmul"]
# Only accept float16 output
if matmul_expr.struct_info.dtype != "float16":
if matmul_expr.ty.dtype != "float16":
return False
return True

Expand All @@ -464,7 +464,7 @@ sub-function.

def check(ctx):
transpose_call = ctx.annotated_expr["wT"]
ndim = transpose_call.args[0].struct_info.ndim
ndim = transpose_call.args[0].ty.ndim
if ndim == -1:
return False
if ndim == 2 and transpose_call.attrs.axes is None:
Expand Down Expand Up @@ -513,7 +513,7 @@ Quick Reference
- Match ``R.call_packed``
* - ``make_fused_bias_activation_pattern(...)``
- Build ``op + bias + activation`` chain
* - ``.has_dtype()`` / ``.has_shape()`` / ``.has_attr()`` / ``.has_struct_info()``
* - ``.has_dtype()`` / ``.has_shape()`` / ``.has_attr()`` / ``.has_ty()``
- Attach constraints
* - ``|`` / ``&`` / ``~``
- Or / And / Not combinators
Expand Down
22 changes: 11 additions & 11 deletions docs/deep_dive/relax/learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ for the end-to-end model execution. The code block below shows a TVMScript imple
cls = Module
n = T.int64()
with R.dataflow():
lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((n, 10), dtype="float32"))
lv = R.call_tir(cls.linear, (x, w0, b0), out_ty=R.Tensor((n, 256), dtype="float32"))
lv1 = R.call_tir(cls.relu, (lv,), out_ty=R.Tensor((n, 256), dtype="float32"))
lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_ty=R.Tensor((n, 10), dtype="float32"))
R.output(lv2)
return lv2

Expand All @@ -194,10 +194,10 @@ Key Elements of Relax
This section will introduce the key elements of Relax abstraction and how it enables optimization
in ML compilers.

Structure Info
~~~~~~~~~~~~~~
Structure info is a new concept in Relax that represents the type of relax expressions. It can
be ``TensorStructInfo``, ``TupleStructInfo``, etc. In the above example, we use ``TensorStructInfo``
Type
~~~~
Type is the Relax representation of expression type information. It can
be ``TensorType``, ``TupleType``, etc. In the above example, we use ``TensorType``
(short in ``R.Tensor`` in TVMScript) to represent the shape and dtype of the tensor of the inputs,
outputs, and intermediate results.

Expand All @@ -210,7 +210,7 @@ Taking one line from the above code as an example:

.. code:: python

lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv = R.call_tir(cls.linear, (x, w0, b0), out_ty=R.Tensor((n, 256), dtype="float32"))

To explain what does ``R.call_tir`` work, let us review an equivalent low-level numpy
implementation of the operation, as follows:
Expand Down Expand Up @@ -238,9 +238,9 @@ Another important element in a relax function is the R.dataflow() scope annotati
.. code:: python

with R.dataflow():
lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((n, 256), dtype="float32"))
lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((n, 10), dtype="float32"))
lv = R.call_tir(cls.linear, (x, w0, b0), out_ty=R.Tensor((n, 256), dtype="float32"))
lv1 = R.call_tir(cls.relu, (lv,), out_ty=R.Tensor((n, 256), dtype="float32"))
lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_ty=R.Tensor((n, 10), dtype="float32"))
R.output(lv2)

Before we talk about the dataflow block, let us first introduce the concept of **pure** and
Expand Down
4 changes: 2 additions & 2 deletions docs/deep_dive/relax/tutorials/relax_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def forward(self, x):
relax.call_dps_packed(
"env.linear",
[x, fc1_weight, fc1_bias],
out_sinfo=relax.TensorStructInfo((n, 128), "float32"),
out_ty=relax.TensorType((n, 128), "float32"),
)
)
lv1 = bb.emit_te(topi.nn.relu, lv0)
Expand All @@ -263,7 +263,7 @@ def forward(self, x):
relax.call_tir(
tir_gv,
[lv1, fc2_weight, fc2_bias],
out_sinfo=relax.TensorStructInfo((n, 10), "float32"),
out_ty=relax.TensorType((n, 10), "float32"),
)
)
bb.emit_output(gv)
Expand Down
18 changes: 8 additions & 10 deletions docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def forward(self, x, y):
"""Takes PyTorch tensors, calls TIR, returns PyTorch tensors."""
x_tvm = self._convert_pytorch_to_tvm(x)
y_tvm = self._convert_pytorch_to_tvm(y)
result = self.call_tir(
self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), "float32")
)
result = self.call_tir(self.add_tir, [x_tvm, y_tvm], out_ty=R.Tensor((4,), "float32"))
return self._convert_tvm_to_pytorch(result)

# TIR functions are JIT-compiled at instantiation
Expand Down Expand Up @@ -157,7 +155,7 @@ def forward(self, x, weights):
out = self.call_tir(
self.matmul_tir,
[x_tvm, w_tvm],
out_sinfo=R.Tensor((x.shape[0], 3), "float32"),
out_ty=R.Tensor((x.shape[0], 3), "float32"),
)
logits = self._convert_tvm_to_pytorch(out)

Expand Down Expand Up @@ -231,15 +229,15 @@ def forward(self, x, weights, bias):
h = self.call_tir(
self.matmul_tir,
[x_tvm, w_tvm],
out_sinfo=R.Tensor((2, 3), "float32"),
out_ty=R.Tensor((2, 3), "float32"),
)
h_pt = self._convert_tvm_to_pytorch(h)

# 2. Packed function for bias add (simulating an external library)
h_biased = self.call_dps_packed(
"my_bias_add",
[h_pt, bias],
out_sinfo=R.Tensor((2, 3), "float32"),
out_ty=R.Tensor((2, 3), "float32"),
)

# 3. Python/PyTorch activation
Expand Down Expand Up @@ -294,7 +292,7 @@ def main(
h_bias = R.call_tir(
cls.bias_add_tir,
(h, b),
out_sinfo=R.Tensor((2, 4), "float32"),
out_ty=R.Tensor((2, 4), "float32"),
)
return R.nn.relu(h_bias)

Expand Down Expand Up @@ -364,8 +362,8 @@ def main(
x: R.Tensor((4, 8), "float32"),
) -> R.Tensor((4, 8), "float32"):
# The VM calls back into Python for these two ops
h = R.call_py_func("layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32"))
out = R.call_py_func("silu", (h,), out_sinfo=R.Tensor((4, 8), "float32"))
h = R.call_py_func("layer_norm", (x,), out_ty=R.Tensor((4, 8), "float32"))
out = R.call_py_func("silu", (h,), out_ty=R.Tensor((4, 8), "float32"))
return out

mod = HybridVMModule(device=tvm.cpu(0))
Expand Down Expand Up @@ -438,7 +436,7 @@ def add_relax(
# Python → TIR with symbolic output shape
n = T.int64()
x7 = torch.randn(7)
scaled = mod.call_tir("scale_tir", [x7], relax.TensorStructInfo((n,), "float32"))
scaled = mod.call_tir("scale_tir", [x7], relax.TensorType((n,), "float32"))
print("scale_tir(len=7):", scaled)
assert torch.allclose(torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5)

Expand Down
6 changes: 3 additions & 3 deletions docs/reference/security.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Subroutine Cache Hash Collision

``SubroutineMixin._get_subroutine()`` in ``python/tvm/relax/frontend/nn/subroutine.py``
used ``tvm_ffi.structural_hash`` as the sole cache lookup key without a subsequent
``structural_equal`` verification. If two different ``arg_sinfo`` values produced the
``structural_equal`` verification. If two different ``arg_ty`` values produced the
same 64-bit hash, the cache would return a previously compiled function with
mismatched parameter shapes, leading to silently incorrect compiled output.

Expand All @@ -73,11 +73,11 @@ The issue is primarily a correctness defect rather than a practically exploitabl
security vulnerability.

**Root Cause**: The subroutine cache (``cls._gvar``) was keyed by
``(structural_hash(arg_sinfo, map_free_vars=True), is_dataflow)``.
``(structural_hash(arg_ty, map_free_vars=True), is_dataflow)``.
A hash match was treated as proof of structural equality, skipping the necessary
``structural_equal`` check.

**Fix**: The cache now stores a list of ``(arg_sinfo, result)`` pairs per hash bucket.
**Fix**: The cache now stores a list of ``(arg_ty, result)`` pairs per hash bucket.
On lookup, each candidate is verified with ``structural_equal`` before returning.
This follows the standard hash-table pattern: hash for bucket selection, equality
for final verification.
Loading
Loading