diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index eb924be63d94..5be6e7c08f13 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include @@ -43,6 +43,48 @@ namespace tvm { // Forward-declare VirtualDevice to avoid circular imports. class VirtualDevice; +/*! + * \brief Type is the base type of all types. + * + * TVM's type system contains following subclasses: + * + * - PrimType: type of primitive type values used in the low-level IR. + * - FuncType: type of a function. + * - TensorType: type of certain Tensor values in the expression. + * + * There are also advanced types to support generic(polymorphic types). + * \sa Type + */ +class TypeNode : public ffi::Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + // span do not participate in structural equal and hash. + refl::ObjectDef().def_ro("span", &TypeNode::span, refl::DefaultValue(Span()), + refl::AttachFieldFlag::SEqHashIgnore()); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + static constexpr const uint32_t _type_child_slots = 14; + TVM_FFI_DECLARE_OBJECT_INFO("ir.Type", TypeNode, ffi::Object); +}; + +/*! + * \brief Managed reference to TypeNode. + * \sa TypeNode + */ +class Type : public ffi::ObjectRef { + public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode); +}; + /*! * \brief Base type of all the expressions. * \sa Expr @@ -55,11 +97,23 @@ class BaseExprNode : public ffi::Object { */ mutable Span span; + /*! + * \brief The deduced or annotated type of the expression. + * + * This field is intentionally nullable because type information may + * be populated by later analysis passes instead of expression + * constructors. + */ + mutable Type ty; + static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - // span do not participate in structural equal and hash. - refl::ObjectDef().def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()), - refl::AttachFieldFlag::SEqHashIgnore()); + // span and ty do not participate in structural equal and hash. + refl::ObjectDef() + .def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()), + refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("ty", &BaseExprNode::ty, refl::DefaultValue(Type()), + refl::AttachFieldFlag::SEqHashIgnore()); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; @@ -414,17 +468,9 @@ TVM_DLL PrimExpr operator~(PrimExpr a); */ class RelaxExprNode : public BaseExprNode { public: - /*! - * \brief Stores the result of structure information of the - * expression that encapsulate both static shape and - * runtime information such as shape. - */ - mutable ffi::Optional struct_info_ = ffi::Optional(); - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("struct_info_", &RelaxExprNode::struct_info_, - refl::AttachFieldFlag::SEqHashIgnore()); + refl::ObjectDef(); } static constexpr const uint32_t _type_child_slots = 22; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index db6151fda702..044430deecca 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -52,6 +52,7 @@ #include #include #include +#include #include #include @@ -59,48 +60,6 @@ namespace tvm { -/*! - * \brief Type is the base type of all types. - * - * TVM's type system contains following subclasses: - * - * - PrimType: type of primitive type values used in the low-level IR. - * - FuncType: type of a function. - * - TensorType: type of certain Tensor values in the expression. - * - * There are also advanced types to support generic(polymorphic types). - * \sa Type - */ -class TypeNode : public ffi::Object { - public: - /*! - * \brief Span that points to the original source code. - * Reserved debug information. - */ - mutable Span span; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - // span do not participate in structural equal and hash. - refl::ObjectDef().def_ro("span", &TypeNode::span, refl::DefaultValue(Span()), - refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const uint32_t _type_child_slots = 14; - TVM_FFI_DECLARE_OBJECT_INFO("ir.Type", TypeNode, ffi::Object); -}; - -/*! - * \brief Managed reference to TypeNode. - * \sa TypeNode - */ -class Type : public ffi::ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode); -}; - /*! * \brief Primitive data types used in the low-level IR. * diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index a34a583b2e4c..82b15e2d279a 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -70,54 +70,53 @@ TVM_DLL bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Arra TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, const arith::Analyzer& ana); //----------------------------------- -// Foundational StructInfo analysis +// Foundational Type analysis //----------------------------------- /*! - * \brief Get the corresponding static type from a given struct info. - * \param info The struct info. + * \brief Get the corresponding static type from a given type. + * \param info The type. * \return the corresponding static type. */ -TVM_DLL Type GetStaticType(const StructInfo& info); +TVM_DLL Type GetStaticType(const Type& info); /*! - * \brief Get the corresponding struct info from static type. + * \brief Get the corresponding type from static type. * \param type The input type - * \return the corresponding struct info. + * \return the corresponding type. */ -TVM_DLL StructInfo StructInfoFromType(const Type& type); +TVM_DLL Type TypeFromStaticType(const Type& type); /*! - * \return Derive the call's ret value struct info from inputs. - * \param finfo The function struct info. + * \return Derive the call's ret value type from inputs. + * \param finfo The function type. * \param call The call expression to be derived. * \param ctx The builder context. - * \return The derived struct info of the call. + * \return The derived type of the call. * \note call->op field is ignored during derivation and we only rely on information - * presented by func_sinfo. + * presented by func_ty. */ -TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, - const BlockBuilder& ctx); +TVM_DLL Type DeriveCallRetType(const FuncType& finfo, const Call& call, const BlockBuilder& ctx); /*! - * \brief Derive the call's ret value struct info using a caller-provided analyzer. - * \param finfo The function struct info. + * \brief Derive the call's ret value type using a caller-provided analyzer. + * \param finfo The function type. * \param call The call expression to be derived. * \param ctx The builder context. * \param ana Context analyzer to prove symbolic expression equality. - * \return The derived struct info of the call. + * \return The derived type of the call. */ -TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, - const BlockBuilder& ctx, const arith::Analyzer& ana); +TVM_DLL Type DeriveCallRetType(const FuncType& finfo, const Call& call, const BlockBuilder& ctx, + const arith::Analyzer& ana); /*! * \brief Erase the info to a corresponding more coarse grained - * struct info that is still well-defined(with all the vars in scope). + * type that is still well-defined(with all the vars in scope). * - * When we are returning a StructInfo to another scope, - * it is important to remember that StructInfo may carry + * When we are returning a Type to another scope, + * it is important to remember that Type may carry * dependencies on var that is not defined the other scope. * * In such cases, it is important to call EraseToWellDefined to get - * another StructInfo that **only** contains the vars that are defined + * another Type that **only** contains the vars that are defined * in the target scope. * * For example, consider the following function @@ -150,10 +149,10 @@ TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Ca * will give us R.Tensor[(3, m)], where n get replaced by 2. * * Use this function in the following scenarios: - * - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr - * - Decide the deduced return struct_info of a function that can be fully decided by params. + * - Decide the ty of expr with sub-scopes, such as If, SeqExpr + * - Decide the deduced return ty of a function that can be fully decided by params. * - * \param info The struct info. + * \param info The type. * \param f_shape_var_map callback function to specify * whether a symbolic shape var is defined and the value it maps to, * return nullopt if var is undefined. @@ -161,15 +160,15 @@ TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Ca * whether a var is defined in the target scope and the value it maps to, * return nullopt if var is undefined. * - * \return the corresponding erased struct info. + * \return the corresponding erased type. */ -TVM_DLL StructInfo EraseToWellDefined( - const StructInfo& info, +TVM_DLL Type EraseToWellDefined( + const Type& info, std::function(const tirx::Var& var)> f_shape_var_map = nullptr, std::function(const Var& var)> f_var_map = nullptr); /*! * \brief EraseToWellDefined overload using a caller-provided analyzer. - * \param info The struct info. + * \param info The type. * \param f_shape_var_map callback function to specify * whether a symbolic shape var is defined and the value it maps to, * return nullopt if var is undefined. @@ -177,16 +176,15 @@ TVM_DLL StructInfo EraseToWellDefined( * whether a var is defined in the target scope and the value it maps to, * return nullopt if var is undefined. * \param ana Context analyzer to prove symbolic expression equality. - * \return the corresponding erased struct info. + * \return the corresponding erased type. */ -TVM_DLL StructInfo EraseToWellDefined( - const StructInfo& info, - std::function(const tirx::Var& var)> f_shape_var_map, +TVM_DLL Type EraseToWellDefined( + const Type& info, std::function(const tirx::Var& var)> f_shape_var_map, std::function(const Var& var)> f_var_map, const arith::Analyzer& ana); /*! * \brief EraseToWellDefined variant with map. - * \param info The struct info. + * \param info The type. * \param shape_var_map map to specify * whether a symbolic shape var is defined and the value it maps to, * return nullopt if var is undefined. @@ -194,14 +192,13 @@ TVM_DLL StructInfo EraseToWellDefined( * whether a var is defined in the target scope and the value it maps to, * return nullopt if var is undefined. * - * \return the corresponding erased struct info. + * \return the corresponding erased type. */ -TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, - ffi::Map shape_var_map, - ffi::Map var_map); +TVM_DLL Type EraseToWellDefined(const Type& info, ffi::Map shape_var_map, + ffi::Map var_map); /*! * \brief EraseToWellDefined map overload using a caller-provided analyzer. - * \param info The struct info. + * \param info The type. * \param shape_var_map map to specify * whether a symbolic shape var is defined and the value it maps to, * return nullopt if var is undefined. @@ -209,11 +206,10 @@ TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, * whether a var is defined in the target scope and the value it maps to, * return nullopt if var is undefined. * \param ana Context analyzer to prove symbolic expression equality. - * \return the corresponding erased struct info. + * \return the corresponding erased type. */ -TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, - ffi::Map shape_var_map, - ffi::Map var_map, const arith::Analyzer& ana); +TVM_DLL Type EraseToWellDefined(const Type& info, ffi::Map shape_var_map, + ffi::Map var_map, const arith::Analyzer& ana); /*! * \brief Fine grained result of base check. @@ -221,10 +217,10 @@ TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, * This analysis comes with different levels of checking failures * that can help to customize the compilation decisions. * - * For a given pair of lhs_struct_info, rhs_struct_info. We adopt + * For a given pair of lhs_ty, rhs_ty. We adopt * the following terminology: - * - LSet = {value | value matches lhs_struct_info} - * - RSet = {value | value matches rhs_struct_info} + * - LSet = {value | value matches lhs_ty} + * - RSet = {value | value matches rhs_ty} * * See the definition of each level below. */ @@ -242,13 +238,13 @@ enum class BaseCheckResult { /*! * \brief WLSet is not superset of RSet because of mismatch in value information. * - * L1-level mismatches in params of FuncStructInfo is categorized as - * If lhs is FuncStructInfo, then L1-level mismatch in its params + * L1-level mismatches in params of FuncType is categorized as + * If lhs is FuncType, then L1-level mismatch in its params * is categorized as L2-level mismatch for lhs. * * Design considerations for functions: * - (a) We want to be able to erase type/value in function signature - * when we unify function struct info and preserve simpler representations. + * when we unify function type and preserve simpler representations. * - (b) We automatically insert match_cast at function boundary, so * we can erase (int)->int argument as (object)->int. * The input shape/type mismatch will be detected by runtime checks at function boundary. @@ -267,47 +263,46 @@ enum class BaseCheckResult { * * This function returns fine-grained base-check result on reasons of failure. * - * \param base The base struct info. - * \param derived The derived struct info. + * \param base The base type. + * \param derived The derived type. * \return Whether the relation holds. * * \sa BaseCheckResult */ -TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived); +TVM_DLL BaseCheckResult TypeBaseCheck(const Type& base, const Type& derived); /*! * \brief Run a base check using a caller-provided analyzer. - * \param base The base struct info. - * \param derived The derived struct info. + * \param base The base type. + * \param derived The derived type. * \param ana Context analyzer to prove symbolic expression equality. * \return Whether the relation holds. * * \sa BaseCheckResult */ -TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, - const arith::Analyzer& ana); +TVM_DLL BaseCheckResult TypeBaseCheck(const Type& base, const Type& derived, + const arith::Analyzer& ana); /*! - * \brief Check the relation of two struct info to see if one subsumes another one. + * \brief Check the relation of two type to see if one subsumes another one. * - * \param base The base struct info. - * \param derived The derived struct info. + * \param base The base type. + * \param derived The derived type. * \return Whether the relation holds. */ -TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived); +TVM_DLL bool IsBaseOf(const Type& base, const Type& derived); /*! - * \brief Check whether one struct info subsumes another using a caller-provided analyzer. - * \param base The base struct info. - * \param derived The derived struct info. + * \brief Check whether one type subsumes another using a caller-provided analyzer. + * \param base The base type. + * \param derived The derived type. * \param ana Context analyzer to prove symbolic expression equality. * \return Whether the relation holds. */ -TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, - const arith::Analyzer& ana); +TVM_DLL bool IsBaseOf(const Type& base, const Type& derived, const arith::Analyzer& ana); /*! * \brief Return the condition for which base is a superset of derived * - * This function returns finer-grained conditions for kFailL2 than StructInfoBaseCheck + * This function returns finer-grained conditions for kFailL2 than TypeBaseCheck * * If the returned expression is true, or simplifies to true, then * base is a superset of derived. If the returned expression is @@ -318,22 +313,22 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, * expression in terms of the symbolic variables available in `base` * and `derived`. * - * \param base The base struct info. - * \param derived The derived struct info. + * \param base The base type. + * \param derived The derived type. * \return Whether base is a base of derived. * * \sa BaseCheckResult */ -TVM_DLL PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const StructInfo& derived); +TVM_DLL PrimExpr TypeBaseCheckPrecondition(const Type& base, const Type& derived); /*! - * \brief Unify the two struct info to their least common ancestor. + * \brief Unify the two type to their least common ancestor. * * \param lhs The left operand. * \param rhs The right operand. * \return The unified information. */ -TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs); +TVM_DLL Type TypeLCA(const Type& lhs, const Type& rhs); /*! * \brief Unify two struct infos using a caller-provided analyzer. * \param lhs The left operand. @@ -341,30 +336,29 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs); * \param ana Context analyzer to prove symbolic expression equality. * \return The unified information. */ -TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, - const arith::Analyzer& ana); +TVM_DLL Type TypeLCA(const Type& lhs, const Type& rhs, const arith::Analyzer& ana); /*! - * \brief Get the TIR variables that appear in the input struct info. + * \brief Get the TIR variables that appear in the input type. * The returned list is deduplicated - each TIR variable will appear at most once. - * \param sinfo The struct info object to be analyzed. - * \return The list of TIR variables that appear in the input struct info. + * \param ty The type object to be analyzed. + * \return The list of TIR variables that appear in the input type. */ -TVM_DLL ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo); +TVM_DLL ffi::Array TIRVarsInType(const Type& ty); /*! - * \brief Get the TIR variables that appear in the input struct info. + * \brief Get the TIR variables that appear in the input type. * * Returns all symbolic variables that are definable based on, and - * used within, the StructInfo. + * used within, the Type. * - * \param sinfo The struct info object to be analyzed. + * \param ty The type object to be analyzed. * * \return A tuple of (definable,used) TIR variables. Both lists are * deduplicated, each TIR variable will appear at most once, and in * order of occurrence. */ -TVM_DLL ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); +TVM_DLL ffi::Array DefinableTIRVarsInType(const Type& ty); /*! \brief Collect expressions whose usage requires them to be non-negative * @@ -373,11 +367,11 @@ TVM_DLL ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sin * to generate assertions prior to calling a kernel, or to provide * assumptions within a kernel that may be useful for simplification. * - * \param sinfo The struct info to be analyzed + * \param ty The type to be analyzed * * \return A list of non-negative expressions. */ -TVM_DLL ffi::Array CollectNonNegativeExpressions(const StructInfo& sinfo); +TVM_DLL ffi::Array CollectNonNegativeExpressions(const Type& ty); /*! * \brief Get the TIR variables that defined in the input function. @@ -609,7 +603,7 @@ TVM_DLL bool HasReshapePattern(const tirx::PrimFunc& func); * can be ignored in the check (must be a Var or GlobalVar). * \return The impure expression, if one exists within the given * expression. Otherwise, std::nullopt. - * \note Relies on StructInfo annotations, so ensure that the module has been normalized first. + * \note Relies on Type annotations, so ensure that the module has been normalized first. * Also, an impure call in a *nested* function does *not* mean that the outer expression contains * an impure call--it only does if the nested function is *later called*. */ @@ -623,7 +617,7 @@ TVM_DLL ffi::Optional FindImpureCall( * the caller can pass the function's name so recursive calls * can be ignored in the check (must be a Var or GlobalVar). * \return A boolean indicating if the expression contains any impure calls. - * \note Relies on StructInfo annotations, so ensure that the module has been normalized first. + * \note Relies on Type annotations, so ensure that the module has been normalized first. * Also, an impure call in a *nested* function does *not* mean that the outer expression contains * an impure call--it only does if the nested function is *later called*. */ @@ -638,12 +632,12 @@ TVM_DLL bool ContainsImpureCall( * path. Use \ref CheckWellFormed for a boolean answer. * * \param obj The IRModule or relax::Function to check. - * \param check_struct_info If true, verify that every Expr has struct_info populated. - * \note By default the structure info is always checked. It is only in test cases - * where `check_struct_info` might be false, so that other well-formed requirements - * will be well tested and will not be blocked by not having structure info. + * \param check_ty If true, verify that every Expr has ty populated. + * \note By default the type information is always checked. It is only in test cases + * where `check_ty` might be false, so that other well-formed requirements + * will be well tested and will not be blocked by not having type information. */ -TVM_DLL void WellFormed(ffi::Variant obj, bool check_struct_info = true); +TVM_DLL void WellFormed(ffi::Variant obj, bool check_ty = true); /*! * \brief Return whether an IRModule or Function is well-formed. @@ -652,10 +646,10 @@ TVM_DLL void WellFormed(ffi::Variant obj, bool check_struct_ * violation. * * \param obj The IRModule or relax::Function to check. - * \param check_struct_info If true, verify that every Expr has struct_info populated. + * \param check_ty If true, verify that every Expr has ty populated. * \return true if the object is well-formed, false otherwise. */ -TVM_DLL bool CheckWellFormed(ffi::Variant obj, bool check_struct_info = true); +TVM_DLL bool CheckWellFormed(ffi::Variant obj, bool check_ty = true); /*! * \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks diff --git a/include/tvm/relax/attrs/distributed.h b/include/tvm/relax/attrs/distributed.h index 23b698eb3604..78e63b330637 100644 --- a/include/tvm/relax/attrs/distributed.h +++ b/include/tvm/relax/attrs/distributed.h @@ -25,7 +25,7 @@ #define TVM_RELAX_ATTRS_DISTRIBUTED_H_ #include -#include +#include #include namespace tvm { diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h index 740e8ed01fda..6adefa9a4372 100644 --- a/include/tvm/relax/binding_rewrite.h +++ b/include/tvm/relax/binding_rewrite.h @@ -47,8 +47,8 @@ class DataflowBlockRewriteNode : public ffi::Object { void Add(Binding binding); /*! \brief Insert an expression as VarBinding with variable name. */ void Add(ffi::String var_name, Expr expr, bool is_dfvar = false) { - auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr)) // - : Var(var_name, GetStructInfo(expr)); + auto var = is_dfvar ? DataflowVar(var_name, GetType(expr)) // + : Var(var_name, GetType(expr)); Add(VarBinding(std::move(var), std::move(expr))); } /*! \brief Insert an expression as VarBinding with automatic variable name. */ diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 8413686dc9df..ac20e2ad8f9a 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -202,11 +202,11 @@ class BlockBuilderNode : public ffi::Object { /*! * \brief Emit a MatchCast. * \param value The input value. - * \param struct_info The struct info to be matched. + * \param ty The type to be matched. * \param name_hint Name hint for the bound variable. * \return The variable bound to the MatchCast. */ - virtual Var EmitMatchCast(Expr value, StructInfo struct_info, ffi::String name_hint = "") = 0; + virtual Var EmitMatchCast(Expr value, Type ty, ffi::String name_hint = "") = 0; /*! * \brief Generate an output for the current dataflow block. @@ -230,7 +230,7 @@ class BlockBuilderNode : public ffi::Object { * \param expr The input expression. * \return The normalized expression. * - * \note Invariant: If any of the sub expr have struct_info field. + * \note Invariant: If any of the sub expr have ty field. * they must have already been normalized. */ virtual Expr Normalize(const Expr& expr) = 0; diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 58d46f04380b..27894da3addd 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -56,7 +56,7 @@ class OrPattern; class AndPattern; class NotPattern; class ShapePattern; -class StructInfoPattern; +class TypePattern; class DataTypePattern; class AttrPattern; class SameShapeConstraint; @@ -114,8 +114,8 @@ class DFPattern : public ffi::ObjectRef { TVM_DLL NotPattern operator~() const; /*! \brief Syntatic Sugar for creating an AttrPattern */ TVM_DLL AttrPattern HasAttr(const ffi::Map& attrs) const; - /*! \brief Syntatic Sugar for creating a StructInfoPattern */ - TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const; + /*! \brief Syntatic Sugar for creating a TypePattern */ + TVM_DLL TypePattern HasType(const Type& ty) const; /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const; /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ @@ -484,7 +484,7 @@ class CallPatternNode : public DFPatternNode { */ bool varg_default_wildcard; /*!< N(args) can be < N(real args) by the padding of Wildcard */ - // Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args + // Todo(relax-team): Dataflow pattern for Type, and match ty_args static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -770,28 +770,27 @@ class WildcardPattern : public DFPattern { }; /*! - * \brief Pattern for matching a certain struct info. - * \sa StructInfoPattern + * \brief Pattern for matching a certain type. + * \sa TypePattern */ -class StructInfoPatternNode : public DFPatternNode { +class TypePatternNode : public DFPatternNode { public: - DFPattern pattern; /*!< The pattern to match */ - StructInfo struct_info; /*!< The type to match */ + DFPattern pattern; /*!< The pattern to match */ + Type ty; /*!< The type to match */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("pattern", &StructInfoPatternNode::pattern) - .def_ro("struct_info", &StructInfoPatternNode::struct_info); + refl::ObjectDef() + .def_ro("pattern", &TypePatternNode::pattern) + .def_ro("ty", &TypePatternNode::ty); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.StructInfoPattern", StructInfoPatternNode, - DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.TypePattern", TypePatternNode, DFPatternNode); }; -class StructInfoPattern : public DFPattern { +class TypePattern : public DFPattern { public: - TVM_DLL StructInfoPattern(DFPattern pattern, StructInfo struct_info); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StructInfoPattern, DFPattern, StructInfoPatternNode); + TVM_DLL TypePattern(DFPattern pattern, Type ty); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TypePattern, DFPattern, TypePatternNode); }; /*! @@ -953,7 +952,7 @@ ExprPattern IsExpr(const Expr& expr); /*! \brief Syntatic Sugar for creating a ExprPattern base on an Op */ ExprPattern IsOp(const ffi::String& op_name); /*! \brief Syntatic Sugar for call_tir (return a tensor) */ -// Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +// Todo(relax-team): Dataflow pattern for Type, and match out_ty CallPattern IsCallTIR(const ffi::String& name, ffi::Optional args = std::nullopt); /*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ CallPattern IsCallTIR(const ffi::String& name, TuplePattern var_args); diff --git a/include/tvm/relax/dataflow_pattern_functor.h b/include/tvm/relax/dataflow_pattern_functor.h index 6816dbb92c93..dfe130a09c76 100644 --- a/include/tvm/relax/dataflow_pattern_functor.h +++ b/include/tvm/relax/dataflow_pattern_functor.h @@ -96,8 +96,7 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const StructInfoPatternNode* op, - Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; @@ -132,7 +131,7 @@ class DFPatternFunctor { RELAX_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); - RELAX_DFPATTERN_FUNCTOR_DISPATCH(StructInfoPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode); @@ -166,7 +165,7 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const ShapePatternNode* op) override; void VisitDFPattern_(const TupleGetItemPatternNode* op) override; void VisitDFPattern_(const TuplePatternNode* op) override; - void VisitDFPattern_(const StructInfoPatternNode* op) override; + void VisitDFPattern_(const TypePatternNode* op) override; void VisitDFPattern_(const WildcardPatternNode* op) override; void VisitDFPattern_(const VarPatternNode* op) override; diff --git a/include/tvm/relax/distributed/axis_group_graph.h b/include/tvm/relax/distributed/axis_group_graph.h index 86b34b71352a..0e15b2016489 100644 --- a/include/tvm/relax/distributed/axis_group_graph.h +++ b/include/tvm/relax/distributed/axis_group_graph.h @@ -21,7 +21,7 @@ #define TVM_RELAX_DISTRIBUTED_AXIS_GROUP_GRAPH_H_ #include -#include +#include #include #include #include diff --git a/include/tvm/relax/distributed/struct_info.h b/include/tvm/relax/distributed/type.h similarity index 76% rename from include/tvm/relax/distributed/struct_info.h rename to include/tvm/relax/distributed/type.h index 81fdf0fb3ffc..c99e2aa3db92 100644 --- a/include/tvm/relax/distributed/struct_info.h +++ b/include/tvm/relax/distributed/type.h @@ -18,15 +18,15 @@ */ /*! - * \file tvm/relax/distributed/struct_info.h - * \brief Struct info for DTensor (Distributed Tensor) + * \file tvm/relax/distributed/type.h + * \brief Type definitions for DTensor (Distributed Tensor) */ -#ifndef TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_ -#define TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_ +#ifndef TVM_RELAX_DISTRIBUTED_TYPE_H_ +#define TVM_RELAX_DISTRIBUTED_TYPE_H_ #include -#include +#include namespace tvm { namespace relax { namespace distributed { @@ -111,14 +111,14 @@ class Placement : public ffi::ObjectRef { }; /*! - * \brief StructInfo of DTensor (Distributed Tensor). + * \brief Type of DTensor (Distributed Tensor). */ -class DTensorStructInfoNode : public StructInfoNode { +class DTensorTypeNode : public DependentTypeNode { public: /*! - * \brief The struct info inherited from TensorStructInfo + * \brief The tensor type carried by the DTensor type. */ - TensorStructInfo tensor_sinfo; + TensorType tensor_ty; /*! * \brief The device mesh of the tensor. */ @@ -130,36 +130,35 @@ class DTensorStructInfoNode : public StructInfoNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("device_mesh", &DTensorStructInfoNode::device_mesh) - .def_ro("placement", &DTensorStructInfoNode::placement) - .def_ro("tensor_sinfo", &DTensorStructInfoNode::tensor_sinfo); + refl::ObjectDef() + .def_ro("device_mesh", &DTensorTypeNode::device_mesh) + .def_ro("placement", &DTensorTypeNode::placement) + .def_ro("tensor_ty", &DTensorTypeNode::tensor_ty); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.DTensorStructInfo", DTensorStructInfoNode, - StructInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.DTensorType", DTensorTypeNode, DependentTypeNode); }; /*! - * \brief Managed reference to DTensorStructInfoNode. - * \sa DTensorStructInfoNode + * \brief Managed reference to DTensorTypeNode. + * \sa DTensorTypeNode */ -class DTensorStructInfo : public StructInfo { +class DTensorType : public Type { public: /*! * \brief Construction with device mesh and placement. - * \param tensor_sinfo The struct info inherited from TensorStructInfo + * \param tensor_ty The tensor type carried by the DTensor type. * \param device_mesh The device mesh of the tensor. * \param placement The placement of the tensor among the device mesh. * \param span The span of the AST. */ - TVM_DLL DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, - Placement placement, Span span = Span()); + TVM_DLL DTensorType(TensorType tensor_ty, DeviceMesh device_mesh, Placement placement, + Span span = Span()); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DTensorStructInfo, StructInfo, DTensorStructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DTensorType, Type, DTensorTypeNode); }; } // namespace distributed } // namespace relax } // namespace tvm -#endif // TVM_RELAX_DISTRIBUTED_STRUCT_INFO_H_ +#endif // TVM_RELAX_DISTRIBUTED_TYPE_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 6da9cb1692a3..937091255b6f 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -36,8 +36,6 @@ namespace tvm { namespace relax { -using Expr = RelaxExpr; -using ExprNode = RelaxExprNode; /*! * \brief The unique identifier of variables. * @@ -76,64 +74,6 @@ class Id : public ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Id, ffi::ObjectRef, IdNode); }; -/*! - * \brief Base type of all structure information. - * - * StructInfo stores possible structure information - * deduced during compile-time. It encapsulates - * both static type and runtime information such - * as shape. - * - * StructInfo of each non-primitive Expr can be - * deduced during compilation in a "best-effort" manner. - * - * When struct_info appears in function parameter and return - * signatures. They will imply a runtime check that matches - * the structure information with the value. - * - * When it appears in Expr, they follow "assume-semantics", - * which means the compiler will take the deduced information as it is - * and only do best effort prove and checks. - * - * Each struct info can be uniquely erased to a static-type. - * The compiler will still compile the code(with less information) - * when we erase to the static type. - * - * If an StructInfo contains an Expr field, then that field - * must be normalized already through NormalizeArg. - * This invariant will be checked in constructors - * and help us to simplify our assumption - * during struct info deduction. - */ -class StructInfoNode : public ffi::Object { - public: - /*! - * \brief Span that points to the original source code. - * Reserved debug information. - */ - mutable Span span; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("span", &StructInfoNode::span, - refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const uint32_t _type_child_slots = 7; - TVM_FFI_DECLARE_OBJECT_INFO("ir.StructInfo", StructInfoNode, ffi::Object); -}; - -/*! - * \brief Managed reference to StructInfoNode. - * \sa StructInfoNode - */ -class StructInfo : public ffi::ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StructInfo, ffi::ObjectRef, StructInfoNode); -}; - /*! * \brief Call corresponds to callable invocation. * Corresponds to operation in computational graph terminology. @@ -155,16 +95,16 @@ class CallNode : public ExprNode { Attrs attrs; /*! - * \brief The structure info arguments of a CallNode. - * sinfo_args is by default designed to be non-empty only for intrinsic op (e.g., + * \brief The type information arguments of a CallNode. + * ty_args is by default designed to be non-empty only for intrinsic op (e.g., * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main - * usage of structure info inference. + * usage of type information inference. * - * Regular ops also at times may have sinfo_args defined to specialize partial - * or complete structure info. Like VDevice customization with mixed input memory_scopes. + * Regular ops also at times may have ty_args defined to specialize partial + * or complete type information. Like VDevice customization with mixed input memory_scopes. * The customized pass can set this info and operator specific inference will respect it. */ - ffi::Array sinfo_args; + ffi::Array ty_args; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -172,7 +112,7 @@ class CallNode : public ExprNode { .def_ro("op", &CallNode::op) .def_ro("args", &CallNode::args) .def_ro("attrs", &CallNode::attrs) - .def_ro("sinfo_args", &CallNode::sinfo_args); + .def_ro("ty_args", &CallNode::ty_args); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Call", CallNode, ExprNode); }; @@ -184,11 +124,11 @@ class Call : public Expr { * \param op The operator to be invoked. * \param args The arguments of the call. * \param attrs The attributes of the call node. - * \param sinfo_args The structure info arguments passed to a function. + * \param ty_args The type information arguments passed to a function. * \param span The source span of the expression. */ TVM_DLL Call(Expr op, ffi::Array args, Attrs attrs = Attrs(), - ffi::Array sinfo_args = ffi::Array(), Span span = Span()); + ffi::Array ty_args = ffi::Array(), Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, Expr, CallNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); @@ -199,12 +139,11 @@ class Call : public Expr { * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -Call WithFields( - Call call, ffi::Optional opt_op = ffi::Optional(), - ffi::Optional> opt_args = ffi::Optional>(), - ffi::Optional opt_attrs = ffi::Optional(), - ffi::Optional> opt_sinfo_args = ffi::Optional>(), - ffi::Optional opt_span = ffi::Optional()); +Call WithFields(Call call, ffi::Optional opt_op = ffi::Optional(), + ffi::Optional> opt_args = ffi::Optional>(), + ffi::Optional opt_attrs = ffi::Optional(), + ffi::Optional> opt_ty_args = ffi::Optional>(), + ffi::Optional opt_span = ffi::Optional()); /*! \brief Tuple container */ class TupleNode : public ExprNode { @@ -353,7 +292,7 @@ class VarNode : public LeafExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("vid", &VarNode::vid); - // customize structural equal and hash to include struct_info_ + // customize structural equal and hash to include ty refl::TypeAttrDef() .def("__s_equal__", &VarNode::SEqual) .def("__s_hash__", &VarNode::SHash); @@ -361,14 +300,13 @@ class VarNode : public LeafExprNode { bool SEqual(const VarNode* other, ffi::TypedFunction equal) const { - return equal(vid, other->vid, false, "vid") && - equal(struct_info_, other->struct_info_, false, "struct_info_"); + return equal(vid, other->vid, false, "vid") && equal(ty, other->ty, false, "ty"); } int64_t SHash(int64_t init_hash, ffi::TypedFunction hash) const { int64_t hash_value = init_hash; hash_value = hash(vid, hash_value, false); - hash_value = hash(struct_info_, hash_value, false); + hash_value = hash(ty, hash_value, false); return hash_value; } @@ -379,12 +317,10 @@ class VarNode : public LeafExprNode { class Var : public LeafExpr { public: - TVM_DLL explicit Var(ffi::String name_hint, ffi::Optional struct_info_annotation, - Span span = Span()) - : Var(Id(name_hint), struct_info_annotation, span) {} + TVM_DLL explicit Var(ffi::String name_hint, ffi::Optional ty_annotation, Span span = Span()) + : Var(Id(name_hint), ty_annotation, span) {} - TVM_DLL explicit Var(Id vid, ffi::Optional struct_info_annotation, - Span span = Span()); + TVM_DLL explicit Var(Id vid, ffi::Optional ty_annotation, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Var, LeafExpr, VarNode); VarNode* CopyOnWrite(); @@ -406,12 +342,11 @@ class DataflowVarNode : public VarNode { class DataflowVar : public Var { public: - TVM_DLL explicit DataflowVar(ffi::String name_hint, - ffi::Optional struct_info_annotation, Span span = Span()) - : DataflowVar(Id(name_hint), struct_info_annotation, span) {} + TVM_DLL explicit DataflowVar(ffi::String name_hint, ffi::Optional ty_annotation, + Span span = Span()) + : DataflowVar(Id(name_hint), ty_annotation, span) {} - TVM_DLL explicit DataflowVar(Id vid, ffi::Optional struct_info_annotation, - Span span = Span()); + TVM_DLL explicit DataflowVar(Id vid, ffi::Optional ty_annotation, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowVar, Var, DataflowVarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode); @@ -445,12 +380,11 @@ class Constant : public LeafExpr { /*! * \brief The constructor * \param data The data of the constant tensor. - * \param struct_info_annotation The struct info of the constant tensor. + * \param ty_annotation The type of the constant tensor. * If not specified, infer it from data. * \param span The source span of the expression. */ - TVM_DLL explicit Constant(runtime::Tensor data, - ffi::Optional struct_info_annotation = std::nullopt, + TVM_DLL explicit Constant(runtime::Tensor data, ffi::Optional ty_annotation = std::nullopt, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Constant, LeafExpr, ConstantNode); @@ -600,26 +534,25 @@ class Binding : public ffi::ObjectRef { }; /*! - * \brief Runtime-match the value to the struct info. + * \brief Runtime-match the value to the type. * * This operation does runtime check, populates the un-defined symbolic shape vars - * and vars in struct_info in first occurance, and insert equality assertions in + * and vars in ty in first occurance, and insert equality assertions in * other cases. */ class MatchCastNode : public BindingNode { public: /*! \brief The input value to match cast. */ Expr value; - /*! \brief The struct info pattern to match to. */ - StructInfo struct_info; + /*! \brief The type pattern to match to. */ + Type ty; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("value", &MatchCastNode::value) // TODO(tqchen): use SEqHashDefNonRecursive after the next pypi tvm-ffi release - .def_ro("struct_info", &MatchCastNode::struct_info, - refl::AttachFieldFlag::SEqHashDefRecursive()); + .def_ro("ty", &MatchCastNode::ty, refl::AttachFieldFlag::SEqHashDefRecursive()); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.MatchCast", MatchCastNode, BindingNode); }; @@ -630,7 +563,7 @@ class MatchCastNode : public BindingNode { */ class MatchCast : public Binding { public: - TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span()); + TVM_DLL explicit MatchCast(Var var, Expr value, Type ty, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchCast, Binding, MatchCastNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode); @@ -818,7 +751,7 @@ class FunctionNode : public BaseFuncNode { /*! \brief The body of the function. */ SeqExpr body; /*! \brief The return type of the function. */ - StructInfo ret_struct_info; + Type ret_ty; /*! \brief Whether the function is annotated as pure or not. */ bool is_pure; @@ -827,7 +760,7 @@ class FunctionNode : public BaseFuncNode { refl::ObjectDef() .def_ro("params", &FunctionNode::params, refl::AttachFieldFlag::SEqHashDefRecursive()) .def_ro("body", &FunctionNode::body) - .def_ro("ret_struct_info", &FunctionNode::ret_struct_info) + .def_ro("ret_ty", &FunctionNode::ret_ty) .def_ro("is_pure", &FunctionNode::is_pure); } @@ -847,8 +780,8 @@ class Function : public BaseFunc { * Relax IR requirement that all scopes be contained in a * SeqExpr. * - * \param ret_struct_info The StructInfo returned by the function. - * If std::nullopt, will be inferred from the StructInfo of the + * \param ret_ty The Type returned by the function. + * If std::nullopt, will be inferred from the Type of the * function's body. * * \param is_pure The purity of the function. @@ -858,17 +791,15 @@ class Function : public BaseFunc { * * \param span The source span of the expression. */ - TVM_DLL explicit Function(ffi::Array params, Expr body, - ffi::Optional ret_struct_info, bool is_pure = true, - DictAttrs attrs = DictAttrs(), Span span = Span()); + TVM_DLL explicit Function(ffi::Array params, Expr body, ffi::Optional ret_ty, + bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); /*! * \brief Mimics the constructor but without body Expr. - * \note ret_struct_info is required, since it can not deduced by the body. + * \note ret_ty is required, since it can not deduced by the body. */ - TVM_DLL static Function CreateEmpty(ffi::Array params, StructInfo ret_struct_info, - bool is_pure = true, DictAttrs attrs = DictAttrs(), - Span span = Span()); + TVM_DLL static Function CreateEmpty(ffi::Array params, Type ret_ty, bool is_pure = true, + DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); @@ -921,7 +852,7 @@ class ExternFuncNode : public BaseFuncNode { class ExternFunc : public BaseFunc { public: TVM_DLL ExternFunc(ffi::String global_symbol, Span span = Span()); - TVM_DLL ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span = Span()); + TVM_DLL ExternFunc(ffi::String global_symbol, Type ty, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternFunc, BaseFunc, ExternFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); @@ -933,7 +864,7 @@ class ExternFunc : public BaseFunc { * \return The corresonding shape. * * \note This function requires expr to be normalized. - * The function will report an error if expr's StructInfo is not TensorStructInfo. + * The function will report an error if expr's Type is not TensorType. * It will try to return symbolic function when possible. If the tensor do not * have a compile-time symbolic shape, the function will then choose to return * Call(relax.op.shape_of, [expr]). diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 9189d5d1fe42..586ba5b92744 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -28,8 +28,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -256,20 +256,20 @@ class ExprVisitor : public ExprFunctor { virtual void VisitVarDef(const Var& var); /*! - * \brief Visit struct_info may recursively contain Expr/PrimExpr. + * \brief Visit ty may recursively contain Expr/PrimExpr. * - * By default, this function recurse into struct info such as - * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr - * accordingly. It does not recurse into FunctionStructInfo as it does + * By default, this function recurse into type such as + * TensorType and ShapeType and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionType as it does * not contain Expr defined in the current scope. * * Pass writers can overload this function to change to other behaviors. - * For example, if we are not interested in Expr in StructInfo, we can + * For example, if we are not interested in Expr in Type, we can * override this function by a no-op. * - * \param struct_info Input struct info field. + * \param ty Input type field. */ - virtual void VisitExprDepStructInfoField(const StructInfo& struct_info); + virtual void VisitExprDepTypeField(const Type& ty); // specific leaf level visitor functions virtual void VisitVarDef_(const VarNode* var); @@ -285,29 +285,29 @@ class ExprVisitor : public ExprFunctor { // initialize the vtable. static VisitBindingVTable InitVisitBindingVTable(); /*! - * \brief Private internal struct info field visitor. + * \brief Private internal type field visitor. * - * Support default visiting of struct info field and recursive into + * Support default visiting of type field and recursive into * their Expr fields. * * We use component instead of sub-classing so there can be other - * joint inheritance between ExprVisitor and StructInfoVisitor. + * joint inheritance between ExprVisitor and TypeVisitor. */ - class DefaultStructInfoFieldVisitor : public StructInfoVisitor { + class DefaultTypeFieldVisitor : public TypeVisitor { public: - explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent); + explicit DefaultTypeFieldVisitor(ExprVisitor* parent); - // Override defaults in struct info visitor. - void VisitStructInfoExprField(const Expr& expr) final; - void VisitStructInfoExprField(const PrimExpr& expr) final; - void VisitStructInfo_(const FuncStructInfoNode* op) final; + // Override defaults in type visitor. + void VisitTypeExprField(const Expr& expr) final; + void VisitTypeExprField(const PrimExpr& expr) final; + void VisitType_(const FuncTypeNode* op) final; private: ExprVisitor* parent_; }; // This visitor is not visible to child classes and only // used to supported default visiting behavior. - DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this}; + DefaultTypeFieldVisitor default_tyfield_visitor_{this}; }; void PostOrderVisit(const Expr& node, std::function fvisit); @@ -315,7 +315,7 @@ void PostOrderVisit(const Expr& node, std::function fvisit); /*! * \brief A mutator works in unnormalized form. * - * ExprMutatorBase expects input AST to be in the unnormalized form, i.e., struct_info_ + * ExprMutatorBase expects input AST to be in the unnormalized form, i.e., ty * of expressions can be nullptr, and the expressions may nest(and as a result the AST is not in * ANF). */ @@ -355,34 +355,36 @@ class ExprMutatorBase : public ExprFunctor { virtual PrimExpr VisitPrimExpr(const PrimExpr& expr); /*! - * \brief Visit struct_info that may recursively contain Expr/PrimExpr. + * \brief Visit ty that may recursively contain Expr/PrimExpr. * - * By default, this function recurse into struct info such as - * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr - * accordingly. It does not recurse into FunctionStructInfo as it does + * By default, this function recurse into type such as + * TensorType and ShapeType and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionType as it does * not contain Expr defined in the current scope. * * Pass writers can overload this function to change to other behaviors. - * For example, if in Expr in StructInfo won't change, we can + * For example, if in Expr in Type won't change, we can * override this function by an identity function. * - * \param struct_info Input struct info field. - * \return The updated struct info. + * \param ty Input type field. + * \return The updated type. */ - virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info); + virtual Type VisitExprDepTypeField(const Type& ty); protected: /*! - * \brief Check whether VisitExprDepStructInfoField change struct_info. - * \return Whether struct info changed. + * \brief Check whether VisitExprDepTypeField change ty. + * \return Whether type changed. * \note This function is used by mutator implementations to check if - * previous Expr update will trigger a change in struct_info. + * previous Expr update will trigger a change in ty. * If change is detected, the implementation can generate a fresh - * node without struct_info, and trigger normalizer to re-derive. + * node without ty, and trigger normalizer to re-derive. */ - bool VisitAndCheckStructInfoFieldUnchanged(const ffi::ObjectRef& struct_info) { - if (const StructInfoNode* sinfo = struct_info.as()) { - return this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)).same_as(struct_info); + bool VisitAndCheckTypeFieldUnchanged(const ffi::ObjectRef& ty) { + if (const DependentTypeNode* ty_node = ty.as()) { + return this->VisitExprDepTypeField(ffi::GetRef(ty_node)).same_as(ty); + } else if (const TupleTypeNode* ty_node = ty.as()) { + return this->VisitExprDepTypeField(ffi::GetRef(ty_node)).same_as(ty); } else { return true; } @@ -390,34 +392,34 @@ class ExprMutatorBase : public ExprFunctor { private: /*! - * \brief Private internal struct info field visitor to support - * Default visiting of struct info field and recursive into their Expr fields. + * \brief Private internal type field visitor to support + * Default visiting of type field and recursive into their Expr fields. * * We use component instead of sub-classing so there can be other - * joint inheritance between ExprMutator and StructInfoMutator. + * joint inheritance between ExprMutator and TypeMutator. */ - class DefaultStructInfoFieldMutator : public StructInfoMutator { + class DefaultTypeFieldMutator : public TypeMutator { public: - explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent); + explicit DefaultTypeFieldMutator(ExprMutatorBase* parent); - // Override defaults in struct info visitor. - Expr VisitStructInfoExprField(const Expr& expr) final; - PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final; - StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final; + // Override defaults in type visitor. + Expr VisitTypeExprField(const Expr& expr) final; + PrimExpr VisitTypeExprField(const PrimExpr& expr) final; + Type VisitType_(const FuncTypeNode* op) final; private: ExprMutatorBase* parent_; }; // This visitor is not visible to child classes and only // used to supported default visiting behavior. - DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this}; + DefaultTypeFieldMutator default_tyfield_mutator_{this}; }; /*! * \brief A mutator works in normal form. * * ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no - * nesting and hence the AST is in ANF), and all struct_info_ of expressions are + * nesting and hence the AST is in ANF), and all ty of expressions are * available. */ class ExprMutator : public ExprMutatorBase { @@ -544,13 +546,13 @@ class ExprMutator : public ExprMutatorBase { } /*! - * \brief Create a new var with specified struct_info if the original var's shape or type does - * not match with the specified ones. + * \brief Create a new var with specified type if the original var's shape or type does not + * match with the specified ones. * \param var The var to be updated. - * \param struct_info The struct info to be updated. - * \return The var filled with struct_info + * \param ty The type to be updated. + * \return The var filled with type information. */ - Var WithStructInfo(Var var, StructInfo struct_info); + Var WithType(Var var, Type ty); /*! \brief Internal block builder to emit bindings during rewriting. */ BlockBuilder builder_; diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 4b11e9d2b043..400f73d1a2f4 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include @@ -277,50 +277,50 @@ NestedMsg MapToNestedMsg(Expr expr, FType fmapleaf) { } /*! - * \brief Map structinfo with possible nested-sinfo to nested message. + * \brief Map structinfo with possible nested-ty to nested message. * - * This function will unpack recursive sinfo and run fmapleaf for each leaf, + * This function will unpack recursive ty and run fmapleaf for each leaf, * then recursively combines the results together into a NestedMsg. * * The nesting structure will corresponds to the tuple structure. * - * \param sinfo The input struct info. - * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmap(StructInfo)` + * \param ty The input type. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmap(Type)` * \tparam T the content type of nested msg * \tparam FType The mapping function type */ template -NestedMsg MapToNestedMsg(StructInfo sinfo, FType fmapleaf) { - if (auto* tuple = sinfo.as()) { +NestedMsg MapToNestedMsg(Type ty, FType fmapleaf) { + if (auto* tuple = ty.as()) { ffi::Array> res; res.reserve(tuple->fields.size()); - for (StructInfo x : tuple->fields) { + for (Type x : tuple->fields) { res.push_back(MapToNestedMsg(x, fmapleaf)); } return res; } else { - return fmapleaf(sinfo); + return fmapleaf(ty); } } /*! * \brief Map expr with possible nested-tuple to nested message. * - * This function will unpack recursive expr by its struct info and + * This function will unpack recursive expr by its type and * run fmapleaf for each leaf, then recursively combines the results * together into a NestedMsg. * - * The nesting structure will corresponds to the struct info of expr. + * The nesting structure will corresponds to the type of expr. * - * \param expr The input expression which should have struct info. + * \param expr The input expression which should have type. * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmapleaf(Expr)` * \tparam T the content type of nested msg * \tparam FType The mapping function type */ template -NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { - auto sinfo = GetStructInfo(expr); - if (auto* tuple = sinfo.as()) { +NestedMsg MapToNestedMsgByType(Expr expr, FType fmapleaf) { + auto ty = GetType(expr); + if (auto* tuple = ty.as()) { ffi::Array> res; res.reserve(tuple->fields.size()); for (size_t i = 0; i < tuple->fields.size(); ++i) { @@ -330,7 +330,7 @@ NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { } else { field = TupleGetItem(expr, i); } - res.push_back(MapToNestedMsgBySInfo(field, fmapleaf)); + res.push_back(MapToNestedMsgByType(field, fmapleaf)); } return res; } else { @@ -520,8 +520,8 @@ void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { */ template Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftransleaf) { - StructInfo sinfo = GetStructInfo(expr); - if (const auto* tuple = sinfo.as()) { + Type ty = GetType(expr); + if (const auto* tuple = ty.as()) { std::array>, N> msg_arrays; for (size_t i = 0; i < N; ++i) { TVM_FFI_ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; @@ -554,33 +554,32 @@ Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftran } /*! - * \brief Recursively transform the tuple structure in sinfo and msgs along with it. + * \brief Recursively transform the tuple structure in ty and msgs along with it. * - * This function will call ftransleaf for each leaf sinfo in sinfo. + * This function will call ftransleaf for each leaf ty in ty. * This function will throw an error if the nesting structure in msg does not - * match the tuple nesting structure in sinfo. + * match the tuple nesting structure in ty. * - * \param sinfo The input sinfo to be transform.  + * \param ty The input ty to be transform.  * \param msgs The input messages to guide the transformation. - * \param ftransleaf with signature ftransleaf(StructInfo, ffi::Array>)->StructInfo + * \param ftransleaf with signature ftransleaf(Type, ffi::Array>)->Type * \tparam T the content type of nested msg * \tparam N the number of messages * \tparam FType The visit function type. */ template -StructInfo TransformTupleLeaf(StructInfo sinfo, std::array, N> msgs, - FType ftransleaf) { - if (const auto* tuple = sinfo.as()) { +Type TransformTupleLeaf(Type ty, std::array, N> msgs, FType ftransleaf) { + if (const auto* tuple = ty.as()) { std::array>, N> msg_arrays; for (size_t i = 0; i < N; ++i) { TVM_FFI_ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; msg_arrays[i] = msgs[i].NestedArray(); } bool same = true; - ffi::Array fields; + ffi::Array fields; fields.reserve(tuple->fields.size()); for (size_t i = 0; i < tuple->fields.size(); ++i) { - StructInfo field = tuple->fields[i]; + Type field = tuple->fields[i]; std::array, N> sub_msgs; for (size_t j = 0; j < N; ++j) { sub_msgs[j] = msg_arrays[j][i]; @@ -588,12 +587,12 @@ StructInfo TransformTupleLeaf(StructInfo sinfo, std::array, N> msgs fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf)); same &= (fields.back().same_as(field)); } - return same ? sinfo : TupleStructInfo(fields); + return same ? ty : TupleType(fields); } else { for (const auto& msg : msgs) { TVM_FFI_ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; } - return ftransleaf(sinfo, msgs); + return ftransleaf(ty, msgs); } } diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 1fd9b45c323c..034b0dd2fe27 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -24,8 +24,9 @@ #ifndef TVM_RELAX_OP_ATTR_TYPES_H_ #define TVM_RELAX_OP_ATTR_TYPES_H_ +#include #include -#include +#include #include namespace tvm { @@ -60,12 +61,12 @@ enum OpPatternKind { using FCallPacked = ffi::String; /*! - * \brief Infer output struct info given the call + * \brief Infer output type given the call * * \param call The call expression to be derived. * \param ctx The builder context. */ -using FInferStructInfo = ffi::TypedFunction; +using FInferType = ffi::TypedFunction; /*! * \brief The function type of a normalization function. diff --git a/include/tvm/relax/script/builder/frame.h b/include/tvm/relax/script/builder/frame.h index ab87aaf778b8..7a20ceed1a53 100644 --- a/include/tvm/relax/script/builder/frame.h +++ b/include/tvm/relax/script/builder/frame.h @@ -100,15 +100,15 @@ class FunctionFrameNode : public SeqExprFrameNode { /*! \brief The function params. */ ffi::Array params; /*! - * \brief The function return struct info. + * \brief The function return type. * \note Usually the function return type can be deduced by the function body. * But we can use this field to specify a more "accurate" return type. - * i.e. If the `ret_struct_info` is None, try to use the deduced type from body - * If the `ret_struct_info` is not None, we can still take body.struct_info - * if we ret_struct_info is base of body.struct_info. If not, we will - * take the specified `ret_struct_info`. + * i.e. If the `ret_ty` is None, try to use the deduced type from body + * If the `ret_ty` is not None, we can still take body.ty + * if we ret_ty is base of body.ty. If not, we will + * take the specified `ret_ty`. */ - ffi::Optional ret_struct_info; + ffi::Optional ret_ty; /*! \brief Whether the function is annotated as pure */ ffi::Optional is_pure; /*! \brief Whether the function is annotated as private */ @@ -123,7 +123,7 @@ class FunctionFrameNode : public SeqExprFrameNode { refl::ObjectDef() .def_ro("name", &FunctionFrameNode::name) .def_ro("params", &FunctionFrameNode::params) - .def_ro("ret_struct_info", &FunctionFrameNode::ret_struct_info) + .def_ro("ret_ty", &FunctionFrameNode::ret_ty) .def_ro("is_pure", &FunctionFrameNode::is_pure) .def_ro("attrs", &FunctionFrameNode::attrs); // `binding_blocks` and `output` are inherited from SeqExprFrameNode. diff --git a/include/tvm/relax/script/builder/ir.h b/include/tvm/relax/script/builder/ir.h index 48318c891859..6a5df22245d6 100644 --- a/include/tvm/relax/script/builder/ir.h +++ b/include/tvm/relax/script/builder/ir.h @@ -21,7 +21,7 @@ #include #include -#include +#include #include namespace tvm { @@ -42,10 +42,10 @@ TVM_DLL FunctionFrame Function(bool is_pure, bool is_private); /*! * \brief Add a parameter to the last function frame. * \param name The name of the parameter. - * \param struct_info The struct_info of the parameter. + * \param ty The ty of the parameter. * \return The created function parameter var. */ -TVM_DLL tvm::relax::Var Arg(const ffi::String& name, const tvm::relax::StructInfo& struct_info); +TVM_DLL tvm::relax::Var Arg(const ffi::String& name, const tvm::Type& ty); /*! * \brief Specify the name of the last function frame. @@ -60,10 +60,10 @@ TVM_DLL void FuncName(const ffi::String& name); TVM_DLL void FuncAttrs(ffi::Map attrs); /*! - * \brief Specify the return struct info of the last function frame. - * \param ret_sinfo The return struct info. + * \brief Specify the return type of the last function frame. + * \param ret_ty The return type. */ -TVM_DLL void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo); +TVM_DLL void FuncRetType(const tvm::Type& ret_ty); /*! * \brief Specify the return value of the last function frame. @@ -96,21 +96,19 @@ TVM_DLL void DataflowBlockOutput(const ffi::Array& vars); /*! * \brief Emit a binding to the last binding block frame. * \param value The right side value of the bindings to be emitted. - * \param annotate_struct_info The optional struct info annotation for the emitted value. + * \param annotate_ty The optional type annotation for the emitted value. * \return The left side var of the emitted binding. */ -TVM_DLL tvm::relax::Var Emit( - const tvm::relax::Expr& value, - const ffi::Optional& annotate_struct_info = std::nullopt); +TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value, + const ffi::Optional& annotate_ty = std::nullopt); /*! * \brief Emit a match_cast binding to the last binding block frame. * \param value The value of the MatchCast to be emitted. - * \param struct_info The struct info of the MatchCast to be emitted. + * \param ty The type of the MatchCast to be emitted. * \return The left side var of the emitted binding. */ -TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, - const tvm::relax::StructInfo& struct_info); +TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, const tvm::Type& ty); /*! * \brief Emit a binding to the last binding block frame. diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h deleted file mode 100644 index 049469027ba2..000000000000 --- a/include/tvm/relax/struct_info.h +++ /dev/null @@ -1,424 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_RELAX_STRUCT_INFO_H_ -#define TVM_RELAX_STRUCT_INFO_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace tvm { -namespace relax { - -/*! - * \brief Opaque object. - */ -class ObjectStructInfoNode : public StructInfoNode { - public: - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ObjectStructInfo", ObjectStructInfoNode, StructInfoNode); -}; - -/*! - * \brief Managed reference to ObjectStructInfoNode. - * \sa ObjectStructInfoNode - */ -class ObjectStructInfo : public StructInfo { - public: - TVM_DLL ObjectStructInfo(Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ObjectStructInfo, StructInfo, ObjectStructInfoNode); -}; - -/*! - * \brief Primitive value. - */ -class PrimStructInfoNode : public StructInfoNode { - public: - /*! \brief Underlying primitive value, if known */ - ffi::Optional value; - - /*! \brief Underlying data type of the primitive value */ - DataType dtype; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("value", &PrimStructInfoNode::value) - .def_ro("dtype", &PrimStructInfoNode::dtype); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PrimStructInfo", PrimStructInfoNode, StructInfoNode); -}; - -/*! - * \brief Managed reference to PrimStructInfoNode. - * \sa PrimStructInfoNode - */ -class PrimStructInfo : public StructInfo { - public: - /* Construct a PrimStructInfo with a known dtype, but unknown value */ - TVM_DLL PrimStructInfo(DataType dtype, Span span = Span()); - - /* Construct a PrimStructInfo with a known value */ - TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimStructInfo, StructInfo, PrimStructInfoNode); -}; - -/*! - * \brief StructInfo of shape value. - */ -class ShapeStructInfoNode : public StructInfoNode { - public: - /*! \brief optionally stores the symbolic value patterns of the shape */ - ffi::Optional> values; - /*! - * \brief The number of dimension of the shape, can be unknown. - * \sa kUnknownNDim - */ - int ndim; - - /*! \return Whether the struct info contains unknown ndim. */ - bool IsUnknownNdim() const { return ndim == kUnknownNDim; } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("values", &ShapeStructInfoNode::values) - .def_ro("ndim", &ShapeStructInfoNode::ndim); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ShapeStructInfo", ShapeStructInfoNode, StructInfoNode); -}; - -/*! - * \brief Managed reference to ShapeStructInfoNode. - * \sa ShapeStructInfoNode - */ -class ShapeStructInfo : public StructInfo { - public: - /*! - * \brief Construction with known symbolic shape patterns - * \param values The symbolic shape values - * \param span The span of the AST. - */ - TVM_DLL ShapeStructInfo(ffi::Array values, Span span = Span()); - /*! - * \brief Construction with known unknown symbolic shape patterns. - * \param ndim Number of dimensions -- can be kUnknownNDim - * \param span The span of the AST. - */ - TVM_DLL ShapeStructInfo(int ndim, Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ShapeStructInfo, StructInfo, ShapeStructInfoNode); -}; - -/*! - * \brief StructInfo of Tensor. - */ -class TensorStructInfoNode : public StructInfoNode { - public: - /*! - * \brief optionally store the shape expression of the tensor. - * \note shape must be normalized: it can only be std::nullopt or ShapeExpr or Var. - */ - ffi::Optional shape; - /*! \brief The virtual device, indicates where the tensor - * is expected to be executed. - */ - ffi::Optional vdevice; - /*! \brief The content data type, use void to denote the dtype is unknown. */ - DataType dtype; - /*! - * \brief The number of dimension of the tensor, can be unknown. - * \sa kUnknownNDim - */ - int ndim; - - /*! \return Whether the struct info contains unknown ndim. */ - bool IsUnknownNdim() const { return ndim == kUnknownNDim; } - - /*! \return Whether the struct info contains unknown dtype. */ - bool IsUnknownDtype() const { return dtype.is_void(); } - - /*! \return Shape if it is known. */ - ffi::Optional> GetShape() const { - if (!shape.defined()) return {}; - ShapeStructInfo shape_sinfo = Downcast(this->shape.value()->struct_info_); - return shape_sinfo->values; - } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("shape", &TensorStructInfoNode::shape) - .def_ro("dtype", &TensorStructInfoNode::dtype) - .def_ro("vdevice", &TensorStructInfoNode::vdevice) - .def_ro("ndim", &TensorStructInfoNode::ndim); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TensorStructInfo", TensorStructInfoNode, StructInfoNode); -}; - -/*! - * \brief Managed reference to TensorStructInfoNode. - * \sa TensorStructInfoNode - */ -class TensorStructInfo : public StructInfo { - public: - /*! - * \brief Construction with a known shape expression. - * \param shape The shape of the tensor. - * \param dtype The data type of tensor's elements. - * \param vdevice The virtual device. - * \param span The span of the AST. - * - * \note shape must already be normalized. - */ - TVM_DLL TensorStructInfo(Expr shape, DataType dtype, - ffi::Optional vdevice = std::nullopt, Span span = Span()); - - /*! - * \brief Construction with an unknown shape expression. - * \param dtype The data type of tensor's elements. - * \param ndim The number of dimensions - * \param vdevice The virtual device. - * \param span The span of the AST. - */ - TVM_DLL TensorStructInfo(DataType dtype, int ndim, ffi::Optional vdevice = std::nullopt, - Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorStructInfo, StructInfo, TensorStructInfoNode); -}; - -/*! - * \brief StructInfo of Tuple. - */ -class TupleStructInfoNode : public StructInfoNode { - public: - /*! \brief The struct info of tuple fields. */ - ffi::Array fields; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("fields", &TupleStructInfoNode::fields); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TupleStructInfo", TupleStructInfoNode, StructInfoNode); -}; - -/*! - * \brief Managed reference to TupleStructInfoNode. - * \sa TupleStructInfoNode - */ -class TupleStructInfo : public StructInfo { - public: - /*! - * \brief Constructor - * \param fields Struct info of tuple fields. - * \param span The span of the AST. - */ - TVM_DLL TupleStructInfo(ffi::Array fields, Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TupleStructInfo, StructInfo, TupleStructInfoNode); -}; - -/*! - * \brief custom-defined StructInfo derivation function. - * \param call The call expression to be derived. - * \param ctx The builder context. - * \return The derived struct info of the call. - */ -using StructInfoDeriveFunc = TypedEnvFunc; - -/*! - * \brief Structure information about function. - * - * This data structure contains enough information for us to - * do best-effort structure information deduction. - */ -class FuncStructInfoNode : public StructInfoNode { - public: - /*! - * \brief The parameter struct info of the function. - * \note When params is std::nullopt means the function can take arbitrary number of arguments. - * We define such functions as Opaque function. - */ - ffi::Optional> params; - /*! - * \brief The struct info of the function's return value. - */ - StructInfo ret; - /*! - * \brief Derivation function of opaque functions that may take any number of parameters. - * \note When derive_func is not empty, then params should be std::nullopt, - * ret should be ObjectStructInfo() - */ - ffi::Optional derive_func; - /*! - * \brief Whether the function is pure. - * \note This parameter should be set to true only if the function is pure on all inputs. - * If the function _may_ have visible side effects, set it to false. - */ - bool purity; - - /*! - * \return Whether the func struct info is opaque. - * \note We define a function as opaque we have no constraints on params. - */ - bool IsOpaque() const { return !params.defined(); } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("params", &FuncStructInfoNode::params, refl::AttachFieldFlag::SEqHashDefRecursive()) - .def_ro("ret", &FuncStructInfoNode::ret) - .def_ro("derive_func", &FuncStructInfoNode::derive_func) - .def_ro("purity", &FuncStructInfoNode::purity); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.FuncStructInfo", FuncStructInfoNode, StructInfoNode); -}; - -/*! - * \brief Managed reference to FuncStructInfoNode. - * \sa FuncStructInfoNode - */ -class FuncStructInfo : public StructInfo { - public: - explicit FuncStructInfo(ffi::ObjectPtr data) : StructInfo(ffi::UnsafeInit{}) { - TVM_FFI_ICHECK(data != nullptr); - data_ = std::move(data); - } - /*! - * \brief Constructor from parameter struct info and return value struct info. - * \param params The struct info of function parameters. - * \param ret The return value struct info. - * \param purity The purity of the function (true by default). - * \param span The span of the AST. - * - * \note If the ret contains variables(tirx::Var and relax::Var), they must be deducible from - * params. If you are unsure, you can always erase ret to static. - */ - TVM_DLL FuncStructInfo(ffi::Array params, StructInfo ret, bool purity = true, - Span span = Span()); - - /*! - * \brief Constructing an opaque function struct info using derive_func. - * - * \param derive_func Derivation function. - * \param purity The purity of the function - * (false by default: most external functions are not pure). - * \param span The span of the AST. - * - * \return The FuncStructInfo for opaque packedfunc. - * \note Defaults to an derive func that always return ObjectStructInfo if not specified. - */ - TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false, - Span span = Span()); - - /*! - * \brief Construct an opaque function using from return struct info. - * - * \param ret The struct info of the return value. - * \param purity The purity of the function - * (false by default: most external functions are not pure). - * \param span The span of the AST. - * - * \return The FuncStructInfo for opaque packedfunc. - * \note Defaults to an derive func that always return ObjectStructInfo if not specified. - */ - TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false, - Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FuncStructInfo, StructInfo, FuncStructInfoNode); -}; - -/*! - * \brief Match and check if expr have StructInfo T and return it. - * - * \param expr The input expression. - * \return The result of match. - * \tparam T the underlying structure info type - */ -template -inline ffi::Optional MatchStructInfo(const Expr& expr) { - using TNode = typename T::ContainerType; - if (const TNode* ptr = expr->struct_info_.as()) { - return ffi::GetRef(ptr); - } else { - return std::nullopt; - } -} - -/*! - * \brief Get the structure info of a given expr and try to cast it as const T*. - * - * \param expr The input expression. - * \return The pointer. Returns nullptr if the type does not match - * \tparam T the underlying structure info type - */ -template -inline const T* GetStructInfoAs(const Expr& expr) { - TVM_FFI_ICHECK(expr->struct_info_.defined()) - << "The struct_info is not populated, check if you have normalized the expr"; - return expr->struct_info_.as(); -} - -/*! - * \brief Get the underlying structure info of expr. - * - * \param expr The input expression. - * \return underlying struct info. - */ -inline StructInfo GetStructInfo(const Expr& expr) { - auto* ptr = expr->struct_info_.as(); - TVM_FFI_ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; - return ffi::GetRef(ptr); -} - -/*! - * \brief Whether the expr has void struct info. - * - * \param expr The input expression. - * \return Whether the expr has void struct info. - */ -inline bool HasVoidStructInfo(const Expr& expr) { - auto* ptr = expr->struct_info_.as(); - return ptr != nullptr && ptr->fields.size() == 0; -} - -/*! - * \brief Update the struct info of an Expr. - * \param expr The Expr whose struct info to be updated. - * \param struct_info The struct_info assigned. - * \note We ensure idempotence, that is we can only update the struct_info of an Expr only - * if the original one is nullptr. - */ -TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info); - -} // namespace relax -} // namespace tvm -#endif // TVM_RELAX_STRUCT_INFO_H_ diff --git a/include/tvm/relax/struct_info_functor.h b/include/tvm/relax/struct_info_functor.h deleted file mode 100644 index c4b3ea31a1c1..000000000000 --- a/include/tvm/relax/struct_info_functor.h +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relax/struct_info_functor.h - * \brief Functors and visitors for struct info. - */ -#ifndef TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ -#define TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ - -#include -#include -#include - -#include - -namespace tvm { -namespace relax { - -template -class StructInfoFunctor; - -// functions to be overriden. -#define STRUCT_INFO_FUNCTOR_DEFAULT \ - { \ - return VisitStructInfoDefault_(op, std::forward(args)...); \ - } - -#define TVM_STRUCT_INFO_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitStructInfo_(static_cast(n.get()), std::forward(args)...); \ - }); - -template -class StructInfoFunctor { - private: - using TSelf = StructInfoFunctor; - using FStructInfo = tvm::NodeFunctor; - - public: - /*! \brief the result type of this functor */ - using result_type = R; - /*! \brief virtual destructor */ - virtual ~StructInfoFunctor() {} - /*! - * \brief Same as call. - * \param n The expression node. - * \param args Additional arguments. - * \return The result of the call - */ - R operator()(const StructInfo& n, Args... args) { - return VisitStructInfo(n, std::forward(args)...); - } - /*! - * \brief The functor call. - * \param n The expression node. - * \param args Additional arguments. - * \return The result of the call - */ - virtual R VisitStructInfo(const StructInfo& n, Args... args) { - TVM_FFI_ICHECK(n.defined()); - static FStructInfo vtable = InitVTable(); - return vtable(n, this, std::forward(args)...); - } - // Functions that can be overriden by subclass - virtual R VisitStructInfo_(const ObjectStructInfoNode* op, - Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; - virtual R VisitStructInfo_(const PrimStructInfoNode* op, - Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; - virtual R VisitStructInfo_(const ShapeStructInfoNode* op, - Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; - virtual R VisitStructInfo_(const TensorStructInfoNode* op, - Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; - virtual R VisitStructInfo_(const distributed::DTensorStructInfoNode* op, - Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; - virtual R VisitStructInfo_(const TupleStructInfoNode* op, - Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; - virtual R VisitStructInfo_(const FuncStructInfoNode* op, - Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; - virtual R VisitStructInfoDefault_(const ffi::Object* op, Args...) { - TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey(); - throw; // unreachable, written to stop compiler warning - } - - private: - // initialize the vtable. - static FStructInfo InitVTable() { - FStructInfo vtable; - // Set dispatch - TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ObjectStructInfoNode); - TVM_STRUCT_INFO_FUNCTOR_DISPATCH(PrimStructInfoNode); - TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ShapeStructInfoNode); - TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TensorStructInfoNode); - TVM_STRUCT_INFO_FUNCTOR_DISPATCH(distributed::DTensorStructInfoNode); - TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode); - TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode); - vtable.Finalize(); - return vtable; - } -}; - -#undef TVM_STRUCT_INFO_FUNCTOR_DISPATCH - -/*! - * \brief A struct info visitor. - */ -class TVM_DLL StructInfoVisitor : public StructInfoFunctor { - public: - void VisitStructInfo_(const ObjectStructInfoNode* op) override; - void VisitStructInfo_(const PrimStructInfoNode* op) override; - void VisitStructInfo_(const ShapeStructInfoNode* op) override; - void VisitStructInfo_(const TensorStructInfoNode* op) override; - void VisitStructInfo_(const distributed::DTensorStructInfoNode* op) override; - void VisitStructInfo_(const TupleStructInfoNode* op) override; - void VisitStructInfo_(const FuncStructInfoNode* op) override; - - protected: - // two functions to override when visit expr fields in struct info. - virtual void VisitStructInfoExprField(const Expr& expr) {} - virtual void VisitStructInfoExprField(const PrimExpr& expr) {} -}; - -/*! - * \brief StructInfoMutator that mutates struct info. - */ -class TVM_DLL StructInfoMutator : public StructInfoFunctor { - public: - StructInfo VisitStructInfo_(const ObjectStructInfoNode* op) override; - StructInfo VisitStructInfo_(const PrimStructInfoNode* op) override; - StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) override; - StructInfo VisitStructInfo_(const TensorStructInfoNode* op) override; - StructInfo VisitStructInfo_(const distributed::DTensorStructInfoNode* op) override; - StructInfo VisitStructInfo_(const TupleStructInfoNode* op) override; - StructInfo VisitStructInfo_(const FuncStructInfoNode* op) override; - - protected: - // two functions to override when visit expr fields in struct info. - virtual Expr VisitStructInfoExprField(const Expr& expr) { return expr; } - virtual PrimExpr VisitStructInfoExprField(const PrimExpr& expr) { return expr; } -}; - -} // namespace relax -} // namespace tvm -#endif // TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 493e51ef50f8..d0d0d1bb5441 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -153,7 +153,7 @@ TVM_DLL Pass AttachGlobalSymbol(); /*! * \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the - * struct_info_ of expressions. + * ty of expressions. * * \return The Pass. */ diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index b70a2756b71f..2760c65c5f9f 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -19,7 +19,7 @@ /*! * \file tvm/relax/type.h - * \brief Relax Types. + * \brief Relax types, including the richer dependent Relax type nodes. */ #ifndef TVM_RELAX_TYPE_H_ #define TVM_RELAX_TYPE_H_ @@ -27,122 +27,438 @@ #include #include #include +#include #include +#include #include #include #include +#include namespace tvm { namespace relax { +using Expr = RelaxExpr; +using ExprNode = RelaxExprNode; + +class BlockBuilder; +class Call; + /*! \brief Indicates the number of dimensions of a tensor is unknown at compile time. */ static constexpr int kUnknownNDim = -1; -class ShapeTypeNode : public TypeNode { +using tvm::TupleType; +using tvm::TupleTypeNode; + +class PackedFuncTypeNode : public TypeNode { + public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PackedFuncType", PackedFuncTypeNode, TypeNode); +}; + +class PackedFuncType : public Type { + public: + TVM_DLL PackedFuncType(Span span = Span()); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PackedFuncType, Type, PackedFuncTypeNode); +}; + +/*! + * \brief Base type of all structure information. + * + * Type stores possible structure information deduced during compile-time. + * It encapsulates both static type and runtime information such as shape. + * + * Type of each non-primitive Expr can be deduced during compilation in a + * "best-effort" manner. + * + * When ty appears in function parameter and return signatures, it + * implies a runtime check that matches the structure information with the value. + * + * When it appears in Expr, it follows "assume-semantics", which means the + * compiler will take the deduced information as it is and only do best effort + * proofs and checks. + * + * Each type can be uniquely erased to a static-type. The compiler will + * still compile the code, with less information, when we erase to the static + * type. + * + * If a Type contains an Expr field, then that field must already be + * normalized through NormalizeArg. This invariant is checked in constructors + * and simplifies assumptions during type deduction. + */ +class DependentTypeNode : public TypeNode { + public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + static constexpr const uint32_t _type_child_slots = 6; + TVM_FFI_DECLARE_OBJECT_INFO("relax.DependentType", DependentTypeNode, TypeNode); +}; + +/*! + * \brief Opaque object. + */ +class ObjectTypeNode : public DependentTypeNode { + public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ObjectType", ObjectTypeNode, DependentTypeNode); +}; + +/*! + * \brief Managed reference to ObjectTypeNode. + * \sa ObjectTypeNode + */ +class ObjectType : public Type { + public: + TVM_DLL ObjectType(Span span = Span()); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ObjectType, Type, ObjectTypeNode); +}; + +/*! + * \brief Primitive value. + */ +class PrimTypeNode : public DependentTypeNode { public: - /*! \brief size of the shape. */ + /*! \brief Underlying primitive value, if known */ + ffi::Optional value; + + /*! \brief Underlying data type of the primitive value */ + DataType dtype; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("value", &PrimTypeNode::value) + .def_ro("dtype", &PrimTypeNode::dtype); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PrimType", PrimTypeNode, DependentTypeNode); +}; + +/*! + * \brief Managed reference to PrimTypeNode. + * \sa PrimTypeNode + */ +class PrimType : public Type { + public: + /* Construct a PrimType with a known dtype, but unknown value */ + TVM_DLL PrimType(DataType dtype, Span span = Span()); + + /* Construct a PrimType with a known value */ + TVM_DLL PrimType(PrimExpr value, Span span = Span()); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimType, Type, PrimTypeNode); +}; + +/*! + * \brief Type of shape value. + */ +class ShapeTypeNode : public DependentTypeNode { + public: + /*! \brief optionally stores the symbolic value patterns of the shape */ + ffi::Optional> values; + /*! + * \brief The number of dimension of the shape, can be unknown. + * \sa kUnknownNDim + */ int ndim; + /*! \return Whether the type contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("ndim", &ShapeTypeNode::ndim); + refl::ObjectDef() + .def_ro("values", &ShapeTypeNode::values) + .def_ro("ndim", &ShapeTypeNode::ndim); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ShapeType", ShapeTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ShapeType", ShapeTypeNode, DependentTypeNode); }; +/*! + * \brief Managed reference to ShapeTypeNode. + * \sa ShapeTypeNode + */ class ShapeType : public Type { public: + /*! + * \brief Construction with known symbolic shape patterns + * \param values The symbolic shape values + * \param span The span of the AST. + */ + TVM_DLL ShapeType(ffi::Array values, Span span = Span()); + /*! + * \brief Construction with known unknown symbolic shape patterns. + * \param ndim Number of dimensions -- can be kUnknownNDim + * \param span The span of the AST. + */ TVM_DLL ShapeType(int ndim, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ShapeType, Type, ShapeTypeNode); }; /*! - * \brief Dynamic version of TensorType - * - * Use relax::TensorStructInfo for more detailed (possibly dynamic) shape constrains + * \brief Type of Tensor. */ -class TensorTypeNode : public TypeNode { +class TensorTypeNode : public DependentTypeNode { public: /*! - * \brief The number of dimensions of the tensor, use -1 to denote tensor with unknown number of - * dimensions. + * \brief optionally store the shape expression of the tensor. + * \note shape must be normalized: it can only be std::nullopt or ShapeExpr or Var. */ - int ndim; + ffi::Optional shape; + /*! \brief The virtual device, indicates where the tensor + * is expected to be executed. + */ + ffi::Optional vdevice; /*! \brief The content data type, use void to denote the dtype is unknown. */ DataType dtype; + /*! + * \brief The number of dimension of the tensor, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the type contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + /*! \return Whether the type contains unknown dtype. */ + bool IsUnknownDtype() const { return dtype.is_void(); } + + /*! \return Shape if it is known. */ + ffi::Optional> GetShape() const { + if (!shape.defined()) return {}; + ShapeType shape_ty = Downcast(this->shape.value()->ty); + return shape_ty->values; + } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("ndim", &TensorTypeNode::ndim) - .def_ro("dtype", &TensorTypeNode::dtype); + .def_ro("shape", &TensorTypeNode::shape) + .def_ro("dtype", &TensorTypeNode::dtype) + .def_ro("vdevice", &TensorTypeNode::vdevice) + .def_ro("ndim", &TensorTypeNode::ndim); } - - inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; } - - inline bool IsUnknownDtype() const { return dtype.is_void(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.DynTensorType", TensorTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TensorType", TensorTypeNode, DependentTypeNode); }; /*! * \brief Managed reference to TensorTypeNode. - * \sa TensorTypeNode. + * \sa TensorTypeNode */ class TensorType : public Type { public: /*! - * \brief Constructor. - * \param ndim The number of dimensions of the tensor. - * \param dtype The runtime dtype of the tensor's elements. - * \param span The span. + * \brief Construction with a known shape expression. + * \param shape The shape of the tensor. + * \param dtype The data type of tensor's elements. + * \param vdevice The virtual device. + * \param span The span of the AST. + * + * \note shape must already be normalized. */ - TVM_DLL TensorType(int ndim, DataType dtype, Span span = Span()); + TVM_DLL TensorType(Expr shape, DataType dtype, ffi::Optional vdevice = std::nullopt, + Span span = Span()); /*! - * \brief Create a TensorType with unknown ndim. + * \brief Construction with an unknown shape expression. + * \param dtype The data type of tensor's elements. + * \param ndim The number of dimensions + * \param vdevice The virtual device. + * \param span The span of the AST. */ - TVM_DLL static TensorType CreateUnknownNDim(DataType dtype, Span span = Span()); + TVM_DLL TensorType(DataType dtype, int ndim, ffi::Optional vdevice = std::nullopt, + Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorType, Type, TensorTypeNode); }; -using TensorTypeNode = TensorTypeNode; -using TensorType = TensorType; +/*! + * \brief custom-defined Type derivation function. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \return The derived type of the call. + */ +using TypeDeriveFunc = TypedEnvFunc; -class ObjectTypeNode : public TypeNode { +/*! + * \brief Structure information about function. + * + * This data structure contains enough information for us to do best-effort + * structure information deduction. + */ +class FuncTypeNode : public DependentTypeNode { public: + /*! + * \brief The parameter type of the function. + * \note When params is std::nullopt means the function can take arbitrary number of arguments. + * We define such functions as Opaque function. + */ + ffi::Optional> params; + /*! + * \brief The type of the function's return value. + */ + Type ret; + /*! + * \brief Derivation function of opaque functions that may take any number of parameters. + * \note When derive_func is not empty, then params should be std::nullopt, + * ret should be ObjectType() + */ + ffi::Optional derive_func; + /*! + * \brief Whether the function is pure. + * \note This parameter should be set to true only if the function is pure on all inputs. + * If the function _may_ have visible side effects, set it to false. + */ + bool purity; + + /*! + * \return Whether the func type is opaque. + * \note We define a function as opaque we have no constraints on params. + */ + bool IsOpaque() const { return !params.defined(); } + static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); + refl::ObjectDef() + .def_ro("params", &FuncTypeNode::params, refl::AttachFieldFlag::SEqHashDefRecursive()) + .def_ro("ret", &FuncTypeNode::ret) + .def_ro("derive_func", &FuncTypeNode::derive_func) + .def_ro("purity", &FuncTypeNode::purity); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ObjectType", ObjectTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.FuncType", FuncTypeNode, DependentTypeNode); }; -class ObjectType : public Type { +/*! + * \brief Managed reference to FuncTypeNode. + * \sa FuncTypeNode + */ +class FuncType : public Type { public: - TVM_DLL ObjectType(Span span = Span()); + explicit FuncType(ffi::ObjectPtr data) : Type(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + /*! + * \brief Constructor from parameter type and return value type. + * \param params The type of function parameters. + * \param ret The return value type. + * \param purity The purity of the function (true by default). + * \param span The span of the AST. + * + * \note If the ret contains variables(tirx::Var and relax::Var), they must be deducible from + * params. If you are unsure, you can always erase ret to static. + */ + TVM_DLL FuncType(ffi::Array params, Type ret, bool purity = true, Span span = Span()); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ObjectType, Type, ObjectTypeNode); + /*! + * \brief Constructing an opaque function type using derive_func. + * + * \param derive_func Derivation function. + * \param purity The purity of the function + * (false by default: most external functions are not pure). + * \param span The span of the AST. + * + * \return The FuncType for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectType if not specified. + */ + TVM_DLL static FuncType OpaqueFunc(TypeDeriveFunc derive_func, bool purity = false, + Span span = Span()); + + /*! + * \brief Construct an opaque function using from return type. + * + * \param ret The type of the return value. + * \param purity The purity of the function + * (false by default: most external functions are not pure). + * \param span The span of the AST. + * + * \return The FuncType for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectType if not specified. + */ + TVM_DLL static FuncType OpaqueFunc(Type ret = ObjectType(), bool purity = false, + Span span = Span()); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FuncType, Type, FuncTypeNode); }; -class PackedFuncTypeNode : public TypeNode { - public: - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); +/*! + * \brief Match and check if expr has Relax type T and return it. + * + * \param expr The input expression. + * \return The result of match. + * \tparam T the underlying Relax type + */ +template +inline ffi::Optional MatchType(const Expr& expr) { + using TNode = typename T::ContainerType; + if (const TNode* ptr = expr->ty.as()) { + return ffi::GetRef(ptr); + } else { + return std::nullopt; } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PackedFuncType", PackedFuncTypeNode, TypeNode); -}; +} -class PackedFuncType : public Type { - public: - TVM_DLL PackedFuncType(Span span = Span()); +/*! + * \brief Get the type of a given expr and try to cast it as const T*. + * + * \param expr The input expression. + * \return The pointer. Returns nullptr if the type does not match. + * \tparam T the underlying Relax type node + */ +template +inline const T* GetTypeAs(const Expr& expr) { + TVM_FFI_ICHECK(expr->ty.defined()) + << "The type is not populated, check if you have normalized the expr"; + return expr->ty.as(); +} - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PackedFuncType, Type, PackedFuncTypeNode); -}; +/*! + * \brief Get the underlying Relax type of expr. + * + * \param expr The input expression. + * \return underlying Relax type. + */ +inline Type GetType(const Expr& expr) { + TVM_FFI_ICHECK(expr->ty.defined()) + << "The type is not populated, check if you have normalized the expr"; + return expr->ty; +} + +/*! + * \brief Whether the expr has void type. + * + * \param expr The input expression. + * \return Whether the expr has void type. + */ +inline bool HasVoidType(const Expr& expr) { + auto* ptr = expr->ty.as(); + return ptr != nullptr && ptr->fields.size() == 0; +} + +/*! + * \brief Update the type of an Expr. + * \param expr The Expr whose type to be updated. + * \param ty The type assigned. + * \note We ensure idempotence, that is we can only update the type of an Expr only + * if the original one is nullptr. + */ +TVM_DLL void UpdateType(Expr expr, Type ty); } // namespace relax } // namespace tvm + #endif // TVM_RELAX_TYPE_H_ diff --git a/include/tvm/ir/type_functor.h b/include/tvm/relax/type_functor.h similarity index 55% rename from include/tvm/ir/type_functor.h rename to include/tvm/relax/type_functor.h index a56690e81709..47513161442f 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/relax/type_functor.h @@ -18,31 +18,31 @@ */ /*! - * \file tvm/ir/type_functor.h - * \brief A way to defined arbitrary function signature with dispatch on types. + * \file tvm/relax/type_functor.h + * \brief Functors and visitors for Relax type nodes. */ -#ifndef TVM_IR_TYPE_FUNCTOR_H_ -#define TVM_IR_TYPE_FUNCTOR_H_ +#ifndef TVM_RELAX_TYPE_FUNCTOR_H_ +#define TVM_RELAX_TYPE_FUNCTOR_H_ #include -#include +#include +#include -#include #include -#include namespace tvm { +namespace relax { template class TypeFunctor; // functions to be overriden. -#define TYPE_FUNCTOR_DEFAULT \ +#define RELAX_TYPE_FUNCTOR_DEFAULT \ { \ return VisitTypeDefault_(op, std::forward(args)...); \ } -#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ +#define TVM_RELAX_TYPE_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch([](const ffi::ObjectRef& n, TSelf* self, Args... args) { \ return self->VisitType_(static_cast(n.get()), std::forward(args)...); \ }); @@ -60,14 +60,14 @@ class TypeFunctor { virtual ~TypeFunctor() {} /*! * \brief Same as call. - * \param n The expression node. + * \param n The type node. * \param args Additional arguments. * \return The result of the call */ R operator()(const Type& n, Args... args) { return VisitType(n, std::forward(args)...); } /*! * \brief The functor call. - * \param n The expression node. + * \param n The type node. * \param args Additional arguments. * \return The result of the call */ @@ -77,10 +77,14 @@ class TypeFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const ObjectTypeNode* op, Args... args) RELAX_TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const PrimTypeNode* op, Args... args) RELAX_TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const ShapeTypeNode* op, Args... args) RELAX_TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TensorTypeNode* op, Args... args) RELAX_TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const distributed::DTensorTypeNode* op, + Args... args) RELAX_TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TupleTypeNode* op, Args... args) RELAX_TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const FuncTypeNode* op, Args... args) RELAX_TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const ffi::Object* op, Args...) { TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning @@ -91,42 +95,58 @@ class TypeFunctor { static FType InitVTable() { FType vtable; // Set dispatch - TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); - TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); - TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); - TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode); + TVM_RELAX_TYPE_FUNCTOR_DISPATCH(ObjectTypeNode); + TVM_RELAX_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); + TVM_RELAX_TYPE_FUNCTOR_DISPATCH(ShapeTypeNode); + TVM_RELAX_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); + TVM_RELAX_TYPE_FUNCTOR_DISPATCH(distributed::DTensorTypeNode); + TVM_RELAX_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); + TVM_RELAX_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); vtable.Finalize(); return vtable; } }; -#undef TVM_TYPE_FUNCTOR_DISPATCH +#undef TVM_RELAX_TYPE_FUNCTOR_DISPATCH /*! - * \brief A type visitor that recursively visit types. + * \brief A type visitor. */ class TVM_DLL TypeVisitor : public TypeFunctor { public: - void VisitType_(const FuncTypeNode* op) override; - void VisitType_(const TupleTypeNode* op) override; + void VisitType_(const ObjectTypeNode* op) override; void VisitType_(const PrimTypeNode* op) override; - void VisitType_(const PointerTypeNode* op) override; + void VisitType_(const ShapeTypeNode* op) override; + void VisitType_(const TensorTypeNode* op) override; + void VisitType_(const distributed::DTensorTypeNode* op) override; + void VisitType_(const TupleTypeNode* op) override; + void VisitType_(const FuncTypeNode* op) override; + + protected: + // two functions to override when visit expr fields in type nodes. + virtual void VisitTypeExprField(const Expr& expr) {} + virtual void VisitTypeExprField(const PrimExpr& expr) {} }; /*! - * \brief TypeMutator that mutates expressions. + * \brief TypeMutator that mutates Relax type nodes. */ class TVM_DLL TypeMutator : public TypeFunctor { public: - Type VisitType(const Type& t) override; - Type VisitType_(const FuncTypeNode* op) override; - Type VisitType_(const TupleTypeNode* op) override; + Type VisitType_(const ObjectTypeNode* op) override; Type VisitType_(const PrimTypeNode* op) override; - Type VisitType_(const PointerTypeNode* op) override; + Type VisitType_(const ShapeTypeNode* op) override; + Type VisitType_(const TensorTypeNode* op) override; + Type VisitType_(const distributed::DTensorTypeNode* op) override; + Type VisitType_(const TupleTypeNode* op) override; + Type VisitType_(const FuncTypeNode* op) override; - private: - ffi::Array MutateArray(ffi::Array arr); + protected: + // two functions to override when visit expr fields in type nodes. + virtual Expr VisitTypeExprField(const Expr& expr) { return expr; } + virtual PrimExpr VisitTypeExprField(const PrimExpr& expr) { return expr; } }; +} // namespace relax } // namespace tvm -#endif // TVM_IR_TYPE_FUNCTOR_H_ +#endif // TVM_RELAX_TYPE_FUNCTOR_H_ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 77f8bab5553f..f36da70f9c54 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -51,11 +51,10 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::ffi::Map& binds, const tvm::ffi::Map& symbolic_var_map = {}); /*! - * \brief Bind the symbolic variables to a StructInfo. This is a helper function usually called by + * \brief Bind the symbolic variables to a Type. This is a helper function usually called by * other pass functions to help optimizations. */ -TVM_DLL StructInfo Bind(const StructInfo& sinfo, - const tvm::ffi::Map& symbolic_var_map); +TVM_DLL Type Bind(const Type& ty, const tvm::ffi::Map& symbolic_var_map); /*! * \brief Infer a binding map for symbolic variables @@ -78,10 +77,10 @@ TVM_DLL tvm::ffi::Map InferSymbolicVarMap( const tvm::ffi::Map& binds, const arith::Analyzer& analyzer); /*! - * \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean + * \brief Check if the given Type is for a boolean scalar (tensor of rank 0 with a boolean * dtype). * - * \param sinfo The input StructInfo. + * \param ty The input Type. * \param permit_unknown_rank If true, it will permit the input type to have unknown rank * (ndim of -1), which will require a dynamic check. * \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype @@ -90,8 +89,8 @@ TVM_DLL tvm::ffi::Map InferSymbolicVarMap( * \return True iff the input type is a boolean scalar type (or, depending on options, has unknown * rank or dtype) */ -TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank = true, - bool permit_unknown_dtype = true); +TVM_DLL bool IsBoolType(const Type& ty, bool permit_unknown_rank = true, + bool permit_unknown_dtype = true); /*! * \brief Check if the given expression is a "leaf" node or tuple node for normalization purposes. @@ -112,7 +111,7 @@ TVM_DLL bool IsLeafOrTuple(const Expr& expr); /*! * \brief Check if the given Call node is an impure operation. If the callee is a general - * expression, this simply requires checking the purity field of the FuncStructInfo. If it is an Op, + * expression, this simply requires checking the purity field of the FuncType. If it is an Op, * then this checks the `fPurity` field. * * \param call The input call diff --git a/include/tvm/script/printer/config.h b/include/tvm/script/printer/config.h index 541d66f63526..6bd774cbf332 100644 --- a/include/tvm/script/printer/config.h +++ b/include/tvm/script/printer/config.h @@ -91,7 +91,7 @@ class PrinterConfigNode : public ffi::Object { * Keys are conventionally namespaced as ".", e.g.: * "tirx.prefix" — the TIR prefix (default "T") * "relax.prefix" — the Relax prefix (default "R") - * "relax.show_all_struct_info" — whether to show all struct info (default true) + * "relax.show_all_ty" — whether to show all struct info (default true) * * Use GetExtraConfig(key, fallback) to read values with a typed fallback. */ @@ -101,7 +101,7 @@ class PrinterConfigNode : public ffi::Object { * \brief Look up a value in extra_config with type cast and fallback. * * Keys are conventionally namespaced as "." - * (e.g. "tirx.prefix", "relax.show_all_struct_info"). + * (e.g. "tirx.prefix", "relax.show_all_ty"). */ template T GetExtraConfig(const ffi::String& key, T fallback) const { diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 4ff3f0812a3b..aa07b7d43adb 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -379,21 +379,21 @@ def _extract_relax_function_signature(f): signature = {} for i, arg in enumerate(f.params): - sinfo = arg.struct_info - if isinstance(sinfo, relax.TensorStructInfo): - signature[f"arg{i}_shape"] = get_const_tuple(sinfo.shape) - signature[f"arg{i}_dtype"] = sinfo.dtype - elif isinstance(sinfo, relax.ShapeStructInfo): - signature[f"arg{i}_shape"] = get_const_tuple(sinfo.values) + ty = arg.ty + if isinstance(ty, relax.TensorType): + signature[f"arg{i}_shape"] = get_const_tuple(ty.shape) + signature[f"arg{i}_dtype"] = ty.dtype + elif isinstance(ty, relax.ShapeType): + signature[f"arg{i}_shape"] = get_const_tuple(ty.values) else: raise NotImplementedError() - ret_sinfo = f.ret_struct_info - if ret_sinfo.shape is not None: - signature["ret_shape"] = get_const_tuple(ret_sinfo.shape) + ret_ty = f.ret_ty + if ret_ty.shape is not None: + signature["ret_shape"] = get_const_tuple(ret_ty.shape) else: signature["ret_shape"] = None - signature["ret_dtype"] = ret_sinfo.dtype + signature["ret_dtype"] = ret_ty.dtype return signature @@ -714,12 +714,12 @@ def handle_attention(self, f, op_type): if "stacked_attention" in op_type: arg["arg0_dtype"] = signature["arg0_dtype"] - q_shape = get_const_tuple(attention_node.args[0].struct_info.shape) - k_shape = get_const_tuple(attention_node.args[1].struct_info.shape) - v_shape = get_const_tuple(attention_node.args[2].struct_info.shape) + q_shape = get_const_tuple(attention_node.args[0].ty.shape) + k_shape = get_const_tuple(attention_node.args[1].ty.shape) + v_shape = get_const_tuple(attention_node.args[2].ty.shape) if len(attention_node.args) == 4: - arg["bias_shape"] = get_const_tuple(attention_node.args[3].struct_info.shape) - arg["bias_dtype"] = attention_node.args[3].struct_info.dtype + arg["bias_shape"] = get_const_tuple(attention_node.args[3].ty.shape) + arg["bias_dtype"] = attention_node.args[3].ty.dtype qkv_layout = "qkv_stacked" else: @@ -803,7 +803,7 @@ def handle_norm(self, f, op_type): def visit_function_(self, f): if "Composite" not in f.attrs: body = super().visit_expr(f.body) - return relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) + return relax.Function(f.params, body, f.ret_ty, f.is_pure, f.attrs, f.span) op_type = f.attrs["Composite"] diff --git a/python/tvm/contrib/hexagon/generate_take_op.py b/python/tvm/contrib/hexagon/generate_take_op.py index 080a7d6a1953..372d71998913 100644 --- a/python/tvm/contrib/hexagon/generate_take_op.py +++ b/python/tvm/contrib/hexagon/generate_take_op.py @@ -70,7 +70,7 @@ def visit_call_(self, call_node: relax.Call) -> relax.Call: var = call_node.args[0] func = self.mod_[var] - if call_node.args[1][0].struct_info.dtype == "uint8": + if call_node.args[1][0].ty.dtype == "uint8": if op_replace(call_node, func): inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in call_node.args[1]] # LUT node creation @@ -78,7 +78,7 @@ def visit_call_(self, call_node: relax.Call) -> relax.Call: inp_scale, inp_zp, out_scale, out_zp, call_node.args[0].name_hint ) # Take operation node creation - take_func = hexagon_unary_ops.generate_take_primfunc(inp, call_node.struct_info) + take_func = hexagon_unary_ops.generate_take_primfunc(inp, call_node.ty) take_func = take_func.without_attr("global_symbol") take_func_gv = self.builder_.add_func(take_func, "take") take_node = relax.call_tir( @@ -86,7 +86,7 @@ def visit_call_(self, call_node: relax.Call) -> relax.Call: relax.expr.Tuple( [call_node.args[1][0], relax.expr.Constant(tvm.runtime.tensor(LUT))] ), - call_node.struct_info, + call_node.ty, ) return take_node return call_node diff --git a/python/tvm/contrib/hexagon/hexagon_unary_ops.py b/python/tvm/contrib/hexagon/hexagon_unary_ops.py index 92fdeed353c9..2a4aecaa0297 100644 --- a/python/tvm/contrib/hexagon/hexagon_unary_ops.py +++ b/python/tvm/contrib/hexagon/hexagon_unary_ops.py @@ -83,16 +83,16 @@ def LUT_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> None: return LUT -def generate_take_primfunc(inp, struct_info): +def generate_take_primfunc(inp, ty): # Generating the take op - N, H, W, C = inp.struct_info.shape - data = te.placeholder((N, H, W, C), dtype=struct_info.dtype, name="data") + N, H, W, C = inp.ty.shape + data = te.placeholder((N, H, W, C), dtype=ty.dtype, name="data") LUT_func = te.placeholder((256,), dtype="uint8", name="LUT") take = te.compute( - struct_info.shape, - lambda *indices: saturate( - (LUT_func[data[indices].astype("uint8")]), struct_info.dtype - ).astype(struct_info.dtype), + ty.shape, + lambda *indices: saturate((LUT_func[data[indices].astype("uint8")]), ty.dtype).astype( + ty.dtype + ), name="take_op", ) mod = te.create_prim_func([data, LUT_func, take]) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index c2107c9a8f01..4fbebeddd0f5 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -17,7 +17,6 @@ """Common expressions data structures in the IR.""" from numbers import Number -from typing import Optional import tvm_ffi @@ -33,6 +32,7 @@ class BaseExpr(Node): """Base class of all the expressions.""" span: Span | None + ty: "tvm.ir.Type | None" @tvm_ffi.register_object("ir.PrimExpr") @@ -50,17 +50,6 @@ class PrimExpr(BaseExpr): class RelaxExpr(BaseExpr): """Base class of all non-primitive expressions.""" - @property - def struct_info(self) -> Optional["tvm.relax.StructInfo"]: - """Get the struct info field - - Returns - ------- - struct_info : tvm.relax.StructInfo - The struct info if available. - """ - return _ffi_api.ExprStructInfo(self) - @tvm_ffi.register_object("ir.GlobalVar") class GlobalVar(RelaxExpr): diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 3ade4b80fc0d..567ebafa2d5c 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -81,8 +81,8 @@ class TupleType(Type): The fields in the tuple """ - def __init__(self, fields): - self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) + def __init__(self, fields, span=None): + self.__init_handle_by_constructor__(_ffi_api.TupleType, fields, span) @tvm_ffi.register_object("ir.FuncType") diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 3784c6a63de4..3eea8b0b023d 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -51,9 +51,10 @@ from .expr import const, extern, get_shape_of # Type -from .ty import ( +from .type import ( Type, ObjectType, + PrimType, ShapeType, TensorType, TupleType, @@ -79,17 +80,6 @@ # ExprFunctor from .expr_functor import ExprFunctor, PyExprVisitor, PyExprMutator -# StructInfo -from .struct_info import ( - StructInfo, - ObjectStructInfo, - PrimStructInfo, - ShapeStructInfo, - TensorStructInfo, - TupleStructInfo, - FuncStructInfo, -) - # pipeline from .pipeline import get_default_pipeline from .pipeline import get_pipeline @@ -105,11 +95,11 @@ from . import exec_builder from . import expr from . import ty +from . import type from . import analysis from . import transform from . import block_builder from . import op -from . import struct_info from . import backend from . import training from . import distributed diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index fa466d93d18d..d9f9fa72e6ad 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -26,9 +26,9 @@ collect_non_negative_expressions, computable_at_compile_time, contains_impure_call, - definable_tir_vars_in_struct_info, + definable_tir_vars_in_type, defined_symbolic_vars, - derive_call_ret_struct_info, + derive_call_ret_type, detect_recursion, erase_to_well_defined, free_symbolic_vars, @@ -40,10 +40,10 @@ name_to_binding, post_order_visit, remove_all_unused, - struct_info_base_check, - struct_info_lca, + type_base_check, + type_lca, suggest_layout_transforms, - tir_vars_in_struct_info, + tir_vars_in_type, udchain, well_formed, ) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 5ffb2f5ec847..3abd47d52159 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -26,44 +26,44 @@ import tvm from tvm import IRModule, tirx +from tvm.ir import Type from tvm.relax.expr import Binding, Call, DataflowBlock, Expr, Function, GlobalVar, Var -from tvm.relax.struct_info import FuncStructInfo, StructInfo -from tvm.relax.ty import Type +from tvm.relax.type import FuncType from tvm.tirx import Buffer, IndexMap, PrimFunc, SBlock from . import _ffi_api -def get_static_type(sinfo: StructInfo) -> Type: - """Get the corresponding static type from a StructInfo. +def get_static_type(ty: Type) -> Type: + """Get the corresponding static type from a Type. Parameters ---------- - sinfo : StructInfo - The input struct info. + ty : Type + The input type. Returns ------- ret : Type The corresponding static type. """ - return _ffi_api.GetStaticType(sinfo) # type: ignore + return _ffi_api.GetStaticType(ty) # type: ignore def erase_to_well_defined( - sinfo: StructInfo, + ty: Type, shape_var_map: dict[tirx.Var, tirx.PrimExpr] | None = None, var_map: dict[Var, Expr] | None = None, -) -> StructInfo: - """Erase sinfo into a well defined form. +) -> Type: + """Erase ty into a well defined form. - This function removes the StructInfo's dependencies on shape and vars that + This function removes the Type's dependencies on shape and vars that are not defined in given maps. Parameters ---------- - sinfo : StructInfo - The input struct info. + ty : Type + The input type. shape_var_map : Dict[tirx.Var, tirx.PrimExpr] Specifies the defined shape vars and the values they should map to. @@ -73,13 +73,13 @@ def erase_to_well_defined( Returns ------- - ret : StructInfo - The corresponding erased struct info. + ret : Type + The corresponding erased type. """ shape_var_map = {} if shape_var_map is None else shape_var_map var_map = {} if var_map is None else var_map - return _ffi_api.EraseToWellDefined(sinfo, shape_var_map, var_map) # type: ignore + return _ffi_api.EraseToWellDefined(ty, shape_var_map, var_map) # type: ignore class BaseCheckResult(IntEnum): @@ -100,33 +100,31 @@ class BaseCheckResult(IntEnum): PASS = 3 -def struct_info_base_check(base: StructInfo, derived: StructInfo) -> BaseCheckResult: +def type_base_check(base: Type, derived: Type) -> BaseCheckResult: """Run a base check to see if base subsumes derived. Parameters ---------- - base: StructInfo - The base struct info. + base: Type + The base type. - derived: StructInfo - The derived struct info. + derived: Type + The derived type. Returns ------- - ret : StructInfo - The derived return value struct info. + ret : Type + The derived return value type. """ - return _ffi_api.StructInfoBaseCheck(base, derived) # type: ignore + return _ffi_api.TypeBaseCheck(base, derived) # type: ignore -def derive_call_ret_struct_info( - func_sinfo: FuncStructInfo, call: Call, ctx: "tvm.relax.BlockBuilder" -) -> StructInfo: - """Derive the call's ret value struct info from inputs. +def derive_call_ret_type(func_ty: FuncType, call: Call, ctx: "tvm.relax.BlockBuilder") -> Type: + """Derive the call's ret value type from inputs. Parameters ---------- - func_sinfo: FuncStructInfo + func_ty: FuncType The call's function signature. call: Call @@ -137,96 +135,96 @@ def derive_call_ret_struct_info( Returns ------- - ret : StructInfo - The derived return value struct info. + ret : Type + The derived return value type. Note ---- This is an internal derivation function, call.op field is - ignored in this case and the derivation only depends on func_sinfo. + ignored in this case and the derivation only depends on func_ty. """ - return _ffi_api.DeriveCallRetStructInfo(func_sinfo, call, ctx) # type: ignore + return _ffi_api.DeriveCallRetType(func_ty, call, ctx) # type: ignore -def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: - """Unify the two struct info to their least common ancestor. +def type_lca(lhs: Type, rhs: Type) -> Type: + """Unify the two type to their least common ancestor. Parameters ---------- - lhs: StructInfo + lhs: Type The left operand. - rhs: StructInfo + rhs: Type The right operand. Returns ------- - ret : StructInfo + ret : Type The corresponding lca result. """ - return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore + return _ffi_api.TypeLCA(lhs, rhs) # type: ignore -def tir_vars_in_struct_info(sinfo: StructInfo) -> list[tirx.Var]: - """Get the TIR variables that appear in the input struct info. +def tir_vars_in_type(ty: Type) -> list[tirx.Var]: + """Get the TIR variables that appear in the input type. The returned list is deduplicated - each TIR variable will appear at most once. Parameters ---------- - sinfo : StructInfo - The struct info object to be analyzed. + ty : Type + The type object to be analyzed. Returns ------- ret : List[tirx.Var] - The list of TIR variables that appear in the input struct info. + The list of TIR variables that appear in the input type. """ - return _ffi_api.TIRVarsInStructInfo(sinfo) # type: ignore + return _ffi_api.TIRVarsInType(ty) # type: ignore -def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> list[tirx.Var]: - """Get the TIR variables that may be defined from input struct info. +def definable_tir_vars_in_type(ty: Type) -> list[tirx.Var]: + """Get the TIR variables that may be defined from input type. The returned list is deduplicated - each TIR variable will appear at most once. Parameters ---------- - sinfo : StructInfo - The struct info object to be analyzed. + ty : Type + The type object to be analyzed. Returns ------- ret : List[tirx.Var] - The list of TIR variables that can be defined from the StructInfo + The list of TIR variables that can be defined from the Type """ - return _ffi_api.DefinableTIRVarsInStructInfo(sinfo) # type: ignore + return _ffi_api.DefinableTIRVarsInType(ty) # type: ignore -def collect_non_negative_expressions(sinfo: StructInfo) -> list[tirx.PrimExpr]: +def collect_non_negative_expressions(ty: Type) -> list[tirx.PrimExpr]: """Collect TIR expressions used in non-negative contexts Get TIR variables that are non-negative within the context where - the struct info is used. For example, any expression used as a + the type is used. For example, any expression used as a tensor shape. The returned list is deduplicated - each TIR expression will appear at most once. The order of the list is in the order of - occurrence within the struct info. + occurrence within the type. Parameters ---------- - sinfo : StructInfo - The struct info object to be analyzed. + ty : Type + The type object to be analyzed. Returns ------- ret : List[tirx.Var] - The list of TIR variables that can be defined from the StructInfo + The list of TIR variables that can be defined from the Type """ - return _ffi_api.CollectNonNegativeExpressions(sinfo) # type: ignore + return _ffi_api.CollectNonNegativeExpressions(ty) # type: ignore def defined_symbolic_vars(func: Function) -> list[Var]: @@ -413,7 +411,7 @@ def contains_impure_call(expr: Expr, own_name: Var | GlobalVar | None = None) -> Notes ----- - Relies on StructInfo annotations, so ensure that the module has been normalized first. + Relies on Type annotations, so ensure that the module has been normalized first. Also, an impure call in a *nested* function does *not* mean that the outer expression contains an impure call--it only does if the nested function is *later called*. """ @@ -481,7 +479,7 @@ def remove_all_unused(func: Function) -> Function: return _ffi_api.remove_all_unused(func) # type: ignore -def well_formed(obj: IRModule | Function, check_struct_info: bool = True) -> None: +def well_formed(obj: IRModule | Function, check_ty: bool = True) -> None: """Check if the IRModule is well formed, raising on the first violation. Raises an error (seeded with the offending node so a pass runner can report a @@ -493,20 +491,20 @@ def well_formed(obj: IRModule | Function, check_struct_info: bool = True) -> Non obj : Union[tvm.IRModule, Function] The input IRModule or relax.Function. - check_struct_info : bool + check_ty : bool A boolean flag indicating if the property "every Expr must - have defined structure info" will be checked. + have defined type information" will be checked. Note ---- - By default the structure info is always checked. It is only in test cases - where `check_struct_info` might be false, so that other well-formed requirements - will be well tested and will not be blocked by not having structure info. + By default the type information is always checked. It is only in test cases + where `check_ty` might be false, so that other well-formed requirements + will be well tested and will not be blocked by not having type information. """ - _ffi_api.well_formed(obj, check_struct_info) # type: ignore + _ffi_api.well_formed(obj, check_ty) # type: ignore -def check_well_formed(obj: IRModule | Function, check_struct_info: bool = True) -> bool: +def check_well_formed(obj: IRModule | Function, check_ty: bool = True) -> bool: """Return whether the IRModule or Function is well formed. Wraps :func:`well_formed`, returning False instead of raising on the first violation. @@ -516,16 +514,16 @@ def check_well_formed(obj: IRModule | Function, check_struct_info: bool = True) obj : Union[tvm.IRModule, Function] The input IRModule or relax.Function. - check_struct_info : bool + check_ty : bool A boolean flag indicating if the property "every Expr must - have defined structure info" will be checked. + have defined type information" will be checked. Returns ------- ret: bool True if the IRModule is well formed, False if not. """ - return _ffi_api.check_well_formed(obj, check_struct_info) # type: ignore + return _ffi_api.check_well_formed(obj, check_ty) # type: ignore def _get_prim_func_default_dtype(func: PrimFunc): diff --git a/python/tvm/relax/backend/adreno/clml.py b/python/tvm/relax/backend/adreno/clml.py index 2b3c78d6a605..8797fe2fa987 100644 --- a/python/tvm/relax/backend/adreno/clml.py +++ b/python/tvm/relax/backend/adreno/clml.py @@ -64,7 +64,7 @@ def visit_tuple_getitem_(self, op: TupleGetItem): bn_call = self.bn_vars[tuple_value] if op.index == 0: bn_out = relax.TupleGetItem(bn_call, 0) - input_shape = bn_call.args[0].struct_info.shape + input_shape = bn_call.args[0].ty.shape return relax.Call(reshape_op, [bn_out, input_shape]) return super().visit_tuple_getitem_(op) @@ -137,13 +137,13 @@ def _check_conv2d(context: PatternCheckContext) -> bool: if "data" in context.annotated_expr: input_expr = context.annotated_expr["data"] - input_dtype = input_expr.struct_info.dtype + input_dtype = input_expr.ty.dtype if input_dtype not in ["float32", "float16"]: return False if "weight" in context.annotated_expr: weight_expr = context.annotated_expr["weight"] - weight_dtype = weight_expr.struct_info.dtype + weight_dtype = weight_expr.ty.dtype if weight_dtype not in ["float32", "float16"]: return False @@ -244,7 +244,7 @@ def _check_maxpool2d(context: PatternCheckContext) -> bool: return False data = context.annotated_expr["data"] - input_shape = data.struct_info.shape + input_shape = data.ty.shape if len(input_shape) != 4: return False @@ -311,7 +311,7 @@ def _check_avgpool2d(context: PatternCheckContext) -> bool: return False data = context.annotated_expr["data"] - input_shape = data.struct_info.shape + input_shape = data.ty.shape if len(input_shape) != 4: return False @@ -371,7 +371,7 @@ def _check_global_avgpool(context: PatternCheckContext) -> bool: return False data = context.annotated_expr["data"] - input_shape = data.struct_info.shape + input_shape = data.ty.shape if len(input_shape) != 4: return False @@ -451,8 +451,8 @@ def _check_batchnorm(context: PatternCheckContext) -> bool: base_shape = None for param in params.values(): - shape = param.struct_info.shape - dtype = param.struct_info.dtype + shape = param.ty.shape + dtype = param.ty.dtype if dtype not in {"float32"}: return False @@ -496,8 +496,8 @@ def batch_norm_pattern(): def _check_binary_op(context: PatternCheckContext) -> bool: def _check_arg(input_expr): - input_dtype = input_expr.struct_info.dtype - input_shape = input_expr.struct_info.shape + input_dtype = input_expr.ty.dtype + input_shape = input_expr.ty.shape if len(input_shape) == 0: return False @@ -523,13 +523,13 @@ def compare_shapes(lhs_shape, rhs_shape): rhs_shape = None if "lhs" in context.annotated_expr: lhs = context.annotated_expr["lhs"] - lhs_shape = lhs.struct_info.shape + lhs_shape = lhs.ty.shape if not _check_arg(lhs): return False if "rhs" in context.annotated_expr: rhs = context.annotated_expr["rhs"] - rhs_shape = rhs.struct_info.shape + rhs_shape = rhs.ty.shape if not _check_arg(rhs): return False @@ -630,7 +630,7 @@ def _check_dequantize_matmul(ctx: relax.transform.PatternCheckContext) -> bool: wdq = ctx.annotated_expr["w_decoded"] w_pack = ctx.annotated_expr["w_encoded"] - if ctx.annotated_expr["lhs"].struct_info.dtype != "float16": + if ctx.annotated_expr["lhs"].ty.dtype != "float16": return False if not isinstance(wdq, relax.Call): return False @@ -639,17 +639,17 @@ def _check_dequantize_matmul(ctx: relax.transform.PatternCheckContext) -> bool: return False if not ( - (len(root.struct_info.shape) == 3) - and isinstance(root.struct_info.shape[0], tirx.IntImm) - and (root.struct_info.dtype == "float16") - and (root.struct_info.shape[0] == 1) + (len(root.ty.shape) == 3) + and isinstance(root.ty.shape[0], tirx.IntImm) + and (root.ty.dtype == "float16") + and (root.ty.shape[0] == 1) ): return False if not ( - (len(wdq.struct_info.shape) == 2) - and (w_pack.struct_info.shape[-1] == root.struct_info.shape[-1]) - and (wdq.struct_info.shape[-2] == _input.struct_info.shape[-1]) + (len(wdq.ty.shape) == 2) + and (w_pack.ty.shape[-1] == root.ty.shape[-1]) + and (wdq.ty.shape[-2] == _input.ty.shape[-1]) ): return False diff --git a/python/tvm/relax/backend/contrib/example_npu/patterns.py b/python/tvm/relax/backend/contrib/example_npu/patterns.py index 17d6656ef899..41ae18535798 100644 --- a/python/tvm/relax/backend/contrib/example_npu/patterns.py +++ b/python/tvm/relax/backend/contrib/example_npu/patterns.py @@ -63,7 +63,7 @@ def _check_npu_memory_constraints( Placeholder for NPU memory hierarchy constraint checking. A real implementation would inspect the annotated expression's - TensorStructInfo to verify the tensor fits within the NPU's + TensorType to verify the tensor fits within the NPU's on-chip SRAM (L1) or compute memory (L2/CMX). Tensors that exceed on-chip capacity require tiling before offload. """ diff --git a/python/tvm/relax/backend/cuda/cublas.py b/python/tvm/relax/backend/cuda/cublas.py index a0907a76784c..eef6c70ad96a 100644 --- a/python/tvm/relax/backend/cuda/cublas.py +++ b/python/tvm/relax/backend/cuda/cublas.py @@ -59,20 +59,20 @@ def _check_matmul(context: PatternCheckContext) -> bool: scale = context.annotated_expr["scale"] zero_point = context.annotated_expr["zp"] # Only scalar values for scale and zero_point are supported. - if scale.struct_info.ndim != 0 or zero_point.struct_info.ndim != 0: + if scale.ty.ndim != 0 or zero_point.ty.ndim != 0: return False # Only zero_point == 0.0 is supported. if zero_point.data.numpy()[()].item() != 0.0: return False - lhs_dtype = lhs.struct_info.dtype - rhs_dtype = rhs.struct_info.dtype - out_dtype = matmul_call.struct_info.dtype + lhs_dtype = lhs.ty.dtype + rhs_dtype = rhs.ty.dtype + out_dtype = matmul_call.ty.dtype if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): return False - lhs_shape = lhs.struct_info.shape.values - rhs_shape = rhs.struct_info.shape.values + lhs_shape = lhs.ty.shape.values + rhs_shape = rhs.ty.shape.values if not isinstance(lhs_shape[-1], tvm.tirx.expr.IntImm | int): # Reduction axis must be constant @@ -120,7 +120,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: # Non-default epilogue not supported for IGEMM return False bias = context.annotated_expr["bias"] - bias_shape = bias.struct_info.shape.values + bias_shape = bias.ty.shape.values bias_batches = reduce(operator.mul, bias_shape[:-1], 1) if not isinstance(bias_batches, tvm.tirx.expr.IntImm | int) or int(bias_batches) > 1: # cuBLAS only supports bias vector diff --git a/python/tvm/relax/backend/cuda/cudnn.py b/python/tvm/relax/backend/cuda/cudnn.py index 2be11fd04d47..2df31ca62f5d 100644 --- a/python/tvm/relax/backend/cuda/cudnn.py +++ b/python/tvm/relax/backend/cuda/cudnn.py @@ -53,8 +53,8 @@ def _check_conv2d(context: PatternCheckContext) -> bool: weight_expr = context.annotated_expr["weight"] # Check if the data types of input and weights are supported by cuDNN BYOC - input_dtype = input_expr.struct_info.dtype - weight_dtype = weight_expr.struct_info.dtype + input_dtype = input_expr.ty.dtype + weight_dtype = weight_expr.ty.dtype if not _is_supported_dtype(input_dtype, weight_dtype): return False @@ -71,14 +71,14 @@ def _check_stacked_attention(context: PatternCheckContext, layout: str) -> bool: if has_leaking_intermediate_variables(context): return False if layout == "BS3NH": - if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3: + if not context.annotated_expr["stacked_qkv"].ty.ndim == 3: return False if "split" in context.annotated_expr: split_op = context.annotated_expr["split"] if not split_op.attrs.axis == 2: return False elif layout == "SBN3H": - if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 4: + if not context.annotated_expr["stacked_qkv"].ty.ndim == 4: return False if "split" in context.annotated_expr: split_op = context.annotated_expr["split"] @@ -167,7 +167,7 @@ def __init__(self, mod): def visit_function_(self, f): if "Composite" not in f.attrs: body = super().visit_expr(f.body) - new_f = relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) + new_f = relax.Function(f.params, body, f.ret_ty, f.is_pure, f.attrs, f.span) if "global_symbol" in f.attrs and "cudnn" in f.attrs["global_symbol"]: composite_func = body.blocks[0].bindings[0].value @@ -178,8 +178,8 @@ def visit_function_(self, f): if "attention" in f.attrs["Composite"] and "cudnn" in f.attrs["Composite"]: # Workspace is needed only for larger head sizes, but for simplicity we always allocate. - out_dtype = f.ret_struct_info.dtype - out_size_1d = _shape_1d(f.ret_struct_info.shape) + out_dtype = f.ret_ty.dtype + out_size_1d = _shape_1d(f.ret_ty.shape) # This needs to be in sync with the actual value that the kernel expects. workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype] if not isinstance(workspace_size_bytes, int | tvm.tirx.expr.IntImm): diff --git a/python/tvm/relax/backend/cuda/cutlass.py b/python/tvm/relax/backend/cuda/cutlass.py index b15b3b698508..3e6e936d8c07 100644 --- a/python/tvm/relax/backend/cuda/cutlass.py +++ b/python/tvm/relax/backend/cuda/cutlass.py @@ -105,8 +105,8 @@ def _check_residual(root_call: Call, context: PatternCheckContext) -> bool: # If residual depends on the result of the root call, this cannot be handled by cutlass. return False - shape1 = root_var.struct_info.shape - shape2 = residual.struct_info.shape + shape1 = root_var.ty.shape + shape2 = residual.ty.shape out_channel = shape1[-1] if not _is_same_shape(shape1, shape2) and not _is_bias_like(shape2, out_channel): @@ -127,7 +127,7 @@ def _check_conv2d(context: PatternCheckContext) -> bool: if ( data_layout != "NHWC" or kernel_layout != "OHWI" - or not _is_supported_dtype(data.struct_info.dtype, weight.struct_info.dtype) + or not _is_supported_dtype(data.ty.dtype, weight.ty.dtype) ): return False @@ -135,13 +135,13 @@ def _check_conv2d(context: PatternCheckContext) -> bool: return False # Check if any dimensions are symbolic. - for dim in data.struct_info.shape.values: + for dim in data.ty.shape.values: if isinstance(dim, tvm.tirx.Var): return False # pylint: disable=invalid-name - IC = data.struct_info.shape.values[3] - OC = weight.struct_info.shape.values[0] + IC = data.ty.shape.values[3] + OC = weight.ty.shape.values[0] # not depthwise conv2d return not IC == OC == conv2d_call.attrs.groups @@ -154,16 +154,16 @@ def _check_matmul(context: PatternCheckContext) -> bool: lhs = context.annotated_expr["lhs"] rhs = context.annotated_expr["rhs"] - lhs_dtype = lhs.struct_info.dtype - rhs_dtype = rhs.struct_info.dtype + lhs_dtype = lhs.ty.dtype + rhs_dtype = rhs.ty.dtype if not _is_supported_dtype(lhs_dtype, rhs_dtype): return False if not _check_residual(context.annotated_expr["root"], context): return False - lhs_shape = lhs.struct_info.shape.values - rhs_shape = rhs.struct_info.shape.values + lhs_shape = lhs.ty.shape.values + rhs_shape = rhs.ty.shape.values return is_shape_valid_for_cutlass_matmul(lhs_shape, rhs_shape) @@ -223,22 +223,22 @@ def _check_decode_matmul(ctx): return False # out_dtype = "float32" not supported unless matmul is followed by cast to fp16. - if root.struct_info.dtype == "float32": + if root.ty.dtype == "float32": return False call_tir_decode = ctx.annotated_expr["w_decoded"] if "decode" not in call_tir_decode.args[0].name_hint: return False - N = root.struct_info.shape[-1] + N = root.ty.shape[-1] - if ctx.annotated_expr["lhs"].struct_info.dtype != "float16": + if ctx.annotated_expr["lhs"].ty.dtype != "float16": return False # weight needs to be packed to int8. packed_weight = ctx.annotated_expr["w_encoded"] - if packed_weight.struct_info.dtype != "int8": + if packed_weight.ty.dtype != "int8": return False # The kernel expects the weight to be preprocessed by this packed function. @@ -251,16 +251,16 @@ def _check_decode_matmul(ctx): scales = ctx.annotated_expr["scales"] - if scales.struct_info.dtype != "float16": + if scales.ty.dtype != "float16": return False # scale shape needs to be (N,) or (1, N) or (K // group_size, N) - if len(scales.struct_info.shape) > 2 or scales.struct_info.shape[-1] != N: + if len(scales.ty.shape) > 2 or scales.ty.shape[-1] != N: return False if "bias" in ctx.annotated_expr: - out_shape = root.struct_info.shape - bias_shape = ctx.annotated_expr["bias"].struct_info.shape + out_shape = root.ty.shape + bias_shape = ctx.annotated_expr["bias"].ty.shape # bias shape needs to be (N,), possibly with additional axes on the front. # It can also have the same shape as the output. @@ -378,7 +378,7 @@ def _check_stacked_attention(context: PatternCheckContext) -> bool: """Check if the given stacked attention workload can be offloaded to CUTLASS.""" if has_leaking_intermediate_variables(context): return False - if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3: + if not context.annotated_expr["stacked_qkv"].ty.ndim == 3: return False if "split" in context.annotated_expr: split_op = context.annotated_expr["split"] @@ -458,7 +458,7 @@ def _check_layer_norm(context: PatternCheckContext) -> bool: return False axis = int(attrs.axes[0]) - rank = len(context.matched_expr.struct_info.shape) + rank = len(context.matched_expr.ty.shape) if axis < 0: axis += rank @@ -536,7 +536,7 @@ def __init__(self, mod): def visit_function_(self, f): if "Composite" not in f.attrs: body = super().visit_expr(f.body) - new_f = Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) + new_f = Function(f.params, body, f.ret_ty, f.is_pure, f.attrs, f.span) if "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]: composite_func = body.blocks[0].bindings[0].value @@ -547,8 +547,8 @@ def visit_function_(self, f): if "attention" in f.attrs["Composite"] and "cutlass" in f.attrs["Composite"]: # Workspace is needed only for larger head sizes, but for simplicity we always allocate. - out_dtype = f.ret_struct_info.dtype - out_size_1d = _shape_1d(f.ret_struct_info.shape) + out_dtype = f.ret_ty.dtype + out_size_1d = _shape_1d(f.ret_ty.shape) # This needs to be in sync with the actual value that the kernel expects. workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype] if not isinstance(workspace_size_bytes, int | tvm.tirx.expr.IntImm): diff --git a/python/tvm/relax/backend/dispatch_sampling.py b/python/tvm/relax/backend/dispatch_sampling.py index 780fcd0cc02c..a82b8fd83709 100644 --- a/python/tvm/relax/backend/dispatch_sampling.py +++ b/python/tvm/relax/backend/dispatch_sampling.py @@ -41,7 +41,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: ) prob, uniform_sample, sample_indices = call.args - tgt = self._get_target(call.struct_info) + tgt = self._get_target(call.ty) dtype = call.attrs.dtype _, prob_dtype = self.get_shape_dtype(prob) sample_shape, sample_dtype = self.get_shape_dtype(uniform_sample) @@ -63,7 +63,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: return relax.call_tir( gv, [prob, uniform_sample, sample_indices], - out_sinfo=call.struct_info, + out_ty=call.ty, ) else: cumsum_prob = relax.op.cumsum(prob, axis=1, dtype=prob_dtype, exclusive=False) @@ -74,7 +74,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: return relax.call_tir( gv, [cumsum_prob, uniform_sample, sample_indices], - out_sinfo=call.struct_info, + out_ty=call.ty, ) return super().visit_call_(call) diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index c901aceed3e3..0656f65a43bc 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -80,16 +80,16 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: input_tensor = call.args[0] boundaries = call.args[1] right = call.attrs.right - tgt = self._get_target(call.struct_info) + tgt = self._get_target(call.ty) te_func = topi.searchsorted with tgt: if self.is_gpu_target(tgt): te_func = topi.gpu.searchsorted return self.builder_.call_te( - te_func, boundaries, input_tensor, right, input_tensor.struct_info.dtype + te_func, boundaries, input_tensor, right, input_tensor.ty.dtype ) if call.op.name == "relax.sort": - tgt = self._get_target(call.struct_info) + tgt = self._get_target(call.ty) te_func = topi.sort kwargs = {} with tgt: @@ -102,7 +102,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: te_func, call.args[0], call.attrs.axis, not call.attrs.descending, **kwargs ) if call.op.name == "relax.argsort": - tgt = self._get_target(call.struct_info) + tgt = self._get_target(call.ty) te_func = topi.argsort kwargs = {} with tgt: @@ -120,7 +120,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: **kwargs, ) if call.op.name == "relax.topk": - tgt = self._get_target(call.struct_info) + tgt = self._get_target(call.ty) te_func = topi.topk kwargs = {} if can_use_thrust(tgt, "tvm.contrib.thrust.sort"): @@ -141,9 +141,9 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: self._append_calls_to_update(tir_call, tgt) return tir_call if call.op.name in ("relax.cumprod", "relax.cumsum"): - tgt = self._get_target(call.struct_info) + tgt = self._get_target(call.ty) axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis - shape = call.struct_info.shape + shape = call.ty.shape # TODO(tvm-team): Support fully dynamic case with `shape=None` if shape is None: raise ValueError("non-symbolic shape is not supported for now") @@ -163,7 +163,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: dim = 1 for i in range(len(shape) - 1): dim *= shape[i] - in_dtype = call.args[0].struct_info.dtype + in_dtype = call.args[0].ty.dtype out_dtype = call.attrs.dtype out_dtype = out_dtype or in_dtype cumsum_2d_shape = relax.ShapeExpr([dim, shape[-1]]) @@ -171,7 +171,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: "vm.builtin.reshape", call.args[0], cumsum_2d_shape, - sinfo_args=relax.TensorStructInfo(cumsum_2d_shape, out_dtype), + ty_args=relax.TensorType(cumsum_2d_shape, out_dtype), ) gv = self.builder_.add_func( gpu_2d_continuous_cumsum(in_dtype=in_dtype, out_dtype=out_dtype), @@ -180,13 +180,13 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: cumsum = relax.call_tir( gv, reshape, - out_sinfo=relax.TensorStructInfo(cumsum_2d_shape, out_dtype), + out_ty=relax.TensorType(cumsum_2d_shape, out_dtype), ) return relax.call_pure_packed( "vm.builtin.reshape", cumsum, shape, - sinfo_args=call.struct_info, + ty_args=call.ty, ) with tgt: @@ -214,8 +214,8 @@ def estimate_thrust_workspace_size(self, call: relax.Call) -> int: """ Estimate the workspace size for thrust sort/argsort/topk/cumsum """ - input_shape = call.args[0].struct_info.shape - input_byte_per_elem = DataType(call.args[0].struct_info.dtype).bits // 8 + input_shape = call.args[0].ty.shape + input_byte_per_elem = DataType(call.args[0].ty.dtype).bits // 8 int64_byte_per_elem = DataType("int64").bits // 8 int32_byte_per_elem = DataType("int32").bits // 8 num_elem = reduce(mul, input_shape, 1) diff --git a/python/tvm/relax/backend/metal/coreml.py b/python/tvm/relax/backend/metal/coreml.py index dab375478ac2..d0b7ea3fc813 100644 --- a/python/tvm/relax/backend/metal/coreml.py +++ b/python/tvm/relax/backend/metal/coreml.py @@ -36,8 +36,8 @@ Var, VarBinding, ) -from tvm.relax.struct_info import PrimStructInfo, TensorStructInfo from tvm.relax.transform import PatternCheckContext +from tvm.relax.type import PrimType, TensorType from tvm.support.xcode import compile_coreml from ...expr_functor import PyExprVisitor, visitor @@ -355,14 +355,14 @@ def __init__(self, model_name, function): def visit_function_(self, op) -> None: for var in op.params: name = var.name_hint - sinfo = var.struct_info - if isinstance(sinfo, TensorStructInfo): - shape = [int(v) for v in list(sinfo.shape)] - elif isinstance(sinfo, PrimStructInfo): + ty = var.ty + if isinstance(ty, TensorType): + shape = [int(v) for v in list(ty.shape)] + elif isinstance(ty, PrimType): shape = [] else: - raise Exception("Currently not supported: ", type(sinfo)) - dtype = sinfo.dtype + raise Exception("Currently not supported: ", type(ty)) + dtype = ty.dtype self.model_inputs_.append((name, shape, dtype)) self.visit_expr(op.body) @@ -456,12 +456,12 @@ def compile(self, out_dir): input_desc = self.builder.spec.description.input input_desc[i].type.multiArrayType.dataType = FEATURE_TYPE_MAP[dtype] - output_dim = [int(n) for n in self.function.struct_info.ret.shape] + output_dim = [int(n) for n in self.function.ty.ret.shape] last_binding_var = self.function.body.blocks[0].bindings[-1].var self.builder.set_output(self.out_map[last_binding_var], [output_dim]) - for i, dtype in enumerate([self.function.struct_info.ret.dtype]): + for i, dtype in enumerate([self.function.ty.ret.dtype]): assert dtype in FEATURE_TYPE_MAP output_desc = self.builder.spec.description.output output_desc[i].type.multiArrayType.dataType = FEATURE_TYPE_MAP[dtype] diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index a328684b7c7e..06011d6f6e97 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -489,12 +489,12 @@ def handle_input(tensor, layout, transpose, repeat=False): transposed = is_op("relax.permute_dims")(reshaped) def rewriter(matchings, x): - if matchings[tensor].struct_info.ndim != 4: + if matchings[tensor].ty.ndim != 4: return None if list(matchings[permuted].attrs.axes) != [0, 2, 1, 3]: return None - before_reshape = matchings[permuted].struct_info.shape.values - after_reshape = matchings[shape].struct_info.values + before_reshape = matchings[permuted].ty.shape.values + after_reshape = matchings[shape].ty.values if not ( len(before_reshape) == 4 and len(after_reshape) == 3 @@ -503,7 +503,7 @@ def rewriter(matchings, x): return None if transpose and list(matchings[transposed].attrs.axes) != [0, 2, 1]: return None - return x, x.struct_info.shape + return x, x.ty.shape if transpose: return transposed, rewriter @@ -514,11 +514,11 @@ def rewriter(matchings, x): transposed = is_op("relax.permute_dims")(tensor) def rewriter(matchings, x): - if matchings[tensor].struct_info.ndim != 3: + if matchings[tensor].ty.ndim != 3: return None if transpose and list(matchings[transposed].attrs.axes) != [0, 2, 1]: return None - before_reshape = x.struct_info.shape.values + before_reshape = x.ty.shape.values after_reshape = [before_reshape[0], before_reshape[1], 1, before_reshape[2]] return R.reshape(x, after_reshape), after_reshape @@ -536,10 +536,10 @@ def handle_output(tensor, layout): permuted = is_op("relax.permute_dims")(reshaped) def rewriter(matchings, x): - if matchings[tensor].struct_info.ndim != 3: + if matchings[tensor].ty.ndim != 3: return None - before_reshape = matchings[tensor].struct_info.shape.values - after_reshape = matchings[shape].struct_info.values + before_reshape = matchings[tensor].ty.shape.values + after_reshape = matchings[shape].ty.values if not ( len(before_reshape) == 3 and len(after_reshape) == 4 @@ -554,9 +554,9 @@ def rewriter(matchings, x): elif layout == "BSH": def rewriter(matchings, x): - if matchings[tensor].struct_info.ndim != 3: + if matchings[tensor].ty.ndim != 3: return None - return R.reshape(x, matchings[tensor].struct_info.shape.values) + return R.reshape(x, matchings[tensor].ty.shape.values) return tensor, rewriter else: @@ -602,7 +602,7 @@ def rewriter(original, matchings): if query is None or key is None or value is None: return original softmax_axis = matchings[softmax].attrs.axis - softmax_input_rank = len(matchings[softmax].struct_info.shape) + softmax_input_rank = len(matchings[softmax].ty.shape) if softmax_axis == -1: softmax_axis += softmax_input_rank if softmax_axis != softmax_input_rank - 1: @@ -611,7 +611,7 @@ def rewriter(original, matchings): _, s_kv, _, _ = key_shape if with_bias: bias = matchings[bias_raw] - bias_shape = list(bias.struct_info.shape) + bias_shape = list(bias.ty.shape) if bias_shape == [b * n, s, s_kv]: bias = R.reshape(bias, [b, n, s, s_kv]) elif bias_shape == [b * n, 1, s_kv]: diff --git a/python/tvm/relax/backend/rocm/hipblas.py b/python/tvm/relax/backend/rocm/hipblas.py index 15eaf80ab0c4..1a52102b4822 100644 --- a/python/tvm/relax/backend/rocm/hipblas.py +++ b/python/tvm/relax/backend/rocm/hipblas.py @@ -47,14 +47,14 @@ def _check_matmul(context: PatternCheckContext) -> bool: rhs = context.annotated_expr["rhs"] matmul_call = context.annotated_expr["root"] - lhs_dtype = lhs.struct_info.dtype - rhs_dtype = rhs.struct_info.dtype - out_dtype = matmul_call.struct_info.dtype + lhs_dtype = lhs.ty.dtype + rhs_dtype = rhs.ty.dtype + out_dtype = matmul_call.ty.dtype if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): return False - lhs_shape = lhs.struct_info.shape.values - rhs_shape = rhs.struct_info.shape.values + lhs_shape = lhs.ty.shape.values + rhs_shape = rhs.ty.shape.values if not isinstance(lhs_shape[-1], tvm.tirx.expr.IntImm | int): # Reduction axis must be constant @@ -73,7 +73,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: # Non-default epilogue not supported for IGEMM return False bias = context.annotated_expr["bias"] - bias_shape = bias.struct_info.shape.values + bias_shape = bias.ty.shape.values bias_batches = reduce(operator.mul, bias_shape[:-1], 1) if not isinstance(bias_batches, tvm.tirx.expr.IntImm | int) or int(bias_batches) > 1: # hipblas only supports bias vector diff --git a/python/tvm/relax/backend/utils.py b/python/tvm/relax/backend/utils.py index 9d4598e669b8..47c01ac431fa 100644 --- a/python/tvm/relax/backend/utils.py +++ b/python/tvm/relax/backend/utils.py @@ -38,13 +38,11 @@ def is_gpu_target(target: Target) -> bool: def get_shape_dtype(expr: relax.Expr) -> tuple[relax.ShapeExpr, str]: """Get shape and dtype from an expression. If the shape and dtype is unknown, raise an error.""" - sinfo = expr.struct_info - if not isinstance(expr.struct_info, relax.TensorStructInfo): - raise ValueError( - f"Expecting a expr with TensorStructInfo, but got {expr} with {expr.struct_info}" - ) + ty = expr.ty + if not isinstance(expr.ty, relax.TensorType): + raise ValueError(f"Expecting a expr with TensorType, but got {expr} with {expr.ty}") - shape, dtype = sinfo.shape, sinfo.dtype + shape, dtype = ty.shape, ty.dtype if shape is None: raise ValueError( f"Expecting a expr with known shape, but got {expr} with unknown shape" @@ -52,14 +50,14 @@ def get_shape_dtype(expr: relax.Expr) -> tuple[relax.ShapeExpr, str]: return shape, dtype - def _get_target(self, sinfo: relax.StructInfo) -> Target: - # Get target information from TensorStructInfo - if isinstance(sinfo, relax.TensorStructInfo): - vdevice = sinfo.vdevice + def _get_target(self, ty: relax.Type) -> Target: + # Get target information from TensorType + if isinstance(ty, relax.TensorType): + vdevice = ty.vdevice if vdevice is not None: return vdevice.target - elif isinstance(sinfo, relax.TupleStructInfo): - for f in sinfo.fields: + elif isinstance(ty, relax.TupleType): + for f in ty.fields: tgt = self._get_target(f) if tgt != Target.current(): return tgt diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 5dd8a107ae08..825b532364fe 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -218,7 +218,7 @@ def wrapper(*args, **kwargs): wrapped_func = create_py_func_wrapper(func_name, py_func) register_py_func(func_name, wrapped_func) - def call_tir(self, tir_func, args, out_sinfo): + def call_tir(self, tir_func, args, out_ty): """Call a TIR function with PyTorch tensors.""" # Try to get function name from different sources if isinstance(tir_func, str): @@ -244,7 +244,7 @@ def call_tir(self, tir_func, args, out_sinfo): ) func = self.compiled_tir_funcs[func_name] - out = self._create_output_tensors(out_sinfo, args) + out = self._create_output_tensors(out_ty, args) tvm_args = self._convert_pytorch_to_tvm(args) tvm_out = self._convert_pytorch_to_tvm(out) @@ -253,7 +253,7 @@ def call_tir(self, tir_func, args, out_sinfo): result = self._convert_tvm_to_pytorch(tvm_out) return result[0] if len(result) == 1 else result - def call_dps_packed(self, func_name: str, args, out_sinfo): + def call_dps_packed(self, func_name: str, args, out_ty): """Call a packed function with PyTorch tensors, converting TVM Tensors via DLPack.""" if hasattr(self, func_name) and callable(getattr(self, func_name)): return getattr(self, func_name)(*args) @@ -268,7 +268,7 @@ def call_dps_packed(self, func_name: str, args, out_sinfo): ) from error func = self.extern_funcs[func_name] - out = self._create_output_tensors(out_sinfo, args) + out = self._create_output_tensors(out_ty, args) tvm_args = self._convert_pytorch_to_tvm(args) tvm_out = self._convert_pytorch_to_tvm(out) func(*tvm_args, *tvm_out) @@ -281,22 +281,20 @@ def call_py_func(self, func_name: str, args): py_func = self.pyfuncs[func_name] return py_func(self, *args) - def _create_output_tensors(self, out_sinfo, in_args=None): + def _create_output_tensors(self, out_ty, in_args=None): # pylint: disable=import-outside-toplevel import torch - sinfo_list = out_sinfo if isinstance(out_sinfo, list) else [out_sinfo] + ty_list = out_ty if isinstance(out_ty, list) else [out_ty] out_tensors = [] - for sinfo in sinfo_list: - if isinstance(sinfo, tuple | list) and all( - isinstance(x, int | np.integer) for x in sinfo - ): - out_tensors.append(torch.zeros(list(map(int, sinfo)), dtype=torch.float32)) + for ty in ty_list: + if isinstance(ty, tuple | list) and all(isinstance(x, int | np.integer) for x in ty): + out_tensors.append(torch.zeros(list(map(int, ty)), dtype=torch.float32)) continue - if hasattr(sinfo, "shape") and hasattr(sinfo, "dtype"): - concrete_shape = self._infer_concrete_shape_from_args(sinfo.shape, in_args) - torch_dtype = self._convert_tvm_dtype_to_torch(sinfo.dtype) + if hasattr(ty, "shape") and hasattr(ty, "dtype"): + concrete_shape = self._infer_concrete_shape_from_args(ty.shape, in_args) + torch_dtype = self._convert_tvm_dtype_to_torch(ty.dtype) out_tensors.append(torch.zeros(concrete_shape, dtype=torch_dtype)) continue @@ -341,7 +339,7 @@ def _infer_concrete_shape_from_args(self, shape, in_args): raise ValueError( "Cannot infer concrete output shape from symbolic shape and inputs. " - "Please provide a concrete `out_sinfo` (e.g., a tuple/list of ints) " + "Please provide a concrete `out_ty` (e.g., a tuple/list of ints) " "or ensure input tensors carry shapes that determine output extents." ) @@ -510,7 +508,7 @@ def script( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, - show_all_struct_info: bool = True, + show_all_ty: bool = True, extra_config: dict | None = None, ) -> str: """Print TVM IR into TVMScript text format with Python function support. @@ -532,7 +530,7 @@ def script( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, - show_all_struct_info=show_all_struct_info, + show_all_ty=show_all_ty, extra_config=extra_config, ) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index f347f05f1555..25a629a7fc6f 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -32,7 +32,7 @@ from . import _ffi_api from .expr import BaseFunc, Binding, BindingBlock, Expr, GlobalVar, Tuple, Var from .op.base import call_tir, call_tir_with_grad -from .struct_info import StructInfo +from .type import Type from .utils import gen_call_tir_inputs @@ -92,7 +92,7 @@ def __init__(self, block_builder, def_vars): else: raise ValueError("def_vars only can take tirx.Var") # setup a dummy var so shape is in scope. - sparam = rx.Var("sparam", rx.ShapeStructInfo(shape_vars)) + sparam = rx.Var("sparam", rx.ShapeType(shape_vars)) self._scope_params = [sparam] def __enter__(self): @@ -114,8 +114,8 @@ class BlockBuilder(Object): m = tirx.Var("m", "int32") n = tirx.Var("n", "int32") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16") + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16") bb = rx.BlockBuilder() with bb.function([x, y], "func"): with bb.dataflow() as df: @@ -361,13 +361,13 @@ def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: """ primfunc_name = kwargs.pop("primfunc_name_hint", None) - tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) + tir_func, call_args, output_ty, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) if not primfunc_name: primfunc_name = func.__name__ gvar = self.add_func(tir_func, primfunc_name) - return call_tir(gvar, call_args, output_sinfo, tir_vars) + return call_tir(gvar, call_args, output_ty, tir_vars) def call_te_with_grad( self, @@ -413,7 +413,7 @@ def call_te_with_grad( """ primfunc_name = kwargs.pop("primfunc_name_hint", None) - tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) + tir_func, call_args, output_ty, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) if te_grad_kwargs is None: te_grad_kwargs = {} @@ -423,7 +423,7 @@ def call_te_with_grad( gvar = self.add_func(tir_func, primfunc_name) return call_tir_with_grad( - gvar, call_args, output_sinfo, te_grad_name, te_grad_kwargs, tir_vars + gvar, call_args, output_ty, te_grad_name, te_grad_kwargs, tir_vars ) def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: @@ -456,8 +456,8 @@ def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: bb = rx.BlockBuilder() n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) - y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + x = rx.Var("x", rx.TensorType([n, m], "float32")) + y = rx.Var("y", rx.TensorType([n, m], "float32")) def te_func(args, args_dict, msg): A = args[0] @@ -506,8 +506,8 @@ def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tenso bb = relax.BlockBuilder() n = tirx.Var("n", "int64") - x = relax.Var("x", relax.TensorStructInfo([n], "float32")) - y = relax.Var("y", relax.TensorStructInfo([n + 1], "float32")) + x = relax.Var("x", relax.TensorType([n], "float32")) + y = relax.Var("y", relax.TensorType([n + 1], "float32")) def te_func(A): C = te.compute((n + 1), lambda i: A[i]) @@ -547,7 +547,7 @@ def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32")) name_hint = kwargs.pop("name_hint", "") return self.emit(self.call_te(func, *args, **kwargs), name_hint=name_hint) - def match_cast(self, value: Expr, struct_info: StructInfo, name_hint: str = "") -> Var: + def match_cast(self, value: Expr, ty: Type, name_hint: str = "") -> Var: """Emit a MatchCast. Parameters @@ -555,8 +555,8 @@ def match_cast(self, value: Expr, struct_info: StructInfo, name_hint: str = "") value : tvm.relax.Expr The value of the MatchCast to be emitted. - struct_info : StructInfo - The struct info to be matched. + ty : Type + The type to be matched. name_hint : str The name of the match cast @@ -569,7 +569,7 @@ def match_cast(self, value: Expr, struct_info: StructInfo, name_hint: str = "") return _ffi_api.BlockBuilderEmitMatchCast( self, value, - struct_info, + ty, name_hint, ) # type: ignore @@ -651,8 +651,8 @@ def emit_func_output( finally: self.end_scope() - # do not specify ret_struct_info and let constructor deduce - # from seqe.struct_info + # do not specify ret_ty and let constructor deduce + # from seqe.ty func = rx.Function(self._func._params, seqe, is_pure=self._func._is_pure) for key, value in self._func._attrs.items(): func = func.with_attr(key, value) diff --git a/python/tvm/relax/distributed/__init__.py b/python/tvm/relax/distributed/__init__.py index 7996b2a25bd9..aa7876fd358e 100644 --- a/python/tvm/relax/distributed/__init__.py +++ b/python/tvm/relax/distributed/__init__.py @@ -19,6 +19,6 @@ """The infrastructure for distributed inference on Relax.""" from .global_info import DeviceMesh, device_mesh -from .struct_info import Placement, DTensorStructInfo, PlacementSpec +from .type import Placement, DTensorType, PlacementSpec from . import transform diff --git a/python/tvm/relax/distributed/struct_info.py b/python/tvm/relax/distributed/type.py similarity index 88% rename from python/tvm/relax/distributed/struct_info.py rename to python/tvm/relax/distributed/type.py index 0c94d3dda9e2..01fb8385a91e 100644 --- a/python/tvm/relax/distributed/struct_info.py +++ b/python/tvm/relax/distributed/type.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=redefined-builtin, invalid-name -"""Struct Info for distributed tensor.""" +"""Types for distributed tensor.""" import enum import tvm_ffi from tvm.ir import Span -from tvm.relax.struct_info import StructInfo, TensorStructInfo +from tvm.relax.type import TensorType, Type from tvm.runtime import Object from . import _ffi_api @@ -111,14 +111,14 @@ def from_text(text: str) -> "Placement": return _ffi_api.PlacementFromText(text) -@tvm_ffi.register_object("relax.DTensorStructInfo") -class DTensorStructInfo(StructInfo): - """StructInfo of a Distributed Tensor value. +@tvm_ffi.register_object("relax.DTensorType") +class DTensorType(Type): + """Type of a Distributed Tensor value. Parameters ---------- - tensor_sinfo: TensorStructInfo - The struct info inherited from TensorStructInfo + tensor_ty: TensorType + The tensor type carried by the distributed tensor. device_mesh: DeviceMesh The device mesh of the tensor. placement: Placement @@ -126,20 +126,20 @@ class DTensorStructInfo(StructInfo): """ - tensor_sinfo: TensorStructInfo + tensor_ty: TensorType device_mesh: DeviceMesh placement: Placement def __init__( self, - tensor_sinfo: TensorStructInfo, + tensor_ty: TensorType, device_mesh: DeviceMesh, placement: Placement, span: Span = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.DTensorStructInfo, - tensor_sinfo, + _ffi_api.DTensorType, + tensor_ty, device_mesh, placement, span, # type: ignore diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 89feac1bce18..47b718981e5d 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -123,8 +123,8 @@ def has_attr(self, attrs: dict[str, Object]) -> "AttrPattern": attrs = make_node("ir.DictAttrs", **attrs) return AttrPattern(self, attrs) - def has_struct_info(self, struct_info: "StructInfo") -> "StructInfoPattern": - return StructInfoPattern(self, struct_info) + def has_ty(self, ty: "Type") -> "TypePattern": + return TypePattern(self, ty) def has_dtype(self, dtype: str) -> "DataTypePattern": """ @@ -568,23 +568,23 @@ def __init__(self): @register_df_node -class StructInfoPattern(DFPattern): - """A pattern that matches another pattern with a certain StructInfo +class TypePattern(DFPattern): + """A pattern that matches another pattern with a certain Type Parameters ---------- pattern: tvm.relax.dpl.DFPattern The input pattern that needs type annotation. - struct_info: tvm.relax.StructInfo - The struct info to match against + ty: tvm.relax.Type + The type to match against """ - def __init__(self, pattern: "DFPattern", struct_info: "StructInfo"): + def __init__(self, pattern: "DFPattern", ty: "Type"): self.__init_handle_by_constructor__( - ffi.StructInfoPattern, + ffi.TypePattern, pattern, - struct_info, + ty, ) # type: ignore @@ -861,7 +861,7 @@ def is_shape(shape: list[tvm.ir.PrimExpr]) -> "PrimArrPattern": return PrimArrPattern(shape) -# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +# Todo(relax-team): Dataflow pattern for Type, and match out_ty def _is_call_tir( func_pattern: DFPattern, args: list | tuple | TuplePattern = None, @@ -877,7 +877,7 @@ def _is_call_tir( return is_op("relax.call_tir")(func_pattern, args, tir_vars, add_constraint=False) -# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +# Todo(relax-team): Dataflow pattern for Type, and match out_ty def is_call_tir( func_name: str, args: list | tuple | TuplePattern = None, diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 5cc928ff2599..a0ac3665a30f 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -274,7 +274,7 @@ def rewriter(matchings): Q_weight = matchings[Q_weight_pat] K_weight = matchings[K_weight_pat] V_weight = matchings[V_weight_pat] - width = Q_weight.struct_info.shape[1] + width = Q_weight.ty.shape[1] concat = R.concat([Q_weight, K_weight, V_weight], axis=1) matmul = R.matmul(inp, concat) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 734d7bde672f..e02bbb51ca8c 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -57,41 +57,13 @@ def __init__(self): raise RuntimeError("Cannot directly construct Id") -# NOTE: place base struct info in expr to avoid cyclic dep -# from expr to struct info. -@tvm_ffi.register_object("ir.StructInfo") -class StructInfo(Node, Scriptable): - """The base class of all StructInfo. - - StructInfo contains both the static type - and runtime structural information. - """ - - def __eq__(self, other): - """Compare two struct info for structural equivalence.""" - return tvm_ffi.structural_equal(self, other) +def _relax_type_is_base_of(self: Type, derived: Type) -> bool: + """Check if this Relax type is a base of another Relax type.""" - def __ne__(self, other): - return not self.__eq__(other) + return _ffi_api.TypeIsBaseOf(self, derived) # type: ignore - def same_as(self, other): - """Overload with structural equality.""" - return super().__eq__(other) - def is_base_of(self, derived: "StructInfo") -> bool: - """Check if self is base of another derived struct info. - - Parameters - ---------- - derived : StructInfo - The derived struct info to be checked. - - Returns - ------- - result : bool - The check result. - """ - return _ffi_api.StructInfoIsBaseOf(self, derived) # type: ignore +Type.is_base_of = _relax_type_is_base_of # type: ignore[attr-defined] # will be registered afterwards in python/tvm/relax/op/init.py @@ -155,7 +127,7 @@ def __le__(self, other: Expr) -> "ExprWithOp": # NOTE: Cannot override __eq__ and __ne__, which will influence object equal def __add__(self, other: Expr) -> "ExprWithOp": - if isinstance(self.struct_info_, tvm.relax.TupleStructInfo) and isinstance(other, tuple): + if isinstance(self.ty, tvm.relax.TupleType) and isinstance(other, tuple): return tuple([*self, *other]) return _binary_op_helper(self, other, _op_ffi_api.add) # type: ignore @@ -246,7 +218,7 @@ def __getitem__(self, index: int) -> "ExprWithOp": raise IndexError from err raise - def _check_for_tensor_struct_info(self): + def _check_for_tensor_ty(self): """Raise an error if this is something other than a Tensor Used for early checks in `expr.dtype` and `expr.shape` @@ -254,10 +226,10 @@ def _check_for_tensor_struct_info(self): raised during shape inference, an earlier check makes it easier to find the invalid usage. """ - if self.struct_info_ is None: + if self.ty is None: return - if not isinstance(self.struct_info_, tvm.relax.TensorStructInfo): + if not isinstance(self.ty, tvm.relax.TensorType): raise TypeError( f"Runtime unpacking of DLDataType is only implemented for tensors, " f"but was applied to object {self} of type {type(self)}." @@ -266,32 +238,32 @@ def _check_for_tensor_struct_info(self): @property def dtype(self) -> "_DLTensorDTypeProxy": """Returns a proxy object for accessing DLTensor::dtype""" - self._check_for_tensor_struct_info() + self._check_for_tensor_ty() return _DLTensorDTypeProxy(self) @property def ndim(self) -> "Expr": """Returns the runtime value of DLTensor::ndim""" - self._check_for_tensor_struct_info() + self._check_for_tensor_ty() op = tvm.ir.Op.get("relax.inspect.tensor_ndim") return tvm.relax.Call(op, [self]) @property def shape(self) -> "_DLTensorShapeProxy": """Returns a proxy object for accessing DLTensor::shape""" - self._check_for_tensor_struct_info() + self._check_for_tensor_ty() return _DLTensorShapeProxy(self) @property def strides(self) -> "_DLTensorStrideProxy": """Returns a proxy object for accessing DLTensor::strides""" - self._check_for_tensor_struct_info() + self._check_for_tensor_ty() return _DLTensorStrideProxy(self) @property def byte_offset(self) -> "Expr": """Returns a proxy object for accessing DLTensor::byte_offset""" - self._check_for_tensor_struct_info() + self._check_for_tensor_ty() op = tvm.ir.Op.get("relax.inspect.tensor_byte_offset") return tvm.relax.Call(op, [self]) @@ -305,7 +277,7 @@ def elem_offset(self) -> "Expr": `tirx::BufferNode::elem_offset` field when interacting with TIR buffers. """ - self._check_for_tensor_struct_info() + self._check_for_tensor_ty() op = tvm.ir.Op.get("relax.inspect.tensor_elem_offset") return tvm.relax.Call(op, [self]) @@ -447,13 +419,11 @@ def __getitem__(self, axis: int | PrimExpr | Expr) -> Expr: if not isinstance(axis, tvm.relax.Expr): axis = tvm.relax.PrimValue(axis) - if axis.struct_info_ is not None and not isinstance( - axis.struct_info_, tvm.relax.PrimStructInfo - ): + if axis.ty is not None and not isinstance(axis.ty, tvm.relax.PrimType): raise TypeError( f"The index used to access {self.tensor}.shape " - f'must have struct info R.Prim("int64"), ' - f"but index {axis} had struct info {axis.struct_info_}." + f'must have type R.Prim("int64"), ' + f"but index {axis} had type {axis.ty}." ) op = tvm.ir.Op.get("relax.inspect.tensor_shape_i") @@ -517,13 +487,11 @@ def __getitem__(self, axis: int | PrimExpr | Expr) -> Expr: if not isinstance(axis, tvm.relax.Expr): axis = tvm.relax.PrimValue(axis) - if axis.struct_info_ is not None and not isinstance( - axis.struct_info_, tvm.relax.PrimStructInfo - ): + if axis.ty is not None and not isinstance(axis.ty, tvm.relax.PrimType): raise TypeError( f"The index used to access {self.tensor}.strides " - f'must have struct info R.Prim("int64"), ' - f"but index {axis} had struct info {axis.struct_info_}." + f'must have type R.Prim("int64"), ' + f"but index {axis} had type {axis.ty}." ) op = tvm.ir.Op.get("relax.inspect.tensor_stride_i") @@ -548,11 +516,11 @@ class Call(ExprWithOp): attrs: Optional[tvm.ir.Attrs] Attributes to the call, can be None - sinfo_args: Optional[Union[List[StructInfo], typing.Tuple[StructInfo, ...]]] - The structure info arguments of a CallNode. - sinfo_args is designed to be non-empty only for intrinsic op (e.g., + ty_args: Optional[Union[List[Type], typing.Tuple[Type, ...]]] + The type information arguments of a CallNode. + ty_args is designed to be non-empty only for intrinsic op (e.g., call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main - usage of structure info inference. + usage of type information inference. span: Optional[Span] Span that points to original source code @@ -561,7 +529,7 @@ class Call(ExprWithOp): op: Expr args: list[Expr] attrs: tvm.ir.Attrs - sinfo_args: list[StructInfo] + ty_args: list[Type] span: Span | None def __init__( @@ -569,17 +537,17 @@ def __init__( op: Expr | tvm.ir.Op, args: list[Expr] | tuple[Expr, ...], attrs: tvm.ir.Attrs | None = None, - sinfo_args: list[StructInfo] | tuple[StructInfo, ...] | None = None, + ty_args: list[Type] | tuple[Type, ...] | None = None, span: Span | None = None, ): - if not sinfo_args: - sinfo_args = [] + if not ty_args: + ty_args = [] self.__init_handle_by_constructor__( _ffi_api.Call, op, args, attrs, - sinfo_args, + ty_args, span, # type: ignore ) @@ -637,7 +605,7 @@ class Tuple(ExprWithOp): def __init__(self, fields: list[Expr] | tuple[Expr, ...], span: Span | None = None): if isinstance(fields, tvm.relax.Tuple): fields = fields.fields - elif isinstance(getattr(fields, "struct_info_", None), tvm.relax.TupleStructInfo): + elif isinstance(getattr(fields, "ty", None), tvm.relax.TupleType): fields = [*fields] self.__init_handle_by_constructor__(_ffi_api.Tuple, fields, span) # type: ignore @@ -730,8 +698,8 @@ class Constant(ExprWithOp): data: tvm.runtime.Tensor The data of the constant tensor. - struct_info: Optional[StructInfo] - The struct info of the constant tensor. If not specified, infer it from data. + ty: Optional[Type] + The type of the constant tensor. If not specified, infer it from data. span: Optional[Span] Span that points to original source code @@ -747,13 +715,13 @@ class Constant(ExprWithOp): def __init__( self, data: tvm.runtime.Tensor, - struct_info: StructInfo | None = None, + ty: Type | None = None, span: Span | None = None, ) -> None: self.__init_handle_by_constructor__( _ffi_api.Constant, data, - struct_info, + ty, span, # type: ignore ) @@ -767,8 +735,8 @@ class Var(ExprWithOp): name_hint: str | Id The name hint of the variable. - struct_info: Optional[StructInfo] - The struct info annotation of the variable. + ty: Optional[Type] + The type annotation of the variable. span: Optional[Span] Span that points to original source code @@ -780,21 +748,21 @@ class Var(ExprWithOp): def __init__( self, name_hint: str | Id, - struct_info: StructInfo | None = None, + ty: Type | None = None, span: Span | None = None, ) -> None: - if struct_info is not None: - struct_info = tvm.runtime.convert(struct_info) - if not isinstance(struct_info, StructInfo): + if ty is not None: + ty = tvm.runtime.convert(ty) + if not isinstance(ty, Type): raise TypeError( - "struct_info needs to be an instance of StructInfo. " + "ty needs to be an instance of Type. " "If you attempt to pass in shape, " - "use relax.TensorStructInfo(shape, dtype)." + "use relax.TensorType(shape, dtype)." ) self.__init_handle_by_constructor__( _ffi_api.Var if isinstance(name_hint, str) else _ffi_api.VarFromId, # type: ignore name_hint, - struct_info, + ty, span, ) @@ -816,8 +784,8 @@ class DataflowVar(Var): name_hint: str | Id The name hint of the variable. - struct_info: Optional[StructInfo] - The struct info annotation of the variable. + ty: Optional[Type] + The type annotation of the variable. span: Optional[Span] Span that points to original source code @@ -829,17 +797,17 @@ class DataflowVar(Var): def __init__( self, name_hint: str | Id, - struct_info: StructInfo | None = None, + ty: Type | None = None, span: Span | None = None, ) -> None: # pylint: disable=super-init-not-called - if struct_info is not None: - struct_info = tvm.runtime.convert(struct_info) - if not isinstance(struct_info, StructInfo): + if ty is not None: + ty = tvm.runtime.convert(ty) + if not isinstance(ty, Type): raise TypeError( - "struct_info needs to be an instance of StructInfo. " + "ty needs to be an instance of Type. " "If you attempt to pass in shape, " - "use relax.TensorStructInfo(shape, dtype)." + "use relax.TensorType(shape, dtype)." ) self.__init_handle_by_constructor__( @@ -849,7 +817,7 @@ def __init__( else _ffi_api.DataflowVarFromId ), # type: ignore name_hint, - struct_info, + ty, span, ) @@ -898,10 +866,10 @@ class Binding(Node, Scriptable): @tvm_ffi.register_object("relax.expr.MatchCast") class MatchCast(Binding): - """Runtime-match the value to the struct info. + """Runtime-match the value to the type. This operation does runtime check, populates the un-defined symbolic shape vars - and vars in struct_info in the first occurrence, and insert equality assertions in + and vars in ty in the first occurrence, and insert equality assertions in other cases. Parameters @@ -912,22 +880,20 @@ class MatchCast(Binding): value: Expr The input value expression. - struct_info: tvm.relax.StructInfo - The struct info to match cast to. + ty: tvm.relax.Type + The type to match cast to. """ - struct_info: StructInfo + ty: Type value: Expr span: Span | None - def __init__( - self, var: Var, value: Expr, struct_info: StructInfo, span: Span | None = None - ) -> None: + def __init__(self, var: Var, value: Expr, ty: Type, span: Span | None = None) -> None: self.__init_handle_by_constructor__( _ffi_api.MatchCast, var, value, - struct_info, + ty, span, # type: ignore ) @@ -996,7 +962,7 @@ class Function(BaseFunc, Scriptable): params: list[Var] body: Expr - ret_struct_info: StructInfo + ret_ty: Type is_pure: bool attrs: tvm.ir.DictAttrs span: Span | None @@ -1005,7 +971,7 @@ def __init__( self, params: list[Var], body: Expr, - ret_struct_info: StructInfo | None = None, + ret_ty: Type | None = None, is_pure: bool | None = True, attrs: tvm.ir.DictAttrs | None = None, span: Span | None = None, @@ -1016,7 +982,7 @@ def __init__( _ffi_api.Function, params, body, - ret_struct_info, + ret_ty, is_pure, attrs, span, @@ -1025,7 +991,7 @@ def __init__( @staticmethod def create_empty( params: list[Var], - ret_struct_info: StructInfo, + ret_ty: Type, is_pure: bool | None = True, attrs: tvm.ir.DictAttrs | None = None, span: Span | None = None, @@ -1033,7 +999,7 @@ def create_empty( """Construct a relax.Function but without body""" if attrs is None: attrs = tvm.ir.DictAttrs({}) - return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, is_pure, attrs, span) # type: ignore + return _ffi_api.FunctionCreateEmpty(params, ret_ty, is_pure, attrs, span) # type: ignore def __call__(self, *args): """Invoke the global function. @@ -1139,20 +1105,20 @@ class ExternFunc(BaseFunc, ExprWithOp): def __init__( self, global_symbol: String, - struct_info: StructInfo | None = None, + ty: Type | None = None, span: Span | None = None, ) -> None: self.__init_handle_by_constructor__( _ffi_api.ExternFunc, global_symbol, - struct_info, + ty, span, # type: ignore ) -def extern(name: str, struct_info: StructInfo | None = None, span: Span | None = None): +def extern(name: str, ty: Type | None = None, span: Span | None = None): """Create extern function.""" - return ExternFunc(name, struct_info, span) + return ExternFunc(name, ty, span) def const( @@ -1218,7 +1184,7 @@ def te_tensor( Parameters ---------- value : Expr - The relax expression, which is required to have TensorStructInfo. + The relax expression, which is required to have TensorType. tir_var_map : Dict[tvm.tirx.Var, tvm.tirx.PrimExpr] The mapping to substitute the TIR variables appeared in the @@ -1246,7 +1212,7 @@ def get_shape_of(expr: Expr) -> Expr: Note ---- This function requires expr to be normalized. - The function will report an error if expr's StructInfo is not TensorStructInfo. + The function will report an error if expr's Type is not TensorType. It will try to return symbolic function when possible. If the tensor do not have a compile-time symbolic shape, the function will then choose to return `Call(relax.op.shape_of, [expr])`. @@ -1254,5 +1220,5 @@ def get_shape_of(expr: Expr) -> Expr: return _ffi_api.GetShapeOf(expr) # type: ignore -def _update_struct_info(expr: Expr, struct_info: StructInfo | None) -> None: - _ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore +def _update_type(expr: Expr, ty: Type | None) -> None: + _ffi_api.UpdateType(expr, ty) # type: ignore diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index c9ea88d11100..e84e37b02ad5 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -54,7 +54,7 @@ Var, VarBinding, ) -from .struct_info import StructInfo +from .type import Type visitor = derived_object """ @@ -1521,7 +1521,7 @@ def lookup_binding(self, var: Var) -> Expr | None: # Using self._outer() to ref _PyExprMutator return _ffi_api.PyExprMutatorLookupBinding(self._outer(), var) # type: ignore - def with_struct_info(self, var: Var, struct_info: StructInfo) -> Var: + def with_type(self, var: Var, ty: Type) -> Var: """Create a new var with specified shape and type if the original var's shape or type does not match with the specified ones. @@ -1529,8 +1529,8 @@ def with_struct_info(self, var: Var, struct_info: StructInfo) -> Var: ---------- var : Var The var to be updated. - struct_info : StructInfo - The struct info. + ty : Type + The type. Returns ------- @@ -1538,4 +1538,4 @@ def with_struct_info(self, var: Var, struct_info: StructInfo) -> Var: The var filled with shape and type. """ # Using self._outer() to ref _PyExprMutator - return _ffi_api.PyExprMutatorWithStructInfo(self._outer(), var, struct_info) # type: ignore + return _ffi_api.PyExprMutatorWithType(self._outer(), var, ty) # type: ignore diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index 63b58cd8f113..b4aad8f7781d 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -75,8 +75,8 @@ def autopad( [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] ) # get input shape - ndim = data.struct_info.ndim - data_shape = list(data.struct_info.shape) + ndim = data.ty.ndim + data_shape = list(data.ty.shape) shape = data_shape[2:ndim] # set up integer constants diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 1301fa471ff6..c3ef1fb268b6 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """The core infra for nn.Module, which includes the following pieces: -- Tensor, a wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, +- Tensor, a wrapper on top of relax.Expr whose ty is a TensorType, providing more convenient access shape and dtype information. Tensor is always symbolic and not bound to any concrete values. - Parameter, a special tensor which could be bound or not bound to concrete values. @@ -45,11 +45,11 @@ from .... import relax as rx from ...block_builder import BlockBuilder -from ...struct_info import ( - ObjectStructInfo, - ShapeStructInfo, - TensorStructInfo, - TupleStructInfo, +from ...type import ( + ObjectType, + ShapeType, + TensorType, + TupleType, ) from ._tensor_op import _TensorOp from .subroutine import SubroutineMixin @@ -88,7 +88,7 @@ def set_default_dtype(dtype: str) -> None: class Tensor(_TensorOp): - """A wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, providing more + """A wrapper on top of relax.Expr whose ty is a TensorType, providing more convenient access shape and dtype information. Tensor is always symbolc and not bound to any concrete values. Shape and dtype inference is done eagerly upon tensor creation, i.e. when operators are applied on tensors, the shape and dtype information is already available. @@ -100,13 +100,13 @@ def __init__(self, *, _expr: rx.Expr) -> None: """Private constructor. Tensor is never supposed to be constructed directly by users.""" def _check_tensor(expr: rx.Expr) -> None: - assert expr.struct_info_ is not None - assert isinstance(expr.struct_info, TensorStructInfo) - assert expr.struct_info.ndim != -1 - assert expr.struct_info.shape is not None - assert expr.struct_info.shape.struct_info_ is not None - assert isinstance(expr.struct_info.shape.struct_info, ShapeStructInfo) - assert expr.struct_info.shape.struct_info.values is not None + assert expr.ty is not None + assert isinstance(expr.ty, TensorType) + assert expr.ty.ndim != -1 + assert expr.ty.shape is not None + assert expr.ty.shape.ty is not None + assert isinstance(expr.ty.shape.ty, ShapeType) + assert expr.ty.shape.ty.values is not None _check_tensor(_expr) self._expr = _expr @@ -122,17 +122,17 @@ def from_scalar(data: int | float, dtype: str) -> "Tensor": return Tensor(_expr=rx.const(data, dtype=dtype)) @staticmethod - def from_struct_info(struct_info: rx.TensorStructInfo, name: str = "tensor") -> "Tensor": - """Construct a nn.Tensor from a Relax TensorStructInfo. + def from_ty(ty: rx.TensorType, name: str = "tensor") -> "Tensor": + """Construct a nn.Tensor from a Relax TensorType. - TensorStructInfo is the Relax type-level description of a tensor, carrying its shape + TensorType is the Relax type-level description of a tensor, carrying its shape and dtype without holding actual data. This factory creates an unbound placeholder ``nn.Tensor`` that can be used as a symbolic input when tracing an ``nn.Module``. Parameters ---------- - struct_info : rx.TensorStructInfo - The struct info describing the tensor's shape and dtype. + ty : rx.TensorType + The type describing the tensor's shape and dtype. name : str Name hint for the underlying Relax variable. @@ -140,12 +140,12 @@ def from_struct_info(struct_info: rx.TensorStructInfo, name: str = "tensor") -> Returns ------- tensor : Tensor - A symbolic ``nn.Tensor`` backed by a ``relax.Var`` with the given struct info. + A symbolic ``nn.Tensor`` backed by a ``relax.Var`` with the given type. """ return Tensor( _expr=rx.Var( name_hint=name, - struct_info=struct_info, + ty=ty, ) ) @@ -179,7 +179,7 @@ def placeholder( return Tensor( _expr=rx.Var( name_hint=name, - struct_info=TensorStructInfo( + ty=TensorType( shape=new_shape, # type: ignore[arg-type] dtype=dtype, ), @@ -203,8 +203,8 @@ def shape(self) -> list[int | tirx.PrimExpr]: def _simplify(expr: tirx.PrimExpr): return expr.value if isinstance(expr, tirx.IntImm) else expr - shape_sinfo: ShapeStructInfo = self._expr.struct_info.shape.struct_info - return [_simplify(x) for x in shape_sinfo.values] + shape_ty: ShapeType = self._expr.ty.shape.ty + return [_simplify(x) for x in shape_ty.values] @property def ndim(self) -> int: @@ -215,7 +215,7 @@ def ndim(self) -> int: ndim : int The number of dimensions of the tensor """ - return self._expr.struct_info.ndim + return self._expr.ty.ndim @property def dtype(self) -> str: @@ -226,7 +226,7 @@ def dtype(self) -> str: dtype : str The data type of the tensor """ - return self._expr.struct_info.dtype + return self._expr.ty.dtype def __repr__(self) -> str: return f'Tensor({self.shape}, "{self.dtype}")' @@ -310,8 +310,8 @@ def to(self, dtype: str | None = None) -> None: # pylint: disable=invalid-name class Object: - """A wrapper on top of relax.Expr whose struct_info is the base - ObjectStructInfo (rather than any its subclass). Object effectively + """A wrapper on top of relax.Expr whose ty is the base + ObjectType (rather than any its subclass). Object effectively represents non-tensor frontend components such as KV caches. """ @@ -322,7 +322,7 @@ def __init__(self, *, _expr: rx.Expr, _name: str) -> None: if not isinstance(_expr, rx.Var): _expr = BlockBuilder.current().emit(_expr, _name) self._expr = _expr - assert isinstance(self._expr.struct_info, ObjectStructInfo) + assert isinstance(self._expr.ty, ObjectType) class Effect: @@ -778,17 +778,17 @@ def wrap_nested(expr: rx.Expr, name: str) -> Tensor | Sequence[Tensor]: """ if not isinstance(expr, rx.DataflowVar): expr = BlockBuilder.current().emit(expr, name) - if isinstance(expr.struct_info_, TensorStructInfo): + if isinstance(expr.ty, TensorType): return Tensor(_expr=expr) - if isinstance(expr.struct_info_, TupleStructInfo): + if isinstance(expr.ty, TupleType): return tuple( wrap_nested( # type: ignore rx.TupleGetItem(expr, i), name=f"{name}.{i}", ) - for i in range(len(expr.struct_info_.fields)) + for i in range(len(expr.ty.fields)) ) - raise TypeError(f"Unsupported return type: {expr.struct_info_}") + raise TypeError(f"Unsupported return type: {expr.ty}") def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any], bool]): diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index 8b68bb6e7088..627bbe1506cc 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -26,7 +26,7 @@ from .... import relax as rx from ...block_builder import BlockBuilder -from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo +from ...type import ObjectType, ShapeType, TupleType from . import core, extern from . import spec as _spec from .modules import IOEffect @@ -178,15 +178,13 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any: def _convert_input(arg): if isinstance(arg, tirx.Var): - return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg])) + return rx.Var(arg.name, ty=ShapeType(values=[arg])) if isinstance(arg, core.Tensor | core.Object): return arg._expr # pylint: disable=protected-access if isinstance(arg, _spec.Tuple): return rx.Var( arg.name, - struct_info=TupleStructInfo( - [_convert_input(arg_i).struct_info for arg_i in arg.elements] - ), + ty=TupleType([_convert_input(arg_i).ty for arg_i in arg.elements]), ) raise TypeError(f"Unsupported input type: {type(arg)}") @@ -215,7 +213,7 @@ def _get_var(shape_var: tirx.Var) -> tirx.Var: if mode == "packed": input_var = rx.Var( "packed_params", - TupleStructInfo(fields=[x.struct_info for x in inputs]), + TupleType(fields=[x.ty for x in inputs]), ) for i, (name, param) in enumerate(params): param._expr = builder.emit(rx.TupleGetItem(input_var, i), name_hint=name) @@ -236,7 +234,7 @@ def _effects(mode: str) -> list[rx.Var]: if mode == "packed": input_var = rx.Var( "packed_effects", - TupleStructInfo(fields=[x.struct_info for x in inputs]), + TupleType(fields=[x.ty for x in inputs]), ) i = 0 for effect_input, (_, effect) in zip(unflat_inputs, effects): @@ -313,7 +311,7 @@ def _convert_input(arg_name, arg_spec): name=arg_name, ) elif isinstance(arg_spec, _spec.Object): - arg = arg_spec.object_type(_expr=rx.Var(arg_name, ObjectStructInfo()), _name=arg_name) + arg = arg_spec.object_type(_expr=rx.Var(arg_name, ObjectType()), _name=arg_name) elif isinstance(arg_spec, _spec.Tuple): elements = type(arg_spec.elements)( [ diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index 866da6c4ab58..9c8efce690f1 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -67,8 +67,8 @@ def _convert(arg, name: str): raise TypeError(f"Unsupported input type: {type(arg)}") rx_inputs = _convert(input_args, "input") - rx_outputs_sinfo = _convert(_inference_function(*input_args), "dummy").struct_info - return wrap_nested(call_dps_packed(func_name, rx_inputs, rx_outputs_sinfo), func_name) + rx_outputs_ty = _convert(_inference_function(*input_args), "dummy").ty + return wrap_nested(call_dps_packed(func_name, rx_inputs, rx_outputs_ty), func_name) return _call diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 0e0828c75db7..1da233a78a65 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -141,7 +141,7 @@ def attention_with_fused_qkv( - The input qkv and output tensor have `head_dim` at the last dim. """ # pylint: disable=protected-access - b, s, _, d = qkv._expr.struct_info.shape + b, s, _, d = qkv._expr.ty.shape qkv = qkv.reshape(b * s, qkv.shape[2], d) return Tensor( _expr=rx.BlockBuilder.current().emit( @@ -153,7 +153,7 @@ def attention_with_fused_qkv( rx.PrimValue(sm_scale), qkv._expr, ], - out_sinfo=rx.TensorStructInfo((b * s, num_qo_heads, d), qkv.dtype), + out_ty=rx.TensorType((b * s, num_qo_heads, d), qkv.dtype), ) ) ).reshape(b, s, num_qo_heads, d) @@ -168,8 +168,8 @@ def self_attention( # pylint: disable=too-many-locals ) -> tuple[Tensor, Tensor]: """Fine-grained API that computes ragged self attention with Q/K/V data.""" # pylint: disable=protected-access - b, s, h_qo, d_qk = q._expr.struct_info.shape - _, _, h_kv, d_v = v._expr.struct_info.shape + b, s, h_qo, d_qk = q._expr.ty.shape + _, _, h_kv, d_v = v._expr.ty.shape q = q.reshape(b * s, h_qo, d_qk) k = k.reshape(b * s, h_kv, d_qk) v = v.reshape(b * s, h_kv, d_v) @@ -185,14 +185,14 @@ def self_attention( # pylint: disable=too-many-locals k._expr, v._expr, ], - out_sinfo=[ - rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype), - rx.TensorStructInfo((b * s, h_qo), "float32"), + out_ty=[ + rx.TensorType((b * s, h_qo, d_v), q.dtype), + rx.TensorType((b * s, h_qo), "float32"), ], ) ) - assert isinstance(attn_results.struct_info, rx.TupleStructInfo) - assert len(attn_results.struct_info.fields) == 2 + assert isinstance(attn_results.ty, rx.TupleType) + assert len(attn_results.ty.fields) == 2 o = Tensor(_expr=bb.emit(rx.TupleGetItem(attn_results, 0))).reshape(b, s, h_qo, d_v) lse = Tensor(_expr=bb.emit(rx.TupleGetItem(attn_results, 1))).reshape(b, s, h_qo) return o, lse @@ -206,7 +206,7 @@ def cross_attention( ) -> tuple[Tensor, Tensor]: """Fine-grained API that computes paged cross attention with Q and in-cache KV data.""" # pylint: disable=protected-access - b, s, h_qo, d_qk = q._expr.struct_info.shape + b, s, h_qo, d_qk = q._expr.ty.shape q = q.reshape(b * s, h_qo, d_qk) bb = rx.BlockBuilder.current() attn_results = bb.emit( @@ -218,14 +218,14 @@ def cross_attention( rx.PrimValue(sm_scale), q._expr, ], - out_sinfo=[ - rx.TensorStructInfo((b * s, h_qo, v_head_dim), q.dtype), - rx.TensorStructInfo((b * s, h_qo), "float32"), + out_ty=[ + rx.TensorType((b * s, h_qo, v_head_dim), q.dtype), + rx.TensorType((b * s, h_qo), "float32"), ], ) ) - assert isinstance(attn_results.struct_info, rx.TupleStructInfo) - assert len(attn_results.struct_info.fields) == 2 + assert isinstance(attn_results.ty, rx.TupleType) + assert len(attn_results.ty.fields) == 2 o = Tensor(_expr=bb.emit(rx.TupleGetItem(attn_results, 0))).reshape(b, s, h_qo, v_head_dim) lse = Tensor(_expr=bb.emit(rx.TupleGetItem(attn_results, 1))).reshape(b, s, h_qo) return o, lse @@ -233,7 +233,7 @@ def cross_attention( def append_mla_kv(self, layer_id: int, kv: Tensor) -> "PagedKVCache": """Fine-grained API that appends the MLA K/V data to KV cache.""" # pylint: disable=protected-access - b, s, _, d_qk = kv._expr.struct_info.shape + b, s, _, d_qk = kv._expr.ty.shape kv = kv.reshape(b * s, d_qk) return PagedKVCache( _expr=rx.call_pure_packed( @@ -241,7 +241,7 @@ def append_mla_kv(self, layer_id: int, kv: Tensor) -> "PagedKVCache": self._expr, rx.PrimValue(layer_id), # type: ignore[arg-type] kv._expr, - sinfo_args=rx.ObjectStructInfo(), + ty_args=rx.ObjectType(), ), _name="paged_kv_cache", ) @@ -257,7 +257,7 @@ def merge_attn_output_inplace( The first two tensors will be inplace updated. """ # pylint: disable=protected-access - b, s, h_qo, d_v = o_self_attn._expr.struct_info.shape + b, s, h_qo, d_v = o_self_attn._expr.ty.shape o_self_attn = o_self_attn.reshape(b * s, h_qo, d_v) lse_self_attn = lse_self_attn.reshape(b * s, h_qo) o_cross_attn = o_cross_attn.reshape(b * s, h_qo, d_v) @@ -271,13 +271,13 @@ def merge_attn_output_inplace( lse_self_attn._expr, o_cross_attn._expr, lse_cross_attn._expr, - sinfo_args=rx.TupleStructInfo( - [o_self_attn._expr.struct_info, lse_self_attn._expr.struct_info] + ty_args=rx.TupleType( + [o_self_attn._expr.ty, lse_self_attn._expr.ty] ), ) ) - assert isinstance(merge_results.struct_info, rx.TupleStructInfo) - assert len(merge_results.struct_info.fields) == 2 + assert isinstance(merge_results.ty, rx.TupleType) + assert len(merge_results.ty.fields) == 2 o_self_attn = Tensor(_expr=bb.emit(rx.TupleGetItem(merge_results, 0))).reshape( b, s, h_qo, d_v ) @@ -304,7 +304,7 @@ def get_query_positions(self, total_length: tirx.PrimExpr) -> Tensor: rx.call_pure_packed( "vm.builtin.attention_kv_cache_get_query_positions", self._expr, - sinfo_args=rx.TensorStructInfo((total_length,), "int32"), + ty_args=rx.TensorType((total_length,), "int32"), ) ) ) @@ -505,7 +505,7 @@ def __init__( # pylint: disable=too-many-locals _expr=rx.call_pure_packed( "vm.builtin.paged_attention_kv_cache_create", *args, - sinfo_args=rx.ObjectStructInfo(), + ty_args=rx.ObjectType(), ), _name=name, ) @@ -680,7 +680,7 @@ def __init__( # pylint: disable=too-many-locals _expr=rx.call_pure_packed( "vm.builtin.paged_attention_kv_cache_create", *args, - sinfo_args=rx.ObjectStructInfo(), + ty_args=rx.ObjectType(), ), _name=name, ) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 8753d37bace1..6bd90c9c4ab9 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -43,7 +43,7 @@ def emit_init(self, name_hint, builder: rx.BlockBuilder) -> list[rx.DataflowVar] def create(self, name_hint: str) -> list[rx.Var]: assert self.effect is None - effect = rx.Var(f"{name_hint}.io", struct_info=rx.ObjectStructInfo()) + effect = rx.Var(f"{name_hint}.io", ty=rx.ObjectType()) return [effect] def set_state(self, state_vars: list[rx.Var]) -> None: @@ -812,7 +812,7 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg rx.op.zeros(init_shape, self.dtype), init_shape, rx.PrimValue(0), - sinfo_args=rx.ObjectStructInfo(), + ty_args=rx.ObjectType(), ), name_hint=name_hint, ) @@ -832,7 +832,7 @@ def create(self, name_hint: str) -> list[rx.Var]: ret : List[relax.Var] The relax.Var for KVCache. """ - cache = rx.Var(name_hint, struct_info=rx.ObjectStructInfo()) + cache = rx.Var(name_hint, ty=rx.ObjectType()) return [cache] def set_state(self, state_vars: list[rx.Var]) -> None: @@ -884,7 +884,7 @@ def view(self, seq_len: tirx.Var) -> Tensor: "vm.builtin.attention_kv_cache_view", self.cache, shape, - sinfo_args=rx.TensorStructInfo(shape, self.dtype), + ty_args=rx.TensorType(shape, self.dtype), ) ) ) @@ -908,7 +908,7 @@ def append(self, new_element: Tensor) -> None: self.cache, new_element._expr, inplace_indices=[0], - sinfo_args=rx.ObjectStructInfo(), + ty_args=rx.ObjectType(), ) ) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 80108e317ec5..f31c276ca0d2 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1362,7 +1362,7 @@ def layer_norm( normalized_shape = [normalized_shape] dim_num = len(normalized_shape) axes = list(range(-dim_num, 0)) - dtype = x._expr.struct_info.dtype + dtype = x._expr.ty.dtype if weight is not None: weight = weight._expr @@ -1480,7 +1480,7 @@ def group_norm( weight = weight._expr if bias is not None: bias = bias._expr - dim = len(x._expr.struct_info.shape) + dim = len(x._expr.ty.shape) if axes is None: axes = list(range(2, dim)) return wrap_nested( @@ -2084,9 +2084,9 @@ def tensor_ir_op( ) if isinstance(out, Tensor): - out_sinfo = [out._expr.struct_info] + out_ty = [out._expr.ty] else: - out_sinfo = [x._expr.struct_info for x in out] + out_ty = [x._expr.ty for x in out] bb = BlockBuilder.current() global_var = bb.add_func(func, name_hint) @@ -2095,7 +2095,7 @@ def tensor_ir_op( tir_vars = None return wrap_nested( - bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, tir_vars=tir_vars)), + bb.emit(rx.call_tir(global_var, call_tir_args, out_ty, tir_vars=tir_vars)), name=name_hint, ) @@ -2154,17 +2154,15 @@ def tensor_ir_inplace_op( ) if isinstance(out, Tensor): - out_sinfo = [out._expr.struct_info] + out_ty = [out._expr.ty] else: - out_sinfo = [x._expr.struct_info for x in out] + out_ty = [x._expr.ty for x in out] bb = BlockBuilder.current() global_var = bb.add_func(func, name_hint) return wrap_nested( - bb.emit( - rx.call_tir_inplace(global_var, call_tir_args, inplace_indices, out_sinfo, tir_vars) - ), + bb.emit(rx.call_tir_inplace(global_var, call_tir_args, inplace_indices, out_ty, tir_vars)), name=name_hint, ) @@ -2211,12 +2209,12 @@ def _convert(arg, name: str): raise TypeError(f"Unsupported input type: {type(arg)}") rx_inputs = _convert(args, "input") - rx_outputs_sinfo = _convert(out, "dummy").struct_info + rx_outputs_ty = _convert(out, "dummy").ty return wrap_nested( _op.call_dps_packed( name, args=rx_inputs, - out_sinfo=rx_outputs_sinfo, + out_ty=rx_outputs_ty, ), name, ) # type: ignore @@ -2282,7 +2280,7 @@ def debug_func(lineno: str, arg_0, arg_1, ...) -> None: rx.StringImm(name), rx.StringImm(_line_info), *converted_args, - sinfo_args=[rx.ObjectStructInfo()], + ty_args=[rx.ObjectType()], ), name_hint=io.effect.name_hint, ) diff --git a/python/tvm/relax/frontend/nn/subroutine.py b/python/tvm/relax/frontend/nn/subroutine.py index e821756e8d81..d197355998ef 100644 --- a/python/tvm/relax/frontend/nn/subroutine.py +++ b/python/tvm/relax/frontend/nn/subroutine.py @@ -41,11 +41,11 @@ def _camel_to_snake(name): def _normalize_expr(block_builder, arg, as_relax_expr=False): - """Ensure that an argument is a relax.Expr with struct info""" + """Ensure that an argument is a relax.Expr with type""" if isinstance(arg, tuple): arg = relax.Tuple([_normalize_expr(block_builder, element) for element in arg]) - if isinstance(arg, relax.Expr) and getattr(arg, "struct_info_", None) is None: + if isinstance(arg, relax.Expr) and getattr(arg, "ty", None) is None: arg = block_builder.emit(arg) if isinstance(arg, nn.Tensor) and as_relax_expr: @@ -54,15 +54,15 @@ def _normalize_expr(block_builder, arg, as_relax_expr=False): return arg -def _get_struct_info(arg): +def _get_ty(arg): if isinstance(arg, relax.Expr): - return arg.struct_info_ + return arg.ty elif isinstance(arg, nn.Tensor): - return arg._expr.struct_info_ + return arg._expr.ty elif isinstance(arg, tuple | list | tvm_ffi.Array): - return relax.TupleStructInfo([_get_struct_info(field) for field in arg]) + return relax.TupleType([_get_ty(field) for field in arg]) else: - raise TypeError(f"Cannot find struct info for {arg} of type {type(arg)}") + raise TypeError(f"Cannot find type for {arg} of type {type(arg)}") class SubroutineMixin: @@ -108,7 +108,7 @@ def new_forward(self, *args, **kwargs): out = subroutine(*subroutine_args) if is_nn_tensor_output: - if out.struct_info_ is None: + if out.ty is None: out = block_builder.emit(out, name_hint=f"{subroutine.name_hint}_output") out = nn.Tensor(_expr=out) return out @@ -141,20 +141,20 @@ def _get_subroutine( param._expr if isinstance(param, nn.Tensor) else param for param in self.parameters() ] - arg_sinfo = _get_struct_info([*func_args.values(), *model_params]) + arg_ty = _get_ty([*func_args.values(), *model_params]) is_dataflow = block_builder.current_block_is_dataflow() lookup_key = ( old_forward, - tvm_ffi.structural_hash(arg_sinfo, map_free_vars=True), + tvm_ffi.structural_hash(arg_ty, map_free_vars=True), is_dataflow, ) - for cached_sinfo, cached_result in cls._gvar.get(lookup_key, []): - if tvm_ffi.structural_equal(cached_sinfo, arg_sinfo, map_free_vars=True): + for cached_ty, cached_result in cls._gvar.get(lookup_key, []): + if tvm_ffi.structural_equal(cached_ty, arg_ty, map_free_vars=True): return cached_result func_name = _camel_to_snake(cls.__name__) - func_params = [relax.Var(name, sinfo) for name, sinfo in zip(func_args, arg_sinfo.fields)] + func_params = [relax.Var(name, ty) for name, ty in zip(func_args, arg_ty.fields)] old_forward_args = [ nn.Tensor(_expr=param) if isinstance(old_arg, nn.Tensor) else param for param, old_arg in zip(func_params, func_args.values()) @@ -175,7 +175,7 @@ def _get_subroutine( gvar = block_builder.emit_func_output(out) # The relax.Var instances in model_params, along with any - # tirx.Var instances in the struct info, appear in both the + # tirx.Var instances in the type, appear in both the # calling scope and as parameters for the subroutine. To # maintain SSA, replace all relax and TIR variables in the # subroutine. @@ -184,5 +184,5 @@ def _get_subroutine( result = (gvar, is_nn_tensor_output) bucket = cls._gvar.setdefault(lookup_key, []) - bucket.append((arg_sinfo, result)) + bucket.append((arg_ty, result)) return result diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index cdb213f10db4..0c36d1c0fb77 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -340,9 +340,9 @@ def _impl_v10(cls, bb, inputs, attr, params): x, scale = inputs[0], inputs[1] zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None axis = attr.get("axis", 1) - if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1: + if hasattr(x.ty, "ndim") and x.ty.ndim <= 1 and axis == 1: axis = 0 - out_dtype = "uint8" if zp is None else zp.struct_info.dtype + out_dtype = "uint8" if zp is None else zp.ty.dtype if zp is None: zp = relax.const(0, out_dtype) return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype) @@ -352,9 +352,9 @@ def _impl_v13(cls, bb, inputs, attr, params): x, scale = inputs[0], inputs[1] zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None axis = attr.get("axis", 1) - if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1: + if hasattr(x.ty, "ndim") and x.ty.ndim <= 1 and axis == 1: axis = 0 - out_dtype = "uint8" if zp is None else zp.struct_info.dtype + out_dtype = "uint8" if zp is None else zp.ty.dtype if zp is None: zp = relax.const(0, out_dtype) return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype) @@ -366,10 +366,10 @@ def _impl_v10(cls, bb, inputs, attr, params): x, scale = inputs[0], inputs[1] zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None axis = attr.get("axis", 1) - if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1: + if hasattr(x.ty, "ndim") and x.ty.ndim <= 1 and axis == 1: axis = 0 if zp is None: - zp = relax.const(0, x.struct_info.dtype) + zp = relax.const(0, x.ty.dtype) return relax.op.dequantize(x, scale, zp, axis=axis, out_dtype="float32") @classmethod @@ -377,10 +377,10 @@ def _impl_v13(cls, bb, inputs, attr, params): x, scale = inputs[0], inputs[1] zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None axis = attr.get("axis", 1) - if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis == 1: + if hasattr(x.ty, "ndim") and x.ty.ndim <= 1 and axis == 1: axis = 0 if zp is None: - zp = relax.const(0, x.struct_info.dtype) + zp = relax.const(0, x.ty.dtype) return relax.op.dequantize(x, scale, zp, axis=axis, out_dtype="float32") @@ -388,7 +388,7 @@ class DynamicQuantizeLinear(OnnxOpConverter): @classmethod def _impl_v11(cls, bb, inputs, attr, params): x = inputs[0] - x_dtype = x.struct_info.dtype + x_dtype = x.ty.dtype qmin = relax.const(0, x_dtype) qmax = relax.const(255, x_dtype) @@ -420,22 +420,18 @@ def _impl_v1(cls, bb, inputs, attr, params): raise ValueError(f"MatMulInteger16 expects two inputs, but got {len(inputs)}") a, b = inputs valid_types = ["int16", "uint16"] - if a.struct_info.dtype not in valid_types: + if a.ty.dtype not in valid_types: raise ValueError( "MatMulInteger16 expects input A to have int16 or uint16 dtype, " - f"but got {a.struct_info.dtype}" + f"but got {a.ty.dtype}" ) - if b.struct_info.dtype not in valid_types: + if b.ty.dtype not in valid_types: raise ValueError( "MatMulInteger16 expects input B to have int16 or uint16 dtype, " - f"but got {b.struct_info.dtype}" + f"but got {b.ty.dtype}" ) - out_dtype = ( - "uint32" - if a.struct_info.dtype == "uint16" and b.struct_info.dtype == "uint16" - else "int32" - ) + out_dtype = "uint32" if a.ty.dtype == "uint16" and b.ty.dtype == "uint16" else "int32" return relax.op.matmul( relax.op.astype(a, out_dtype), relax.op.astype(b, out_dtype), @@ -535,8 +531,8 @@ class Div(BinaryBase): @classmethod def _impl_v7(cls, bb, inputs, attr, params): try: - lhs_code = DataType(inputs[0].struct_info.dtype).type_code - rhs_code = DataType(inputs[1].struct_info.dtype).type_code + lhs_code = DataType(inputs[0].ty.dtype).type_code + rhs_code = DataType(inputs[1].ty.dtype).type_code except (AttributeError, ValueError, TypeError, RuntimeError): return cls.base_impl(bb, inputs, attr, params) @@ -682,10 +678,10 @@ def base_impl(cls, bb, inputs, attr, params): """Base implementation for bitwise operations.""" valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] for num, inp in enumerate(inputs): - if inp.struct_info.dtype not in valid_types: + if inp.ty.dtype not in valid_types: raise ValueError( f"Bitwise operations expect all inputs to have integer types, " - f"got {inp.struct_info.dtype} for input {num}" + f"got {inp.ty.dtype} for input {num}" ) return super().base_impl(bb, inputs, attr, params) @@ -729,7 +725,7 @@ class BitwiseNot(OnnxOpConverter): @classmethod def _impl_v18(cls, bb, inputs, attr, params): if isinstance(inputs[0], relax.Constant): - return relax.const(_np.bitwise_not(inputs[0].data.numpy()), inputs[0].struct_info.dtype) + return relax.const(_np.bitwise_not(inputs[0].data.numpy()), inputs[0].ty.dtype) return relax.op.bitwise_not(inputs[0]) @@ -798,13 +794,13 @@ def _legacy_softmax_prepare( return None axis = _normalize_legacy_softmax_axis(axis, rank, op_name) - struct_info = data.struct_info - if not isinstance(struct_info, relax.TensorStructInfo): + ty = data.ty + if not isinstance(ty, relax.TensorType): return None - if not isinstance(struct_info.shape, relax.ShapeExpr): + if not isinstance(ty.shape, relax.ShapeExpr): return None - original_shape = list(struct_info.shape.values) + original_shape = list(ty.shape.values) if len(original_shape) != rank: return None @@ -822,11 +818,9 @@ def _get_axis_extent(data: relax.Expr, axis: int, op_name: str) -> tuple[int, in raise ValueError(f"{op_name} requires a statically known input rank.") normalized_axis = _normalize_constant_axes([axis], rank, op_name)[0] - struct_info = data.struct_info - if isinstance(struct_info, relax.TensorStructInfo) and isinstance( - struct_info.shape, relax.ShapeExpr - ): - axis_extent = struct_info.shape.values[normalized_axis] + ty = data.ty + if isinstance(ty, relax.TensorType) and isinstance(ty.shape, relax.ShapeExpr): + axis_extent = ty.shape.values[normalized_axis] if isinstance(axis_extent, tirx.IntImm): axis_extent = int(axis_extent.value) return normalized_axis, axis_extent @@ -908,7 +902,7 @@ def _hardmax_impl(cls, *args): if bb is not None: data = bb.normalize(data) normalized_axis, axis_extent = _get_axis_extent(data, axis, "Hardmax") - dtype = data.struct_info.dtype + dtype = data.ty.dtype argmax = relax.op.argmax(data, axis=normalized_axis) on_value = relax.PrimValue(tvm.tirx.const(1.0, dtype)) off_value = relax.PrimValue(tvm.tirx.const(0.0, dtype)) @@ -949,10 +943,10 @@ def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] axes = attr.get("perm", None) - if hasattr(data.struct_info, "ndim"): - input_ndim = data.struct_info.ndim - elif hasattr(data.struct_info, "shape") and data.struct_info.shape: - input_ndim = len(data.struct_info.shape) + if hasattr(data.ty, "ndim"): + input_ndim = data.ty.ndim + elif hasattr(data.ty, "shape") and data.ty.shape: + input_ndim = len(data.ty.shape) else: if isinstance(data, relax.Constant): input_ndim = data.data.numpy().ndim @@ -1015,7 +1009,7 @@ def _impl_v13(cls, bb, inputs, attr, params): else: new_shape.append(next(input_dims_iter)) expanded = expanded.reshape(new_shape) - return relax.const(expanded, data.struct_info.dtype) + return relax.const(expanded, data.ty.dtype) if isinstance(axes, relax.Constant): if data_ndim is None: @@ -1056,7 +1050,7 @@ def is_shape_like(x: Any) -> bool: if isinstance(x, relax.ShapeExpr): return True elif isinstance(x, relax.Constant): - return x.struct_info.ndim == 1 and x.struct_info.dtype == "int64" + return x.ty.ndim == 1 and x.ty.dtype == "int64" else: return False @@ -1090,7 +1084,7 @@ def resolve(x): for inp in inputs: const_inputs.append(inp.data.numpy()) out = _np.concatenate(const_inputs, axis=axis) - dtype = inputs[0].struct_info.dtype + dtype = inputs[0].ty.dtype return relax.const(out, dtype) return relax.op.concat(inputs, axis=axis) @@ -1120,7 +1114,7 @@ def _impl_v13(cls, bb, inputs, attr, params): if np_dst.kind in ("i", "u"): src = inputs[0] - src_dtype = getattr(getattr(src, "struct_info", None), "dtype", None) or getattr( + src_dtype = getattr(getattr(src, "ty", None), "dtype", None) or getattr( src, "dtype", None ) if src_dtype is not None and _relax_dtype_is_floating_point(src_dtype): @@ -1199,7 +1193,7 @@ def _impl_v13(cls, bb, inputs, attr, params): shape_val = data[np_index] return relax.PrimValue(shape_val) - indices_dtype = indices.struct_info.dtype + indices_dtype = indices.ty.dtype if not indices_dtype.startswith("uint"): data_shape = bb.normalize(relax.op.shape_of(data)) data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) @@ -1320,13 +1314,13 @@ def _impl_v11(cls, bb, inputs, attr, params): axis = attr.get("axis", None) # Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4] - if condition.struct_info.dtype != "bool": + if condition.ty.dtype != "bool": raise ValueError("Condition tensor is expected to be a boolean tensor") - if condition.struct_info.ndim != 1: + if condition.ty.ndim != 1: raise ValueError("Condition tensor is expected to be a 1D boolean tensor") indices = relax.op.nonzero(condition) num_nonzero = tirx.Var("num_nonzero", "int64") - indices = bb.match_cast(indices, relax.TensorStructInfo([1, num_nonzero], "int64")) + indices = bb.match_cast(indices, relax.TensorType([1, num_nonzero], "int64")) indices = relax.op.reshape(indices, [-1]) if axis is not None: @@ -1351,7 +1345,7 @@ class EyeLike(OnnxOpConverter): @classmethod def _impl_v9(cls, bb, inputs, attr, params): k = attr.get("k", 0) - input_dtype = inputs[0].struct_info.dtype + input_dtype = inputs[0].ty.dtype if "dtype" in attr and get_type(attr["dtype"]) != input_dtype: raise ValueError( f"dtype mismatch between input ({input_dtype}) and attribute ({attr['dtype']})" @@ -1371,7 +1365,7 @@ def _impl_v13(cls, bb, inputs, attr, params): A = inputs[0] B = inputs[1] C = inputs[2] - dtype = A.struct_info.dtype + dtype = A.ty.dtype # Compute Y = alpha * A X B + beta * C @@ -1444,7 +1438,7 @@ class Clip(OnnxOpConverter): @staticmethod def _sanitize_nan_clip_bound(bb, bound: relax.Expr, *, for_min: bool) -> relax.Expr: """ONNX/ORT treat NaN clip bounds as unbounded; plain max/min with NaN poisons output.""" - dtype = bound.struct_info.dtype + dtype = bound.ty.dtype if not _relax_dtype_is_floating_point(dtype): return bound repl = -_np.inf if for_min else _np.inf @@ -1482,14 +1476,14 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): - data_info = inputs[0].struct_info + data_info = inputs[0].ty - if isinstance(data_info, relax.ShapeStructInfo): + if isinstance(data_info, relax.ShapeType): if data_info.ndim == -1: raise ValueError("The ndim of ShapeExpr is expected to a real number, but got -1.") return relax.ShapeExpr([data_info.ndim]) - # If no shape is defined in the struct info, it must be computed at runtime. + # If no shape is defined in the type, it must be computed at runtime. if not data_info.shape: data_shape = bb.normalize(relax.op.shape_of(inputs[0])) return data_shape @@ -1563,7 +1557,7 @@ class Mish(OnnxOpConverter): @classmethod def _impl_v18(cls, bb, inputs, attr, params): - dtype = inputs[0].struct_info.dtype + dtype = inputs[0].ty.dtype return inputs[0] * relax.op.tanh( relax.op.log(relax.const(1.0, dtype) + relax.op.exp(inputs[0])) ) @@ -1580,8 +1574,8 @@ def _impl_v1(cls, bb, inputs, attr, params): x = inputs[0] slope = inputs[1] - x_shape = x.struct_info.shape - slope_shape = slope.struct_info.shape + x_shape = x.ty.shape + slope_shape = slope.ty.shape ndim = len(x_shape) s_ndim = len(slope_shape) @@ -1673,12 +1667,12 @@ def _impl_v1(cls, bb, inputs, attr, params): x = inputs[0] if len(inputs) > 1 and inputs[1] is not None: bias = inputs[1] - bias_shape = bias.struct_info.shape + bias_shape = bias.ty.shape assert len(bias_shape) == 1, "bias term must be a 1D tensor" x = bb.emit(relax.op.add(x, bias)) # Declare consts - const_dtype = x.struct_info.dtype + const_dtype = x.ty.dtype half = relax.const(0.5, dtype=const_dtype) one = relax.const(1.0, dtype=const_dtype) const1 = relax.const(math.sqrt(2 / math.pi), dtype=const_dtype) @@ -1715,7 +1709,7 @@ class Shrink(OnnxOpConverter): @classmethod def _impl_v9(cls, bb, inputs, attr, params): x = inputs[0] - dtype = x.struct_info.dtype + dtype = x.ty.dtype lambd = relax.const(attr.get("lambd", 0.5), dtype) bias = relax.const(attr.get("bias", 0.0), dtype) zeros = relax.op.zeros_like(x) @@ -1730,13 +1724,13 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v11(cls, bb, inputs, attr, params): data = inputs[0] - if hasattr(inputs[0].struct_info, "ndim"): - ndim = inputs[0].struct_info.ndim + if hasattr(inputs[0].ty, "ndim"): + ndim = inputs[0].ty.ndim else: - ndim = len(inputs[0].struct_info.shape) + ndim = len(inputs[0].ty.shape) if "kernel_shape" not in attr: - attr["kernel_shape"] = inputs[1].struct_info.shape.values[2:] + attr["kernel_shape"] = inputs[1].ty.shape.values[2:] if ndim == 3: op = relax.op.nn.conv1d @@ -1800,10 +1794,10 @@ class ConvTranspose(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - if hasattr(inputs[0].struct_info, "ndim"): - ndim = inputs[0].struct_info.ndim + if hasattr(inputs[0].ty, "ndim"): + ndim = inputs[0].ty.ndim else: - ndim = len(inputs[0].struct_info.shape) + ndim = len(inputs[0].ty.shape) if ndim == 3: op = relax.op.nn.conv1d_transpose @@ -1827,7 +1821,7 @@ def _impl_v1(cls, bb, inputs, attr, params): if "kernel_shape" in attr: kernel_shape = list(attr["kernel_shape"]) else: - kernel_shape = [int(s) for s in inputs[1].struct_info.shape.values[2:]] + kernel_shape = [int(s) for s in inputs[1].ty.shape.values[2:]] # Resolve `auto_pad` per ONNX ConvTranspose spec. Unlike Conv, the spec # derives `pads` from `output_shape`/`strides` when auto_pad is SAME_*, @@ -1917,9 +1911,7 @@ def _impl_v14(cls, bb, inputs, attr, params): f"got shape {axis_data.shape}" ) elif isinstance(axis_input, relax.Var): - axis_shape = ( - axis_input.struct_info.shape if hasattr(axis_input.struct_info, "shape") else None - ) + axis_shape = axis_input.ty.shape if hasattr(axis_input.ty, "shape") else None raise ValueError( "CumSum with non-constant axis input is not supported yet. " "ONNX permits runtime axis tensors, but Relax/TE currently requires a compile-time " @@ -1957,7 +1949,7 @@ def _impl_v13(cls, bb, inputs, attr, params): else: raise NotImplementedError("Squeeze with symbolic axes not supported") - return relax.const(out_data, data.struct_info.dtype) + return relax.const(out_data, data.ty.dtype) if isinstance(data, relax.ShapeExpr): shape_tensor_ndim = 1 @@ -2146,7 +2138,7 @@ class Neg(OnnxOpConverter): def _impl_v13(cls, bb, inputs, attr, params): if isinstance(inputs[0], relax.Constant): data_np = inputs[0].data.numpy() - return relax.const(_np.negative(data_np), inputs[0].struct_info.dtype) + return relax.const(_np.negative(data_np), inputs[0].ty.dtype) if isinstance(inputs[0], relax.PrimValue): return relax.PrimValue(-inputs[0].value) return relax.op.negative(inputs[0]) @@ -2168,7 +2160,7 @@ class Reciprocal(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): - input_dtype = inputs[0].struct_info.dtype + input_dtype = inputs[0].ty.dtype return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0]) @@ -2251,7 +2243,7 @@ def _impl_v1(cls, bb, inputs, attr, params): output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable return relax.const(output, output.dtype) - input_shapes = [inp.struct_info.shape for inp in inputs] + input_shapes = [inp.ty.shape for inp in inputs] target_shape = functools.reduce(compute_broadcast_shape, input_shapes) # broadcast_to, stack them, then perform minimum over the new axis. @@ -2294,7 +2286,7 @@ class Log(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): if isinstance(inputs[0], relax.Constant): - return relax.const(_np.log(inputs[0].data.numpy()), inputs[0].struct_info.dtype) + return relax.const(_np.log(inputs[0].data.numpy()), inputs[0].ty.dtype) return relax.op.log(inputs[0]) @@ -2309,7 +2301,7 @@ def _check_type(cls, dtype, valid_types): def _impl_v1(cls, bb, inputs, attr, params): data = inputs[0] valid_types = ["float", "float32", "double", "float64", "float16"] - cls._check_type(data.struct_info.dtype, valid_types) + cls._check_type(data.ty.dtype, valid_types) return relax.op.exp(data) @@ -2317,7 +2309,7 @@ def _impl_v1(cls, bb, inputs, attr, params): def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] valid_types = ["float", "float32", "double", "float64", "float16", "bfloat16"] - cls._check_type(data.struct_info.dtype, valid_types) + cls._check_type(data.ty.dtype, valid_types) return relax.op.exp(data) @@ -2327,7 +2319,7 @@ class Softplus(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - dtype = inputs[0].struct_info.dtype + dtype = inputs[0].ty.dtype threshold = 10.0 if dtype == "float16" else 20.0 return relax.op.nn.softplus(inputs[0], threshold=threshold) @@ -2337,7 +2329,7 @@ class Softsign(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - dtype = inputs[0].struct_info.dtype + dtype = inputs[0].ty.dtype return inputs[0] / (relax.op.abs(inputs[0]) + relax.const(1, dtype=dtype)) @@ -2363,7 +2355,7 @@ def _impl_v13(cls, bb, inputs, attr, params): splits = inputs[1] splits_rank = None if splits is not None: - splits_rank = splits.struct_info.ndim + splits_rank = splits.ty.ndim if splits is not None and splits_rank > 0: if isinstance(splits, relax.Constant): splits = splits.data.numpy() @@ -2399,9 +2391,9 @@ def _get_known_tensor_rank(expr: relax.Expr) -> int | None: return 1 if isinstance(expr, relax.PrimValue): return 0 - struct_info = expr.struct_info - if isinstance(struct_info, relax.TensorStructInfo): - return None if struct_info.ndim == -1 else struct_info.ndim + ty = expr.ty + if isinstance(ty, relax.TensorType): + return None if ty.ndim == -1 else ty.ndim return None @@ -2419,15 +2411,15 @@ def _get_known_tensor_length(expr: relax.Expr | None) -> int | None: return len(expr.values) if isinstance(expr, relax.PrimValue): return 1 - struct_info = expr.struct_info - if not isinstance(struct_info, relax.TensorStructInfo): + ty = expr.ty + if not isinstance(ty, relax.TensorType): return None - if struct_info.ndim == -1: + if ty.ndim == -1: return None - if struct_info.ndim != 1: - raise ValueError(f"Expected a 1-D tensor, but got ndim={struct_info.ndim}.") - if isinstance(struct_info.shape, relax.ShapeExpr): - dim = struct_info.shape.values[0] + if ty.ndim != 1: + raise ValueError(f"Expected a 1-D tensor, but got ndim={ty.ndim}.") + if isinstance(ty.shape, relax.ShapeExpr): + dim = ty.shape.values[0] if isinstance(dim, tirx.IntImm): return int(dim.value) if isinstance(dim, int): @@ -2459,10 +2451,10 @@ def _as_int64_tensor(bb: relax.BlockBuilder, expr: relax.Expr) -> relax.Expr: if isinstance(expr, relax.PrimValue): return bb.normalize(relax.op.full((1,), expr, dtype="int64")) if isinstance(expr, relax.Constant): - if expr.struct_info.dtype == "int64": + if expr.ty.dtype == "int64": return expr return bb.normalize(relax.op.astype(expr, "int64")) - if isinstance(expr.struct_info, relax.TensorStructInfo) and expr.struct_info.dtype != "int64": + if isinstance(expr.ty, relax.TensorType) and expr.ty.dtype != "int64": return bb.normalize(relax.op.astype(expr, "int64")) return expr @@ -2472,10 +2464,10 @@ def _tensor_to_shape_expr( ) -> relax.ShapeExpr: """Convert a statically sized int64 tensor into a ShapeExpr.""" - shape_tensor = bb.match_cast(shape_tensor, relax.TensorStructInfo([shape_ndim], "int64")) + shape_tensor = bb.match_cast(shape_tensor, relax.TensorType([shape_ndim], "int64")) shape_dataflow_var = bb.emit(relax.op.tensor_to_shape(shape_tensor)) shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in range(shape_ndim)] - bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) + bb.match_cast(shape_dataflow_var, relax.ShapeType(shape_vars)) return relax.ShapeExpr(shape_vars) @@ -2546,7 +2538,7 @@ def _build_squeezed_shape_tensor( keep_mask = bb.normalize(relax.op.equal(remove_mask, relax.const(0, "int64"))) keep_indices = bb.normalize(relax.op.nonzero(keep_mask)) num_keep_dims = tirx.Var("squeeze_num_keep_dims", "int64") - keep_indices = bb.match_cast(keep_indices, relax.TensorStructInfo([1, num_keep_dims], "int64")) + keep_indices = bb.match_cast(keep_indices, relax.TensorType([1, num_keep_dims], "int64")) keep_indices = bb.normalize(relax.op.reshape(keep_indices, [-1])) return bb.normalize(relax.op.take(data_shape_tensor, keep_indices, axis=0)) @@ -2701,7 +2693,7 @@ class Pad(OnnxOpConverter): @classmethod def _impl_v2(cls, bb, inputs, attr, params): pads = attr.get("pads") - pads = relax.const(_np.array(pads), inputs[0].struct_info.shape[0].dtype) + pads = relax.const(_np.array(pads), inputs[0].ty.shape[0].dtype) constant_value = attr.get("value") if constant_value is None: constant_value = 0.0 @@ -2763,7 +2755,7 @@ class Tile(OnnxOpConverter): @staticmethod def _tensor_length(expr): - shape = expr.struct_info.shape + shape = expr.ty.shape if not isinstance(shape, relax.ShapeExpr): return None @@ -2780,12 +2772,12 @@ def _impl_v13(cls, bb, inputs, attr, params): return bb.emit_te(topi.tile, inputs[0], reps) data = inputs[0] - data_ndim = data.struct_info.ndim + data_ndim = data.ty.ndim reps_len = cls._tensor_length(reps) if data_ndim == -1 or reps_len is None: raise ValueError("Dynamic Tile requires known input rank and repeats length.") - if reps.struct_info.dtype != "int64": + if reps.ty.dtype != "int64": reps = bb.normalize(relax.op.astype(reps, "int64")) data_shape = bb.normalize(relax.op.shape_of(data)) @@ -2810,7 +2802,7 @@ def _impl_v13(cls, bb, inputs, attr, params): output_shape_vars = [ tirx.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, reps_len)) ] - bb.match_cast(output_shape, relax.ShapeStructInfo(output_shape_vars)) + bb.match_cast(output_shape, relax.ShapeType(output_shape_vars)) return bb.emit_te(topi.dyn_tile, data, output_shape_vars, reps_len) @@ -2822,7 +2814,7 @@ def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] shape = inputs[1] if isinstance(shape, relax.ShapeExpr): - data_shape = list(data.struct_info.shape) + data_shape = list(data.ty.shape) target_shape = list(shape.values) original_data_shape = [ dim.value if hasattr(dim, "value") else str(dim) for dim in data_shape @@ -2875,7 +2867,7 @@ def _impl_v13(cls, bb, inputs, attr, params): new_shape = shape.data.numpy().tolist() # ONNX Expand operator requires preserving target rank and broadcasting # according to standard rules. Dimensions are right-aligned. - data_shape = [dim.value for dim in data.struct_info.shape] + data_shape = [dim.value for dim in data.ty.shape] original_data_shape = data_shape.copy() original_new_shape = new_shape.copy() @@ -2916,22 +2908,22 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape)) # Otherwise handle dynamic shapes. - shape_ndim = next(dim.value for dim in shape.struct_info.shape.values) + shape_ndim = next(dim.value for dim in shape.ty.shape.values) shape_dataflow_var = bb.emit( relax.Call( relax.ExternFunc("vm.builtin.tensor_to_shape"), [shape], - sinfo_args=[relax.ShapeStructInfo(ndim=shape_ndim)], + ty_args=[relax.ShapeType(ndim=shape_ndim)], ) ) shape_vars = [] for i in range(shape_ndim): shape_vars.append(tvm.tirx.Var(f"x_{i}", "int64")) - bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) + bb.match_cast(shape_dataflow_var, relax.ShapeType(shape_vars)) # Applying broadcasting rules for dynamic shapes - data_shape = list(data.struct_info.shape) + data_shape = list(data.ty.shape) data_ndim = len(data_shape) target_ndim = shape_ndim padded_data = data @@ -2993,10 +2985,8 @@ def optional_input(k: int): assert inputs[6] is None, "past_sequence_length is not currently supported" - (batch_size, seq_len, input_hidden_size) = [ - val.value for val in input_emb.struct_info.shape.values - ] - weight_shape = [val.value for val in weight.struct_info.shape.values] + (batch_size, seq_len, input_hidden_size) = [val.value for val in input_emb.ty.shape.values] + weight_shape = [val.value for val in weight.ty.shape.values] assert weight_shape[0] == input_hidden_size, ( "input and weight should share the same input hiden size" @@ -3017,7 +3007,7 @@ def optional_input(k: int): head_size_v = hidden_size_v // num_heads if mask_index is not None: - mask_index_shape = [val.value for val in mask_index.struct_info.shape.values] + mask_index_shape = [val.value for val in mask_index.ty.shape.values] assert mask_index_shape in ( [batch_size, seq_len], [ @@ -3027,14 +3017,12 @@ def optional_input(k: int): ], ), """mask index should be in shape of (batch_size, seq_len), or (batch_size, seq_len, seq_len)""" - mask_bias = relax.op.subtract( - relax.const(1, dtype=mask_index.struct_info.dtype), mask_index - ) - mask_bias = relax.op.astype(mask_bias, dtype=input_emb.struct_info.dtype) + mask_bias = relax.op.subtract(relax.const(1, dtype=mask_index.ty.dtype), mask_index) + mask_bias = relax.op.astype(mask_bias, dtype=input_emb.ty.dtype) mask_bias = bb.normalize( relax.op.multiply( mask_bias, - relax.const(mask_filter_value, dtype=input_emb.struct_info.dtype), + relax.const(mask_filter_value, dtype=input_emb.ty.dtype), ) ) if qk_bias is None: @@ -3053,7 +3041,7 @@ def optional_input(k: int): QKV = relax.op.matmul(input_emb, weight) if bias: - bias_shape = [val.value for val in bias.struct_info.shape.values] + bias_shape = [val.value for val in bias.ty.shape.values] assert bias_shape[0] == weight_shape[1], ( "bias and weight should share the same hidden size sum" ) @@ -3206,7 +3194,7 @@ def _impl_v18(cls, bb, inputs, attr, params): roi = get_constant(inputs[1], params) if len(inputs) > 1 and inputs[1] is not None else None scales = get_constant(inputs[2], params) if len(inputs) > 2 else None sizes = get_constant(inputs[3], params) if len(inputs) > 3 else None - ndims = len(x.struct_info.shape) + ndims = len(x.ty.shape) assert ndims in (3, 4, 5), "Only resize1d/resize2d/resize3d are supported." assert scales is None or sizes is None, ( @@ -3245,7 +3233,7 @@ def _impl_v18(cls, bb, inputs, attr, params): raise ValueError(f"Type {type(scales)} for scale is currently unsupported.") sizes = [] - for i, dim in enumerate(x.struct_info.shape): + for i, dim in enumerate(x.ty.shape): sizes.append(cast(scales[i] * dim, "int64")) sizes = sizes[2:] else: @@ -3365,7 +3353,7 @@ def _impl(cls, bb, inputs, attr, params, default_coordinate_transformation_mode) data = inputs[0] rois = inputs[1] batch_indices = inputs[2] - rois_dtype = rois.struct_info.dtype + rois_dtype = rois.ty.dtype mode = attr.get("mode", b"avg") if isinstance(mode, bytes): @@ -3449,7 +3437,7 @@ def _impl_v12(cls, bb, inputs, attr, params): start = get_constant(inputs[0], params) limit = get_constant(inputs[1], params) delta = get_constant(inputs[2], params) - out_dtype = start.struct_info.dtype + out_dtype = start.ty.dtype if isinstance(start, relax.Constant): start = start.data.numpy().tolist() @@ -3478,9 +3466,9 @@ def _impl_v6(cls, bb, inputs, attr, params): scale = inputs[1] B = inputs[2] epsilon = attr.get("epsilon", 1e-05) - epsilon = relax.const(epsilon, dtype=data.struct_info.dtype) + epsilon = relax.const(epsilon, dtype=data.ty.dtype) - ndim = len(data.struct_info.shape) + ndim = len(data.ty.shape) redux_axes = list(range(2, ndim)) mean = relax.op.mean(data, axis=redux_axes, keepdims=True) @@ -3551,10 +3539,10 @@ def _impl_v13(cls, bb, inputs, attr, params): beta = attr.get("beta", 0.75) bias = attr.get("bias", 1.0) - if hasattr(data.struct_info, "ndim"): - ndim = data.struct_info.ndim + if hasattr(data.ty, "ndim"): + ndim = data.ty.ndim else: - ndim = len(data.struct_info.shape) + ndim = len(data.ty.shape) if ndim not in [3, 4]: raise ValueError(f"LRN only supports 3D or 4D input, got {ndim}D.") @@ -3619,7 +3607,7 @@ def get_pad_pair(cls, input1d, kernel1d, stride1d, mode): def _impl_v1(cls, bb, inputs, attr, params): # Unpack inputs and attributes. data = inputs[0] - input_shape = data.struct_info.shape + input_shape = data.ty.shape ndim = len(input_shape) auto_pad = attr.get("auto_pad", b"NOTSET").decode("utf-8") @@ -3678,7 +3666,7 @@ def _impl_v1(cls, bb, inputs, attr, params): @classmethod def _get_input_spatial_shape(cls, tensor): # shape is (N x C x D1 x D2 ... Dn) - return _np.array([int(d) for d in tensor.struct_info.shape], dtype="int64")[2:] + return _np.array([int(d) for d in tensor.ty.shape], dtype="int64")[2:] class MaxPool(Pool): @@ -3698,10 +3686,10 @@ class LpPool(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - dtype = inputs[0].struct_info.dtype + dtype = inputs[0].ty.dtype p = attr.get("p", 2.0) reci_p = relax.const(1.0 / p, dtype=dtype) - # emit for get struct_info + # emit for get ty data = bb.emit(relax.op.power(inputs[0], relax.const(p, dtype=dtype))) attr.update({"count_include_pad": True}) avg_pool = AveragePool._impl_v1(bb, [data], attr, params) @@ -3715,7 +3703,7 @@ class GlobalAveragePool(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - rank = len(inputs[0].struct_info.shape) + rank = len(inputs[0].ty.shape) axes = list(range(2, rank)) return relax.op.mean(inputs[0], axis=axes, keepdims=True) @@ -3725,7 +3713,7 @@ class GlobalMaxPool(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - rank = len(inputs[0].struct_info.shape) + rank = len(inputs[0].ty.shape) axes = list(range(2, rank)) return relax.op.max(inputs[0], axis=axes, keepdims=True) @@ -3736,8 +3724,8 @@ class GlobalLpPool(OnnxOpConverter): @classmethod def _impl_v2(cls, bb, inputs, attr, params): p = attr.get("p", 2.0) - dtype = inputs[0].struct_info.dtype - rank = len(inputs[0].struct_info.shape) + dtype = inputs[0].ty.dtype + rank = len(inputs[0].ty.shape) axes = list(range(2, rank)) x_abs = relax.op.abs(inputs[0]) x_p = relax.op.power(x_abs, relax.const(p, dtype=dtype)) @@ -3758,7 +3746,7 @@ def _impl_v9(cls, bb, inputs, attr, params): strides = attr.get("strides", [1] * len(kernel_shape)) multiplier = _np.concatenate([[1, 1], list(strides)]) - shape = [v.value for v in data.struct_info.shape] + shape = [v.value for v in data.ty.shape] total_output_shape = multiplier * shape # Add extra dimensions from kernel size and stride mismatch total_output_shape += _np.concatenate([[0, 0], list(kernel_shape)], axis=0) @@ -3778,7 +3766,7 @@ def _impl_v9(cls, bb, inputs, attr, params): # Create a tensor of zeros then scatter our data through it. relax_shape = relax.ShapeExpr(total_output_shape.tolist()) - zeros_tensor = bb.emit(relax.op.zeros(relax_shape, data.struct_info.dtype)) + zeros_tensor = bb.emit(relax.op.zeros(relax_shape, data.ty.dtype)) # We need to flatten all our tensors before scattering. flat_tensor = relax.op.scatter_elements( relax.op.reshape(zeros_tensor, [-1]), @@ -3797,7 +3785,7 @@ class Flatten(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): axis = attr.get("axis", 1) - data_shape = list(inputs[0].struct_info.shape) + data_shape = list(inputs[0].ty.shape) if axis == 0: new_shape = (1, -1) @@ -3829,12 +3817,12 @@ def _impl_v17(cls, bb, inputs, attr, params): axis = attr.get("axis", -1) epsilon = attr.get("epsilon", 1e-05) - gamma_shape = get_const_tuple(scale.struct_info.shape) + gamma_shape = get_const_tuple(scale.ty.shape) if bias is None: - bias = relax.const(_np.zeros(gamma_shape), dtype=scale.struct_info.dtype) + bias = relax.const(_np.zeros(gamma_shape), dtype=scale.ty.dtype) else: - beta_shape = get_const_tuple(bias.struct_info.shape) + beta_shape = get_const_tuple(bias.ty.shape) if gamma_shape != beta_shape: raise ValueError("gamma and beta shapes do not match") @@ -3868,7 +3856,7 @@ def _impl_v23(cls, bb, inputs, attr, params): axes = list(range(axis, ndim)) # If stash_type requires float32 computation and input is not float32, cast - input_dtype = data.struct_info.dtype + input_dtype = data.ty.dtype if stash_type == 1 and input_dtype != "float32": data_compute = relax.op.astype(data, "float32") scale_compute = relax.op.astype(scale, "float32") @@ -4249,7 +4237,7 @@ def _argreduce_select_last_index(bb, data, axis, keepdims, op): """ data_flipped = relax.op.flip(data, axis=axis) flipped_idx = bb.normalize(op(data_flipped, axis, keepdims)) - axis_size = data.struct_info.shape[axis] + axis_size = data.ty.shape[axis] if isinstance(axis_size, tirx.IntImm): offset = relax.const(int(axis_size) - 1, "int64") else: @@ -4269,7 +4257,7 @@ class ArgMax(OnnxOpConverter): @classmethod def _check_attrs(cls, data, attr, shift_axis=True): - dims_num = len(data.struct_info.shape) + dims_num = len(data.ty.shape) axis = attr.get("axis", 0) if shift_axis and axis < 0: axis += dims_num @@ -4304,7 +4292,7 @@ class ArgMin(OnnxOpConverter): @classmethod def _check_attrs(cls, data, attr, shift_axis=True): - dims_num = len(data.struct_info.shape) + dims_num = len(data.ty.shape) axis = attr.get("axis", 0) if shift_axis and axis < 0: axis += dims_num @@ -4405,7 +4393,7 @@ def _impl_v1(cls, bb, inputs, attr, params): epsilon = attr.get("epsilon", 1e-12) - (batch_size, seq_len) = [dim.value for dim in input_ids.struct_info.shape] + (batch_size, seq_len) = [dim.value for dim in input_ids.ty.shape] if segment_ids: assert segment_emb @@ -4473,8 +4461,8 @@ def _impl_v11(cls, bb, inputs, attr, params): ) unique_numbers = tirx.Var("unique_numbers", "int64") - input_shape = data.struct_info.shape - dtype = data.struct_info.dtype + input_shape = data.ty.shape + dtype = data.ty.dtype if axis is None: output_shape = (unique_numbers,) @@ -4487,15 +4475,15 @@ def _impl_v11(cls, bb, inputs, attr, params): ] if num_outputs == 1: - return bb.match_cast(unique, relax.TensorStructInfo(output_shape, dtype)) + return bb.match_cast(unique, relax.TensorType(output_shape, dtype)) - outputs = [bb.match_cast(unique[0], relax.TensorStructInfo(output_shape, dtype))] + outputs = [bb.match_cast(unique[0], relax.TensorType(output_shape, dtype))] tuple_idx = 1 # Track which index in the tuple we're at if return_index: index_shape = (unique_numbers,) - index_sinfo = relax.TensorStructInfo(index_shape, "int64") - outputs.append(bb.match_cast(unique[tuple_idx], index_sinfo)) + index_ty = relax.TensorType(index_shape, "int64") + outputs.append(bb.match_cast(unique[tuple_idx], index_ty)) tuple_idx += 1 if return_inverse: @@ -4503,14 +4491,14 @@ def _impl_v11(cls, bb, inputs, attr, params): # When axis is None: shape is [X.size] # When axis is specified: shape is [X.shape[axis]] inverse_shape = (tirx.Var("inverse_numbers", "int64"),) - inverse_sinfo = relax.TensorStructInfo(inverse_shape, "int64") - outputs.append(bb.match_cast(unique[tuple_idx], inverse_sinfo)) + inverse_ty = relax.TensorType(inverse_shape, "int64") + outputs.append(bb.match_cast(unique[tuple_idx], inverse_ty)) tuple_idx += 1 if return_counts: count_shape = (unique_numbers,) - count_sinfo = relax.TensorStructInfo(count_shape, "int64") - outputs.append(bb.match_cast(unique[tuple_idx], count_sinfo)) + count_ty = relax.TensorType(count_shape, "int64") + outputs.append(bb.match_cast(unique[tuple_idx], count_ty)) return relax.Tuple(outputs) @@ -4520,11 +4508,11 @@ class NonZero(OnnxOpConverter): @classmethod def _impl_v9(cls, bb, inputs, attr, params): - ndim = inputs[0].struct_info.ndim + ndim = inputs[0].ty.ndim ndim = 1 if ndim == 0 else ndim nonzero_numbers = tirx.Var("nonzero_numbers", "int64") return bb.match_cast( - relax.op.nonzero(inputs[0]), relax.TensorStructInfo((ndim, nonzero_numbers), "int64") + relax.op.nonzero(inputs[0]), relax.TensorType((ndim, nonzero_numbers), "int64") ) @@ -4537,7 +4525,7 @@ def _impl_v9(cls, bb, inputs, attr, params): assert len(scales) == 4 assert scales[0] == scales[1] == 1 - inp_shape = [int(x) for x in inputs[0].struct_info.shape] + inp_shape = [int(x) for x in inputs[0].ty.shape] assert len(inp_shape) == 4 out_shape2d = [int(dim * scale) for dim, scale in zip(inp_shape[2:], scales[2:])] @@ -4563,7 +4551,7 @@ class HardSigmoid(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): x = inputs[0] - dtype = x.struct_info.dtype + dtype = x.ty.dtype alpha = float(attr.get("alpha", 0.2)) alpha = relax.const(alpha, dtype=dtype) beta = float(attr.get("beta", 0.5)) @@ -4577,7 +4565,7 @@ class HardSwish(OnnxOpConverter): @classmethod def _impl_v14(cls, bb, inputs, attr, params): x = inputs[0] - dtype = x.struct_info.dtype + dtype = x.ty.dtype return relax.op.multiply( x, relax.op.divide( @@ -4610,7 +4598,7 @@ class DepthToSpace(OnnxOpConverter): def _impl_v11(cls, bb, inputs, attr, params): block_size = int(attr["blocksize"]) mode = attr.get("mode", b"DCR").decode("utf-8") - b, c, h, w = inputs[0].struct_info.shape + b, c, h, w = inputs[0].ty.shape if mode == "DCR": x = relax.op.reshape(inputs[0], (b, block_size, block_size, c // (block_size**2), h, w)) x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2]) @@ -4629,7 +4617,7 @@ class SpaceToDepth(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): block_size = int(attr["blocksize"]) - b, c, h, w = inputs[0].struct_info.shape + b, c, h, w = inputs[0].ty.shape x = relax.op.reshape( inputs[0], (b, c, h // block_size, block_size, w // block_size, block_size) ) @@ -4800,7 +4788,7 @@ def _impl_v11(cls, bb, inputs, attr, params): keepdims = attr.get("keepdims", 1) input_tensor = inputs[0] - input_shape = input_tensor.struct_info.shape + input_shape = input_tensor.ty.shape if len(inputs) == 1: split = _np.array(1) @@ -4831,7 +4819,7 @@ def _impl_v11(cls, bb, inputs, attr, params): # Per ONNX spec: "If input 'split' is specified, this attribute is ignored." if not keepdims and len(inputs) == 1: output = bb.emit(output) - n = len(output.struct_info.fields) + n = len(output.ty.fields) squeezed = [ relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis]) for i in range(n) @@ -4920,8 +4908,8 @@ def _impl_v10(cls, bb, inputs, attr, params): yc = split_result[1] w = split_result[2] h = split_result[3] - half_w = w / relax.const(2.0, boxes.struct_info.dtype) - half_h = h / relax.const(2.0, boxes.struct_info.dtype) + half_w = w / relax.const(2.0, boxes.ty.dtype) + half_h = h / relax.const(2.0, boxes.ty.dtype) x1 = xc - half_w x2 = xc + half_w y1 = yc - half_h @@ -5009,8 +4997,8 @@ def _impl_v1(cls, bb, inputs, attr, params): yc = split_result[1] w = split_result[2] h = split_result[3] - half_w = w / relax.const(2.0, boxes.struct_info.dtype) - half_h = h / relax.const(2.0, boxes.struct_info.dtype) + half_w = w / relax.const(2.0, boxes.ty.dtype) + half_h = h / relax.const(2.0, boxes.ty.dtype) x1 = xc - half_w x2 = xc + half_w y1 = yc - half_h @@ -5055,10 +5043,10 @@ def _impl_v16(cls, bb, inputs, attr, params): align_corners = bool(attr.get("align_corners", 0)) - if hasattr(data.struct_info, "ndim"): - ndim = data.struct_info.ndim + if hasattr(data.ty, "ndim"): + ndim = data.ty.ndim else: - ndim = len(data.struct_info.shape) + ndim = len(data.ty.shape) if ndim == 5 and method == "bicubic": raise NotImplementedError( @@ -5118,8 +5106,8 @@ def _impl_v10(cls, bb, inputs, attr, params): a_zp = relax.op.astype( a_zero_point, "int32" ) # Ensure zero point is int32 for subtraction - a_zp = bb.normalize(a_zp) # Normalize the expr so struct_info gets populated - a_zp_ndim = len(a_zp.struct_info.shape) + a_zp = bb.normalize(a_zp) # Normalize the expr so ty gets populated + a_zp_ndim = len(a_zp.ty.shape) # Per-row case: [M] -> [M, 1] so it broadcasts over [M, K] row-wise # N-D case: spec says shape is [D1, D2, M, 1], which already broadcasts correctly (no need to reshape) @@ -5131,7 +5119,7 @@ def _impl_v10(cls, bb, inputs, attr, params): if b_zero_point is not None: b_zp = relax.op.astype(b_zero_point, "int32") b_zp = bb.normalize(b_zp) - b_zp_ndim = len(b_zp.struct_info.shape) + b_zp_ndim = len(b_zp.ty.shape) # Per-col case: [N] -> [1, N] so it broadcasts over [K, N] column-wise # N-D case: [D1, D2, 1, N] already broadcasts correctly @@ -5467,9 +5455,7 @@ def _sanitize_name(self, name: str) -> str: def _new_var(self, var_name: str, shape: list, dtype: str = "float32"): """Creates a new Relax variable.""" - return relax.Var( - name_hint=var_name, struct_info=relax.TensorStructInfo(shape=shape, dtype=dtype) - ) + return relax.Var(name_hint=var_name, ty=relax.TensorType(shape=shape, dtype=dtype)) def _parse_graph_input(self, graph: onnx.onnx_ml_pb2.GraphProto): """Parse model inputs to Relax parameters.""" @@ -5579,7 +5565,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): if ( inp is not None and isinstance(inp, relax.Expr) - and isinstance(inp.struct_info, relax.ShapeStructInfo) + and isinstance(inp.ty, relax.ShapeType) and op_name not in shape_compatible_ops ): raise ValueError(f"Node {node.name} cannot handle ShapeExpr inputs.") @@ -5595,11 +5581,11 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): if op_name in return_tuple_ops: outputs_num = 1 elif not isinstance(op, relax.Tuple): - if isinstance(op.struct_info, relax.TupleStructInfo): + if isinstance(op.ty, relax.TupleType): # This is a var bound to a tuple. We need to unpack it and create # a new tuple. tuple_items = [] - for i in range(len(op.struct_info.fields)): + for i in range(len(op.ty.fields)): tuple_items.append(self.bb.emit(relax.TupleGetItem(op, i))) op = relax.Tuple(tuple_items) outputs_num = len(tuple_items) @@ -5735,19 +5721,17 @@ def _convert_subgraph(self, bb, graph): op = self._convert_operator(op_name, inputs, attr, self.opset) try: - _ = op.struct_info - has_struct_info = True + _ = op.ty + has_ty = True except tvm.error.InternalError: - has_struct_info = False + has_ty = False - if not has_struct_info: + if not has_ty: op = bb.normalize(op) if not isinstance(op, relax.Tuple): - if isinstance(op.struct_info, relax.TupleStructInfo): - tuple_items = [ - relax.TupleGetItem(op, i) for i in range(len(op.struct_info.fields)) - ] + if isinstance(op.ty, relax.TupleType): + tuple_items = [relax.TupleGetItem(op, i) for i in range(len(op.ty.fields))] op = relax.Tuple(tuple_items) outputs = node.output diff --git a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py index a5196095eee0..51994122d775 100644 --- a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py +++ b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py @@ -145,10 +145,10 @@ def _promote_binary_op_args(lhs, rhs): if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): return lhs, rhs if isinstance(lhs, relax.Expr): - assert isinstance(lhs.struct_info, relax.TensorStructInfo) - return lhs, relax.const(rhs, lhs.struct_info.dtype) - assert isinstance(rhs.struct_info, relax.TensorStructInfo) - return relax.const(lhs, rhs.struct_info.dtype), rhs + assert isinstance(lhs.ty, relax.TensorType) + return lhs, relax.const(rhs, lhs.ty.dtype) + assert isinstance(rhs.ty, relax.TensorType) + return relax.const(lhs, rhs.ty.dtype), rhs def _call_binary_op(self, op, lhs, rhs): lhs, rhs = StableHLOImporter._promote_binary_op_args(lhs, rhs) @@ -381,7 +381,7 @@ def from_stablehlo(self, model, input_info: list[tuple[tuple[int], str]]) -> tvm ipt_shape = self.get_shape(arg_shape) ipt_dtype = self._convert_data_type(arg_shape.element_type) ipt_name = "arg" + str(idx) - ipt_var = relax.Var(f"arg{idx}", relax.TensorStructInfo(ipt_shape, ipt_dtype)) + ipt_var = relax.Var(f"arg{idx}", relax.TensorType(ipt_shape, ipt_dtype)) self._nodes[ipt_name] = ipt_var inputs.append(ipt_var) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 35e75e89bc6a..a087ce5d36fd 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -528,7 +528,7 @@ def unbind(self, data, axis=1): return relax.Tuple(relax.Tuple(ret), selections) def _infer_shape(self, arg): - return self.bb.normalize(arg).struct_info.shape + return self.bb.normalize(arg).ty.shape def convert_op_to_relax(self): """Convert TFLite ops to relax ops""" @@ -1120,13 +1120,11 @@ def _get_shape_expr_from_tensor(self, shape_tensor, prefix): dims_expr = self.get_expr(shape_tensor.tensor_idx) dims_ndim = int(self.get_tensor_shape(shape_tensor)[0]) dims_dtype = self.get_tensor_type_str(shape_tensor.tensor.Type()) - dims_expr = self.bb.match_cast( - dims_expr, relax.TensorStructInfo([dims_ndim], dims_dtype) - ) + dims_expr = self.bb.match_cast(dims_expr, relax.TensorType([dims_ndim], dims_dtype)) dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64")) shape_dataflow_var = self.bb.emit(relax.op.tensor_to_shape(dims_expr)) shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in range(dims_ndim)] - self.bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) + self.bb.match_cast(shape_dataflow_var, relax.ShapeType(shape_vars)) return relax.ShapeExpr(shape_vars), shape_vars dims = to_int_list(self.get_tensor_value(shape_tensor)) @@ -1134,7 +1132,7 @@ def _get_shape_expr_from_tensor(self, shape_tensor, prefix): def flatten_to_nd(self, x, nd=3): """Flatten input tensor to nd rank""" - shape = x.struct_info.shape + shape = x.ty.shape ndims = len(shape) if ndims == nd: return x @@ -1899,7 +1897,7 @@ def _convert_stablehlo_and(self, op): lhs = self.get_tensor_expr(input_tensors[0]) rhs = self.get_tensor_expr(input_tensors[1]) - dtype = lhs.struct_info.dtype + dtype = lhs.ty.dtype if dtype == "bool": op_fn = _op.logical_and elif dtype.startswith(("int", "uint")): @@ -1917,7 +1915,7 @@ def _convert_stablehlo_or(self, op): lhs = self.get_tensor_expr(input_tensors[0]) rhs = self.get_tensor_expr(input_tensors[1]) - dtype = lhs.struct_info.dtype + dtype = lhs.ty.dtype if dtype == "bool": op_fn = _op.logical_or elif dtype.startswith(("int", "uint")): @@ -1987,7 +1985,7 @@ def _get_stablehlo_i64_vector(self, vector, default): def _ensure_stablehlo_float_dtype(self, expr, op_name): """Return expr dtype if the StableHLO subset supports it.""" - dtype = expr.struct_info.dtype + dtype = expr.ty.dtype if not dtype.startswith("float"): raise tvm.error.OpNotImplemented(f"{op_name} with dtype {dtype} is not supported") return dtype @@ -2473,8 +2471,8 @@ def _convert_stablehlo_rng_bit_generator(self, op): gv, [state_expr], [ - relax.TensorStructInfo(tuple(state_shape), "uint64"), - relax.TensorStructInfo(out_shape, out_dtype), + relax.TensorType(tuple(state_shape), "uint64"), + relax.TensorType(out_shape, out_dtype), ], ) return self.bb.normalize(call) @@ -2617,7 +2615,7 @@ def _get_subgraph_params(self, subgraph): input_name = get_tensor_name(subgraph, int(input_index)) shape = self._get_relax_tensor_shape(tensor) dtype = self._get_relax_tensor_dtype(tensor) - param = relax.Var(input_name, relax.TensorStructInfo(shape=shape, dtype=dtype)) + param = relax.Var(input_name, relax.TensorType(shape=shape, dtype=dtype)) exp_tab.set_expr(input_name, param) params.append(param) return params, exp_tab @@ -2627,7 +2625,7 @@ def _get_tensor_param(self, tensor_wrapper): name = get_tensor_name(self.subgraph, tensor_wrapper.tensor_idx) shape = self._get_relax_tensor_shape(tensor_wrapper) dtype = self._get_relax_tensor_dtype(tensor_wrapper) - return relax.Var(name, relax.TensorStructInfo(shape=shape, dtype=dtype)) + return relax.Var(name, relax.TensorType(shape=shape, dtype=dtype)) def _lower_subgraph_to_function(self, subgraph_index, function_name_hint, op_name="CALL"): """Lower a TFLite subgraph into a private Relax function.""" @@ -3852,7 +3850,7 @@ def _transform_mask(stride_dim, ellipsis_mask): stride = [int(i) for i in stride] axes = list(range(len(begin))) out = relax.op.strided_slice(data_expr, axes=axes, begin=begin, end=end, strides=stride) - out_shape = self.bb.normalize(out).struct_info.shape + out_shape = self.bb.normalize(out).ty.shape if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -3941,7 +3939,7 @@ def convert_random_uniform(self, op): return relax.op.call_dps_packed( "tvm.contrib.random.uniform", (seed, seed2, 0.0, 1.0), - out_sinfo=relax.TensorStructInfo(out_shape, output_dtype), + out_ty=relax.TensorType(out_shape, output_dtype), ) def convert_random_standard_normal(self, op): @@ -3962,7 +3960,7 @@ def convert_random_standard_normal(self, op): return relax.op.call_dps_packed( "tvm.contrib.random.normal", (seed, seed2, 0.0, 1.0), - out_sinfo=relax.TensorStructInfo(out_shape, output_dtype), + out_ty=relax.TensorType(out_shape, output_dtype), ) def convert_multinomial(self, op): @@ -3976,12 +3974,12 @@ def convert_multinomial(self, op): if self.has_expr(num_samples_tensor.tensor_idx): scalar_expr = self.get_expr(num_samples_tensor.tensor_idx) scalar_dtype = self.get_tensor_type_str(num_samples_tensor.tensor.Type()) - scalar_expr = self.bb.match_cast(scalar_expr, relax.TensorStructInfo([], scalar_dtype)) + scalar_expr = self.bb.match_cast(scalar_expr, relax.TensorType([], scalar_dtype)) scalar_expr = self.bb.normalize(relax.op.astype(scalar_expr, "int64")) scalar_expr = self.bb.normalize(relax.op.reshape(scalar_expr, [1])) shape_dataflow_var = self.bb.emit(relax.op.tensor_to_shape(scalar_expr)) num_samples = tirx.Var("multinomial_num_samples", "int64") - self.bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo([num_samples])) + self.bb.match_cast(shape_dataflow_var, relax.ShapeType([num_samples])) else: value = self.get_tensor_value(num_samples_tensor) assert value.size == 1, ( @@ -4001,7 +3999,7 @@ def convert_multinomial(self, op): uniform_sample = relax.op.call_dps_packed( "tvm.contrib.random.uniform", (seed, seed2, 0.0, 1.0), - out_sinfo=relax.TensorStructInfo([output_batch, 1], "float32"), + out_ty=relax.TensorType([output_batch, 1], "float32"), ) sample_indices = relax.op.reshape( relax.op.broadcast_to( @@ -4879,7 +4877,7 @@ def convert_split_v(self, op): # TFLite fixes the tuple arity in the graph, even when the split # sizes themselves are supplied at runtime. num_splits = len(output_tensors) - rank = len(in_expr.struct_info.shape) + rank = len(in_expr.ty.shape) # end_base is the full input shape; only split_axis changes per slice. end_base = relax.op.shape_to_tensor(relax.op.shape_of(in_expr)) @@ -5256,7 +5254,7 @@ def convert_rfft2d(self, op): call = relax.call_tir( gv, [data_expr], - relax.TensorStructInfo(relax_output_shape, "float32"), + relax.TensorType(relax_output_shape, "float32"), ) return self.bb.normalize(call) @@ -6543,7 +6541,7 @@ def convert_batch_to_space_nd(self, op): relax.ShapeExpr(crop_begin), relax.ShapeExpr(crop_end), ), - out_sinfo=relax.TensorStructInfo(output_shape, output_dtype), + out_ty=relax.TensorType(output_shape, output_dtype), ) return out @@ -6711,8 +6709,8 @@ def convert_batch_matmul(self, op): input_a = self.get_expr(input_tensors[0].tensor_idx) input_b = self.get_expr(input_tensors[1].tensor_idx) - shape_a = list(input_a.struct_info.shape) - shape_b = list(input_b.struct_info.shape) + shape_a = list(input_a.ty.shape) + shape_b = list(input_b.ty.shape) rank_a = len(shape_a) rank_b = len(shape_b) @@ -6787,7 +6785,7 @@ def convert_space_to_batch_nd(self, op): relax.ShapeExpr(pad_after), 0.0, ), - out_sinfo=relax.TensorStructInfo(output_shape, output_dtype), + out_ty=relax.TensorType(output_shape, output_dtype), ) return out @@ -6882,7 +6880,7 @@ def convert_sparse_to_dense(self, op): out = relax.op.call_dps_packed( "topi.sparse_to_dense", (indices_expr, output_shape_expr, values_expr, default_value_expr), - out_sinfo=relax.TensorStructInfo(output_shape_val, output_dtype), + out_ty=relax.TensorType(output_shape_val, output_dtype), ) return out @@ -7134,13 +7132,11 @@ def convert_dilate(self, op): # per-axis math. if self.has_expr(dilations_tensor.tensor_idx): dilations_expr = self.get_expr(dilations_tensor.tensor_idx) - dilations_expr = self.bb.match_cast( - dilations_expr, relax.TensorStructInfo([n_dims], "int32") - ) + dilations_expr = self.bb.match_cast(dilations_expr, relax.TensorType([n_dims], "int32")) dilations_int64 = self.bb.normalize(relax.op.astype(dilations_expr, "int64")) shape_var = self.bb.emit(relax.op.tensor_to_shape(dilations_int64)) stride_vars = [tirx.Var(f"dilate_stride_{i}", "int64") for i in range(n_dims)] - self.bb.match_cast(shape_var, relax.ShapeStructInfo(stride_vars)) + self.bb.match_cast(shape_var, relax.ShapeType(stride_vars)) strides = stride_vars else: strides = to_int_list(self.get_tensor_value(dilations_tensor)) @@ -7328,7 +7324,7 @@ def convert_detection_postprocess(self, op): num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2)) class_id_from_score = relax.op.squeeze(class_id_from_score, axis=[1]) - selected_score_slots = selected_scores.struct_info.shape.values[1] + selected_score_slots = selected_scores.ty.shape.values[1] selected_detection_positions = relax.op.expand_dims( relax.op.arange(selected_score_slots, dtype="int64"), axis=0 ) @@ -7693,7 +7689,7 @@ def convert_matrix_set_diag(self, op): relax.const(False), relax.const(False), ), - out_sinfo=relax.TensorStructInfo(output_shape, output_dtype), + out_ty=relax.TensorType(output_shape, output_dtype), ) return out @@ -7734,7 +7730,7 @@ def convert_matrix_diag(self, op): relax.const(False), relax.const(False), ), - out_sinfo=relax.TensorStructInfo(output_shape, output_dtype), + out_ty=relax.TensorType(output_shape, output_dtype), ) return out @@ -8776,7 +8772,7 @@ def func(self, data): shape = tuple(shape) + (2,) input_var = relax.Var( name_hint=model_input_name, - struct_info=relax.TensorStructInfo(shape=shape, dtype=dtype), + ty=relax.TensorType(shape=shape, dtype=dtype), ) exp_tab.set_expr(model_input_name, input_var) input_list.append(input_var) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c3cdd464ff1..f987f48d4251 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -115,9 +115,9 @@ def shape_of(tensor): import torch # type: ignore if isinstance(tensor, relax.Expr): - if not isinstance(tensor.struct_info, relax.TensorStructInfo): + if not isinstance(tensor.ty, relax.TensorType): raise TypeError("The input Expr of shape_of should be a Tensor") - return tensor.struct_info.shape + return tensor.ty.shape elif isinstance(tensor, torch.Tensor): return tensor.shape raise ValueError(f"Unsupported type: {type(tensor)}") @@ -210,7 +210,7 @@ def convert(node: fx.Node) -> relax.Var: def _celu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.0) - dtype = x.struct_info.dtype + dtype = x.ty.dtype if isinstance(alpha, int | float): alpha = relax.const(alpha, dtype) @@ -329,7 +329,7 @@ def _clamp_max(self, node: fx.Node) -> relax.Expr: def _elu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.0) - dtype = x.struct_info.dtype + dtype = x.ty.dtype if isinstance(alpha, int | float): alpha = relax.const(-alpha, dtype) @@ -360,7 +360,7 @@ def _gelu(self, node: fx.Node) -> relax.Expr: def _hardsigmoid(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] - dtype = x.struct_info.dtype + dtype = x.ty.dtype x0 = relax.op.add(x, relax.const(3, dtype)) x1 = relax.op.clip(x0, 0, 6) return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) @@ -368,7 +368,7 @@ def _hardsigmoid(self, node: fx.Node) -> relax.Var: def _hardswish(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] - dtype = x.struct_info.dtype + dtype = x.ty.dtype x0 = relax.op.add(x, relax.const(3, dtype)) x1 = relax.op.clip(x0, 0, 6) x2 = relax.op.divide(x1, relax.const(6, dtype)) @@ -396,9 +396,9 @@ def _logical_and(self, node: fx.Node) -> relax.Var: rhs = self.env[node.args[1]] # torch.logical_and accepts any dtype (treating nonzero as True) and returns bool, but # relax.op.logical_and requires boolean inputs, so cast non-bool inputs to bool first. - if lhs.struct_info.dtype != "bool": + if lhs.ty.dtype != "bool": lhs = self.block_builder.emit(relax.op.astype(lhs, "bool")) - if rhs.struct_info.dtype != "bool": + if rhs.ty.dtype != "bool": rhs = self.block_builder.emit(relax.op.astype(rhs, "bool")) return self.block_builder.emit(relax.op.logical_and(lhs, rhs)) @@ -406,7 +406,7 @@ def _logical_not(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] # torch.logical_not accepts any dtype (treating nonzero as True) and returns bool, but # relax.op.logical_not requires a boolean input, so cast non-bool inputs to bool first. - if x.struct_info.dtype != "bool": + if x.ty.dtype != "bool": x = self.block_builder.emit(relax.op.astype(x, "bool")) return self.block_builder.emit(relax.op.logical_not(x)) @@ -415,9 +415,9 @@ def _logical_or(self, node: fx.Node) -> relax.Var: rhs = self.env[node.args[1]] # torch.logical_or accepts any dtype (treating nonzero as True) and returns bool, but # relax.op.logical_or requires boolean inputs, so cast non-bool inputs to bool first. - if lhs.struct_info.dtype != "bool": + if lhs.ty.dtype != "bool": lhs = self.block_builder.emit(relax.op.astype(lhs, "bool")) - if rhs.struct_info.dtype != "bool": + if rhs.ty.dtype != "bool": rhs = self.block_builder.emit(relax.op.astype(rhs, "bool")) return self.block_builder.emit(relax.op.logical_or(lhs, rhs)) @@ -426,16 +426,16 @@ def _logical_xor(self, node: fx.Node) -> relax.Var: rhs = self.env[node.args[1]] # torch.logical_xor accepts any dtype (treating nonzero as True) and returns bool, but # relax.op.logical_xor requires boolean inputs, so cast non-bool inputs to bool first. - if lhs.struct_info.dtype != "bool": + if lhs.ty.dtype != "bool": lhs = self.block_builder.emit(relax.op.astype(lhs, "bool")) - if rhs.struct_info.dtype != "bool": + if rhs.ty.dtype != "bool": rhs = self.block_builder.emit(relax.op.astype(rhs, "bool")) return self.block_builder.emit(relax.op.logical_xor(lhs, rhs)) def _prelu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = self.env[node.args[1]] - axis = 0 if len(x.struct_info.shape) == 1 else 1 + axis = 0 if len(x.ty.shape) == 1 else 1 return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) def _round(self, node: fx.Node) -> relax.Expr: @@ -446,7 +446,7 @@ def _round(self, node: fx.Node) -> relax.Expr: return self.block_builder.emit(relax.op.round(arg)) # For decimals != 0, use: round(x * 10^decimals) / 10^decimals - dtype = arg.struct_info.dtype + dtype = arg.ty.dtype scale = relax.const(10**decimals, dtype) scaled = relax.op.multiply(arg, scale) rounded = relax.op.round(scaled) @@ -487,17 +487,17 @@ def _softshrink(self, node: fx.Node) -> relax.Var: """ args = self.retrieve_args(node) x = args[0] - lambd = relax.const(args[1] if len(args) > 1 else 0.5, x.struct_info.dtype) + lambd = relax.const(args[1] if len(args) > 1 else 0.5, x.ty.dtype) # Apply Softshrink transformation with masking shrink_pos = relax.op.multiply( relax.op.subtract(x, lambd), - relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype), + relax.op.astype(relax.op.greater(x, lambd), x.ty.dtype), ) shrink_neg = relax.op.multiply( relax.op.add(x, lambd), - relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype), + relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), x.ty.dtype), ) # Combine the positive and negative shrink results @@ -522,10 +522,10 @@ def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: def convert(node: fx.Node) -> relax.Var: def promote_binary_op_args(lhs, rhs): if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): - lhs_si = getattr(lhs, "struct_info", None) - rhs_si = getattr(rhs, "struct_info", None) - if isinstance(lhs_si, relax.TensorStructInfo) and isinstance( - rhs_si, relax.TensorStructInfo + lhs_si = getattr(lhs, "ty", None) + rhs_si = getattr(rhs, "ty", None) + if isinstance(lhs_si, relax.TensorType) and isinstance( + rhs_si, relax.TensorType ): target_dtype = self._promote_common_dtype(lhs_si.dtype, rhs_si.dtype) if target_dtype is not None: @@ -535,11 +535,11 @@ def promote_binary_op_args(lhs, rhs): rhs = self.block_builder.emit(relax.op.astype(rhs, target_dtype)) return lhs, rhs elif isinstance(lhs, relax.Expr): - assert isinstance(lhs.struct_info, relax.TensorStructInfo) - return lhs, relax.const(rhs, lhs.struct_info.dtype) + assert isinstance(lhs.ty, relax.TensorType) + return lhs, relax.const(rhs, lhs.ty.dtype) elif isinstance(rhs, relax.Expr): - assert isinstance(rhs.struct_info, relax.TensorStructInfo) - return relax.const(lhs, rhs.struct_info.dtype), rhs + assert isinstance(rhs.ty, relax.TensorType) + return relax.const(lhs, rhs.ty.dtype), rhs else: assert False @@ -551,9 +551,9 @@ def call_binary_op(op, lhs, rhs): if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return call_binary_op(relax_op, lhs, rhs) elif isinstance(lhs, relax.expr.Constant) and not isinstance(rhs, relax.expr.Constant): - return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) + return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.ty.dtype)) elif isinstance(rhs, relax.expr.Constant) and not isinstance(lhs, relax.expr.Constant): - return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) + return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.ty.dtype), rhs) return intrinsic_op(lhs, rhs) return convert @@ -565,8 +565,8 @@ def _pow(self, node: fx.Node) -> relax.Var: # a constant non-negative integer exponent into repeated multiplication instead. if ( isinstance(lhs, relax.Expr) - and isinstance(lhs.struct_info, relax.TensorStructInfo) - and "int" in lhs.struct_info.dtype + and isinstance(lhs.ty, relax.TensorType) + and "int" in lhs.ty.dtype and isinstance(rhs, int) and not isinstance(rhs, bool) and rhs >= 0 @@ -612,9 +612,9 @@ def _fmod(self, node: fx.Node): if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): return self.block_builder.emit(relax.op.mod(lhs, rhs)) elif isinstance(lhs, relax.Expr): - rhs = relax.const(rhs, lhs.struct_info.dtype) + rhs = relax.const(rhs, lhs.ty.dtype) elif isinstance(rhs, relax.Expr): - lhs = relax.const(lhs, rhs.struct_info.dtype) + lhs = relax.const(lhs, rhs.ty.dtype) else: assert False return self.block_builder.emit(relax.op.mod(lhs, rhs)) @@ -639,7 +639,7 @@ def _isin(self, node: fx.Node) -> relax.Var: comparison = relax.op.equal(expanded_elements, flattened_test_elements) summed = relax.op.sum(comparison, axis=-1) - result = relax.op.greater(summed, relax.const(0, dtype=elements.struct_info.dtype)) + result = relax.op.greater(summed, relax.const(0, dtype=elements.ty.dtype)) return self.block_builder.emit(result) @@ -656,7 +656,7 @@ def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: # If ord_val is a Python float/int, wrap it in a Relax const # so that it matches data's dtype. - dtype = data.struct_info.dtype + dtype = data.ty.dtype ord_expr = ( ord_val if isinstance(ord_val, relax.Expr) else relax.const(float(ord_val), dtype) ) @@ -684,7 +684,7 @@ def _adaptive_avg_pool1d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output_size = node.args[1] if len(node.args) > 1 else node.kwargs["output_size"] # Expand to 3D by adding batch dim if input is 2D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 2: x = relax.op.expand_dims(x, axis=0) @@ -700,7 +700,7 @@ def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output_size = node.args[1] # Expand to 4D by adding batch dim if input is 3D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 3: x = relax.op.expand_dims(x, axis=0) @@ -716,7 +716,7 @@ def _adaptive_avg_pool3d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output_size = node.args[1] # Expand to 5D by adding batch dim if input is 4D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 4: x = relax.op.expand_dims(x, axis=0) @@ -739,10 +739,10 @@ def _addmm(self, node: fx.Node) -> relax.Var: if alpha != 0: res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) if alpha != 1: - dtype = res.struct_info.dtype + dtype = res.ty.dtype res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) if beta != 0: - dtype = x.struct_info.dtype + dtype = x.ty.dtype if beta != 1: bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) else: @@ -760,7 +760,7 @@ def _avg_pool1d_impl( count_include_pad: bool | None = True, ) -> relax.Var: # Expand to 3D by adding batch dim if input is 2D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 2: x = relax.op.expand_dims(x, axis=0) stride = kernel_size if stride is None or stride == [] else stride @@ -802,7 +802,7 @@ def _avg_pool2d_impl( count_include_pad: bool | None = True, ) -> relax.Var: # Expand to 4D by adding batch dim if input is 3D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 3: x = relax.op.expand_dims(x, axis=0) stride = kernel_size if stride is None or stride == [] else stride @@ -843,7 +843,7 @@ def _avg_pool3d_impl( count_include_pad: bool | None = True, ) -> relax.Var: # Expand to 5D by adding batch dim if input is 4D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 4: x = relax.op.expand_dims(x, axis=0) stride = kernel_size if stride is None or stride == [] else stride @@ -886,10 +886,10 @@ def _baddbmm(self, node: fx.Node) -> relax.Var: if alpha != 0: res = self.block_builder.emit(relax.op.matmul(batch1, batch2)) if alpha != 1: - dtype = res.struct_info.dtype + dtype = res.ty.dtype res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) if beta != 0: - dtype = x.struct_info.dtype + dtype = x.ty.dtype if beta != 1: bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) else: @@ -1259,12 +1259,12 @@ def _embedding_impl( ) -> relax.Var: x = self.block_builder.emit(relax.op.astype(x, "int32")) - ndim = x.struct_info.ndim + ndim = x.ty.ndim if ndim == 1: return self.block_builder.emit(relax.op.take(weight, x, axis=0)) else: - x_shape = x.struct_info.shape.values - emb_size = weight.struct_info.shape.values[-1] + x_shape = x.ty.shape.values + emb_size = weight.ty.shape.values[-1] x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) @@ -1286,10 +1286,10 @@ def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: if gamma is None: shape_tuple = [int(s) for s in normalized_shape] - gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) + gamma = relax.const(np.ones(shape_tuple), x.ty.dtype) if beta is None: shape_tuple = [int(s) for s in normalized_shape] - beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) + beta = relax.const(np.zeros(shape_tuple), x.ty.dtype) return self.block_builder.emit( relax.op.nn.layer_norm( @@ -1319,8 +1319,8 @@ def _layer_norm_module(self, node: fx.Node) -> relax.Var: gamma = self.params[module.weight] beta = self.params[module.bias] else: - gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) - beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) + gamma = relax.const(torch.ones_like(module.normalized_shape), x.ty.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.ty.dtype) eps = module.eps return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) @@ -1341,7 +1341,7 @@ def _max_pool1d_impl( ceil_mode: bool | None = False, ) -> relax.Var: # Expand to 3D by adding batch dim if input is 2D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 2: x = relax.op.expand_dims(x, axis=0) @@ -1385,7 +1385,7 @@ def _max_pool2d_impl( ceil_mode: bool | None = False, ) -> relax.Var: # Expand to 4D by adding batch dim if input is 3D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 3: x = relax.op.expand_dims(x, axis=0) @@ -1429,7 +1429,7 @@ def _max_pool3d_impl( ceil_mode: bool | None = False, ) -> relax.Var: # Expand to 5D by adding batch dim if input is 4D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 4: x = relax.op.expand_dims(x, axis=0) @@ -1519,7 +1519,7 @@ def _pad(self, node: fx.Node) -> relax.Var: # Calculate symmetric padding width for each dimension # and applying them in reverse order to match the input dimensions. - input_ndim = x.struct_info.ndim + input_ndim = x.ty.ndim pad_width = [0] * (input_ndim * 2) pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)] reversed_pairs = list(reversed(pad_pairs)) @@ -1536,7 +1536,7 @@ def _constant_pad_nd(self, node: fx.Node) -> relax.Var: # Calculate symmetric padding width for each dimension # and applying them in reverse order to match the input dimensions. - input_ndim = x.struct_info.ndim + input_ndim = x.ty.ndim pad_width = [0] * (input_ndim * 2) pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)] reversed_pairs = list(reversed(pad_pairs)) @@ -1560,7 +1560,7 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: value_tensor = self.env[node.args[2]] # Check the dimensionality of the input tensors - query_ndim = len(query_tensor.struct_info.shape) + query_ndim = len(query_tensor.ty.shape) # TVM's nn.attention requires 4D inputs in format (batch, num_heads, seq_len, head_dim) # For 2D inputs (seq_len, head_dim), we need to reshape to 4D first @@ -1607,7 +1607,7 @@ def transpose_and_reshape_back(tensor): if attn_mask is not None: attn_mask = self.env[attn_mask] msg = "Only a float mask is supported for the attn_mask input." - assert "float" in attn_mask.struct_info.dtype, msg + assert "float" in attn_mask.ty.dtype, msg attention_output = self.block_builder.emit( relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) @@ -1647,7 +1647,7 @@ def _median(self, node: fx.Node) -> relax.Var: def _norm(self, node: fx.Node) -> relax.Var: data = self.env[node.args[0]] - dtype = data.struct_info.dtype + dtype = data.ty.dtype order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) axis = node.args[2] if len(node.args) > 2 else None keepdims = node.args[3] if len(node.args) > 3 else False @@ -1714,7 +1714,7 @@ def _sum(self, node: fx.Node) -> relax.Var: else: # Match PyTorch type promotion: summing bool or integer tensors # accumulates in int64 unless an explicit dtype is given. - input_dtype = x.struct_info.dtype + input_dtype = x.ty.dtype if input_dtype == "bool" or ( (input_dtype.startswith("int") or input_dtype.startswith("uint")) and input_dtype != "int64" @@ -1761,14 +1761,12 @@ def _var_correction(self, node: fx.Node) -> relax.Var: scale = float("nan") else: scale = float(n) / float(n - correction) - return self.block_builder.emit( - relax.op.multiply(var, relax.const(scale, x.struct_info.dtype)) - ) + return self.block_builder.emit(relax.op.multiply(var, relax.const(scale, x.ty.dtype))) @staticmethod def _reduction_size(x: relax.Expr, dim) -> int | None: """Static product of reduced-axis sizes; None if any axis is dynamic.""" - shape = x.struct_info.shape + shape = x.ty.shape if shape is None: return None rank = len(shape) @@ -1796,7 +1794,7 @@ def _any(self, node: fx.Node) -> relax.Var: keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) # max doesn't support boolean tensors directly, so we compute it in int8 and cast back - if x.struct_info.dtype == "bool": + if x.ty.dtype == "bool": x = relax.op.astype(x, "int8") ret = relax.op.max(x, dim, keepdims=keepdim) return self.block_builder.emit(relax.op.astype(ret, "bool")) @@ -1977,7 +1975,7 @@ def _index_put(self, node: fx.Node) -> relax.Var: data_shape = self.shape_of(tensor) processed_indices = [] - max_ndim = max((idx.struct_info.ndim for _, idx in non_none_indices), default=1) + max_ndim = max((idx.ty.ndim for _, idx in non_none_indices), default=1) for i, idx in enumerate(indices): if idx is None: @@ -2045,10 +2043,10 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: # Check if all indices can be squeezed to 1D for sequential take def is_squeezable(idx): - if idx.struct_info.ndim == 1: + if idx.ty.ndim == 1: return True - if idx.struct_info.ndim == 2: - shape = idx.struct_info.shape + if idx.ty.ndim == 2: + shape = idx.ty.shape for d in shape: if isinstance(d, int) and d == 1: return True @@ -2061,13 +2059,13 @@ def is_squeezable(idx): if all_squeezable: result = data for axis, idx in reversed(non_none_indices): - if idx.struct_info.ndim > 1: + if idx.ty.ndim > 1: idx = self.block_builder.emit(relax.op.squeeze(idx)) result = self.block_builder.emit(relax.op.take(result, idx, axis=axis)) return result # General case: replace None with arange, reshaped for broadcasting - max_ndim = max((idx.struct_info.ndim for _, idx in non_none_indices), default=1) + max_ndim = max((idx.ty.ndim for _, idx in non_none_indices), default=1) processed_indices = [] data_shape = self.shape_of(data) @@ -2096,9 +2094,9 @@ def _meshgrid(self, node: fx.Node) -> relax.Var: return input_list new_inputs = [] for i, item in enumerate(input_list): - if item.struct_info.ndim == 1: + if item.ty.ndim == 1: new_inputs.append(item) - elif item.struct_info.ndim == 0: # Change scalar value into 1D + elif item.ty.ndim == 0: # Change scalar value into 1D const_tensor = relax.op.reshape(item, (1,)) new_inputs.append(const_tensor) else: @@ -2385,8 +2383,8 @@ def _copy_(self, node: fx.Node) -> relax.Var: src = self.env[node.args[1]] # Match PyTorch semantics: cast to destination dtype and broadcast to destination shape. - if src.struct_info.dtype != dest.struct_info.dtype: - src = self.block_builder.emit(relax.op.astype(src, dest.struct_info.dtype)) + if src.ty.dtype != dest.ty.dtype: + src = self.block_builder.emit(relax.op.astype(src, dest.ty.dtype)) dest_shape = self.shape_of(dest) src_shape = self.shape_of(src) @@ -2476,16 +2474,16 @@ def _eye(self, node: fx.Node) -> relax.Var: def _fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] - dtype = x.struct_info.dtype + dtype = x.ty.dtype value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) - return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + return self.block_builder.emit(relax.op.full(x.ty.shape, value, dtype)) def _inplace_fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] - dtype = x.struct_info.dtype + dtype = x.ty.dtype value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) - filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + filled = self.block_builder.emit(relax.op.full(x.ty.shape, value, dtype)) self.env[node.args[0]] = filled return filled @@ -2511,7 +2509,7 @@ def _full_like(self, node: fx.Node) -> relax.Var: value = node.args[1] fill_value = relax.const(value) - x_dtype = x.struct_info.dtype + x_dtype = x.ty.dtype fill_dtype = None if isinstance(value, int | float) and (math.isinf(value) or math.isnan(value)): if not ("float" in x_dtype or "bfloat16" in x_dtype): @@ -2531,7 +2529,7 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: value = node.args[2] rx_value = relax.const(value) - x_dtype = x.struct_info.dtype + x_dtype = x.ty.dtype fill_dtype = None if isinstance(value, int | float) and (math.isinf(value) or math.isnan(value)): if not ("float" in x_dtype or "bfloat16" in x_dtype): @@ -2576,7 +2574,7 @@ def _masked_fill(self, node: fx.Node) -> relax.Var: value = node.args[2] rx_value = relax.const(value) - x_dtype = x.struct_info.dtype + x_dtype = x.ty.dtype fill_dtype = None if isinstance(value, int | float) and (math.isinf(value) or math.isnan(value)): if not ("float" in x_dtype or "bfloat16" in x_dtype): @@ -2620,8 +2618,8 @@ def _new_ones(self, node: fx.Node) -> relax.Var: return self.block_builder.emit( relax.op.full( size, - relax.const(1, self_var.struct_info.dtype), - self_var.struct_info.dtype, + relax.const(1, self_var.ty.dtype), + self_var.ty.dtype, ) ) @@ -2639,8 +2637,8 @@ def _new_zeros(self, node: fx.Node) -> relax.Var: return self.block_builder.emit( relax.op.full( size, - relax.const(0, input_tensor.struct_info.dtype), - input_tensor.struct_info.dtype, + relax.const(0, input_tensor.ty.dtype), + input_tensor.ty.dtype, ) ) @@ -2678,7 +2676,7 @@ def _to(self, node: fx.Node) -> relax.Var: def _type_as(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] other = self.env[node.args[1]] - dtype = other.struct_info.dtype + dtype = other.ty.dtype return self.block_builder.emit(relax.op.astype(x, dtype)) ########## Others ########## @@ -2690,10 +2688,10 @@ def _getitem(self, node: fx.Node) -> relax.Var: if isinstance(x, list | tuple | relax.ShapeExpr | relax.Tuple): return x[node.args[1]] elif isinstance(x, relax.Var): - if isinstance(x.struct_info, relax.TupleStructInfo): + if isinstance(x.ty, relax.TupleType): return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) - assert isinstance(x.struct_info, relax.TensorStructInfo) + assert isinstance(x.ty, relax.TensorType) if isinstance(node.args[1], int): return x if not isinstance(node.args[1], list | tuple): @@ -2765,7 +2763,7 @@ def _getitem(self, node: fx.Node) -> relax.Var: sliced_shape.insert(i, 1) return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) elif isinstance(x, relax.Constant): - dtype = x.struct_info.dtype + dtype = x.ty.dtype return relax.const(x.data.numpy()[node.args[1]], dtype) else: assert False diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6c9e3e3f5ef5..2e44d800d596 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -84,27 +84,27 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: def _log2(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit( - relax.op.divide(relax.op.log(x), relax.const(0.6931471805599453, x.struct_info.dtype)) + relax.op.divide(relax.op.log(x), relax.const(0.6931471805599453, x.ty.dtype)) ) def _log10(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit( - relax.op.divide(relax.op.log(x), relax.const(2.302585092994046, x.struct_info.dtype)) + relax.op.divide(relax.op.log(x), relax.const(2.302585092994046, x.ty.dtype)) ) def _log1p(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - one = relax.const(1, x.struct_info.dtype) + one = relax.const(1, x.ty.dtype) return self.block_builder.emit(relax.op.log(relax.op.add(x, one))) def _reciprocal(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) + return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.ty.dtype), x)) def _sqrt(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - dtype = x.struct_info.dtype + dtype = x.ty.dtype # Check if input is integer type and convert to float32 if needed if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"): @@ -114,7 +114,7 @@ def _sqrt(self, node: fx.Node) -> relax.Var: def _rsqrt(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - dtype = x.struct_info.dtype + dtype = x.ty.dtype # Check if input is integer type and convert to float32 if needed if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"): @@ -134,7 +134,7 @@ def _batch_norm(self, node: fx.Node, training: bool, return_tuple: bool = False) x = self.env[node.args[0]] channel = int(self.shape_of(x)[1]) - dtype = x.struct_info.dtype + dtype = x.ty.dtype scale = node.args[1] is not None center = node.args[2] is not None weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) @@ -192,7 +192,7 @@ def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] channel = int(self.shape_of(x)[1]) - dtype = x.struct_info.dtype + dtype = x.ty.dtype output = self.block_builder.emit(bn_tuple[0]) new_running_mean = self.block_builder.emit(bn_tuple[1]) @@ -210,7 +210,7 @@ def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] channel = int(self.shape_of(x)[1]) - dtype = x.struct_info.dtype + dtype = x.ty.dtype weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) eps = node.args[5] if len(node.args) > 5 else node.kwargs.get("eps", 1e-05) @@ -508,7 +508,7 @@ def _lstm(self, node: fx.Node) -> relax.Var: # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho) # c_t = f_t * c_{t-1} + i_t * g_t # h_t = o_t * tanh(c_t) - dtype = input_tensor.struct_info.dtype + dtype = input_tensor.ty.dtype params_per_direction = 4 if has_biases else 2 # Extract or create forward direction weights @@ -807,7 +807,7 @@ def _gru(self, node: fx.Node) -> relax.Var: # Fallback to a default hidden size hidden_size = 16 - dtype = input_tensor.struct_info.dtype + dtype = input_tensor.ty.dtype # Extract forward direction weights if params and len(params) >= params_per_direction: @@ -1045,9 +1045,9 @@ def _randn_like(self, node: fx.Node) -> relax.Var: import numpy as np x = self.env[node.args[0]] - x_sinfo = x.struct_info - shape = [int(s) for s in x_sinfo.shape] - dtype = self._convert_data_type(node.kwargs.get("dtype", None) or x_sinfo.dtype, self.env) + x_ty = x.ty + shape = [int(s) for s in x_ty.shape] + dtype = self._convert_data_type(node.kwargs.get("dtype", None) or x_ty.dtype, self.env) data = np.random.randn(*shape).astype(dtype) return self.block_builder.emit(relax.const(data, dtype)) @@ -1088,13 +1088,13 @@ def _sparse_addmm(self, node: fx.Node) -> relax.Var: ) if alpha != 1.0: - alpha_const = relax.const(alpha, matmul_result.struct_info.dtype) + alpha_const = relax.const(alpha, matmul_result.ty.dtype) matmul_result = self.block_builder.emit(relax.op.multiply(matmul_result, alpha_const)) # Compute beta * input + alpha * matmul_result if beta != 0.0: if beta != 1.0: - beta_const = relax.const(beta, input_tensor.struct_info.dtype) + beta_const = relax.const(beta, input_tensor.ty.dtype) input_scaled = self.block_builder.emit(relax.op.multiply(input_tensor, beta_const)) else: input_scaled = input_tensor @@ -1168,7 +1168,7 @@ def _torchvision_roi_align(self, node: fx.Node) -> relax.Var: relax.op.strided_slice(rois, axes=[1], begin=[1], end=[5]) ) boxes = self.block_builder.emit( - relax.op.subtract(boxes, relax.const(0.5, rois.struct_info.dtype)) + relax.op.subtract(boxes, relax.const(0.5, rois.ty.dtype)) ) rois = self.block_builder.emit(relax.op.concat([batch_indices, boxes], axis=1)) @@ -1198,7 +1198,7 @@ def _instance_norm(self, node: fx.Node): x = self.env[node.args[0]] channel = int(self.shape_of(x)[1]) - dtype = x.struct_info.dtype + dtype = x.ty.dtype gamma = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) beta = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) eps = node.args[4] if node.args[4] else 1e-05 @@ -1247,7 +1247,7 @@ def _scatter_value(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] value = node.args[3] - value_const = relax.const(value, x.struct_info.dtype) + value_const = relax.const(value, x.ty.dtype) src = self.block_builder.emit(relax.op.broadcast_to(value_const, self.shape_of(index))) return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) @@ -1386,7 +1386,7 @@ def _import_branch_subgraph( graph_module : torch.fx.GraphModule The branch subgraph (e.g. true_graph_0 / false_graph_0). operands : list[relax.Expr] - The operands passed to the cond; used to derive parameter struct_info. + The operands passed to the cond; used to derive parameter ty. name_hint : str A hint for the function name (e.g. "cond_true_branch_0"). @@ -1417,10 +1417,8 @@ def _import_branch_subgraph( placeholders = [n for n in nodes if n.op == "placeholder"] params = [] for ph, operand in zip(placeholders, operands): - if hasattr(operand, "struct_info") and isinstance( - operand.struct_info, relax.TensorStructInfo - ): - orig_si = operand.struct_info + if hasattr(operand, "ty") and isinstance(operand.ty, relax.TensorType): + orig_si = operand.ty # Create fresh SizeVars to avoid sharing with the caller function. if orig_si.shape is not None: new_shape = [ @@ -1429,13 +1427,13 @@ def _import_branch_subgraph( else s for s in orig_si.shape ] - si = relax.TensorStructInfo(new_shape, orig_si.dtype) + si = relax.TensorType(new_shape, orig_si.dtype) else: si = orig_si - elif hasattr(operand, "struct_info"): - si = operand.struct_info + elif hasattr(operand, "ty"): + si = operand.ty else: - si = relax.ObjectStructInfo() + si = relax.ObjectType() param = relax.Var(ph.name, si) params.append(param) self.env[ph] = param @@ -1536,7 +1534,7 @@ def create_convert_map( "expm1.default": lambda node: self.block_builder.emit( relax.op.subtract( relax.op.exp(self.env[node.args[0]]), - relax.const(1.0, self.env[node.args[0]].struct_info.dtype), + relax.const(1.0, self.env[node.args[0]].ty.dtype), ) ), "floor.default": self._unary_op(relax.op.floor), @@ -1959,7 +1957,7 @@ def create_input_vars( relax_shape.append(s) dtype = self._convert_data_type(torch_dtype) - relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype)) + relax_var = relax.Var(name_hint, relax.TensorType(relax_shape, dtype)) if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: user_inputs[name_hint] = relax_var else: diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 66d17a58283b..31018dbf8fc3 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -66,7 +66,7 @@ def _fetch_attr(self, model, target: str): def _reciprocal(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) + return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.ty.dtype), x)) def _leakyrelu_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -84,23 +84,23 @@ def _softplus_module(self, node: fx.Node) -> relax.Var: def _log2(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit( - relax.op.divide(relax.op.log(x), relax.const(0.6931471805599453, x.struct_info.dtype)) + relax.op.divide(relax.op.log(x), relax.const(0.6931471805599453, x.ty.dtype)) ) def _log10(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit( - relax.op.divide(relax.op.log(x), relax.const(2.302585092994046, x.struct_info.dtype)) + relax.op.divide(relax.op.log(x), relax.const(2.302585092994046, x.ty.dtype)) ) def _log1p(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - one = relax.const(1, x.struct_info.dtype) + one = relax.const(1, x.ty.dtype) return self.block_builder.emit(relax.op.log(relax.op.add(x, one))) def _sqrt(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - dtype = x.struct_info.dtype + dtype = x.ty.dtype # Check if input is integer type and convert to float32 if needed if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]: @@ -110,7 +110,7 @@ def _sqrt(self, node: fx.Node) -> relax.Var: def _rsqrt(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - dtype = x.struct_info.dtype + dtype = x.ty.dtype # Check if input is integer type and convert to float32 if needed if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]: @@ -130,7 +130,7 @@ def _prelu_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] alpha_tensor = module.weight.numpy() alpha = relax.const(alpha_tensor, dtype="float32") - axis = 0 if len(x.struct_info.shape) == 1 else 1 # Extract Channel size + axis = 0 if len(x.ty.shape) == 1 else 1 # Extract Channel size return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) def _softmax_module(self, node: fx.Node) -> relax.Var: @@ -164,11 +164,11 @@ def promote_binary_op_args(lhs, rhs): if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): return lhs, rhs elif isinstance(lhs, relax.Expr): - assert isinstance(lhs.struct_info, relax.TensorStructInfo) - return lhs, relax.const(rhs, lhs.struct_info.dtype) + assert isinstance(lhs.ty, relax.TensorType) + return lhs, relax.const(rhs, lhs.ty.dtype) elif isinstance(rhs, relax.Expr): - assert isinstance(rhs.struct_info, relax.TensorStructInfo) - return relax.const(lhs, rhs.struct_info.dtype), rhs + assert isinstance(rhs.ty, relax.TensorType) + return relax.const(lhs, rhs.ty.dtype), rhs else: assert False @@ -183,16 +183,12 @@ def call_binary_op(op, lhs, rhs): return output elif isinstance(lhs, relax.expr.Constant): - output = call_binary_op( - relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) - ) + output = call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.ty.dtype)) self.env[node.args[0]] = output return output elif isinstance(rhs, relax.expr.Constant): - output = call_binary_op( - relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs - ) + output = call_binary_op(relax_op, relax.const(lhs, dtype=rhs.ty.dtype), rhs) self.env[node.args[0]] = output return output @@ -209,7 +205,7 @@ def _adaptive_avg_pool1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output_size = module.output_size # Expand to 3D by adding batch dim if input is 2D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 2: x = relax.op.expand_dims(x, axis=0) result = self.block_builder.emit( @@ -225,7 +221,7 @@ def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output_size = module.output_size # Expand to 4D by adding batch dim if input is 3D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 3: x = relax.op.expand_dims(x, axis=0) result = self.block_builder.emit( @@ -241,7 +237,7 @@ def _adaptive_avg_pool3d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output_size = module.output_size # Expand to 5D by adding batch dim if input is 4D - x_ndim = x.struct_info.ndim + x_ndim = x.ty.ndim if x_ndim == 4: x = relax.op.expand_dims(x, axis=0) result = self.block_builder.emit( @@ -312,7 +308,7 @@ def _instance_norm(self, node: fx.Node) -> relax.Var: else: import numpy as np - dtype = x.struct_info.dtype + dtype = x.ty.dtype channel = int(self.shape_of(x)[1]) weight = relax.const(np.ones(channel), dtype=dtype) bias = relax.const(np.zeros(channel), dtype=dtype) @@ -432,7 +428,7 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: if weights in self.params: weights = self.params[weights] else: - weights = relax.const(weights.numpy(), preds.struct_info.dtype) + weights = relax.const(weights.numpy(), preds.ty.dtype) reduction = module.reduction ignore_index = module.ignore_index @@ -461,8 +457,8 @@ def _group_norm_module(self, node: fx.Node) -> relax.Var: gamma = self.params[module.weight] beta = self.params[module.bias] else: - gamma = relax.const(torch.ones_like(module.num_channels), x.struct_info.dtype) - beta = relax.const(torch.zeros_like(module.num_channels), x.struct_info.dtype) + gamma = relax.const(torch.ones_like(module.num_channels), x.ty.dtype) + beta = relax.const(torch.zeros_like(module.num_channels), x.ty.dtype) eps = module.eps dim = len(self.shape_of(x)) @@ -559,7 +555,7 @@ def _interpolate(self, node: fx.Node) -> relax.Var: else: coord_trans = "half_pixel" - if data.struct_info.ndim == 5: + if data.ty.ndim == 5: if self.default_image_layout == "NDHWC": layout_3d = "NDHWC" else: @@ -691,8 +687,8 @@ def _inplace_copy(self, node: fx.Node) -> relax.Var: dest = self.env[node.args[0]] src = self.env[node.args[1]] - if src.struct_info.dtype != dest.struct_info.dtype: - src = self.block_builder.emit(relax.op.astype(src, dest.struct_info.dtype)) + if src.ty.dtype != dest.ty.dtype: + src = self.block_builder.emit(relax.op.astype(src, dest.ty.dtype)) dest_shape = self.shape_of(dest) src_shape = self.shape_of(src) @@ -706,7 +702,7 @@ def _masked_scatter(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] source = self.env[node.args[2]] - ndim = len(mask.struct_info.shape) + ndim = len(mask.ty.shape) if ndim == 1: index = self.block_builder.emit(relax.op.cumsum(mask, 0, dtype="int32")) index = self.block_builder.emit(relax.op.subtract(index, relax.const(1, "int32"))) @@ -715,16 +711,12 @@ def _masked_scatter(self, node: fx.Node) -> relax.Var: f_mask = self.block_builder.emit(relax.op.reshape(mask, [-1])) index = self.block_builder.emit(relax.op.cumsum(f_mask, 0, dtype="int32")) index = self.block_builder.emit(relax.op.subtract(index, relax.const(1, "int32"))) - source_shape = [-1] + [ - s for idx, s in enumerate(source.struct_info.shape) if idx >= ndim - ] + source_shape = [-1] + [s for idx, s in enumerate(source.ty.shape) if idx >= ndim] f_source = self.block_builder.emit(relax.op.reshape(source, source_shape)) gathered_source = self.block_builder.emit(relax.op.take(f_source, index, axis=0)) - gathered_source = self.block_builder.emit( - relax.op.reshape(gathered_source, x.struct_info.shape) - ) - if ndim != len(x.struct_info.shape): - mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape)) + gathered_source = self.block_builder.emit(relax.op.reshape(gathered_source, x.ty.shape)) + if ndim != len(x.ty.shape): + mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.ty.shape)) return self.block_builder.emit(relax.op.where(mask, gathered_source, x)) def _one_hot(self, node: fx.Node) -> relax.Var: @@ -757,9 +749,7 @@ def _half(self, node: fx.Node) -> relax.Var: def _is_floating_point(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - return relax.const( - x.struct_info.dtype in ["float16", "float32", "float64", "bfloat16"], "bool" - ) + return relax.const(x.ty.dtype in ["float16", "float32", "float64", "bfloat16"], "bool") def _type(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -771,7 +761,7 @@ def _type(self, node: fx.Node) -> relax.Var: def _getattr(self, node: fx.Node) -> relax.Var: if isinstance(self.env[node.args[0]], relax.Expr): if node.args[1] == "dtype": - return self.env[node.args[0]].struct_info.dtype + return self.env[node.args[0]].ty.dtype elif node.args[1] == "shape": return self.shape_of(self.env[node.args[0]]) return getattr(self.env[node.args[0]], node.args[1]) @@ -780,9 +770,7 @@ def create_input_vars(self, input_info: list[tuple[tuple[int], str]]) -> list[re inputs = list() for idx, (shape, dtype) in enumerate(input_info): inputs.append( - relax.Var( - f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) - ) + relax.Var(f"inp_{idx}", relax.TensorType(shape, self._convert_data_type(dtype))) ) return inputs @@ -1092,7 +1080,7 @@ def from_fx( for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): shape = param.data.shape dtype = self._convert_data_type(str(param.data.dtype)) - inputs.append(relax.Var(name, relax.TensorStructInfo(shape, dtype))) + inputs.append(relax.Var(name, relax.TensorType(shape, dtype))) self.params[param] = inputs[-1] params.append(tvm.runtime.tensor(param.data.cpu().numpy())) else: diff --git a/python/tvm/relax/ir/instrument.py b/python/tvm/relax/ir/instrument.py index 6254a75a8e00..711ec10e6ac3 100644 --- a/python/tvm/relax/ir/instrument.py +++ b/python/tvm/relax/ir/instrument.py @@ -27,9 +27,9 @@ class WellFormedInstrument: Parameters ---------- - check_struct_info: bool + check_ty: bool - If True, validate the struct info in the module. If False, + If True, validate the type in the module. If False, skip these checks. validate_before_transform: bool @@ -39,9 +39,9 @@ class WellFormedInstrument: after running a transform. """ - def __init__(self, check_struct_info: bool = True, validate_before_transform: bool = True): + def __init__(self, check_ty: bool = True, validate_before_transform: bool = True): self.skip_pass_name = ["Normalize", "NormalizeGlobalVar", "ResolveGlobals"] - self.check_struct_info = check_struct_info + self.check_ty = check_ty self.validate_before_transform = validate_before_transform def run_before_pass(self, mod, pass_info): @@ -53,7 +53,7 @@ def run_after_pass(self, mod, pass_info): def _check(self, mod, pass_name, name_prefix): if pass_name not in self.skip_pass_name: - is_well_formed = relax.analysis.check_well_formed(mod, self.check_struct_info) + is_well_formed = relax.analysis.check_well_formed(mod, self.check_ty) if not is_well_formed: mod.show(name=f"{name_prefix}{pass_name}") assert is_well_formed diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index 2d949ee27643..5aae26b75f20 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -22,7 +22,7 @@ from tvm import relax from tvm.arith import Analyzer -from tvm.relax.struct_info import ShapeStructInfo +from tvm.relax.type import ShapeType from ...tirx import PrimExpr from ..block_builder import BlockBuilder @@ -65,7 +65,7 @@ def _get_shape(expr: Expr) -> ShapeExpr: """Get the shape from a Tensor expr.""" try: - shape = expr.struct_info.shape + shape = expr.ty.shape except Exception as error: raise RuntimeError( f"Get the shape of {expr} failed. Please normalize it first and ensure it is a Tensor." @@ -76,7 +76,7 @@ def _get_shape(expr: Expr) -> ShapeExpr: def _get_dtype(expr: Expr) -> str: """Get the dtype from a Tensor expr.""" try: - dtype = expr.struct_info.dtype + dtype = expr.ty.dtype except Exception as error: raise RuntimeError( f"Get the dtype of {expr} failed. Please normalize it first and ensure it is a Tensor." @@ -86,17 +86,17 @@ def _get_dtype(expr: Expr) -> str: def _fit_shape(bb: BlockBuilder, input_grad: Expr, input: Expr) -> Expr: """When expr and target has the same shape, return expr; - otherwise return `collapse_sum_to(expr, target.struct_info.shape)`. + otherwise return `collapse_sum_to(expr, target.ty.shape)`. Will use BlockBuilder to normalize expr first. """ target_shape = _get_shape(input) - expr_sinfo = _get_shape(bb.normalize(input_grad)).struct_info - target_sinfo = target_shape.struct_info - assert isinstance(expr_sinfo, ShapeStructInfo) - assert isinstance(target_sinfo, ShapeStructInfo) + expr_ty = _get_shape(bb.normalize(input_grad)).ty + target_ty = target_shape.ty + assert isinstance(expr_ty, ShapeType) + assert isinstance(target_ty, ShapeType) - def _check_shape_equal(lhs: ShapeStructInfo, rhs: ShapeStructInfo): + def _check_shape_equal(lhs: ShapeType, rhs: ShapeType): if len(lhs.values) != len(rhs.values): return False analyzer = Analyzer() @@ -107,7 +107,7 @@ def _check_shape_equal(lhs: ShapeStructInfo, rhs: ShapeStructInfo): return ( input_grad - if _check_shape_equal(expr_sinfo, target_sinfo) + if _check_shape_equal(expr_ty, target_ty) else collapse_sum_to(input_grad, target_shape) ) @@ -736,13 +736,13 @@ def concat_grad( assert axis is not None axis = int(axis) split_indices: list[PrimExpr] = [] - sinfo = orig_call.args[0].struct_info - assert isinstance(sinfo, relax.TupleStructInfo) - for i in range(len(sinfo.fields) - 1): - tensor_sinfo = sinfo.fields[i] - assert isinstance(tensor_sinfo, relax.TensorStructInfo) - assert tensor_sinfo.shape is not None - index = tensor_sinfo.shape[axis] + ty = orig_call.args[0].ty + assert isinstance(ty, relax.TupleType) + for i in range(len(ty.fields) - 1): + tensor_ty = ty.fields[i] + assert isinstance(tensor_ty, relax.TensorType) + assert tensor_ty.shape is not None + index = tensor_ty.shape[axis] if i > 0: index += split_indices[i - 1] split_indices.append(index) @@ -1108,7 +1108,7 @@ def cross_entropy_with_logits_grad( """ x, y = orig_call.args - if x.struct_info.ndim > 1: + if x.ty.ndim > 1: batch_size = int(_get_shape(x)[0]) output_grad = output_grad / relax.const(batch_size, _get_dtype(output_grad)) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 04f12d087f65..4534c3efc43d 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -27,7 +27,7 @@ from ...ir import PrimExpr from ..expr import Call, Expr, ExternFunc, GlobalVar, ShapeExpr, StringImm, Var -from ..struct_info import StructInfo, TensorStructInfo +from ..type import TensorType, Type from ..utils import convert_to_expr from . import _ffi_api @@ -82,10 +82,7 @@ def _wrap_inline_arg_tuple(args) -> Expr: elif ( isinstance(args, Expr) and not isinstance(args, tvm.relax.Tuple) - and ( - args.struct_info_ is None - or not isinstance(args.struct_info_, tvm.relax.TupleStructInfo) - ) + and (args.ty is None or not isinstance(args.ty, tvm.relax.TupleType)) ): return tvm.relax.Tuple([args]) else: @@ -95,7 +92,7 @@ def _wrap_inline_arg_tuple(args) -> Expr: def call_tir( gvar: GlobalVar, args: Expr, - out_sinfo: TensorStructInfo | list[TensorStructInfo], + out_ty: TensorType | list[TensorType], tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, ) -> Call: """ @@ -109,10 +106,10 @@ def call_tir( args : Expr The input arguments. - out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] - The structure info of the call_tir output. - It should be a single or a list of TensorStructInfo. Each one denotes the - structure info of a returned tensor. + out_ty : Union[TensorType, List[TensorType]] + The type information of the call_tir output. + It should be a single or a list of TensorType. Each one denotes the + type information of a returned tensor. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used @@ -124,19 +121,19 @@ def call_tir( """ args = _wrap_inline_arg_tuple(args) - if not isinstance(out_sinfo, list): - out_sinfo = [out_sinfo] + if not isinstance(out_ty, list): + out_ty = [out_ty] if isinstance(tir_vars, list | tuple): tir_vars = ShapeExpr(tir_vars) - return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore + return _ffi_api.call_tir(gvar, args, out_ty, tir_vars) # type: ignore def call_tir_with_grad( gvar: GlobalVar, args: Expr, - out_sinfo: TensorStructInfo | list[TensorStructInfo], + out_ty: TensorType | list[TensorType], te_grad_name: str, te_grad_kwargs: dict[str, Object] | None = None, tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, @@ -154,10 +151,10 @@ def call_tir_with_grad( args : Expr The input arguments. - out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] - The structure info of the call_tir_with_grad output. - It should be a single or a list of TensorStructInfo. Each one denotes the - structure info of a returned tensor. + out_ty : Union[TensorType, List[TensorType]] + The type information of the call_tir_with_grad output. + It should be a single or a list of TensorType. Each one denotes the + type information of a returned tensor. te_grad_name : str The registered name of the te gradient function associated with the call_tir_with_grad @@ -177,8 +174,8 @@ def call_tir_with_grad( """ args = _wrap_inline_arg_tuple(args) - if not isinstance(out_sinfo, list): - out_sinfo = [out_sinfo] + if not isinstance(out_ty, list): + out_ty = [out_ty] if isinstance(tir_vars, list | tuple): tir_vars = ShapeExpr(tir_vars) @@ -187,7 +184,7 @@ def call_tir_with_grad( te_grad_kwargs = {} return _ffi_api.call_tir_with_grad( # type: ignore - gvar, args, out_sinfo, te_grad_name, te_grad_kwargs, tir_vars + gvar, args, out_ty, te_grad_name, te_grad_kwargs, tir_vars ) @@ -195,7 +192,7 @@ def call_tir_inplace( gvar: GlobalVar, args: Expr, inplace_indices: int | list[int], - out_sinfo: TensorStructInfo | list[TensorStructInfo], + out_ty: TensorType | list[TensorType], tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, ) -> Call: """ @@ -227,11 +224,11 @@ def call_tir_inplace( If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. At least one member of `inplace_indices` must not be -1. - out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] - The structure info of the call_tir_inplace output. - It should be a single `TensorStructInfo` or a list of `TensorStructInfo`. - Each one denotes the structure info of a returned tensor. - If a list of `TensorStructInfo` is given, the result will be a tuple of `TensorStructInfo`. + out_ty : Union[TensorType, List[TensorType]] + The type information of the call_tir_inplace output. + It should be a single `TensorType` or a list of `TensorType`. + Each one denotes the type information of a returned tensor. + If a list of `TensorType` is given, the result will be a tuple of `TensorType`. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used @@ -246,8 +243,8 @@ def call_tir_inplace( if not isinstance(inplace_indices, list): inplace_indices = [inplace_indices] - if not isinstance(out_sinfo, list): - out_sinfo = [out_sinfo] + if not isinstance(out_ty, list): + out_ty = [out_ty] if isinstance(tir_vars, list | tuple): tir_vars = ShapeExpr(tir_vars) @@ -256,7 +253,7 @@ def call_tir_inplace( gvar, args, inplace_indices, - out_sinfo, + out_ty, tir_vars, ) @@ -264,7 +261,7 @@ def call_tir_inplace( def call_dps_packed( func: str | Expr, args: Expr, - out_sinfo: TensorStructInfo | list[TensorStructInfo], + out_ty: TensorType | list[TensorType], ) -> Call: """ Call a destination-passing-style packed function and return the output. @@ -281,10 +278,10 @@ def call_dps_packed( args : Expr The input arguments. - out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] - The structure info of the call_dps_packed output. - It should be a single or a list of TensorStructInfo. Each one denotes the - structure info of a returned tensor. + out_ty : Union[TensorType, List[TensorType]] + The type information of the call_dps_packed output. + It should be a single or a list of TensorType. Each one denotes the + type information of a returned tensor. Returns ------- @@ -296,16 +293,16 @@ def call_dps_packed( args = _wrap_inline_arg_tuple(args) - if not isinstance(out_sinfo, list): - out_sinfo = [out_sinfo] + if not isinstance(out_ty, list): + out_ty = [out_ty] - return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore + return _ffi_api.call_dps_packed(func, args, out_ty) # type: ignore def call_py_func( func_name: str, args: Expr, - out_sinfo: TensorStructInfo | list[TensorStructInfo], + out_ty: TensorType | list[TensorType], ) -> Call: """ Call a Python function and return the output. @@ -319,10 +316,10 @@ def call_py_func( args : Expr The input arguments. - out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] - The structure info of the call_py_func output. - It should be a single or a list of TensorStructInfo. Each one denotes the - structure info of a returned tensor. + out_ty : Union[TensorType, List[TensorType]] + The type information of the call_py_func output. + It should be a single or a list of TensorType. Each one denotes the + type information of a returned tensor. Returns ------- @@ -331,17 +328,17 @@ def call_py_func( """ args = _wrap_inline_arg_tuple(args) - if not isinstance(out_sinfo, list): - out_sinfo = [out_sinfo] + if not isinstance(out_ty, list): + out_ty = [out_ty] - return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore + return _ffi_api.call_py_func(func_name, args, out_ty) # type: ignore def call_builtin_with_ctx( func: str | Expr, args: Expr, *, - sinfo_args: StructInfo | list[StructInfo] | None = None, + ty_args: Type | list[Type] | None = None, ) -> Call: """Call a builtin function func. @@ -353,8 +350,8 @@ def call_builtin_with_ctx( args : Expr The input arguments. - sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] - The struct info arguments to the call node. + ty_args: Optional[Union[Type, List[Type]]] + The type arguments to the call node. Returns ------- @@ -366,13 +363,13 @@ def call_builtin_with_ctx( args = _wrap_inline_arg_tuple(args) - if sinfo_args is not None and not isinstance(sinfo_args, list | tuple): - sinfo_args = [sinfo_args] + if ty_args is not None and not isinstance(ty_args, list | tuple): + ty_args = [ty_args] return _ffi_api.call_builtin_with_ctx( # type: ignore func, args, - sinfo_args, # type: ignore + ty_args, # type: ignore ) @@ -406,7 +403,7 @@ def make_closure( def invoke_closure( closure: Expr, args: Expr, - sinfo_args: list[StructInfo] | StructInfo, + ty_args: list[Type] | Type, ) -> Call: """ Invoke a closure. @@ -419,8 +416,8 @@ def invoke_closure( args : Expr The input arguments. - type_args: Union[List[StructInfo], StructInfo] - The structure info arguments of the CallNode + type_args: Union[List[Type], Type] + The type information arguments of the CallNode Returns ------- @@ -429,10 +426,10 @@ def invoke_closure( """ args = _wrap_inline_arg_tuple(args) - if not isinstance(sinfo_args, list | tuple): - sinfo_args = [sinfo_args] + if not isinstance(ty_args, list | tuple): + ty_args = [ty_args] - return _ffi_api.invoke_closure(closure, args, sinfo_args) # type: ignore + return _ffi_api.invoke_closure(closure, args, ty_args) # type: ignore def render_object(val: tvm.Object) -> str: @@ -681,7 +678,7 @@ def call_inplace_packed( func: str | ExternFunc | GlobalVar, *args: Expr, inplace_indices: int | list[int], - sinfo_args: StructInfo | list[StructInfo], + ty_args: Type | list[Type], ) -> Expr: """ Construct a call to a packed function that consumes some of its arguments "in-place" @@ -717,36 +714,36 @@ def call_inplace_packed( If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. At least one member of `inplace_indices` must not be -1. - sinfo_args: Union[StructInfo, List[StructInfo]] - The list of structure info arguments (giving the structural info for the returned value). + ty_args: Union[Type, List[Type]] + The list of type information arguments (giving the type information for the returned value). Returns ------- result : Expr A Relax call, corresponding to - `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), sinfo_args)` + `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), ty_args)` """ if isinstance(func, ExternFunc): func = func.global_symbol op = ExternFunc(func) args = tuple(convert_to_expr(a) for a in args) - if sinfo_args is None: + if ty_args is None: raise ValueError("R.call_pure_packed is required to have type_args") - if isinstance(sinfo_args, tuple): # type: ignore - sinfo_args = list(sinfo_args) - elif not isinstance(sinfo_args, list): - sinfo_args = [sinfo_args] + if isinstance(ty_args, tuple): # type: ignore + ty_args = list(ty_args) + elif not isinstance(ty_args, list): + ty_args = [ty_args] if not isinstance(inplace_indices, list): inplace_indices = [inplace_indices] - return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args) # type: ignore # pylint: disable=no-member + return _ffi_api.call_inplace_packed(op, args, inplace_indices, ty_args) # type: ignore # pylint: disable=no-member def call_pure_packed( func: str | ExternFunc | GlobalVar, *args: Expr, - sinfo_args: StructInfo | list[StructInfo], + ty_args: Type | list[Type], ) -> Expr: """ Construct a call to a packed function that should be treated as pure, @@ -768,14 +765,14 @@ def call_pure_packed( args: Expr The arguments for the PackedFunc. - sinfo_args: Union[StructInfo, List[StructInfo]] - The list of structure info arguments (giving the structural info for the returned value). + ty_args: Union[Type, List[Type]] + The list of type information arguments (giving the type information for the returned value). Returns ------- result : Expr A Relax call, corresponding to - `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), sinfo_args)` + `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), ty_args)` """ if isinstance(func, ExternFunc): func = func.global_symbol @@ -783,34 +780,28 @@ def call_pure_packed( op = ExternFunc(func) args = tuple(convert_to_expr(a) for a in args) - if sinfo_args is None: + if ty_args is None: raise ValueError("R.call_pure_packed is required to have type_args") - if isinstance(sinfo_args, tuple): # type: ignore - sinfo_args = list(sinfo_args) - elif not isinstance(sinfo_args, list): - sinfo_args = [sinfo_args] - - sinfo_args = [ - ( - sinfo() - if callable(sinfo) - else sinfo.asobject() - if isinstance(sinfo, ObjectConvertible) - else sinfo - ) - for sinfo in sinfo_args + if isinstance(ty_args, tuple): # type: ignore + ty_args = list(ty_args) + elif not isinstance(ty_args, list): + ty_args = [ty_args] + + ty_args = [ + (ty() if callable(ty) else ty.asobject() if isinstance(ty, ObjectConvertible) else ty) + for ty in ty_args ] # note: if we need attributes, we can also take them here - return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type: ignore # pylint: disable=no-member + return _ffi_api.call_pure_packed(op, args, None, ty_args) # type: ignore # pylint: disable=no-member def invoke_pure_closure( closure: Expr, args: Expr, - sinfo_args: list[StructInfo] | StructInfo, + ty_args: list[Type] | Type, ) -> Call: """ Invoke a closure and indicate to the compiler that it is pure. @@ -829,8 +820,8 @@ def invoke_pure_closure( args : Expr The input arguments. - type_args: Union[List[StructInfo], StructInfo] - The structure info arguments of the CallNode + type_args: Union[List[Type], Type] + The type information arguments of the CallNode Returns ------- @@ -839,10 +830,10 @@ def invoke_pure_closure( """ args = _wrap_inline_arg_tuple(args) - if not isinstance(sinfo_args, list | tuple): - sinfo_args = [sinfo_args] + if not isinstance(ty_args, list | tuple): + ty_args = [ty_args] - return _ffi_api.invoke_pure_closure(closure, args, sinfo_args) # type: ignore + return _ffi_api.invoke_pure_closure(closure, args, ty_args) # type: ignore def to_vdevice(data, dst_vdevice) -> Expr: diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 9480612e6f52..4b26226276a7 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -43,9 +43,9 @@ def add(x1: Expr, x2: Expr) -> Expr: .. code:: python bb = relax.BlockBuilder() - a = relax.Var("a", relax.TensorStructInfo(shape=(2, 3), dtype="float32")) - b = relax.Var("b", relax.TensorStructInfo(shape=(2, 1), dtype="float32")) - c = bb.normalize(relax.op.add(a, b)) # c has TensorStructInfo(shape=(2, 3), dtype="float32") + a = relax.Var("a", relax.TensorType(shape=(2, 3), dtype="float32")) + b = relax.Var("b", relax.TensorType(shape=(2, 1), dtype="float32")) + c = bb.normalize(relax.op.add(a, b)) # c has TensorType(shape=(2, 3), dtype="float32") """ return _ffi_api.add(x1, x2) # type: ignore diff --git a/python/tvm/relax/op/distributed/distributed.py b/python/tvm/relax/op/distributed/distributed.py index 07ff674dd09e..aa35125257c0 100644 --- a/python/tvm/relax/op/distributed/distributed.py +++ b/python/tvm/relax/op/distributed/distributed.py @@ -18,8 +18,7 @@ """Operators for distributed Relax.""" from tvm.ir import PrimExpr -from tvm.relax.distributed import DTensorStructInfo -from tvm.relax.distributed.struct_info import DeviceMesh, Placement +from tvm.relax.distributed import DeviceMesh, DTensorType, Placement from ...expr import Call, Expr, GlobalVar, ShapeExpr from ...expr import Tuple as RxTuple @@ -69,7 +68,7 @@ def redistribute(input: Expr, device_mesh: DeviceMesh, placement: Placement) -> def call_tir_local_view( gvar: GlobalVar, args: Expr, - out_sinfo: DTensorStructInfo | list[DTensorStructInfo], + out_ty: DTensorType | list[DTensorType], tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, ) -> Call: """ @@ -85,10 +84,10 @@ def call_tir_local_view( args : Expr The input arguments. - out_sinfo : Union[DTensorStructInfo, List[DTensorStructInfo]] - The structure info of the call_tir output. - It should be a single or a list of DTensorStructInfo. Each one denotes the - structure info of a returned tensor. + out_ty : Union[DTensorType, List[DTensorType]] + The type information of the call_tir output. + It should be a single or a list of DTensorType. Each one denotes the + type information of a returned tensor. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used @@ -103,19 +102,19 @@ def call_tir_local_view( elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) - if not isinstance(out_sinfo, list): - out_sinfo = [out_sinfo] + if not isinstance(out_ty, list): + out_ty = [out_ty] if isinstance(tir_vars, list | tuple): tir_vars = ShapeExpr(tir_vars) - return _ffi_api.call_tir_local_view(gvar, args, out_sinfo, tir_vars) # type: ignore + return _ffi_api.call_tir_local_view(gvar, args, out_ty, tir_vars) # type: ignore def redistribute_replica_to_shard(input: Expr, num_workers: int, axis: int) -> Expr: """Slice tensor into several parts along one axis, and each worker takes one part. - input.struct_info.shape[axis] % num_workers == 0 is required. + input.ty.shape[axis] % num_workers == 0 is required. Each worker must have an identical copy of the input. This is a specialized version of redistribute op. diff --git a/python/tvm/relax/op/grad/grad.py b/python/tvm/relax/op/grad/grad.py index 45ceda6fdbcc..86c0792b54de 100644 --- a/python/tvm/relax/op/grad/grad.py +++ b/python/tvm/relax/op/grad/grad.py @@ -18,8 +18,8 @@ """Operators to implement operaor gradients. Used in `_op_gradient.py`. We are trying to keep grad operators as simple as possible, and hope they are only used for finding -gradients for forward operators. The struct_info inference for grad operators just returns the -struct_info of the input. +gradients for forward operators. The ty inference for grad operators just returns the +ty of the input. """ from ...expr import Expr @@ -52,8 +52,8 @@ def start_checkpoint(input: Expr) -> Expr: For instance, ``` - a = relax.Var("a", relax.TensorStructInfo((2, 2), "float32")) - b = relax.Var("b", relax.TensorStructInfo((2, 2), "float32")) + a = relax.Var("a", relax.TensorType((2, 2), "float32")) + b = relax.Var("b", relax.TensorType((2, 2), "float32")) c = a * 2 d = b * 2 c_cp = start_checkpoint(c) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 21fd7b565c4e..e4814bc62ab6 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -144,7 +144,7 @@ def layout_transform( if callable(index_map): index_map = IndexMap.from_func(index_map, index_dtype=default_index_dtype) - x_dtype = x.struct_info.dtype + x_dtype = x.ty.dtype # Explicitly convert python int/float pad_value to the x's type. If the default behavior # is applied, it would be converted to int32/float32, which may not match the x's type. diff --git a/python/tvm/relax/op/mask.py b/python/tvm/relax/op/mask.py index 3c013768a276..a4f821884927 100644 --- a/python/tvm/relax/op/mask.py +++ b/python/tvm/relax/op/mask.py @@ -35,5 +35,5 @@ def masked_fill(x: Expr, mask: Expr, value: Expr): result : relax.Expr The filled tensor. """ - values = _ffi_api.full_like(x, value, value.struct_info.dtype) # type: ignore + values = _ffi_api.full_like(x, value, value.ty.dtype) # type: ignore return _ffi_api.where(mask, values, x) # type: ignore diff --git a/python/tvm/relax/op/vision/multibox_transform_loc.py b/python/tvm/relax/op/vision/multibox_transform_loc.py index 6830b1dc6321..e4e41a987372 100644 --- a/python/tvm/relax/op/vision/multibox_transform_loc.py +++ b/python/tvm/relax/op/vision/multibox_transform_loc.py @@ -60,7 +60,7 @@ def multibox_transform_loc( Notes ----- - **Shape/dtype (checked in ``FInferStructInfo`` when static):** + **Shape/dtype (checked in ``FInferType`` when static):** - ``cls_pred``: 3-D; ``loc_pred``: 2-D; ``anchor``: 3-D. - ``cls_pred``, ``loc_pred``, ``anchor`` dtypes must match. @@ -69,7 +69,7 @@ def multibox_transform_loc( - ``cls_pred.shape[0]`` must equal ``loc_pred.shape[0]`` (batch). If ``cls_pred`` has **unknown** shape, inference only returns generic rank-3 tensor - struct info for the two outputs; it does **not** verify ``4*N`` vs ``loc_pred`` or + type for the two outputs; it does **not** verify ``4*N`` vs ``loc_pred`` or ``anchor.shape[1]`` vs ``N``, because ``N`` is not available statically. Other checks (ranks, dtypes, ``loc_pred.shape[1] % 4 == 0`` when known, batch match when both batch axes are known, etc.) still run where applicable. diff --git a/python/tvm/relax/relax_to_pyfunc_converter.py b/python/tvm/relax/relax_to_pyfunc_converter.py index bd15a85610f0..970e51360eb7 100644 --- a/python/tvm/relax/relax_to_pyfunc_converter.py +++ b/python/tvm/relax/relax_to_pyfunc_converter.py @@ -426,8 +426,8 @@ def _convert_var(self, var: relax.Var, args: list[Any]) -> Any: return self.variable_map[var_name] # Try to infer shape from var's type annotation - if hasattr(var, "struct_info") and hasattr(var.struct_info, "shape"): - shape = var.struct_info.shape + if hasattr(var, "ty") and hasattr(var.ty, "shape"): + shape = var.ty.shape if shape and len(shape) > 0: # Convert symbolic shapes to concrete values concrete_shape = [] @@ -599,7 +599,7 @@ def _convert_call_tir(self, call: relax.Call, args: list[Any]) -> Any: # Extract TIR function name and arguments tir_func = call.args[0] tir_args = call.args[1] if len(call.args) > 1 else [] - out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + out_ty = call.attrs.get("out_ty") if call.attrs else None # Get function name if isinstance(tir_func, relax.GlobalVar): @@ -660,8 +660,8 @@ def _convert_call_tir(self, call: relax.Call, args: list[Any]) -> Any: # For call_tir, we need to allocate output tensor output_shape = None - if out_sinfo and hasattr(out_sinfo, "shape"): - output_shape = out_sinfo.shape + if out_ty and hasattr(out_ty, "shape"): + output_shape = out_ty.shape elif converted_args: # Use the shape of the first input tensor first_arg = converted_args[0] @@ -713,7 +713,7 @@ def _convert_call_dps_packed(self, call: relax.Call, args: list[Any]) -> Any: # Extract packed function name and arguments packed_func = call.args[0] packed_args = call.args[1] if len(call.args) > 1 else [] - _out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + _out_ty = call.attrs.get("out_ty") if call.attrs else None # Get function name if isinstance(packed_func, relax.GlobalVar): diff --git a/python/tvm/relax/script/builder/distributed/ir.py b/python/tvm/relax/script/builder/distributed/ir.py index a74727c99560..798fb99c675b 100644 --- a/python/tvm/relax/script/builder/distributed/ir.py +++ b/python/tvm/relax/script/builder/distributed/ir.py @@ -27,7 +27,7 @@ import tvm from tvm import base as _base from tvm.ir import PrimExpr -from tvm.relax.distributed import DeviceMesh, DTensorStructInfo, Placement +from tvm.relax.distributed import DeviceMesh, DTensorType, Placement from tvm.relax.expr import Call, Constant, Expr, ExternFunc, ShapeExpr from tvm.relax.expr import Tuple as RxTuple from tvm.relax.op.distributed import ( @@ -52,7 +52,7 @@ def call_tir( func: str | Expr, args: Expr, - out_sinfo: DTensorStructInfo | list[DTensorStructInfo], + out_ty: DTensorType | list[DTensorType], tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, ) -> Call: """Distributed version of call_tir @@ -65,10 +65,10 @@ def call_tir( args : Expr The input arguments. - out_sinfo : Union[DTensorStructInfo, List[DTensorStructInfo]] - The structure info of the call_tir output. - It should be a single or a list of DTensorStructInfo. Each one denotes the - structure info of a returned distributed tensor. + out_ty : Union[DTensorType, List[DTensorType]] + The type information of the call_tir output. + It should be a single or a list of DTensorType. Each one denotes the + type information of a returned distributed tensor. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used @@ -86,18 +86,18 @@ def call_tir( elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) - if not isinstance(out_sinfo, list): - out_sinfo = [out_sinfo] + if not isinstance(out_ty, list): + out_ty = [out_ty] if isinstance(tir_vars, list | tuple): tir_vars = ShapeExpr(tir_vars) - return _ffi_api.call_tir_dist(func, args, out_sinfo, tir_vars) # type: ignore + return _ffi_api.call_tir_dist(func, args, out_ty, tir_vars) # type: ignore def const( value: bool | int | float | _np.ndarray | tvm.runtime.Tensor, - struct_info: DTensorStructInfo, + ty: DTensorType, ) -> Constant: """Create a constant value. @@ -118,10 +118,10 @@ def const( - bool maps to "bool" - other using the same default rule as numpy. """ - struct_info = tvm.runtime.convert(struct_info) - if not isinstance(struct_info, DTensorStructInfo): - raise TypeError("struct_info needs to be an instance of DTensorStructInfo. ") - dtype = str(struct_info.tensor_sinfo.dtype) + ty = tvm.runtime.convert(ty) + if not isinstance(ty, DTensorType): + raise TypeError("ty needs to be an instance of DTensorType. ") + dtype = str(ty.tensor_ty.dtype) if isinstance(value, Number | (bool | list)): value = _np.array(value, dtype=dtype) @@ -133,7 +133,7 @@ def const( if not isinstance(value, _tensor.Tensor): raise ValueError("value has to be scalar or Tensor") - return Constant(value, struct_info) + return Constant(value, ty) def _lookup_device_mesh(device_mesh_str: py_str) -> DeviceMesh: diff --git a/python/tvm/relax/script/builder/ir.py b/python/tvm/relax/script/builder/ir.py index 84ad485a33bf..a94d239105ca 100644 --- a/python/tvm/relax/script/builder/ir.py +++ b/python/tvm/relax/script/builder/ir.py @@ -198,7 +198,7 @@ call_py_func as _call_py_func, ) from tvm.relax.op.builtin import stop_lift_params -from tvm.relax.struct_info import StructInfo +from tvm.relax.type import Type from tvm.relax.utils import convert_to_expr, gen_call_tir_inputs from tvm.runtime import Object as tvm_Object from tvm.runtime import ObjectConvertible @@ -278,13 +278,13 @@ def function(is_pure: bool = True, is_private: bool = False) -> frame.FunctionFr ) -def arg(name: py_str, struct_info: StructInfo) -> Var: +def arg(name: py_str, ty: Type) -> Var: """Add a parameter to the last function frame. Parameters ---------- name: str The name of the parameter. - struct_info: StructInfo + ty: Type The Struct Info of the parameter Returns @@ -293,7 +293,7 @@ def arg(name: py_str, struct_info: StructInfo) -> Var: The created function parameter var. """ - return _ffi_api.Arg(name, struct_info) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Arg(name, ty) # type: ignore[attr-defined] # pylint: disable=no-member def func_name(name: py_str) -> None: @@ -316,14 +316,19 @@ def func_attr(attrs: dict[py_str, tvm_Object]) -> None: return _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member -def func_ret_struct_info(ret_sinfo: StructInfo) -> None: - """Specify the return struct info of the last function frame. +def func_ret_type(ret_ty: Type) -> None: + """Specify the return type of the last function frame. Parameters ---------- - ret_type: StructInfo - The function return struct info. + ret_ty: Type + The function return type. """ - return _ffi_api.FuncRetStructInfo(ret_sinfo) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.FuncRetType(ret_ty) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_ret_ty(ret_ty: Type) -> None: + """Backward-compatible alias for `func_ret_type`.""" + return func_ret_type(ret_ty) def func_ret_value(value: Expr) -> None: @@ -407,7 +412,7 @@ def output(*vars: tuple[Var]) -> None: def call_packed( func: py_str, *args: Expr, - sinfo_args: StructInfo | list[StructInfo] | None = None, + ty_args: Type | list[Type] | None = None, **kwargs: Any, ) -> Call: """Create a relax Call, which calls a packed function. @@ -417,8 +422,8 @@ def call_packed( The name of extern function. *args : Expr The arguments. - sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] - The list of structure info arguments. + ty_args: Optional[Union[Type, List[Type]]] + The list of type information arguments. kwargs: Expr The keyword arguments. @@ -429,22 +434,16 @@ def call_packed( """ op = ExternFunc(func) args = py_tuple(convert_to_expr(a) for a in args) - if sinfo_args is None: - sinfo_args = [] - if isinstance(sinfo_args, py_tuple): # type: ignore - sinfo_args = list(sinfo_args) - elif not isinstance(sinfo_args, list): - sinfo_args = [sinfo_args] - - sinfo_args = [ - ( - sinfo() - if callable(sinfo) - else sinfo.asobject() - if isinstance(sinfo, ObjectConvertible) - else sinfo - ) - for sinfo in sinfo_args + if ty_args is None: + ty_args = [] + if isinstance(ty_args, py_tuple): # type: ignore + ty_args = list(ty_args) + elif not isinstance(ty_args, list): + ty_args = [ty_args] + + ty_args = [ + (ty() if callable(ty) else ty.asobject() if isinstance(ty, ObjectConvertible) else ty) + for ty in ty_args ] is_default = False @@ -458,13 +457,13 @@ def call_packed( if kwargs or not is_default: attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) - return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) + return Call(op, args, attrs=attrs, ty_args=ty_args) def call_py_func( py_func_name: py_str, *args: Expr, - out_sinfo: StructInfo | list[StructInfo], + out_ty: Type | list[Type], ) -> Call: """Create a relax Call, which calls a Python function. @@ -475,10 +474,10 @@ def call_py_func( in the IRModule's pyfuncs attribute. *args : Expr The arguments. - out_sinfo: Union[StructInfo, List[StructInfo]] - The structure info of the call_py_func output. - It should be a single or a list of TensorStructInfo. Each one denotes the - structure info of a returned tensor. + out_ty: Union[Type, List[Type]] + The type information of the call_py_func output. + It should be a single or a list of TensorType. Each one denotes the + type information of a returned tensor. Returns ------- @@ -486,20 +485,14 @@ def call_py_func( The created Relax Call for call_py_func operator. """ args = py_tuple(convert_to_expr(a) for a in args) - if isinstance(out_sinfo, py_tuple): # type: ignore - out_sinfo = list(out_sinfo) - elif not isinstance(out_sinfo, list): - out_sinfo = [out_sinfo] - - out_sinfo = [ - ( - sinfo() - if callable(sinfo) - else sinfo.asobject() - if isinstance(sinfo, ObjectConvertible) - else sinfo - ) - for sinfo in out_sinfo + if isinstance(out_ty, py_tuple): # type: ignore + out_ty = list(out_ty) + elif not isinstance(out_ty, list): + out_ty = [out_ty] + + out_ty = [ + (ty() if callable(ty) else ty.asobject() if isinstance(ty, ObjectConvertible) else ty) + for ty in out_ty ] # Convert string to StringImm @@ -509,11 +502,11 @@ def call_py_func( ) except (TypeError, ValueError, AttributeError): func_name_imm = StringImm(py_func_name) - return _call_py_func(func_name_imm, args, out_sinfo) + return _call_py_func(func_name_imm, args, out_ty) -def _sinfo_arg_wrapper(func): - """A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args""" +def _ty_arg_wrapper(func): + """A wrapper to convert TypeProxies to Type for builtin operators with ty_args""" def _convert_tensor_type(args): if isinstance(args, list | py_tuple): # type: ignore @@ -534,30 +527,30 @@ def wrapped(*args, **kwargs): return wrapped # type: ignore -invoke_closure = _sinfo_arg_wrapper(invoke_closure) # pylint: disable=invalid-name +invoke_closure = _ty_arg_wrapper(invoke_closure) # pylint: disable=invalid-name -call_builtin_with_ctx = _sinfo_arg_wrapper(call_builtin_with_ctx) # pylint: disable=invalid-name +call_builtin_with_ctx = _ty_arg_wrapper(call_builtin_with_ctx) # pylint: disable=invalid-name ############################### Emits ############################### -def emit(value: Expr, annotate_struct_info: StructInfo | None = None) -> Var: +def emit(value: Expr, annotate_ty: Type | None = None) -> Var: """Emit a binding to the last binding block frame. Parameters ---------- value: Expr The right side value of the bindings to be emitted. - annotate_struct_info: Optional[StructInfo] - The optional struct info annotation for the emitted value. + annotate_ty: Optional[Type] + The optional type annotation for the emitted value. Returns ------- var: Var The left side var of the emitted binding. """ - return _ffi_api.Emit(value, annotate_struct_info) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Emit(value, annotate_ty) # type: ignore[attr-defined] # pylint: disable=no-member def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Call: @@ -588,28 +581,28 @@ def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Call: A newly created call that calls into a tirx function. """ primfunc_name_hint = kwargs.pop("primfunc_name_hint", None) - tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) + tir_func, call_args, out_ty, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) if not primfunc_name_hint: primfunc_name_hint = func.__name__ gvar = decl_function(primfunc_name_hint, tir_func) # type: ignore - return call_tir(gvar, call_args, out_sinfo, tir_vars) + return call_tir(gvar, call_args, out_ty, tir_vars) -def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var: +def emit_match_cast(value: Expr, ty: Type) -> Var: """Emit a match_cast binding to the last binding block frame. Parameters ---------- value: Expr The value of the MatchCast to be emitted. - struct_info: StructInfo - The struct_info of the MatchCast to be emitted. + ty: Type + The ty of the MatchCast to be emitted. Returns ------- var: Var The left side var of the emitted binding. """ - return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore + return _ffi_api.EmitMatchCast(value, ty) # type: ignore def emit_var_binding(value: VarBinding) -> Var: @@ -626,20 +619,20 @@ def emit_var_binding(value: VarBinding) -> Var: return _ffi_api.EmitVarBinding(value) # type: ignore -def emit_with_sinfo( +def emit_with_type( op: str, args: Expr, - sinfo_args: StructInfo | list[StructInfo] | None = None, + ty_args: Type | list[Type] | None = None, ) -> Call: - """Create a relax Call with sinfo_args. + """Create a Relax Call with type arguments. Parameters ---------- op: Expr - The relax op for which sinfo_args to be appended + The relax op for which type args are to be appended args : Expr The arguments. - sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] - The list of structure info arguments. + ty_args: Optional[Union[Type, List[Type]]] + The list of type arguments. Returns ------- @@ -647,7 +640,16 @@ def emit_with_sinfo( The created Relax Call """ builtin_call = tvm.ir.Op.get(op) - return Call(builtin_call, args, attrs=None, sinfo_args=sinfo_args) + return Call(builtin_call, args, attrs=None, ty_args=ty_args) + + +def emit_with_ty( + op: str, + args: Expr, + ty_args: Type | list[Type] | None = None, +) -> Call: + """Backward-compatible alias for `emit_with_type`.""" + return emit_with_type(op, args, ty_args) ############################### SeqExpr ############################### @@ -856,7 +858,8 @@ def dtype(value: py_str | DataType) -> Expr: "emit_match_cast", "emit_te", "emit_var_binding", - "emit_with_sinfo", + "emit_with_ty", + "emit_with_type", "equal", "erf", "ewise_fma", @@ -874,7 +877,8 @@ def dtype(value: py_str | DataType) -> Expr: "full_like", "func_attr", "func_name", - "func_ret_struct_info", + "func_ret_ty", + "func_ret_type", "func_ret_value", "function", "gather_elements", diff --git a/python/tvm/relax/script/parser/dist.py b/python/tvm/relax/script/parser/dist.py index f4e59f95fe6f..60f382907c32 100644 --- a/python/tvm/relax/script/parser/dist.py +++ b/python/tvm/relax/script/parser/dist.py @@ -20,8 +20,8 @@ from typing import Any, Optional, Union from tvm.ir import Range -from tvm.relax import TensorStructInfo -from tvm.relax.distributed import DeviceMesh, DTensorStructInfo, Placement, device_mesh +from tvm.relax import TensorType +from tvm.relax.distributed import DeviceMesh, DTensorType, Placement, device_mesh from tvm.relax.script.builder.distributed import ( annotate_sharding, call_tir, @@ -34,33 +34,33 @@ from tvm.script.ir_builder.ir import IRModuleFrame from tvm.tirx import PrimExpr -from .entry import StructInfoProxy, TensorProxy +from .entry import TensorProxy, TypeProxy ############################### R.DTensor ############################### -class DTensorProxy(StructInfoProxy): - tensor_sinfo_proxy: TensorProxy +class DTensorProxy(TypeProxy): + tensor_ty_proxy: TensorProxy device_mesh: DeviceMesh placement: Placement def __init__( self, - tensor_sinfo_proxy: TensorProxy, + tensor_ty_proxy: TensorProxy, device_mesh: DeviceMesh, placement: Placement, ) -> None: self.device_mesh = device_mesh self.placement = placement - self.tensor_sinfo_proxy = tensor_sinfo_proxy + self.tensor_ty_proxy = tensor_ty_proxy super().__init__() def get_symbolic_vars(self) -> set[str]: - return self.tensor_sinfo_proxy.get_symbolic_vars() + return self.tensor_ty_proxy.get_symbolic_vars() - def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> TensorStructInfo: - return DTensorStructInfo( - self.tensor_sinfo_proxy.as_struct_info(dict_globals), + def as_ty(self, dict_globals: dict[str, Any] | None = None) -> TensorType: + return DTensorType( + self.tensor_ty_proxy.as_ty(dict_globals), self.device_mesh, self.placement, ) diff --git a/python/tvm/relax/script/parser/entry.py b/python/tvm/relax/script/parser/entry.py index 5a14fe4ecd88..5c4611263bf0 100644 --- a/python/tvm/relax/script/parser/entry.py +++ b/python/tvm/relax/script/parser/entry.py @@ -22,16 +22,16 @@ import tvm from tvm.relax import ( Expr, - FuncStructInfo, Function, - ObjectStructInfo, - PrimStructInfo, + FuncType, + ObjectType, + PrimType, SeqExpr, ShapeExpr, - ShapeStructInfo, - StructInfo, - TensorStructInfo, - TupleStructInfo, + ShapeType, + TensorType, + TupleType, + Type, ) from tvm.relax.expr import Var from tvm.relax.script import builder as R @@ -145,22 +145,22 @@ def wrapper(*args, **kwargs): ############################# Struct Info ############################## -class StructInfoProxy(ObjectConvertible): - def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> StructInfo: +class TypeProxy(ObjectConvertible): + def as_ty(self, dict_globals: dict[str, Any] | None = None) -> Type: raise NotImplementedError() def get_symbolic_vars(self) -> set[str]: return {} def asobject(self): - return self.as_struct_info(None) + return self.as_ty(None) ############################### R.Object ################################ -class ObjectProxy(StructInfoProxy): - """The proxy fo ObjectStructInfo. +class ObjectProxy(TypeProxy): + """The proxy fo ObjectType. Parameters ---------- @@ -177,8 +177,8 @@ def __init__(self) -> None: def get_symbolic_vars(self) -> set[str]: return set() - def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> ShapeStructInfo: - return ObjectStructInfo() + def as_ty(self, dict_globals: dict[str, Any] | None = None) -> ShapeType: + return ObjectType() def Object() -> ObjectProxy: @@ -196,7 +196,7 @@ def _eval_shape(expr: str | PrimExpr, dict_globals: dict[str, Any] | None) -> Pr return expr -class TensorProxy(StructInfoProxy): +class TensorProxy(TypeProxy): shape: list[str | PrimExpr] | None dtype: str vdevice: str | None @@ -215,10 +215,10 @@ def __init__( "When the shape is an Expr, it must be a ShapeExpr or a Var with ShapeExpr " f"value. But got: {shape} with type: {type(shape)}" ) - if isinstance(shape, Var) and not isinstance(shape.struct_info, ShapeStructInfo): + if isinstance(shape, Var) and not isinstance(shape.ty, ShapeType): raise ValueError( - "When the shape is a Var, it must have shape struct_info. But got " - f"{shape} with struct_info: {shape.struct_info}" + "When the shape is a Var, it must have shape ty. But got " + f"{shape} with ty: {shape.ty}" ) self.shape = shape self.dtype = dtype @@ -231,7 +231,7 @@ def get_symbolic_vars(self) -> set[str]: else: return {s for s in self.shape if isinstance(s, str) and s.isidentifier()} - def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> TensorStructInfo: + def as_ty(self, dict_globals: dict[str, Any] | None = None) -> TensorType: vdev = self.vdevice if isinstance(self.vdevice, str): if ":" in self.vdevice: @@ -241,9 +241,9 @@ def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> TensorSt vdev = lookup_vdevice(self.vdevice, 0) if self.shape is None: - return TensorStructInfo(None, self.dtype, vdev, self.ndim) + return TensorType(None, self.dtype, vdev, self.ndim) elif isinstance(self.shape, ShapeExpr | Var): - return TensorStructInfo(self.shape, self.dtype, vdev, self.ndim) + return TensorType(self.shape, self.dtype, vdev, self.ndim) else: if dict_globals is None and any([isinstance(s, str) for s in self.shape]): raise ValueError( @@ -251,7 +251,7 @@ def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> TensorSt "and return annotations for TVMScript." ) shape = [_eval_shape(s, dict_globals) for s in self.shape] - return TensorStructInfo(shape, self.dtype, vdev, self.ndim) + return TensorType(shape, self.dtype, vdev, self.ndim) def Tensor( @@ -275,9 +275,9 @@ def Tensor( ############################## R.Callable ############################## -class CallableProxy(StructInfoProxy): - params: list[StructInfoProxy] - ret: StructInfoProxy +class CallableProxy(TypeProxy): + params: list[TypeProxy] + ret: TypeProxy purity: bool derive_func: str | tvm.ir.EnvFunc | None @@ -290,28 +290,28 @@ class CallableProxy(StructInfoProxy): Parameters ---------- - params : List[StructInfoProxy] - The argument StructInfoProxy + params : List[TypeProxy] + The argument TypeProxy - ret : StructInfoProxy - The return StructInfoProxy. + ret : TypeProxy + The return TypeProxy. purity : bool Whether the callable is pure. derive_func: Optional[Union[str, tvm.ir.EnvFunc]] - The derivation function to determine the output StructInfo, + The derivation function to determine the output Type, based on the arguments provided to the function. The specified function should be accessible using `tvm.get_global_func`, and should have a signature - `Callable[[relax.Call, relax.BlockBuilder], relax.StructInfo]`. + `Callable[[relax.Call, relax.BlockBuilder], relax.Type]`. """ def __init__( self, - params: StructInfoProxy | list[StructInfoProxy] | None = None, - ret: StructInfoProxy | None = None, + params: TypeProxy | list[TypeProxy] | None = None, + ret: TypeProxy | None = None, purity: bool | None = None, derive_func: str | tvm.ir.EnvFunc | None = None, ) -> None: @@ -339,28 +339,26 @@ def get_symbolic_vars(self) -> set[str]: else: return set().union(*[p.get_symbolic_vars() for p in self.params]) - def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> FuncStructInfo: + def as_ty(self, dict_globals: dict[str, Any] | None = None) -> FuncType: if self.ret is None: ret = None else: - ret = self.ret.as_struct_info(dict_globals) + ret = self.ret.as_ty(dict_globals) if self.params is None: params = None else: - params = [param.as_struct_info(dict_globals) for param in self.params] + params = [param.as_ty(dict_globals) for param in self.params] if params is None: - return FuncStructInfo.opaque_func( - ret=ret, derive_func=self.derive_func, purity=self.purity - ) + return FuncType.opaque_func(ret=ret, derive_func=self.derive_func, purity=self.purity) else: - return FuncStructInfo(params, ret, purity=self.purity) + return FuncType(params, ret, purity=self.purity) def Callable( - params: StructInfoProxy | list[StructInfoProxy] | None = None, - ret: StructInfoProxy | None = None, + params: TypeProxy | list[TypeProxy] | None = None, + ret: TypeProxy | None = None, purity: bool | None = None, derive_func: str | tvm.ir.EnvFunc | None = None, ) -> CallableProxy: @@ -370,19 +368,19 @@ def Callable( ############################### R.Tuple ################################ -class TupleProxy(StructInfoProxy): - fields: list[StructInfoProxy] +class TupleProxy(TypeProxy): + fields: list[TypeProxy] """The type of tuple values. Parameters ---------- - fields : List[StructInfoProxy] + fields : List[TypeProxy] The fields in the tuple """ def __init__( self, - *fields: list[StructInfoProxy], + *fields: list[TypeProxy], ) -> None: if len(fields) == 1 and isinstance(fields[0], tuple | list): fields = fields[0] @@ -392,19 +390,19 @@ def __init__( def get_symbolic_vars(self) -> set[str]: return set().union(*[f.get_symbolic_vars() for f in self.fields]) - def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> TupleStructInfo: - fields = [field.as_struct_info(dict_globals) for field in self.fields] - return TupleStructInfo(fields) + def as_ty(self, dict_globals: dict[str, Any] | None = None) -> TupleType: + fields = [field.as_ty(dict_globals) for field in self.fields] + return TupleType(fields) -def Tuple(*fields: list[StructInfoProxy]) -> TupleProxy: +def Tuple(*fields: list[TypeProxy]) -> TupleProxy: return TupleProxy(*fields) ############################### R.Shape ################################ -class ShapeProxy(StructInfoProxy): +class ShapeProxy(TypeProxy): values: list[PrimExpr] | None ndim: int """The type of shape values. @@ -432,9 +430,9 @@ def get_symbolic_vars(self) -> set[str]: else: return {v for v in self.values if isinstance(v, str) and v.isidentifier()} - def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> ShapeStructInfo: + def as_ty(self, dict_globals: dict[str, Any] | None = None) -> ShapeType: values = [_eval_shape(v, dict_globals) for v in self.values] if self.values else None - return ShapeStructInfo(values, self.ndim) + return ShapeType(values, self.ndim) def Shape(values: list[PrimExpr] | None = None, ndim: int = -1) -> ShapeProxy: @@ -444,7 +442,7 @@ def Shape(values: list[PrimExpr] | None = None, ndim: int = -1) -> ShapeProxy: ################################ R.Prim ################################ -class PrimProxy(StructInfoProxy): +class PrimProxy(TypeProxy): dtype: str | None value: int | float | str | PrimExpr | None @@ -478,12 +476,12 @@ def get_symbolic_vars(self) -> set[str]: else: return set() - def as_struct_info(self, dict_globals: dict[str, Any] | None = None) -> ShapeStructInfo: + def as_ty(self, dict_globals: dict[str, Any] | None = None) -> ShapeType: if self.value is None: - return PrimStructInfo(dtype=self.dtype) + return PrimType(dtype=self.dtype) else: value = _eval_shape(self.value, dict_globals) - return PrimStructInfo(dtype=self.dtype, value=value) + return PrimType(dtype=self.dtype, value=value) def Prim( @@ -496,37 +494,37 @@ def Prim( ############################ R.match_cast ############################# class MatchCastPair: value: Expr - struct_info: StructInfo + ty: Type - def __init__(self, value: Expr, struct_info: StructInfo) -> None: + def __init__(self, value: Expr, ty: Type) -> None: self.value = value - self.struct_info = struct_info + self.ty = ty -def match_cast(value: Expr, struct_info: StructInfo): - struct_info = _normalize_struct_info(struct_info) +def match_cast(value: Expr, ty: Type): + ty = _normalize_ty(ty) if value is None: raise ValueError("value of match_cast cannot be None") - if struct_info is None: - raise ValueError("struct_info of match_cast cannot be None") - return MatchCastPair(value, struct_info) + if ty is None: + raise ValueError("ty of match_cast cannot be None") + return MatchCastPair(value, ty) -def _normalize_struct_info_proxy(annotation) -> StructInfoProxy: +def _normalize_ty_proxy(annotation) -> TypeProxy: if annotation is None: return TupleProxy([]) elif callable(annotation): return annotation() - elif isinstance(annotation, StructInfoProxy): + elif isinstance(annotation, TypeProxy): return annotation else: - raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") + raise TypeError(f"Expected TypeProxy but got {type(annotation)}.") -def _normalize_struct_info(struct_info, dict_globals: dict[str, Any] | None = None) -> StructInfo: - if isinstance(struct_info, StructInfo): - return struct_info +def _normalize_ty(ty, dict_globals: dict[str, Any] | None = None) -> Type: + if isinstance(ty, Type): + return ty else: - proxy = _normalize_struct_info_proxy(struct_info) - return proxy.as_struct_info(dict_globals) + proxy = _normalize_ty_proxy(ty) + return proxy.as_ty(dict_globals) diff --git a/python/tvm/relax/script/parser/parser.py b/python/tvm/relax/script/parser/parser.py index 256b867a1021..45012f0f59c8 100644 --- a/python/tvm/relax/script/parser/parser.py +++ b/python/tvm/relax/script/parser/parser.py @@ -24,7 +24,7 @@ from tvm import relax, tirx from tvm.ir import GlobalVar -from tvm.relax import Expr, StructInfo +from tvm.relax import Expr, Type from tvm.relax.script import builder as R from tvm.relax.script.builder.frame import BindingBlockFrame from tvm.relax.utils import convert_to_expr @@ -34,9 +34,9 @@ from .entry import ( MatchCastPair, - StructInfoProxy, - _normalize_struct_info, - _normalize_struct_info_proxy, + TypeProxy, + _normalize_ty, + _normalize_ty_proxy, ) @@ -45,7 +45,7 @@ def bind_assign_value( node: doc.expr, var_name: str, value: Any, - anno_sinfo: StructInfo | None = None, + anno_ty: Type | None = None, ) -> Any: var_table = self.var_table.get() @@ -87,13 +87,13 @@ def bind_assign_value( value = R.const(value) if isinstance(value, relax.Expr): - var = R.emit(value, anno_sinfo) + var = R.emit(value, anno_ty) elif isinstance(value, MatchCastPair): - if anno_sinfo is not None and not tvm_ffi.structural_equal(anno_sinfo, value.struct_info): + if anno_ty is not None and not tvm_ffi.structural_equal(anno_ty, value.ty): self.report_error( node, "Cannot specify inconsistent annotation for a match cast pair. " ) - var = R.emit_match_cast(value.value, value.struct_info) + var = R.emit_match_cast(value.value, value.ty) else: return value # raise TypeError(f"Unsupported type {type(value)} in assignment") @@ -102,20 +102,20 @@ def bind_assign_value( return var -def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy: +def eval_ty_proxy(self: Parser, node: doc.expr) -> TypeProxy: try: annotation = self.eval_expr(node) - return _normalize_struct_info_proxy(annotation) + return _normalize_ty_proxy(annotation) except Exception as err: # pylint: disable=broad-except self.report_error(node, err) raise -def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo: +def eval_ty(self: Parser, node: doc.expr, eval_str: bool = False) -> Type: var_table = self.var_table.get() if eval_str else None try: - struct_info = self.eval_expr(node) - return _normalize_struct_info(struct_info, var_table) + ty = self.eval_expr(node) + return _normalize_ty(ty, var_table) except Exception as err: # pylint: disable=broad-except self.report_error(node, err) raise @@ -182,9 +182,9 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non for arg in node.args.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") - param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation) + param_ty_proxy = eval_ty_proxy(self, arg.annotation) - for var_name in param_sinfo_proxy.get_symbolic_vars(): + for var_name in param_ty_proxy.get_symbolic_vars(): if var_name not in symbolic_vars: symbolic_vars[var_name] = tirx.Var(var_name, "int64") @@ -206,17 +206,17 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: if not func_val and is_recursive(node): collect_symbolic_var_from_params(self, node) if node.returns is None: - ret_sinfo = relax.TupleStructInfo([]) + ret_ty = relax.TupleType([]) else: - ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) - params_sinfo = [] + ret_ty = eval_ty(self, node.returns, eval_str=True) + params_ty = [] for arg in node.args.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") - param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) - params_sinfo.append(param_sinfo) + param_ty = eval_ty(self, arg.annotation, eval_str=True) + params_ty.append(param_ty) # created a var for the local function, the same var could be used for recursive call - local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo)) + local_func_var = relax.Var(node.name, relax.FuncType(params_ty, ret_ty)) self.var_table.add(node.name, local_func_var) purity = find_decorator_annotation(node, "pure") @@ -231,8 +231,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: collect_symbolic_var_from_params(self, node) if node.returns is not None: - ann_sinfo = eval_struct_info(self, node.returns, eval_str=True) - R.func_ret_struct_info(ann_sinfo) + ann_ty = eval_ty(self, node.returns, eval_str=True) + R.func_ret_ty(ann_ty) self.visit(node.args) @@ -270,21 +270,21 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar collect_symbolic_var_from_params(self, node) if node.returns is None: - # Use ObjectStructInfo as unknown return type - # NOTE: Cannot use VoidStructInfo here because the return type can be refined later. - ret_sinfo = relax.ObjectStructInfo() + # Use ObjectType as unknown return type + # NOTE: Cannot use VoidType here because the return type can be refined later. + ret_ty = relax.ObjectType() else: - ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) + ret_ty = eval_ty(self, node.returns, eval_str=True) params = [] for arg in node.args.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") - param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) - params.append(relax.Var(arg.arg, param_sinfo)) + param_ty = eval_ty(self, arg.annotation, eval_str=True) + params.append(relax.Var(arg.arg, param_ty)) is_pure = find_decorator_annotation(node, "pure") - func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure) + func_signature = relax.Function.create_empty(params, ret_ty, is_pure=is_pure) return I.decl_function(node.name, func_signature) @@ -315,15 +315,13 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: if isinstance(value, relax.Expr): var = R.emit(value) IRBuilder.name("_", var) - is_void_value = ( - isinstance(var.struct_info, relax.TupleStructInfo) and len(var.struct_info.fields) == 0 - ) + is_void_value = isinstance(var.ty, relax.TupleType) and len(var.ty.fields) == 0 if not is_void_value: self.report_error( node, f"Non-void relax expressions must be bound to a variable, " - f"but expression of type {var.struct_info} was used as a statement.", + f"but expression of type {var.ty} was used as a statement.", ) elif value is not None: @@ -336,15 +334,15 @@ def visit_arguments(self: Parser, node: doc.arguments) -> None: for arg in node.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") - param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) - param = R.arg(arg.arg, param_sinfo) + param_ty = eval_ty(self, arg.annotation, eval_str=True) + param = R.arg(arg.arg, param_ty) self.var_table.add(arg.arg, param) @dispatch.register(token="relax", type_name="tvm_annotation") -def visit_tvm_annotation(self: Parser, node: doc.expr) -> StructInfo: - return eval_struct_info(self, node, eval_str=False) +def visit_tvm_annotation(self: Parser, node: doc.expr) -> Type: + return eval_ty(self, node, eval_str=False) @dispatch.register(token="relax", type_name="With") @@ -386,11 +384,11 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: lhs = node.target rhs = self.eval_expr(node.value) - anno_sinfo = self.visit_tvm_annotation(node.annotation) + anno_ty = self.visit_tvm_annotation(node.annotation) self.eval_assign( target=lhs, source=rhs, - bind_value=functools.partial(bind_assign_value, anno_sinfo=anno_sinfo), + bind_value=functools.partial(bind_assign_value, anno_ty=anno_ty), allow_shadowing=True, ) diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 4aa930ba6e1e..eb20d8ab5d8f 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -46,11 +46,11 @@ class ASTPrinter(ExprFunctor): def __init__( self, indent_str=" ", - include_struct_info_annotations=True, + include_ty_annotations=True, include_call_attrs=True, ): self.indent_str = indent_str - self.include_struct_info_annotations = include_struct_info_annotations + self.include_ty_annotations = include_ty_annotations self.include_call_attrs = include_call_attrs def visit_expr(self, expr: relax.Expr) -> str: @@ -88,11 +88,11 @@ def build_ast_node(self, nodename: str, force_newline=False, **kwargs: str) -> s def build_expr(self, node: relax.Expr, nodename: str, force_newline=False, **kwargs: str): """ Renders a Relax expression as a string using `build_ast_node`. - Handles whether to include the struct_info fields. + Handles whether to include the ty fields. """ fields = kwargs.copy() - if node.struct_info_ and self.include_struct_info_annotations: - fields["struct_info"] = self.visit_struct_info_(node.struct_info) + if node.ty and self.include_ty_annotations: + fields["ty"] = self.visit_ty_(node.ty) return self.build_ast_node(nodename, force_newline=force_newline, **fields) def build_list( @@ -134,7 +134,7 @@ def visit_shape_expr_(self, op: relax.ShapeExpr) -> str: def visit_extern_func_(self, op: relax.ExternFunc) -> str: # ExternFunc does not inherit from relax.Expr either, - # so it doesn't have struct_info fields and we don't use build_expr + # so it doesn't have ty fields and we don't use build_expr return self.build_ast_node("ExternFunc", global_symbol=wrap_quotes(op.global_symbol)) def visit_global_var_(self, op: relax.GlobalVar) -> str: @@ -144,7 +144,7 @@ def visit_function_(self, op: relax.Function) -> str: fields = { "params": self.build_list(map(self.visit_expr, op.params)), "body": self.visit_expr(op.body), - "ret_struct_info": self.visit_struct_info_(op.ret_struct_info), + "ret_ty": self.visit_ty_(op.ret_ty), "is_pure": op.is_pure, } if op.attrs: @@ -163,8 +163,8 @@ def visit_call_(self, op: relax.Call) -> str: "op": self.visit_expr(op.op), "args": self.build_list(map(self.visit_expr, op.args)), } - if op.sinfo_args: - fields["sinfo_args"] = self.build_list(map(self.visit_struct_info_, op.sinfo_args)) + if op.ty_args: + fields["ty_args"] = self.build_list(map(self.visit_ty_, op.ty_args)) if op.attrs and self.include_call_attrs: def display_attrs(attr_key): @@ -219,7 +219,7 @@ def visit_data_type_imm_(self, op: relax.DataTypeImm) -> str: def visit_op_(self, op: tvm.ir.Op) -> str: # TODO: List other attributes? # op is not actually a Relax expr and does not have - # struct_info fields, so we don't use build_expr here + # ty fields, so we don't use build_expr here return self.build_ast_node("Op", name=wrap_quotes(op.name)) def visit_prim_expr_(self, prim_expr: PrimExpr) -> str: @@ -258,55 +258,53 @@ def visit_type_(self, type_node: relax.Type) -> str: "TupleType", fields=self.build_list(map(self.visit_type_, type_node.fields)) ) if isinstance(type_node, relax.FuncType): + fields = {} + if type_node.params is not None: + fields["params"] = self.build_list(map(self.visit_type_, type_node.params)) + fields["ret"] = self.visit_type_(type_node.ret) + fields["purity"] = bool(type_node.purity) return self.build_ast_node( "FuncType", - arg_types=self.build_list(map(self.visit_type_, type_node.arg_types)), - ret_type=self.visit_type_(type_node.ret_type), + **fields, ) raise ValueError(f"Invalid Relax Type {type_node} ({type(type_node)})") - def visit_struct_info_(self, struct_info_node: relax.StructInfo) -> str: + def visit_ty_(self, ty_node: relax.Type) -> str: """ - Recurse down struct info and print their ASTs too + Recurse down type and print their ASTs too """ - if isinstance(struct_info_node, relax.ShapeStructInfo): + if isinstance(ty_node, relax.ShapeType): fields = {} - fields["ndim"] = str(struct_info_node.ndim) - if struct_info_node.values is not None: - fields["values"] = self.build_list( - map(self.visit_prim_expr_, struct_info_node.values) - ) - return self.build_ast_node("ShapeStructInfo", **fields) - elif isinstance(struct_info_node, relax.ObjectStructInfo): - return self.build_ast_node("ObjectStructInfo") - elif isinstance(struct_info_node, relax.PrimStructInfo): - return self.build_ast_node("PrimStructInfo", dtype=struct_info_node.dtype) - elif isinstance(struct_info_node, relax.TensorStructInfo): + fields["ndim"] = str(ty_node.ndim) + if ty_node.values is not None: + fields["values"] = self.build_list(map(self.visit_prim_expr_, ty_node.values)) + return self.build_ast_node("ShapeType", **fields) + elif isinstance(ty_node, relax.ObjectType): + return self.build_ast_node("ObjectType") + elif isinstance(ty_node, relax.PrimType): + return self.build_ast_node("PrimType", dtype=ty_node.dtype) + elif isinstance(ty_node, relax.TensorType): fields = {} - fields["dtype"] = struct_info_node.dtype - if struct_info_node.shape: - fields["shape"] = self.visit_expr(struct_info_node.shape) + fields["dtype"] = ty_node.dtype + if ty_node.shape: + fields["shape"] = self.visit_expr(ty_node.shape) else: - fields["ndim"] = str(struct_info_node.ndim) - return self.build_ast_node("TensorStructInfo", **fields) - elif isinstance(struct_info_node, relax.TupleStructInfo): + fields["ndim"] = str(ty_node.ndim) + return self.build_ast_node("TensorType", **fields) + elif isinstance(ty_node, relax.TupleType): return self.build_ast_node( - "TupleStructInfo", - fields=self.build_list(map(self.visit_struct_info_, struct_info_node.fields)), + "TupleType", + fields=self.build_list(map(self.visit_ty_, ty_node.fields)), ) - elif isinstance(struct_info_node, relax.FuncStructInfo): + elif isinstance(ty_node, relax.FuncType): fields = {} - if struct_info_node.params is not None: - fields["params"] = self.build_list( - map(self.visit_struct_info_, struct_info_node.params) - ) - fields["ret"] = self.visit_struct_info_(struct_info_node.ret) - fields["purity"] = bool(struct_info_node.purity) - return self.build_ast_node("FuncStructInfo", **fields) + if ty_node.params is not None: + fields["params"] = self.build_list(map(self.visit_ty_, ty_node.params)) + fields["ret"] = self.visit_ty_(ty_node.ret) + fields["purity"] = bool(ty_node.purity) + return self.build_ast_node("FuncType", **fields) else: - raise ValueError( - f"Invalid Relax StructInfo {struct_info_node} ({type(struct_info_node)})" - ) + raise ValueError(f"Invalid Relax Type {ty_node} ({type(ty_node)})") def visit_binding_block_(self, block: relax.BindingBlock) -> str: """ @@ -343,7 +341,7 @@ def visit_match_cast_(self, match_cast: relax.MatchCast) -> str: fields = { "var": self.visit_expr(match_cast.var), "value": self.visit_expr(match_cast.value), - "struct_info": self.visit_struct_info_(match_cast.struct_info), + "ty": self.visit_ty_(match_cast.ty), } return self.build_ast_node("MatchCast", **fields) @@ -361,7 +359,7 @@ def visit_var_binding_(self, var_binding: relax.VarBinding) -> str: def dump_ast( exp: relax.Expr, indent_str=" ", - include_struct_info_annotations=True, + include_ty_annotations=True, include_call_attrs=True, ) -> str: """ @@ -371,7 +369,7 @@ def dump_ast( """ printer = ASTPrinter( indent_str=indent_str, - include_struct_info_annotations=include_struct_info_annotations, + include_ty_annotations=include_ty_annotations, include_call_attrs=include_call_attrs, ) return printer.visit_expr(exp) diff --git a/python/tvm/relax/testing/attention.py b/python/tvm/relax/testing/attention.py index a4157cabb6e1..449902425dd0 100644 --- a/python/tvm/relax/testing/attention.py +++ b/python/tvm/relax/testing/attention.py @@ -126,7 +126,7 @@ def get_relax_stacked_attention_module( qkv, [split_axis], [split_sections[1]], - [int(qkv.struct_info.shape[split_axis])], + [int(qkv.ty.shape[split_axis])], [1], ) else: diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py index 97a3b83eb77d..d2297f376663 100644 --- a/python/tvm/relax/testing/nn.py +++ b/python/tvm/relax/testing/nn.py @@ -196,7 +196,7 @@ class Placeholder(relax.Var): def __init__(self, shape: list[Any] | tuple[Any, ...], dtype="float32", name="data"): if not isinstance(shape, list | tuple): raise TypeError("the shape of Placeholder is expected to be a list or a tuple") - super().__init__(_try_unique_name(name), relax.TensorStructInfo(shape, dtype)) + super().__init__(_try_unique_name(name), relax.TensorType(shape, dtype)) class Parameter(relax.Var): @@ -206,7 +206,7 @@ def __init__(self, shape: list[Any] | tuple[Any, ...], dtype="float32", name="pa if not isinstance(shape, list | tuple): raise TypeError("the shape of Parameter is expected to be a list or a tuple") - super().__init__(_try_unique_name(name), relax.TensorStructInfo(shape, dtype)) + super().__init__(_try_unique_name(name), relax.TensorType(shape, dtype)) class Module(tvm.relax.frontend.nn.SubroutineMixin): @@ -279,7 +279,7 @@ def _unpack_params(value: object) -> list[relax.Var]: def init_params(mod: tvm.IRModule) -> list[tvm.runtime.Tensor]: """Utility function to initialize model's parameters.""" - shape_dict = {v.name_hint: v.struct_info.shape for v in mod["main"].params} + shape_dict = {v.name_hint: v.ty.shape for v in mod["main"].params} params = [] for k, v in shape_dict.items(): if k.startswith("data"): diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index ba9f7b5ed0b8..3a77500110a4 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -24,20 +24,20 @@ # isort: on from ..block_builder import BlockBuilder -from ..expr import Expr, Function, StructInfo, Var +from ..expr import Expr, Function, Type, Var from ..op import abs, argmax, mean, multiply, reshape, subtract, sum from ..op.nn import log_softmax, nll_loss -def _create_param_var(param: Var | StructInfo, param_name: str) -> Var: - """If param is a StructInfo, create a Var with the given StructInfo and name. +def _create_param_var(param: Var | Type, param_name: str) -> Var: + """If param is a Type, create a Var with the given Type and name. - If param is a Var, create a Var with the same StructInfo and name as the given param Var.""" - if isinstance(param, StructInfo): + If param is a Var, create a Var with the same Type and name as the given param Var.""" + if isinstance(param, Type): param = Var(param_name, param) if not isinstance(param, Var): - raise TypeError("The type of param should be Var or StructInfo, but got " + type(param)) - return Var(param.name_hint, param.struct_info) + raise TypeError("The type of param should be Var or Type, but got " + type(param)) + return Var(param.name_hint, param.ty) class Loss: @@ -138,17 +138,17 @@ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None: def __call__( self, - predictions: Var | StructInfo, - targets: Var | StructInfo, + predictions: Var | Type, + targets: Var | Type, ) -> Function: """Get the relax function of L1Loss. If the parameters are - struct info, it will create corresponding variables. + type, it will create corresponding variables. Parameters ---------- - predictions : Union[Var, StructInfo] + predictions : Union[Var, Type] The predictions of the model in the calculation of loss. - targets : Union[Var, StructInfo] + targets : Union[Var, Type] The ground truth in the calculation of loss. Returns @@ -187,17 +187,17 @@ def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None: def __call__( self, - predictions: Var | StructInfo, - targets: Var | StructInfo, + predictions: Var | Type, + targets: Var | Type, ) -> Function: """Get the relax function of MSELoss. If the parameters are - struct info, it will create corresponding variables. + type, it will create corresponding variables. Parameters ---------- - predictions : Union[Var, StructInfo] + predictions : Union[Var, Type] The predictions of the model in the calculation of loss. - targets : Union[Var, StructInfo] + targets : Union[Var, Type] The ground truth in the calculation of loss. Returns @@ -247,22 +247,22 @@ def __init__( def __call__( self, - predictions: Var | StructInfo, - targets: Var | StructInfo, - weights: Var | StructInfo | None = None, + predictions: Var | Type, + targets: Var | Type, + weights: Var | Type | None = None, ) -> Function: """Get the relax function of CrossEntropyLoss. If the parameters are - struct info, it will create corresponding variables. + type, it will create corresponding variables. Parameters ---------- - predictions : Union[Var, StructInfo] + predictions : Union[Var, Type] The predictions of the model in the calculation of loss. - targets : Union[Var, StructInfo] + targets : Union[Var, Type] The ground truth in the calculation of loss. - weights : Optional[Union[Var, StructInfo]] + weights : Optional[Union[Var, Type]] a manual rescaling weight given to each class. It has to be a Tensor of size C. Returns @@ -320,22 +320,22 @@ def __init__( def __call__( self, - predictions: Var | StructInfo, - targets: Var | StructInfo, - weights: Var | StructInfo | None = None, + predictions: Var | Type, + targets: Var | Type, + weights: Var | Type | None = None, ) -> Function: """Get the relax function of CategoricalCrossEntropyLoss. If the parameters are - struct info, it will create corresponding variables. + type, it will create corresponding variables. Parameters ---------- - predictions : Union[Var, StructInfo] + predictions : Union[Var, Type] The predictions of the model in the calculation of loss. - targets : Union[Var, StructInfo] + targets : Union[Var, Type] The ground truth in the calculation of loss. - weights : Optional[Union[Var, StructInfo]] + weights : Optional[Union[Var, Type]] a manual rescaling weight given to each class. It has to be a Tensor of size C. Returns @@ -367,7 +367,7 @@ def __call__( logits = bb.emit(log_softmax(predictions)) if self.ignore_index >= 0: targets = bb.emit( - reshape(argmax(targets, axis=1), shape=(targets.struct_info.shape[0],)) + reshape(argmax(targets, axis=1), shape=(targets.ty.shape[0],)) ) loss = bb.emit_output( nll_loss(logits, targets, weights, self._reduction, self.ignore_index) diff --git a/python/tvm/relax/training/optimizer.py b/python/tvm/relax/training/optimizer.py index 654317568572..db612a4277cc 100644 --- a/python/tvm/relax/training/optimizer.py +++ b/python/tvm/relax/training/optimizer.py @@ -29,7 +29,7 @@ from ..expr import Function, TupleGetItem, Var, const from ..expr import Tuple as RxTuple from ..op import add, divide, multiply, sqrt, subtract -from ..struct_info import TensorStructInfo, TupleStructInfo +from ..type import TensorType, TupleType # TODO(chaofan, yixin): Migrate key logics to C++ @@ -144,24 +144,24 @@ def _set_params_and_dtype(self, params: list[Var]) -> None: for x in params: if not isinstance(x, Var): raise ValueError(f"Parameter {x} is not a Var") - if not isinstance(x.struct_info, TensorStructInfo): + if not isinstance(x.ty, TensorType): raise ValueError( f"Optimizers only support Tensor parameters, but parameter {x.name_hint} has " - f"struct info {x.struct_info}" + f"type {x.ty}" ) - data_type = tvm.DataType(x.struct_info.dtype) + data_type = tvm.DataType(x.ty.dtype) if data_type.type_code not in (tvm.DataTypeCode.BFLOAT, tvm.DataTypeCode.FLOAT): raise ValueError( f"Optimizers only support Tensor parameters of floating point dtype, but dtype " - f"of {x.name_hint} is {x.struct_info.dtype}" + f"of {x.name_hint} is {x.ty.dtype}" ) if dtype is None: - dtype = x.struct_info.dtype + dtype = x.ty.dtype else: - if dtype != x.struct_info.dtype: + if dtype != x.ty.dtype: raise ValueError( f"All parameters should have the same dtype, but parameter {x.name_hint} " - f"has dtype {x.struct_info.dtype}, which differs from the previous dtype " + f"has dtype {x.ty.dtype}, which differs from the previous dtype " f"{dtype}" ) if x in params_set: @@ -231,7 +231,7 @@ def SGD( # TODO(chaofan, yixin): Support symbolic shapes def _get_shape_as_int_list(var: Var) -> list[int]: - return [int(val) for val in var.struct_info.shape] + return [int(val) for val in var.ty.shape] # We need to subtract on hyperparameters, but do not want to introduce floating point error. @@ -317,9 +317,9 @@ def get_function(self) -> Function: dtype = self.dtype # input variables - param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) - state_var = Var("optim_states", TupleStructInfo([TensorStructInfo((), "int64")])) + param_var = Var("params", TupleType([p.ty for p in plist])) + grad_var = Var("gradients", TupleType([p.ty for p in plist])) + state_var = Var("optim_states", TupleType([TensorType((), "int64")])) # constants lr = const(self.lr, dtype) @@ -443,7 +443,7 @@ def init(self, params: Var | list[Var]) -> "MomentumSGD": tvm.runtime.tensor(np.zeros((), "int64")), # v_{param} is initialized to all zeros *( - tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) + tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.ty.dtype)) for p in self.param_list ), ) @@ -464,11 +464,11 @@ def get_function(self) -> Function: dtype = self.dtype # input variables - param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) + param_var = Var("params", TupleType([p.ty for p in plist])) + grad_var = Var("gradients", TupleType([p.ty for p in plist])) state_var = Var( "optim_states", - TupleStructInfo([TensorStructInfo((), "int64"), *(p.struct_info for p in plist)]), + TupleType([TensorType((), "int64"), *(p.ty for p in plist)]), ) # constants @@ -618,12 +618,12 @@ def init(self, params: Var | list[Var]) -> "Adam": tvm.runtime.tensor(np.ones((), self.dtype)), # first_momentum *( - tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) + tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.ty.dtype)) for p in self.param_list ), # second_momentum *( - tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) + tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.ty.dtype)) for p in self.param_list ), ) @@ -644,17 +644,17 @@ def get_function(self) -> Function: dtype = self.dtype # input variables - param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) + param_var = Var("params", TupleType([p.ty for p in plist])) + grad_var = Var("gradients", TupleType([p.ty for p in plist])) state_var = Var( "optim_states", - TupleStructInfo( + TupleType( [ - TensorStructInfo((), "int64"), - TensorStructInfo((), dtype), - TensorStructInfo((), dtype), - *(p.struct_info for p in plist), - *(p.struct_info for p in plist), + TensorType((), "int64"), + TensorType((), dtype), + TensorType((), dtype), + *(p.ty for p in plist), + *(p.ty for p in plist), ] ), ) diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index 03ae52a61fbb..bcbea9640382 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -23,9 +23,9 @@ from ..analysis import check_well_formed from ..expr import Tuple -from ..struct_info import TensorStructInfo from ..training.utils import AppendLoss from ..transform import DecomposeOpsForInference, DecomposeOpsForTraining, Gradient, LegalizeOps +from ..type import TensorType from .loss import Loss from .optimizer import Optimizer @@ -103,7 +103,7 @@ def optimizer(params, gradients, optim_states): optimizer : Optimizer The optimizer. It will be put as the `optimizer` function of the transformed module. - loss_args : List[TensorStructInfo] + loss_args : List[TensorType] The arguments to call the loss function. legalize : bool @@ -119,7 +119,7 @@ def optimizer(params, gradients, optim_states): STATE_NUM_ATTR_KEY: str = "state_num" def __init__( - self, loss: Loss, optimizer: Optimizer, loss_args: list[TensorStructInfo], legalize=True + self, loss: Loss, optimizer: Optimizer, loss_args: list[TensorType], legalize=True ): self._loss = loss self._optimizer = optimizer diff --git a/python/tvm/relax/training/trainer.py b/python/tvm/relax/training/trainer.py index d7bdea3e1f5d..91fd724f9564 100644 --- a/python/tvm/relax/training/trainer.py +++ b/python/tvm/relax/training/trainer.py @@ -55,7 +55,7 @@ class Trainer: setup_trainer = SetupTrainer( MSELoss(reduction="sum"), SGD(0.001), - [pred_sinfo, target_sinfo], + [pred_ty, target_ty], ) train_mod = setup_trainer(Backbone) ex = tvm.compile(train_mod, target) @@ -116,7 +116,7 @@ def __init__( @staticmethod def _get_shape_list(expr): - return [int(dim) for dim in expr.struct_info.shape] + return [int(dim) for dim in expr.ty.shape] def xaiver_uniform_init_params(self): """Xaiver uniformly initialize parameters using the method described in `Understanding the @@ -127,7 +127,7 @@ def xaiver_uniform_init_params(self): """ self._params = [] for p in self._param_vars: - shape, dtype = self._get_shape_list(p), p.struct_info.dtype + shape, dtype = self._get_shape_list(p), p.ty.dtype self._params.append( tvm.runtime.tensor( (np.sqrt(6.0 / np.sum(shape)) * np.random.uniform(-1.0, 1.0, shape)).astype( @@ -140,14 +140,14 @@ def xaiver_uniform_init_params(self): def zero_init_params(self): """Zero initialize all parameters. Requires all parameters have static shapes.""" self._params = [ - tvm.runtime.tensor(np.zeros(self._get_shape_list(p), p.struct_info.dtype), self.device) + tvm.runtime.tensor(np.zeros(self._get_shape_list(p), p.ty.dtype), self.device) for p in self._param_vars ] def zero_init_states(self): """Zero initialize all states. Requires all states have static shapes.""" self._states = [ - tvm.runtime.tensor(np.zeros(self._get_shape_list(s), s.struct_info.dtype), self.device) + tvm.runtime.tensor(np.zeros(self._get_shape_list(s), s.ty.dtype), self.device) for s in self._state_vars ] diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index c3188adf5027..1ee65452e7d7 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -79,7 +79,7 @@ ToMixedPrecision, ToNonDataflow, TopologicalSort, - UpdateParamStructInfo, + UpdateParamType, UpdateVDevice, VMBuiltinLower, VMShapeLower, diff --git a/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py b/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py index 2aaebe527efb..18d75b66bf4d 100644 --- a/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py +++ b/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py @@ -87,13 +87,13 @@ def rewriter(expr, matches): wt = relax.op.divide(bn_weight, dino) bs = relax.op.subtract(bn_bias, relax.op.multiply(bn_mean, wt)) if conv_attrs["kernel_layout"] == "OIHW": - wt = relax.op.reshape(wt, shape=(bn_weight.struct_info.shape[0], 1, 1, 1)) + wt = relax.op.reshape(wt, shape=(bn_weight.ty.shape[0], 1, 1, 1)) elif conv_attrs["kernel_layout"] == "IOHW": - wt = wt.reshape(1, bn_weight.struct_info.shape[0], 1, 1) + wt = wt.reshape(1, bn_weight.ty.shape[0], 1, 1) else: return expr wt_conv = relax.op.multiply(conv_weight, wt) - bs_args = relax.op.reshape(bs, shape=(1, bn_bias.struct_info.shape[0], 1, 1)) + bs_args = relax.op.reshape(bs, shape=(1, bn_bias.ty.shape[0], 1, 1)) conv_out = relax.Call(conv_op.op, (conv_input, wt_conv), conv_attrs) return relax.op.add(conv_out, bs_args) diff --git a/python/tvm/relax/transform/fuse_transpose_matmul.py b/python/tvm/relax/transform/fuse_transpose_matmul.py index ecefa876120f..dfaf8331399e 100644 --- a/python/tvm/relax/transform/fuse_transpose_matmul.py +++ b/python/tvm/relax/transform/fuse_transpose_matmul.py @@ -64,7 +64,7 @@ def _pattern(): def _check(context: relax.transform.PatternCheckContext) -> bool: transpose_call = context.annotated_expr["wT"] - ndim = transpose_call.args[0].struct_info.ndim + ndim = transpose_call.args[0].ty.ndim if ndim == -1: return False if ndim == 2 and transpose_call.attrs.axes is None: @@ -106,13 +106,11 @@ def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: is_a_larger = len(a_shape) > len(b_shape) offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) - a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + a_relax = relax.Var("a", relax.TensorType(a.shape)) bT_shape = list(b.shape) bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1] - bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) - output_shape = self.builder_.normalize( - relax.op.matmul(a_relax, bT_relax) - ).struct_info.shape + bT_relax = relax.Var("b", relax.TensorType(bT_shape)) + output_shape = self.builder_.normalize(relax.op.matmul(a_relax, bT_relax)).ty.shape def matmul_compute(*idx_spatial): k = te.reduce_axis((0, a_shape[-1]), name="k") @@ -165,7 +163,7 @@ def multiply_compute(idx_reduce): "Composite" in function.attrs and function.attrs["Composite"] == "transpose_matmul_fuse" ): - out_dtype = function.ret_struct_info.dtype + out_dtype = function.ret_ty.dtype return self.builder_.call_te( te_transposed_matmul, call.args[1], diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index b254519e110e..432426bf74ad 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -177,12 +177,12 @@ def transform(self, func: relax.Function) -> relax.Function: if leaf_outputs: new_bindings = [ relax.VarBinding( - relax.Var("_", relax.ObjectStructInfo()), + relax.Var("_", relax.ObjectType()), relax.Call( relax.ExternFunc(self.fset_item), [*self.extra_set_item_params, index, expr], None, - [relax.ObjectStructInfo()], + [relax.ObjectType()], ), ) for expr, indices in leaf_outputs.items() @@ -206,23 +206,23 @@ def transform(self, func: relax.Function) -> relax.Function: symbolic_vars = relax.analysis.defined_symbolic_vars(func) if symbolic_vars: - def unpack_sinfo(sinfo): - if isinstance(sinfo, relax.TupleStructInfo): - for field in sinfo.fields: - yield from unpack_sinfo(field) + def unpack_ty(ty): + if isinstance(ty, relax.TupleType): + for field in ty.fields: + yield from unpack_ty(field) else: - yield sinfo + yield ty - # direct iterate over the struct info annotation + # direct iterate over the type annotation for param in func.params[num_input:]: - for sinfo in unpack_sinfo(param.struct_info): - if isinstance(sinfo, relax.PrimStructInfo | relax.ShapeStructInfo): - params.append(relax.Var("symbolic_var_holder", sinfo)) + for ty in unpack_ty(param.ty): + if isinstance(ty, relax.PrimType | relax.ShapeType): + params.append(relax.Var("symbolic_var_holder", ty)) return relax.Function( params, new_body, - relax.ObjectStructInfo(), + relax.ObjectType(), attrs=func.attrs, is_pure=False, ).without_attr("relax.force_pure") @@ -241,7 +241,7 @@ def visit_function_(self, func: relax.Function) -> relax.Expr: num_input = 0 params = list(func.params)[num_input:] - if len(params) == 1 and isinstance(params[0].struct_info_, relax.TupleStructInfo): + if len(params) == 1 and isinstance(params[0].ty, relax.TupleType): self.tuple_param = params[0] self.params = {} else: @@ -250,7 +250,7 @@ def visit_function_(self, func: relax.Function) -> relax.Expr: func = relax.Function( func.params[:num_input], func.body, - func.ret_struct_info, + func.ret_ty, is_pure=False, attrs=func.attrs, span=func.span, @@ -268,10 +268,10 @@ def visit_var_(self, var: relax.Var) -> relax.Expr: relax.ExternFunc(self.func_creator.fget_item), self.func_creator.extra_get_item_params + [relax.PrimValue(index)], None, - [relax.ObjectStructInfo()], + [relax.ObjectType()], ) ) - match_cast = relax.MatchCast(var, get_item_result, var.struct_info) + match_cast = relax.MatchCast(var, get_item_result, var.ty) self.builder_.emit_normalized(match_cast) del self.params[var] @@ -279,7 +279,7 @@ def visit_var_(self, var: relax.Var) -> relax.Expr: return super().visit_var_(var) def visit_tuple_getitem_(self, node: relax.TupleGetItem) -> relax.Expr: - sinfo = node.struct_info + ty = node.ty node = super().visit_tuple_getitem_(node) @@ -289,10 +289,10 @@ def visit_tuple_getitem_(self, node: relax.TupleGetItem) -> relax.Expr: relax.ExternFunc(self.func_creator.fget_item), self.func_creator.extra_get_item_params + [relax.PrimValue(node.index)], None, - [relax.ObjectStructInfo()], + [relax.ObjectType()], ) ) - return self.builder_.match_cast(get_item_result, sinfo) + return self.builder_.match_cast(get_item_result, ty) else: return node @@ -329,7 +329,7 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None: self.func_creator.extra_set_item_params + [index, super().visit_var_(var)], None, - [relax.ObjectStructInfo()], + [relax.ObjectType()], ), name_hint="_", ) diff --git a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py index 391e74a9086b..8e7270383827 100644 --- a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py +++ b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py @@ -31,6 +31,6 @@ def conv2d_NCHWc_OIHWo(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: layout=call.attrs.data_layout, out_layout=call.attrs.out_layout, # out_dtype=call.attrs.out_dtype, - sinfo_args=call.sinfo_args, + ty_args=call.ty_args, primfunc_name_hint="conv2d_NCHWc_OIHWo", ) diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py index 104d155bf1c4..5976943090ca 100644 --- a/python/tvm/relax/transform/legalize_ops/ccl.py +++ b/python/tvm/relax/transform/legalize_ops/ccl.py @@ -23,7 +23,7 @@ from ...block_builder import BlockBuilder from ...expr import Call, Expr, ShapeExpr from ...op import call_dps_packed -from ...struct_info import ShapeStructInfo, TensorStructInfo +from ...type import ShapeType, TensorType from .common import register_legalize @@ -45,19 +45,17 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr: return call_dps_packed( "runtime.disco.allreduce", [call.args[0], ShapeExpr([op_type_map[op_type_str]]), call.attrs.in_group], - out_sinfo=call.args[0].struct_info, + out_ty=call.args[0].ty, ) @register_legalize("relax.ccl.allgather") def _allgather(_bb: BlockBuilder, call: Call) -> Expr: output_shape = [] - arg_sinfo = call.args[0].struct_info - assert isinstance(arg_sinfo, TensorStructInfo), ( - "The input struct info of allgather should be TensorStructInfo." - ) - assert isinstance(arg_sinfo.shape.struct_info, ShapeStructInfo) - arg_shape = arg_sinfo.shape.struct_info + arg_ty = call.args[0].ty + assert isinstance(arg_ty, TensorType), "The input type of allgather should be TensorType." + assert isinstance(arg_ty.shape.ty, ShapeType) + arg_shape = arg_ty.shape.ty for i, shape_value in enumerate(arg_shape.values): if i == 0: output_shape.append(shape_value * call.attrs.num_workers) @@ -66,10 +64,10 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr: return call_dps_packed( "runtime.disco.allgather", [call.args[0], call.attrs.in_group], - out_sinfo=TensorStructInfo( + out_ty=TensorType( shape=output_shape, - dtype=arg_sinfo.dtype, - vdevice=arg_sinfo.vdevice, + dtype=arg_ty.dtype, + vdevice=arg_ty.vdevice, ), ) @@ -79,18 +77,16 @@ def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr: return call_dps_packed( "runtime.disco.broadcast_from_worker0", [call.args[0], False], - out_sinfo=call.args[0].struct_info, + out_ty=call.args[0].ty, ) # Since collective communication ops are performed on contiguous memory, # we need to reshape and transpose the input tensor to make sharding dimension in the highest order def _transpose_for_ccl(_bb: BlockBuilder, expr: Expr, axis: int, num_workers: int): - assert isinstance(expr.struct_info, TensorStructInfo), ( - "The input struct info should be TensorStructInfo." - ) - assert isinstance(expr.struct_info.shape.struct_info, ShapeStructInfo) - arg_shape = expr.struct_info.shape.struct_info + assert isinstance(expr.ty, TensorType), "The input type should be TensorType." + assert isinstance(expr.ty.shape.ty, ShapeType) + arg_shape = expr.ty.shape.ty new_shape = [] for i, shape_value in enumerate(arg_shape.values): if i == axis: @@ -115,14 +111,14 @@ def _transpose_for_ccl(_bb: BlockBuilder, expr: Expr, axis: int, num_workers: in @register_legalize("relax.ccl.scatter_from_worker0") def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr: transpose_var = _transpose_for_ccl(_bb, call.args[0], call.attrs.axis, call.attrs.num_workers) - output_shape = transpose_var.struct_info.shape.struct_info.values + output_shape = transpose_var.ty.shape.ty.values output_shape = output_shape[1:] return call_dps_packed( "runtime.disco.scatter_from_worker0", [transpose_var, False], - out_sinfo=TensorStructInfo( + out_ty=TensorType( shape=output_shape, - dtype=call.args[0].struct_info.dtype, - vdevice=call.args[0].struct_info.vdevice, + dtype=call.args[0].ty.dtype, + vdevice=call.args[0].ty.vdevice, ), ) diff --git a/python/tvm/relax/transform/legalize_ops/common.py b/python/tvm/relax/transform/legalize_ops/common.py index 6cb50d70ab74..1b7d1179a521 100644 --- a/python/tvm/relax/transform/legalize_ops/common.py +++ b/python/tvm/relax/transform/legalize_ops/common.py @@ -65,10 +65,10 @@ def _try_convert_to_scalar_const( if the python native flag is True. Or return the input itself if it is not a scalar constant. """ - if isinstance(expr, Constant) and expr.struct_info.ndim == 0: + if isinstance(expr, Constant) and expr.ty.ndim == 0: # get the value of the scalar constant value = expr.data.numpy()[()].item() - dtype = expr.struct_info.dtype + dtype = expr.ty.dtype if python_native: return value # preserve the data type of the constant diff --git a/python/tvm/relax/transform/legalize_ops/create.py b/python/tvm/relax/transform/legalize_ops/create.py index 6708859caaff..3cab08cf26e3 100644 --- a/python/tvm/relax/transform/legalize_ops/create.py +++ b/python/tvm/relax/transform/legalize_ops/create.py @@ -24,7 +24,7 @@ from ...block_builder import BlockBuilder from ...expr import Call, Expr, PrimValue, ShapeExpr, const -from ...struct_info import ShapeStructInfo +from ...type import ShapeType from .common import LegalizeFunc, _try_convert_to_scalar_const, register_legalize @@ -35,22 +35,22 @@ def full_call_te(bb: BlockBuilder, call: Call) -> Expr: if fill_value is None else fill_value ) - shape = call.args[0].struct_info.shape if is_like else call.args[0] + shape = call.args[0].ty.shape if is_like else call.args[0] if isinstance(shape, ShapeExpr): output_shape = shape.values else: - assert isinstance(shape.struct_info, ShapeStructInfo) - assert shape.struct_info.ndim >= 0 + assert isinstance(shape.ty, ShapeType) + assert shape.ty.ndim >= 0 shape = bb.emit(shape) - output_shape = [tirx.Var(f"s{i}", "int64") for i in range(shape.struct_info.ndim)] - bb.match_cast(shape, ShapeStructInfo(output_shape)) + output_shape = [tirx.Var(f"s{i}", "int64") for i in range(shape.ty.ndim)] + bb.match_cast(shape, ShapeType(output_shape)) return bb.call_te( topi.full, output_shape, - call.struct_info.dtype, + call.ty.dtype, _fill_value, primfunc_name_hint=primfunc_name, ) @@ -88,8 +88,8 @@ def eye_call_te(bb: BlockBuilder, call: Call) -> Expr: if is_like: x = call.args[0] k = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else 0 - n, m = x.struct_info.shape - dtype = x.struct_info.dtype + n, m = x.ty.shape + dtype = x.ty.dtype else: n = _convert_to_scalar_const(call.args[0]) m = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else n diff --git a/python/tvm/relax/transform/legalize_ops/distributed.py b/python/tvm/relax/transform/legalize_ops/distributed.py index acd6bd4a4514..c20dc09a70d4 100644 --- a/python/tvm/relax/transform/legalize_ops/distributed.py +++ b/python/tvm/relax/transform/legalize_ops/distributed.py @@ -22,7 +22,7 @@ from ...block_builder import BlockBuilder from ...expr import Call, Expr from ...op import call_pure_packed -from ...struct_info import ShapeStructInfo +from ...type import ShapeType from .common import register_legalize @@ -31,12 +31,10 @@ def _redistribute_replica_to_shard(_bb: BlockBuilder, call: Call) -> Expr: num_workers = call.attrs.num_workers axis = call.attrs.axis worker_id_symbol = tirx.Var("worker_id", "int64") - worker_id_var = _bb.emit( - call_pure_packed("runtime.disco.worker_id", sinfo_args=[ShapeStructInfo(None)]) - ) - _bb.match_cast(worker_id_var, ShapeStructInfo([worker_id_symbol])) + worker_id_var = _bb.emit(call_pure_packed("runtime.disco.worker_id", ty_args=[ShapeType(None)])) + _bb.match_cast(worker_id_var, ShapeType([worker_id_symbol])) - split_axis_size = call.args[0].struct_info.shape[axis] + split_axis_size = call.args[0].ty.shape[axis] return relax.op.strided_slice( call.args[0], axes=[axis], diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index 9bc47bd676e5..b71d8958e0dd 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -22,7 +22,7 @@ from ...block_builder import BlockBuilder from ...expr import Call, Expr from ...op import tensor_to_shape -from ...struct_info import PrimStructInfo, ShapeStructInfo +from ...type import PrimType, ShapeType from .common import register_legalize @@ -37,15 +37,15 @@ def _take(bb: BlockBuilder, call: Call) -> Expr: def _strided_slice(bb: BlockBuilder, call: Call) -> Expr: def _relax_tuple_to_tir(relax_tuple): output = [] - for field in relax_tuple.struct_info.fields: - assert isinstance(field, PrimStructInfo) + for field in relax_tuple.ty.fields: + assert isinstance(field, PrimType) assert field.value is not None output.append(field.value) return output if len(call.args) == 4: data, axes, begin, end = call.args - strides = [tirx.IntImm("int64", 1)] * len(axes.struct_info.fields) + strides = [tirx.IntImm("int64", 1)] * len(axes.ty.fields) elif len(call.args) == 5: data, axes, begin, end, strides = call.args strides = _relax_tuple_to_tir(strides) @@ -113,10 +113,10 @@ def get_length(begin, end, strides, length): ) # 2. Convert tensor to shape and match cast with new symbolic vars - ndim = int(output_shape.struct_info.shape[0]) + ndim = int(output_shape.ty.shape[0]) output_shape = bb.emit(tensor_to_shape(output_shape)) output_shape_vars = [tirx.Var("s", "int64") for i in range(ndim)] - bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars)) + bb.match_cast(output_shape, ShapeType(output_shape_vars)) # 3. Pass the output shape vars to TOPI return bb.call_te( diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py index 2b4d1efd108f..00179964ad0d 100644 --- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -41,14 +41,14 @@ def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: is_a_larger = len(a_shape) > len(b_shape) offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) - a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) - b_relax = relax.Var("b", relax.TensorStructInfo(b.shape)) - f_infer_sinfo = call.op.get_attr("FInferStructInfo") - output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape + a_relax = relax.Var("a", relax.TensorType(a.shape)) + b_relax = relax.Var("b", relax.TensorType(b.shape)) + f_infer_ty = call.op.get_attr("FInferType") + output_shape = f_infer_ty(relax.op.matmul(a_relax, b_relax), bb).shape if isinstance(a_shape[-1], tirx.IntImm) and a_shape[-1] == 0: return te.compute( output_shape, - lambda *_: tirx.const(0, call.struct_info.dtype), + lambda *_: tirx.const(0, call.ty.dtype), name="matmul", ) @@ -98,12 +98,12 @@ def multiply_compute(idx_reduce): ) lhs, rhs = call.args - lhs_sinfo = call.args[0].struct_info - rhs_sinfo = call.args[1].struct_info - assert lhs_sinfo.dtype and rhs_sinfo.dtype, ( + lhs_ty = call.args[0].ty + rhs_ty = call.args[1].ty + assert lhs_ty.dtype and rhs_ty.dtype, ( f"To legalize R.matmul into R.call_tir, the dtype of both operands must be known. " - f"However, the LHS {lhs} has struct info {lhs_sinfo} (dtype='{lhs_sinfo.dtype}') " - f"and the RHS {rhs} has struct info {rhs_sinfo} (dtype='{rhs_sinfo.dtype}')." + f"However, the LHS {lhs} has type {lhs_ty} (dtype='{lhs_ty.dtype}') " + f"and the RHS {rhs} has type {rhs_ty} (dtype='{rhs_ty.dtype}')." ) return bb.call_te(te_matmul, call.args[0], call.args[1], primfunc_name_hint="matmul") @@ -111,7 +111,7 @@ def multiply_compute(idx_reduce): @register_legalize("relax.einsum") def _einsum(bb: BlockBuilder, call: Call) -> Expr: t = call.args[0] - n_field = len(t.struct_info.fields) + n_field = len(t.ty.fields) while isinstance(t, Var): binding = bb.lookup_binding(t) if not isinstance(binding, Tuple | Var): diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index ed1349c05f56..1f3abaaf6e48 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -21,7 +21,7 @@ import tvm from tvm import relax, s_tir, te, tirx, topi from tvm.relax.op.base import call_tir -from tvm.relax.struct_info import TensorStructInfo +from tvm.relax.type import TensorType from tvm.relax.utils import gen_call_tir_inputs from tvm.tirx.expr import IntImm @@ -34,7 +34,7 @@ def _reshape( te_func: TEFunc, primfunc_name: str, is_collapse_sum_like: bool = False ) -> LegalizeFunc: def reshape_call_te(bb: BlockBuilder, call: Call): - tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like else call.args[1] + tgt_shape = call.args[1].ty.shape if is_collapse_sum_like else call.args[1] # If target shape is Var, pass its bound expr only when it is ShapeExpr if isinstance(tgt_shape, Var): tgt_shape = bb.lookup_binding(tgt_shape) @@ -57,7 +57,7 @@ def reshape_call_te(bb: BlockBuilder, call: Call): @register_legalize("relax.concat") def _concat(bb: BlockBuilder, call: Call) -> Expr: t = call.args[0] - n_field = len(t.struct_info.fields) + n_field = len(t.ty.fields) while isinstance(t, Var): binding = bb.lookup_binding(t) if not isinstance(binding, Tuple | Var): @@ -76,9 +76,9 @@ def _concat(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.expand_dims") def _expand_dims(bb: BlockBuilder, call: Call) -> Expr: def te_expand_dims(data, axis): - data_relax = relax.Var("data", relax.TensorStructInfo(data.shape)) - f_infer_sinfo = call.op.get_attr("FInferStructInfo") - output_shape = f_infer_sinfo(relax.op.expand_dims(data_relax, axis), bb).shape + data_relax = relax.Var("data", relax.TensorType(data.shape)) + f_infer_ty = call.op.get_attr("FInferType") + output_shape = f_infer_ty(relax.op.expand_dims(data_relax, axis), bb).shape output_ndim = len(output_shape) data_dims = [] @@ -98,7 +98,7 @@ def te_expand_dims(data, axis): @register_legalize("relax.flatten") def _flatten(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.reshape, call.args[0], call.struct_info.shape.values) + return bb.call_te(topi.reshape, call.args[0], call.ty.shape.values) @register_legalize("relax.permute_dims") @@ -123,7 +123,7 @@ def _squeeze(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.stack") def _stack(bb: BlockBuilder, call: Call) -> Expr: t = call.args[0] - n_field = len(t.struct_info.fields) + n_field = len(t.ty.fields) # Follow bindings to find the actual tuple while isinstance(t, Var): @@ -189,7 +189,7 @@ def te_gather_nd(data, indices, batch_dims): @register_legalize("relax.index_tensor") def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: t = call.args[1] - n_field = len(t.struct_info.fields) + n_field = len(t.ty.fields) fields = [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] return bb.call_te(topi.index_tensor, call.args[0], fields) @@ -219,7 +219,7 @@ def _index_put(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.meshgrid") def _meshgrid(bb: BlockBuilder, call: Call) -> Expr: t = call.args[0] - n_field = len(t.struct_info.fields) + n_field = len(t.ty.fields) while isinstance(t, Var): binding = bb.lookup_binding(t) if not isinstance(binding, Tuple | Var): @@ -325,7 +325,7 @@ def set_axis_sep(axis_sep: list, sch: s_tir.schedule, buffer_type: str): if pad_value is not None: pad_value = pad_value.value else: - if "int" in call.args[0].struct_info.dtype: + if "int" in call.args[0].ty.dtype: pad_value = 0 else: pad_value = 0.0 @@ -336,7 +336,7 @@ def set_axis_sep(axis_sep: list, sch: s_tir.schedule, buffer_type: str): # Convert to list from array axis_separators = [int(sep) for sep in axis_separators] primfunc_name = "te_layout_transform" - _, padding_predicate = index_map.non_surjective_inverse(call.args[0].struct_info.shape) + _, padding_predicate = index_map.non_surjective_inverse(call.args[0].ty.shape) if not isinstance(padding_predicate, tvm.tirx.expr.IntImm): primfunc_name += "_with_pad" if len(axis_separators) != 0: @@ -351,7 +351,7 @@ def set_axis_sep(axis_sep: list, sch: s_tir.schedule, buffer_type: str): if input_axis_separators is not None: set_axis_sep(input_axis_separators, sch, "read") gvar = bb.add_func(sch.mod["main"], primfunc_name) - output_shape = index_map.map_shape(list(call_args[0].struct_info.shape)) - output_dtype = call_args[0].struct_info.dtype - output_sinfo = [TensorStructInfo(output_shape, output_dtype)] - return call_tir(gvar, call_args, output_sinfo, tir_vars) + output_shape = index_map.map_shape(list(call_args[0].ty.shape)) + output_dtype = call_args[0].ty.dtype + output_ty = [TensorType(output_shape, output_dtype)] + return call_tir(gvar, call_args, output_ty, tir_vars) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 35d81f968b37..f87c16aa0a14 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -44,8 +44,8 @@ def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr: if call.attrs.groups != 1: data_layout = s_tir.slayout(call.attrs.data_layout) kernel_layout = s_tir.slayout(call.attrs.kernel_layout) - ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] - oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] + ic = call.args[0].ty.shape.values[data_layout.index_of("C")] + oc = call.args[1].ty.shape.values[kernel_layout.index_of("O")] if not isinstance(ic, tirx.IntImm) or not isinstance(oc, tirx.IntImm): logging.info( "Conv1D where number of groups is more than one and input or output " @@ -85,8 +85,8 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr: if call.attrs.groups != 1: data_layout = s_tir.slayout(call.attrs.data_layout) kernel_layout = s_tir.slayout(call.attrs.kernel_layout) - ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] - oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] + ic = call.args[0].ty.shape.values[data_layout.index_of("C")] + oc = call.args[1].ty.shape.values[kernel_layout.index_of("O")] if not isinstance(ic, tirx.IntImm) or not isinstance(oc, tirx.IntImm): logging.info( "Conv2D where number of groups is more than one and input or output " @@ -126,8 +126,8 @@ def _nn_conv3d(bb: BlockBuilder, call: Call) -> Expr: if call.attrs.groups != 1: data_layout = s_tir.slayout(call.attrs.data_layout) kernel_layout = s_tir.slayout(call.attrs.kernel_layout) - ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] - oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] + ic = call.args[0].ty.shape.values[data_layout.index_of("C")] + oc = call.args[1].ty.shape.values[kernel_layout.index_of("O")] if not isinstance(ic, tirx.IntImm) or not isinstance(oc, tirx.IntImm): logging.info( "Conv3D where number of groups is more than one and input or output " @@ -178,7 +178,7 @@ def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr: call.args[1], stride=call.attrs.strides, padding=call.attrs.padding, - out_dtype=call.struct_info.dtype, + out_dtype=call.ty.dtype, output_padding=call.attrs.output_padding, groups=call.attrs.groups, primfunc_name_hint="conv1d_transpose", @@ -213,7 +213,7 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr: call.args[1], stride=call.attrs.strides, padding=call.attrs.padding, - out_dtype=call.struct_info.dtype, + out_dtype=call.ty.dtype, output_padding=call.attrs.output_padding, groups=call.attrs.groups, primfunc_name_hint="conv2d_transpose", @@ -250,7 +250,7 @@ def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> Expr: call.args[1], strides=call.attrs.strides, padding=call.attrs.padding, - out_dtype=call.struct_info.dtype, + out_dtype=call.ty.dtype, output_padding=call.attrs.output_padding, groups=call.attrs.groups, primfunc_name_hint="conv3d_transpose", @@ -817,6 +817,6 @@ def nll_loss_without_weight(predictions, targets, reduction, ignore_index): @register_legalize("relax.nn.batch_flatten") def _nn_batch_flatten(bb: BlockBuilder, call: Call) -> Expr: - if call.struct_info.shape is None: + if call.ty.shape is None: return call - return bb.call_te(topi.reshape, call.args[0], call.struct_info.shape.values) + return bb.call_te(topi.reshape, call.args[0], call.ty.shape.values) diff --git a/python/tvm/relax/transform/legalize_ops/search.py b/python/tvm/relax/transform/legalize_ops/search.py index cf87cd71f78c..65dd484c9403 100644 --- a/python/tvm/relax/transform/legalize_ops/search.py +++ b/python/tvm/relax/transform/legalize_ops/search.py @@ -47,6 +47,4 @@ def _bucketize(bb, call): input_tensor = call.args[0] boundaries = call.args[1] right = call.attrs.right - return bb.call_te( - topi.searchsorted, boundaries, input_tensor, right, input_tensor.struct_info.dtype - ) + return bb.call_te(topi.searchsorted, boundaries, input_tensor, right, input_tensor.ty.dtype) diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index 168cd7139997..51a621962413 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -40,7 +40,7 @@ def _normalize_reduction_axes(axis: list[int] | None, ndim: int) -> list[int]: def _has_const_zero_reduction_dim(call: Call) -> bool: - input_shape = call.args[0].struct_info.shape + input_shape = call.args[0].ty.shape if not isinstance(input_shape, ShapeExpr): return False @@ -58,14 +58,14 @@ def _statistical( def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: if zero_dim_identity is not None and _has_const_zero_reduction_dim(call): fill_value = ( - zero_dim_identity(call.struct_info.dtype) + zero_dim_identity(call.ty.dtype) if callable(zero_dim_identity) else zero_dim_identity ) return bb.call_te( topi.full, - call.struct_info.shape.values, - call.struct_info.dtype, + call.ty.shape.values, + call.ty.dtype, fill_value, ) return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index 4419549164cf..b675f2f43390 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -49,7 +49,7 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E score_threshold = call.args[4] output_format = call.attrs.output_format - scores_shape = scores.struct_info.shape + scores_shape = scores.ty.shape if len(scores_shape) == 3: _, _, num_boxes = scores_shape elif len(scores_shape) == 2: diff --git a/python/tvm/relax/transform/lower_gpu_ipc_alloc_storage.py b/python/tvm/relax/transform/lower_gpu_ipc_alloc_storage.py index b73b40d6fc9c..b2b08888d8e0 100644 --- a/python/tvm/relax/transform/lower_gpu_ipc_alloc_storage.py +++ b/python/tvm/relax/transform/lower_gpu_ipc_alloc_storage.py @@ -65,7 +65,7 @@ def rewrite_alloc_storage(self, call: relax.Call) -> relax.Call: return relax.Call( relax.ExternFunc("runtime.disco.cuda_ipc.alloc_storage"), args=[shape, dtype], - sinfo_args=[call.struct_info], + ty_args=[call.ty], ) def rewrite_alloc_tensor(self, call: relax.Call) -> relax.Call: @@ -74,10 +74,10 @@ def rewrite_alloc_tensor(self, call: relax.Call) -> relax.Call: ipc_alloc_storage = relax.Call( relax.ExternFunc("runtime.disco.cuda_ipc.alloc_storage"), args=[shape, dtype], - sinfo_args=[relax.ObjectStructInfo()], + ty_args=[relax.ObjectType()], ) return relax.Call( self.memory_alloc_tensor_op, args=[ipc_alloc_storage, call.args[2], shape, dtype, relax.PrimValue(0)], - sinfo_args=call.sinfo_args, + ty_args=call.ty_args, ) diff --git a/python/tvm/relax/transform/optimize_layout_transform.py b/python/tvm/relax/transform/optimize_layout_transform.py index 7dd071dd7ea8..0dc00b2f4f51 100644 --- a/python/tvm/relax/transform/optimize_layout_transform.py +++ b/python/tvm/relax/transform/optimize_layout_transform.py @@ -78,8 +78,8 @@ def rewriter(expr, matches): arg2 = matches[self.gv_] if "remove_pad" == self.mod[arg2].attrs["operator_name"]: arg2 = matches[self.input] - if hasattr(arg1.struct_info, "shape") and hasattr(arg2.struct_info, "shape"): - if tvm_ffi.structural_equal(arg1.struct_info.shape, arg2.struct_info.shape): + if hasattr(arg1.ty, "shape") and hasattr(arg2.ty, "shape"): + if tvm_ffi.structural_equal(arg1.ty.shape, arg2.ty.shape): return arg2 return expr diff --git a/python/tvm/relax/transform/remove_redundant_reshape.py b/python/tvm/relax/transform/remove_redundant_reshape.py index 11119f5e8ceb..8d3709e1679d 100644 --- a/python/tvm/relax/transform/remove_redundant_reshape.py +++ b/python/tvm/relax/transform/remove_redundant_reshape.py @@ -74,9 +74,7 @@ def rewriter(expr, matches): elif self.no_op_reshape in matches: output_shape = matches[self.no_op_reshape].args[1] - if arg.struct_info.shape and tvm_ffi.structural_equal( - arg.struct_info.shape, output_shape - ): + if arg.ty.shape and tvm_ffi.structural_equal(arg.ty.shape, output_shape): return arg return expr diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index a291fb973730..2ea1bad5f7b1 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -30,7 +30,7 @@ from tvm_ffi import Array import tvm.ir -from tvm.relax import Expr, StructInfo, Var +from tvm.relax import Expr, Type, Var from tvm.relax.dpl import DFPattern from tvm.runtime import Object, Tensor from tvm.tirx import IndexMap, PrimFunc @@ -421,8 +421,7 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass: def Normalize() -> tvm.ir.transform.Pass: """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting - and hence the AST is in ANF), and all `struct_info_` of expressions are - available. + and hence the AST is in ANF), and all `ty` fields of expressions are available. Returns ------- @@ -1449,20 +1448,19 @@ def SplitCallTIRByPattern(patterns: list[PrimFunc], fcodegen: Callable) -> tvm.i return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore -def UpdateParamStructInfo(sinfo_func: Callable[[Var], StructInfo | None]): - """Update struct info of parameters +def UpdateParamType(ty_func: Callable[[Var], Type | None]): + """Update parameter types. - Update struct info of parameters. Internal bindings and function - return type will be updated using relax's struct inference rules. - Errors resulting from struct inference will be propagated to the - user. + Internal bindings and the function return type are updated using Relax's + type inference rules. Errors resulting from type inference are propagated + to the user. Parameters ---------- - sinfo_func: Callable[[Var], Optional[StructInfo]] + ty_func: Callable[[Var], Optional[Type]] A function that is called once for each function parameter, - and returns the updated struct info to be used for it. If the + and returns the updated type to be used for it. If the function returns `None`, the parameter is not modified. Returns @@ -1471,7 +1469,7 @@ def UpdateParamStructInfo(sinfo_func: Callable[[Var], StructInfo | None]): The corresponding pass. """ - return _ffi_api.UpdateParamStructInfo(sinfo_func) # type: ignore + return _ffi_api.UpdateParamType(ty_func) # type: ignore def AdjustMatmulOrder(): @@ -1840,7 +1838,7 @@ class TestReplaceBinding: def __init__(self): # create a new VarBinding m, n = tirx.Var("m", "int64"), tirx.Var("n", "int64") - lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n], "float32")) + lv0 = relax.Var("lv1", relax.TensorType([m, n], "float32")) val = relax.const(np.random.rand(24, 56)) self.new_binding = relax.VarBinding(lv0, val) diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index afa25d0dd003..1e909c9382b7 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -15,58 +15,15 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-import -# ruff: noqa: F401 """The type nodes of the Relax language.""" import tvm_ffi -from tvm.ir import FuncType, Span, TupleType, Type +from tvm.ir import Span, Type from . import _ffi_api -@tvm_ffi.register_object("relax.ShapeType") -class ShapeType(Type): - """The type of shape in Relax. - - Parameters - ---------- - ndim : int - The number of dimensions of the shape. Use -1 for unknown ndim. - """ - - def __init__(self, ndim: int, span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore - - -@tvm_ffi.register_object("relax.ObjectType") -class ObjectType(Type): - """A type that corresponds to tvm::runtime::Object, is base of all possible object - values in TVM.""" - - def __init__(self, span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore - - -@tvm_ffi.register_object("relax.DynTensorType") -class TensorType(Type): - """A dynamic tensor type in Relax. - - This is the type assigned to tensors with a known dtype and unknown shape. - - Parameters - ---------- - ndim : Optional[int] - The ndim of the Tensor - - dtype : Optional[str] - The content data type. - """ - - def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.TensorType, ndim, dtype, span) # type: ignore - - @tvm_ffi.register_object("relax.PackedFuncType") class PackedFuncType(Type): """The type of ExternFunc in Relax.""" diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/type.py similarity index 68% rename from python/tvm/relax/struct_info.py rename to python/tvm/relax/type.py index e2f9141550f1..3ba5a86b90b9 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/type.py @@ -16,33 +16,32 @@ # under the License. # pylint: disable=invalid-name, unused-import # ruff: noqa: F401 -"""The struct info nodes of the Relax language.""" - -from typing import Optional, Union +"""The Relax type nodes, including richer dependent type nodes.""" import tvm_ffi from tvm_ffi import Array import tvm -from tvm.ir import EnvFunc, Span, VDevice +from tvm.ir import EnvFunc, Span, TupleType, VDevice from tvm.runtime import DataType from tvm.tirx import PrimExpr -from . import _ffi_api, expr, ty -from .expr import Expr, ShapeExpr, StructInfo +from . import _ffi_api +from .expr import Expr, ShapeExpr, Type +from .ty import PackedFuncType -@tvm_ffi.register_object("relax.ObjectStructInfo") -class ObjectStructInfo(StructInfo): - """StructInfo of an Object.""" +@tvm_ffi.register_object("relax.ObjectType") +class ObjectType(Type): + """Type of an Object.""" def __init__(self, span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore -@tvm_ffi.register_object("relax.PrimStructInfo") -class PrimStructInfo(StructInfo): - """StructInfo of a primitive POD value. +@tvm_ffi.register_object("relax.PrimType") +class PrimType(Type): + """Type of a primitive POD value. Parameters ---------- @@ -63,7 +62,7 @@ def __init__( ) -> None: # Guard against incorrect usage. For backwards compatibility, # the dtype and value are in the opposite order from most - # usages. While PrimStructInfo could take a single positional + # usages. While PrimType could take a single positional # argument and check the type, this would require an API # difference from TVMScript's PrimProxy, which cannot. # (PrimProxy uses string arguments for datatype, and also for @@ -72,23 +71,23 @@ def __init__( # the two cases.) if isinstance(dtype, PrimExpr | int | float): raise TypeError( - f"The first positional argument of PrimStructInfo must be the datatype, " + f"The first positional argument of PrimType must be the datatype, " f", but received {type(dtype)}. " f"The value can be specified as a keyword argument " f"without needing specifying the dtype: " - f"PrimStructInfo(value=arg)." + f"PrimType(value=arg)." ) if dtype is None and value is None: raise TypeError( - "PrimStructInfo.__init__ missing required argument. " + "PrimType.__init__ missing required argument. " "Must provide either 'dtype' or 'value'" ) if dtype is not None: if isinstance(value, PrimExpr): assert value.dtype == dtype, ( - "When providing both 'value' and 'dtype' to PrimStructInfo.__init__, " + "When providing both 'value' and 'dtype' to PrimType.__init__, " "they must be consistent with each other. " "However, the value {value} has dtype {value.dtype}, " "but the specified dtype was {dtype}." @@ -101,14 +100,14 @@ def __init__( value = tvm.tirx.IntImm("int64", value) if value is None: - self.__init_handle_by_constructor__(_ffi_api.PrimStructInfoFromDtype, dtype, span) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.PrimTypeFromDtype, dtype, span) # type: ignore else: - self.__init_handle_by_constructor__(_ffi_api.PrimStructInfoFromValue, value, span) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.PrimTypeFromValue, value, span) # type: ignore -@tvm_ffi.register_object("relax.ShapeStructInfo") -class ShapeStructInfo(StructInfo): - """StructInfo of a shape value. +@tvm_ffi.register_object("relax.ShapeType") +class ShapeType(Type): + """Type of a shape value. Parameters ---------- @@ -131,16 +130,16 @@ def __init__( self, values: list[PrimExpr] | None = None, ndim: int = -1, span: Span = None ) -> None: self.__init_handle_by_constructor__( - _ffi_api.ShapeStructInfo, + _ffi_api.ShapeType, values, ndim, span, # type: ignore ) -@tvm_ffi.register_object("relax.TensorStructInfo") -class TensorStructInfo(StructInfo): - """StructInfo of a Tensor value. +@tvm_ffi.register_object("relax.TensorType") +class TensorType(Type): + """Type of a Tensor value. Parameters ---------- @@ -178,7 +177,7 @@ def __init__( if isinstance(shape, list | tuple | Array): shape = ShapeExpr(shape) self.__init_handle_by_constructor__( - _ffi_api.TensorStructInfo, + _ffi_api.TensorType, shape, dtype, ndim, @@ -187,34 +186,17 @@ def __init__( ) -@tvm_ffi.register_object("relax.TupleStructInfo") -class TupleStructInfo(StructInfo): - """StructInfo of a Tuple value. - - Parameters - ---------- - fields: List[StructInfo] - The struct info of the fields. - """ - - fields: list[StructInfo] - span: Span - - def __init__(self, fields: list[StructInfo], span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.TupleStructInfo, fields, span) # type: ignore - - -@tvm_ffi.register_object("relax.FuncStructInfo") -class FuncStructInfo(StructInfo): - """StructInfo of a function value. +@tvm_ffi.register_object("relax.FuncType") +class FuncType(Type): + """Type of a function value. Parameters ---------- - params: List[StructInfo] - The struct info of the fields. + params: List[Type] + The type of the fields. - ret: StructInfo - The struct info of return value + ret: Type + The type of return value purity: bool Whether the function is pure (has no visible side effects). @@ -223,17 +205,17 @@ class FuncStructInfo(StructInfo): we still consider it impure. """ - params: list[StructInfo] | None - ret: StructInfo + params: list[Type] | None + ret: Type derive_func: EnvFunc | None purity: bool span: Span def __init__( - self, params: list[StructInfo], ret: StructInfo, purity: bool = True, span: Span = None + self, params: list[Type], ret: Type, purity: bool = True, span: Span = None ) -> None: self.__init_handle_by_constructor__( - _ffi_api.FuncStructInfo, + _ffi_api.FuncType, params, ret, purity, @@ -243,22 +225,22 @@ def __init__( @staticmethod def opaque_func( *, - ret: StructInfo | None = None, + ret: Type | None = None, derive_func: str | EnvFunc | None = None, purity: bool = False, span: Span = None, - ) -> "FuncStructInfo": + ) -> "FuncType": """ - Create an opaque FuncStructInfo. + Create an opaque FuncType. The opaque function takes either a ret - that specificies the struct info of the return value + that specificies the type of the return value or a derive_func that provides a customized derivation rule. Parameters ---------- - ret: Optional[StructInfo] - The struct info of the function return value. + ret: Optional[Type] + The type of the function return value. derive_func: Optional[Union[str,EnvFunc]] The environment function used for derivation @@ -271,7 +253,7 @@ def opaque_func( Returns ------- - info: FuncStructInfo + info: FuncType Note ---- @@ -279,5 +261,5 @@ def opaque_func( """ if isinstance(derive_func, str): - derive_func = tvm.ir.EnvFunc.get("tvm.relax.struct_info.infer_view_sinfo") - return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, purity, span) # type: ignore + derive_func = tvm.ir.EnvFunc.get("tvm.relax.type.infer_view_ty") + return _ffi_api.FuncTypeOpaqueFunc(ret, derive_func, purity, span) # type: ignore diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index b50a19cae19f..89c9ac82c1fa 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -38,7 +38,7 @@ from . import _ffi_api from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm, te_tensor from .expr import Tuple as rx_Tuple -from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo +from .type import PrimType, ShapeType, TensorType def metadata_partitioner(rx_txt: str) -> list[str]: @@ -146,7 +146,7 @@ def copy_with_new_vars(func: Function) -> Function: def gen_call_tir_inputs( func: Callable, *args: Any, **kwargs: Any -) -> tuple[tirx.PrimFunc, Expr, list[TensorStructInfo], ShapeExpr | None]: +) -> tuple[tirx.PrimFunc, Expr, list[TensorType], ShapeExpr | None]: """Generate the inputs for call_tir according to the te function. This function converts arguments from relax expression to te tensor, The callback func should return a te tensor or a list of te tensors. @@ -166,9 +166,9 @@ def gen_call_tir_inputs( Returns ------- - ret : Tuple[tirx.PrimFunc, Expr, List[TensorStructInfo], Optional[ShapeExpr]] + ret : Tuple[tirx.PrimFunc, Expr, List[TensorType], Optional[ShapeExpr]] ret contains the inputs for call_tir, including a tirx prim_func, args, - out_sinfo, and tir_vars. + out_ty, and tir_vars. """ tir_var_map: dict[tirx.Var, tirx.PrimExpr] = {} @@ -197,7 +197,7 @@ def _convert_te_arg(te_args: Any) -> Any: Common values of type int, float, and str are preserved. In dynamic shape cases, the passed in arguments may contain TIR variable. - For example, the argument can be a Relax Var with TensorStructInfo, which + For example, the argument can be a Relax Var with TensorType, which has symbolic shape, or the argument can be a ShapeExpr with symbolic variables. To make the PrimFunc generated has independent variables with the caller Relax function, we will substitute the TIR variables in the input @@ -221,11 +221,11 @@ def _convert_te_arg(te_args: Any) -> Any: def _convert_te_arg_helper(arg): if isinstance(arg, Expr): # type: ignore - if isinstance(arg.struct_info, TensorStructInfo): - assert isinstance(arg.struct_info.shape, ShapeExpr), ( + if isinstance(arg.ty, TensorType): + assert isinstance(arg.ty.shape, ShapeExpr), ( "emit_te now only supports Tensor that has ShapeExpr shape" ) - for shape_value in arg.struct_info.shape.values: + for shape_value in arg.ty.shape.values: _copy_undefined_var(shape_value) n_args = len(create_primfunc_args) @@ -243,14 +243,14 @@ def _convert_te_arg_helper(arg): return te_arg - if isinstance(arg.struct_info, ShapeStructInfo): + if isinstance(arg.ty, ShapeType): assert isinstance(arg, ShapeExpr), ( - "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" + "For Expr having ShapeType, emit_te now only supports ShapeExpr" ) return [_convert_te_arg_helper(val) for val in arg.values] - if isinstance(arg.struct_info, PrimStructInfo): - if arg.struct_info.value is None: + if isinstance(arg.ty, PrimType): + if arg.ty.value is None: n_args = len(create_primfunc_args) if isinstance(arg, tvm.relax.Var): name = arg.name_hint @@ -259,14 +259,14 @@ def _convert_te_arg_helper(arg): else: name = f"scalar_input_{n_args}" - tir_param = tirx.Var(name, arg.struct_info.dtype) + tir_param = tirx.Var(name, arg.ty.dtype) call_tir_args.append(arg) create_primfunc_args.append(tir_param) return tir_param else: - return _convert_te_arg_helper(arg.struct_info.value) + return _convert_te_arg_helper(arg.ty.value) elif isinstance(arg, list | Array): return [_convert_te_arg_helper(x) for x in arg] @@ -326,8 +326,8 @@ def _get_vdevice(arg: Any) -> VDevice | None: """get the virtual device from arguments.""" vdevice = None if isinstance(arg, Expr): # type: ignore - if isinstance(arg.struct_info, TensorStructInfo): - vdevice = arg.struct_info.vdevice + if isinstance(arg.ty, TensorType): + vdevice = arg.ty.vdevice elif isinstance(arg, list | Array | tuple): for x in arg: vdevice = _get_vdevice(x) @@ -348,7 +348,7 @@ def _shape_with_old_tir_var( ) primfunc_attrs = kwargs.pop("primfunc_attrs", None) - custom_out_sinfo = kwargs.pop("sinfo_args", []) + custom_out_ty = kwargs.pop("ty_args", []) te_args = _convert_te_arg(args) te_kwargs = _convert_te_arg(kwargs) @@ -373,11 +373,11 @@ def _shape_with_old_tir_var( # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} - if len(custom_out_sinfo) == 1: - output_sinfo = custom_out_sinfo[0] + if len(custom_out_ty) == 1: + output_ty = custom_out_ty[0] else: - output_sinfo = [ - TensorStructInfo( + output_ty = [ + TensorType( _shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype, _get_vdevice(args), @@ -389,4 +389,4 @@ def _shape_with_old_tir_var( if len(unbound_tir_vars) > 0: tir_vars = _shape_with_old_tir_var(unbound_tir_vars, tir_var_inverse_map) - return (tir_func, call_tir_args, output_sinfo, tir_vars) + return (tir_func, call_tir_args, output_ty, tir_vars) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 238973725fbc..3f67e285de53 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -66,7 +66,7 @@ def __init__( num_context_lines: int | None = None, syntax_sugar: bool = True, show_object_address: bool = False, - show_all_struct_info: bool = True, + show_all_ty: bool = True, extra_config: dict | None = None, path_to_underline: list[AccessPath] | None = None, path_to_annotate: dict[AccessPath, str] | None = None, @@ -93,7 +93,7 @@ def __init__( "obj_to_underline": obj_to_underline, "obj_to_annotate": obj_to_annotate, # Dialect-specific config via dotted keys in extra_config - "relax.show_all_struct_info": show_all_struct_info, + "relax.show_all_ty": show_all_ty, } if name is not None: @@ -133,7 +133,7 @@ def script( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, - show_all_struct_info: bool = True, + show_all_ty: bool = True, extra_config: dict | None = None, path_to_underline: list[AccessPath] | None = None, path_to_annotate: dict[AccessPath, str] | None = None, @@ -169,7 +169,7 @@ def script( Whether to output with syntax sugar, set false for complete printing. show_object_address: bool = False Whether to include the object's address as part of the TVMScript name - show_all_struct_info: bool = True + show_all_ty: bool = True If True (default), annotate all variable bindings with the struct info of that variable. If False, only add annotations where required for unambiguous round-trip of Relax -> TVMScript -> Relax. @@ -241,7 +241,7 @@ def script( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, - show_all_struct_info=show_all_struct_info, + show_all_ty=show_all_ty, extra_config=merged_extra if merged_extra else None, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, @@ -311,7 +311,7 @@ def show( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, - show_all_struct_info: bool = True, + show_all_ty: bool = True, extra_config: dict | None = None, path_to_underline: list[AccessPath] | None = None, path_to_annotate: dict[AccessPath, str] | None = None, @@ -370,7 +370,7 @@ def show( Whether to output with syntax sugar, set false for complete printing. show_object_address: bool = False Whether to include the object's address as part of the TVMScript name - show_all_struct_info: bool = True + show_all_ty: bool = True If True (default), annotate all variable bindings with the struct info of that variable. If False, only add annotations where required for unambiguous round-trip of Relax -> TVMScript -> Relax. @@ -406,7 +406,7 @@ def show( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, - show_all_struct_info=show_all_struct_info, + show_all_ty=show_all_ty, extra_config=extra_config, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, diff --git a/python/tvm/s_tir/dlight/benchmark/bench.py b/python/tvm/s_tir/dlight/benchmark/bench.py index aa7aefc02cb2..9c759fae069b 100644 --- a/python/tvm/s_tir/dlight/benchmark/bench.py +++ b/python/tvm/s_tir/dlight/benchmark/bench.py @@ -43,7 +43,7 @@ def benchmark( mod_or_func: PrimFunc | IRModule, *, dym_var_sample: dict[str, int], - args: list[relax.TensorStructInfo | tuple[tuple[int | str, ...], str]] | None, + args: list[relax.TensorType | tuple[tuple[int | str, ...], str]] | None, target: str | tvm.target.Target | None = None, func_name: str | None = None, evaluator_config: Optional["EvaluatorConfig"] = None, @@ -57,7 +57,7 @@ def benchmark( The PrimFunc or IRModule to be benchmarked. dym_var_sample : Optional[Dict[str, int]] The dynamic shape variable sample, e.g., {"n": 64, "m": 128}. - args : Optional[List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, str], ...], str]]]] + args : Optional[List[Union[relax.TensorType, Tuple[Tuple[Union[int, str], ...], str]]]] The input tensor information, including shape and dtype. If none, will use the input information from the PrimFunc or IRModule. target : Optional[Union[str, tvm.target.Target]] @@ -156,7 +156,7 @@ def benchmark_prim_func( mod_or_func: PrimFunc | IRModule, *, dym_var_sample_func: Callable[[dict[str, str]], dict[str, int]] = default_dym_var_sample_func, - args: list[relax.TensorStructInfo | tuple[tuple[int | str, ...], str]] | None = None, + args: list[relax.TensorType | tuple[tuple[int | str, ...], str]] | None = None, dym_var_dict: dict[str, str] | None = None, sample_number: int = 5, target: str | tvm.target.Target | None = None, @@ -179,7 +179,7 @@ def benchmark_prim_func( dym_var_dict : Optional[Dict[str, str]] Dynamic shape variable dictionary, e.g., {"n": "int32", "m": "int32"}. If none, will use the input information from the PrimFunc or IRModule. - args : Optional[List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, str], ...], str]]]] + args : Optional[List[Union[relax.TensorType, Tuple[Tuple[Union[int, str], ...], str]]]] The input tensor information, including shape and dtype. If none, will use the input information from the PrimFunc or IRModule. sample_number : int diff --git a/python/tvm/s_tir/dlight/benchmark/extract.py b/python/tvm/s_tir/dlight/benchmark/extract.py index ee7358efcfa6..82052120b158 100644 --- a/python/tvm/s_tir/dlight/benchmark/extract.py +++ b/python/tvm/s_tir/dlight/benchmark/extract.py @@ -66,18 +66,18 @@ def extract_shape( - arg: tuple | list | relax.Tuple | relax.ShapeStructInfo, -) -> list[relax.ShapeStructInfo]: + arg: tuple | list | relax.Tuple | relax.ShapeType, +) -> list[relax.ShapeType]: """Extract shape information from a relax argument. Parameters ---------- - arg : Union[Tuple, List, relax.Tuple, relax.ShapeStructInfo] + arg : Union[Tuple, List, relax.Tuple, relax.ShapeType] The relax argument to be extracted. Returns ------- - result : List[relax.ShapeStructInfo] + result : List[relax.ShapeType] The extracted shape information. """ if isinstance(arg, tuple | list | tvm.relax.Tuple): @@ -85,7 +85,7 @@ def extract_shape( for sub_arg in arg: results.extend(extract_shape(sub_arg)) return results - return [arg.struct_info] + return [arg.ty] def extract_dynamic_var( @@ -122,16 +122,16 @@ def extract_dynamic_var( for arg_list, _ in func_dict[gv][functor]: flattened_arg_list = [] for arg in arg_list: - if isinstance(arg, relax.TupleStructInfo): + if isinstance(arg, relax.TupleType): flattened_arg_list.extend(arg.fields) else: flattened_arg_list.append(arg) for arg in flattened_arg_list: - if isinstance(arg, relax.TensorStructInfo): + if isinstance(arg, relax.TensorType): for val in arg.shape.values: if isinstance(val, tvm.tirx.Var): dym_var_dict[gv][str(val)] = val.dtype - elif isinstance(arg, relax.ShapeStructInfo): + elif isinstance(arg, relax.ShapeType): for val in arg.values: if isinstance(val, tvm.tirx.Var): dym_var_dict[gv][str(val)] = val.dtype @@ -141,15 +141,15 @@ def extract_dynamic_var( def update_records( - records: dict[list[relax.ShapeStructInfo], int], new_args: list[relax.ShapeStructInfo] + records: dict[list[relax.ShapeType], int], new_args: list[relax.ShapeType] ) -> None: """Update the count of a function input argument config. Parameters ---------- - records : Dict[List[relax.ShapeStructInfo], int] + records : Dict[List[relax.ShapeType], int] The dictionary to count how many times a function input argument config appears. - new_args : List[relax.ShapeStructInfo] + new_args : List[relax.ShapeType] The new input argument config. """ for i, (args, count) in enumerate(records): diff --git a/python/tvm/s_tir/dlight/benchmark/utils.py b/python/tvm/s_tir/dlight/benchmark/utils.py index e25a5968499c..4d832756d96b 100644 --- a/python/tvm/s_tir/dlight/benchmark/utils.py +++ b/python/tvm/s_tir/dlight/benchmark/utils.py @@ -57,7 +57,7 @@ def dym_var_sample_str(sample: dict[str | tvm.relax.expr.Call, int]) -> str: def populuate_input_shape( - input_infos: list[relax.TensorStructInfo | tuple[tuple[int | str, ...], str]], + input_infos: list[relax.TensorType | tuple[tuple[int | str, ...], str]], dym_var_sample: dict[str, int], ) -> INPUT_SHAPE_TYPE: """ @@ -65,7 +65,7 @@ def populuate_input_shape( Parameters ---------- - input_infos : List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, str], ...], str]]] + input_infos : List[Union[relax.TensorType, Tuple[Tuple[Union[int, str], ...], str]]] Input tensor information, including shape and dtype, e.g., [..., Shape(1, n, 128) with dtype="int32", ...] dym_var_sample : Dict[str, int] @@ -81,11 +81,11 @@ def populuate_input_shape( results: INPUT_SHAPE_TYPE = [] for input_info in input_infos: shape = [] - if isinstance(input_info, relax.struct_info.ShapeStructInfo): + if isinstance(input_info, relax.ShapeType): # scalar input results.append(((dym_var_sample[str(input_info.values[0])],), "scalar")) else: - if isinstance(input_info, relax.TensorStructInfo): + if isinstance(input_info, relax.TensorType): tensor_shape = input_info.shape tensor_dtype = input_info.dtype else: diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index aeeb68ca2f19..a7a2889c444b 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -474,10 +474,10 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args): assert isinstance(global_var, tvm.ir.GlobalVar) dtype = "void" - if global_var.struct_info is not None: - ret_sinfo = global_var.struct_info.ret - if hasattr(ret_sinfo, "dtype"): - dtype = ret_sinfo.dtype + if global_var.ty is not None: + ret_ty = global_var.ty.ret + if hasattr(ret_ty, "dtype"): + dtype = ret_ty.dtype return Call(dtype=dtype, op=global_var, args=args) diff --git a/src/ir/module.cc b/src/ir/module.cc index 156ca17c1255..2d31692788ba 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -29,7 +29,6 @@ #include #include #include -#include #include #include diff --git a/src/ir/type.cc b/src/ir/type.cc index 2de056809a3a..d6d059dba079 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -92,7 +92,8 @@ TupleType TupleType::Empty() { return TupleType(ffi::Array()); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("ir.TupleType", [](ffi::Array fields) { return TupleType(fields); }) + .def("ir.TupleType", + [](ffi::Array fields, Span span) { return TupleType(fields, span); }) .def("ir.TensorMapType", [](Span span) { return TensorMapType(span); }); } diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc deleted file mode 100644 index 699a3d97da1c..000000000000 --- a/src/ir/type_functor.cc +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file type_functor.cc - * \brief Implementations of type functors. - */ -#include -#include - -#include - -namespace tvm { - -void TypeVisitor::VisitType_(const FuncTypeNode* op) { - for (auto arg_type : op->arg_types) { - this->VisitType(arg_type); - } - this->VisitType(op->ret_type); -} - -void TypeVisitor::VisitType_(const TupleTypeNode* op) { - for (const Type& t : op->fields) { - this->VisitType(t); - } -} - -void TypeVisitor::VisitType_(const PrimTypeNode* op) {} - -void TypeVisitor::VisitType_(const PointerTypeNode* op) { this->VisitType(op->element_type); } - -Type TypeMutator::VisitType(const Type& t) { - return t.defined() ? TypeFunctor::VisitType(t) : t; -} - -// Type Mutator. -ffi::Array TypeMutator::MutateArray(ffi::Array arr) { - // The array will do copy on write - // If no changes are made, the original array will be returned. - return arr.Map([this](const Type& ty) { return VisitType(ty); }); -} - -Type TypeMutator::VisitType_(const FuncTypeNode* op) { - bool changed = false; - - ffi::Array new_args = MutateArray(op->arg_types); - changed = changed || !new_args.same_as(op->arg_types); - - Type new_ret_type = VisitType(op->ret_type); - changed = changed || !new_ret_type.same_as(op->ret_type); - - if (!changed) return ffi::GetRef(op); - return FuncType(new_args, new_ret_type); -} - -Type TypeMutator::VisitType_(const TupleTypeNode* op) { - ffi::Array new_fields = MutateArray(op->fields); - if (new_fields.same_as(op->fields)) { - return ffi::GetRef(op); - } else { - return TupleType(new_fields); - } -} - -Type TypeMutator::VisitType_(const PrimTypeNode* op) { return ffi::GetRef(op); } - -Type TypeMutator::VisitType_(const PointerTypeNode* op) { - Type element_type = VisitType(op->element_type); - - if (element_type.same_as(op->element_type)) { - return ffi::GetRef(op); - } else { - return PointerType(element_type, op->storage_scope); - } -} - -} // namespace tvm diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index d48260b22917..9d9318992dc4 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -111,8 +111,8 @@ class VarVisitor : protected ExprVisitor { VisitSpan(call_node->span); VisitExpr(call_node->op); - for (StructInfo sinfo_arg : call_node->sinfo_args) { - VisitExprDepStructInfoField(sinfo_arg); + for (Type ty_arg : call_node->ty_args) { + VisitExprDepTypeField(ty_arg); } for (Expr arg : call_node->args) { diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index 0d7da4317b82..3610fb49bb85 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -79,7 +79,7 @@ class CompileTimeCollector : ExprVisitor { void MarkAsKnown(const Var& var) { known_relax_vars_.insert(var); - for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(var))) { + for (const auto& tir_var : DefinableTIRVarsInType(GetType(var))) { known_tir_vars_.insert(tir_var); } } diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/type_analysis.cc similarity index 66% rename from src/relax/analysis/struct_info_analysis.cc rename to src/relax/analysis/type_analysis.cc index 932e7efeedfa..b6c272a827b5 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/type_analysis.cc @@ -18,17 +18,17 @@ */ /*! - * \file struct_info_analysis.cc - * \brief Implementations of foundation struct info analysis + * \file type_analysis.cc + * \brief Implementations of foundational Relax type analysis. * - * \note Update this file when you added a new StructInfo. + * \note Update this file when you added a new Type. */ #include #include #include #include #include -#include +#include #include #include #include @@ -39,93 +39,95 @@ namespace relax { //-------------------------- // GetStaticType //-------------------------- -class StaticTypeDeriver : public StructInfoFunctor { +class StaticTypeDeriver : public TypeFunctor { public: - Type VisitStructInfo_(const ObjectStructInfoNode* op) final { return ObjectType(op->span); } + Type VisitType_(const ObjectTypeNode* op) final { return ObjectType(op->span); } - Type VisitStructInfo_(const PrimStructInfoNode* op) final { - return PrimType(op->dtype, op->span); - } + Type VisitType_(const PrimTypeNode* op) final { return PrimType(op->dtype, op->span); } - Type VisitStructInfo_(const ShapeStructInfoNode* op) final { - return ShapeType(op->ndim, op->span); - } + Type VisitType_(const ShapeTypeNode* op) final { return ShapeType(op->ndim, op->span); } - Type VisitStructInfo_(const TensorStructInfoNode* op) final { - return TensorType(op->ndim, op->dtype); + Type VisitType_(const TensorTypeNode* op) final { + return TensorType(op->dtype, op->ndim, op->vdevice, op->span); } // module: distributed - Type VisitStructInfo_(const distributed::DTensorStructInfoNode* op) final { return ObjectType(); } + Type VisitType_(const distributed::DTensorTypeNode* op) final { return ObjectType(); } // end-module: distributed - Type VisitStructInfo_(const TupleStructInfoNode* op) final { + Type VisitType_(const TupleTypeNode* op) final { ffi::Array fields = - op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + op->fields.Map([this](const Type& ty) { return this->VisitType(ty); }); return TupleType(fields, op->span); } - Type VisitStructInfo_(const FuncStructInfoNode* op) final { + Type VisitType_(const FuncTypeNode* op) final { if (op->IsOpaque()) return PackedFuncType(op->span); - ffi::Array params = op->params.value().Map( - [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); - Type ret = this->VisitStructInfo(op->ret); - return FuncType(params, ret, op->span); + ffi::Array params = + op->params.value().Map([this](const Type& ty) { return this->VisitType(ty); }); + Type ret = this->VisitType(op->ret); + return FuncType(params, ret, op->purity, op->span); } }; -Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } +Type GetStaticType(const Type& info) { return StaticTypeDeriver()(info); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.GetStaticType", - [](const StructInfo& info) { return GetStaticType(info); }); + [](const Type& info) { return GetStaticType(info); }); } //-------------------------- -// StructInfoFromType +// TypeFromStaticType //-------------------------- -StructInfo StructInfoFromType(const Type& type) { +Type TypeFromStaticType(const Type& type) { if (type.as()) { - return ObjectStructInfo(type->span); + return ObjectType(type->span); } else if (const PrimTypeNode* prim_type = type.as()) { - return PrimStructInfo(prim_type->dtype, prim_type->span); + return PrimType(prim_type->dtype, prim_type->span); + } else if (const tvm::PrimTypeNode* prim_type = type.as()) { + return PrimType(prim_type->dtype, prim_type->span); } else if (const ShapeTypeNode* shape_type = type.as()) { - return ShapeStructInfo(shape_type->ndim, type->span); + return ShapeType(shape_type->ndim, type->span); } else if (const TensorTypeNode* tensor_type = type.as()) { - return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); + return TensorType(tensor_type->dtype, tensor_type->ndim); } else if (const TupleTypeNode* tuple_type = type.as()) { - ffi::Array fields; + ffi::Array fields; for (const Type& field : tuple_type->fields) { - fields.push_back(StructInfoFromType(field)); + fields.push_back(TypeFromStaticType(field)); } - return TupleStructInfo(fields, type->span); + return TupleType(fields, type->span); } else if (const FuncTypeNode* func_type = type.as()) { - ffi::Array params = - func_type->arg_types.Map([](const Type& param) { return StructInfoFromType(param); }); - StructInfo ret = StructInfoFromType(func_type->ret_type); + if (func_type->IsOpaque()) return FuncType::OpaqueFunc(func_type->ret, func_type->purity); + ffi::Array params = + func_type->params.value().Map([](const Type& param) { return TypeFromStaticType(param); }); + Type ret = TypeFromStaticType(func_type->ret); + return FuncType(params, ret, func_type->purity, func_type->span); + } else if (const tvm::FuncTypeNode* func_type = type.as()) { + ffi::Array params = + func_type->arg_types.Map([](const Type& param) { return TypeFromStaticType(param); }); + Type ret = TypeFromStaticType(func_type->ret_type); // TODO(relax-team): Maybe add purity into the type as well - return FuncStructInfo(params, ret, true, func_type->span); + return FuncType(params, ret, true, func_type->span); } else { TVM_FFI_THROW(InternalError) << "Unsupported type: " << type; - return StructInfo(); + return Type(); } } //-------------------------- // EraseToWellDefined //-------------------------- -class WellDefinedEraser : public StructInfoMutator, - public ExprMutatorBase, - public tirx::ExprMutator { +class WellDefinedEraser : public TypeMutator, public ExprMutatorBase, public tirx::ExprMutator { public: WellDefinedEraser(std::function(const tirx::Var& var)> f_shape_var_map, std::function(const Var& var)> f_var_map, arith::AnalyzerObj* ana) : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} - StructInfo VisitStructInfo_(const PrimStructInfoNode* op) final { + Type VisitType_(const PrimTypeNode* op) final { bool has_undefined = false; ffi::Optional value; @@ -138,16 +140,16 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (value.same_as(op->value)) { - return ffi::GetRef(op); + return ffi::GetRef(op); } else { - return PrimStructInfo(value.value(), op->span); + return PrimType(value.value(), op->span); } } else { - return PrimStructInfo(op->dtype, op->span); + return PrimType(op->dtype, op->span); } } - StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final { + Type VisitType_(const ShapeTypeNode* op) final { bool has_undefined = false; ffi::Optional> values; @@ -159,16 +161,16 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (values.same_as(op->values)) { - return ffi::GetRef(op); + return ffi::GetRef(op); } else { - return ShapeStructInfo(values.value(), op->span); + return ShapeType(values.value(), op->span); } } else { - return ShapeStructInfo(op->ndim, op->span); + return ShapeType(op->ndim, op->span); } } - StructInfo VisitStructInfo_(const TensorStructInfoNode* op) final { + Type VisitType_(const TensorTypeNode* op) final { bool has_undefined = false; ffi::Optional shape; @@ -183,25 +185,25 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (shape.same_as(op->shape)) { - return ffi::GetRef(op); + return ffi::GetRef(op); } else { if (shape.defined()) { - return TensorStructInfo(shape.value(), op->dtype, vdev, op->span); + return TensorType(shape.value(), op->dtype, vdev, op->span); } else { - return TensorStructInfo(op->dtype, op->ndim, vdev, op->span); + return TensorType(op->dtype, op->ndim, vdev, op->span); } } } else { - return TensorStructInfo(op->dtype, op->ndim, vdev, op->span); + return TensorType(op->dtype, op->ndim, vdev, op->span); } } - StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final { - // NOTE: we always require func struct info to be well-defined. + Type VisitType_(const FuncTypeNode* op) final { + // NOTE: we always require func type to be well-defined. // // All the occuring symbolic variables are defined in parameters' - // struct info annotations. So there is no needed to erase. - return ffi::GetRef(op); + // type annotations. So there is no needed to erase. + return ffi::GetRef(op); } using relax::ExprMutatorBase::VisitExpr_; @@ -226,7 +228,7 @@ class WellDefinedEraser : public StructInfoMutator, has_undefined_ = has_undefined_ || !ret.defined(); if (ret.defined()) { TVM_FFI_ICHECK(ret.as() || ret.as()) - << "Only allow Expr in StructInfo to be ShapeExpr or Var"; + << "Only allow Expr in Type to be ShapeExpr or Var"; } return ret.value_or(ffi::GetRef(var)); } @@ -258,29 +260,27 @@ class WellDefinedEraser : public StructInfoMutator, arith::AnalyzerObj* ana_; }; -StructInfo EraseToWellDefined( - const StructInfo& info, - std::function(const tirx::Var& var)> f_shape_var_map, +Type EraseToWellDefined( + const Type& info, std::function(const tirx::Var& var)> f_shape_var_map, std::function(const Var& var)> f_var_map) { arith::Analyzer analyzer; return EraseToWellDefined(info, f_shape_var_map, f_var_map, analyzer); } -StructInfo EraseToWellDefined( - const StructInfo& info, - std::function(const tirx::Var& var)> f_shape_var_map, +Type EraseToWellDefined( + const Type& info, std::function(const tirx::Var& var)> f_shape_var_map, std::function(const Var& var)> f_var_map, const arith::Analyzer& ana) { - return WellDefinedEraser(f_shape_var_map, f_var_map, ana.get()).VisitStructInfo(info); + return WellDefinedEraser(f_shape_var_map, f_var_map, ana.get()).VisitType(info); } -StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, - ffi::Map var_map) { +Type EraseToWellDefined(const Type& info, ffi::Map shape_var_map, + ffi::Map var_map) { arith::Analyzer analyzer; return EraseToWellDefined(info, shape_var_map, var_map, analyzer); } -StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, - ffi::Map var_map, const arith::Analyzer& ana) { +Type EraseToWellDefined(const Type& info, ffi::Map shape_var_map, + ffi::Map var_map, const arith::Analyzer& ana) { std::function(const tirx::Var& var)> f_shape_var_map = nullptr; std::function(const Var& var)> f_var_map = nullptr; @@ -307,34 +307,33 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.analysis.EraseToWellDefined", - [](const StructInfo& info, ffi::Map shape_var_map, + [](const Type& info, ffi::Map shape_var_map, ffi::Map var_map) { return EraseToWellDefined(info, shape_var_map, var_map); }); } //-------------------------- // IsBaseOf //-------------------------- -class StructInfoBaseChecker - : public StructInfoFunctor { +class TypeBaseChecker : public TypeFunctor { public: - explicit StructInfoBaseChecker(arith::AnalyzerObj* ana) : analyzer_(ana) {} + explicit TypeBaseChecker(arith::AnalyzerObj* ana) : analyzer_(ana) {} - BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { + BaseCheckResult VisitType(const Type& lhs, const Type& other) override { // quick path - // Note: subclass may disable this quick path if we need to go over all struct info. + // Note: subclass may disable this quick path if we need to go over all type. if (lhs.same_as(other)) return BaseCheckResult::kPass; - return StructInfoFunctor::VisitStructInfo(lhs, other); + return TypeFunctor::VisitType(lhs, other); } // ffi::Object is base of everything - BaseCheckResult VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + BaseCheckResult VisitType_(const ObjectTypeNode* lhs, const Type& other) final { return BaseCheckResult::kPass; } - BaseCheckResult VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); + BaseCheckResult VisitType_(const PrimTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) return BaseCheckResult::kFailL1; + if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } @@ -348,10 +347,10 @@ class StructInfoBaseChecker return PrimValueMatchCheck(lhs->value.value(), rhs->value.value()); } - BaseCheckResult VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); + BaseCheckResult VisitType_(const ShapeTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) return BaseCheckResult::kFailL1; + if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } // lhs have unknown ndim @@ -372,10 +371,10 @@ class StructInfoBaseChecker return ShapeMatchCheck(lhs->values.value(), rhs->values.value()); } - BaseCheckResult VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); + BaseCheckResult VisitType_(const TensorTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) return BaseCheckResult::kFailL1; + if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } // dtype mismatch @@ -415,15 +414,13 @@ class StructInfoBaseChecker } // module: distributed - BaseCheckResult VisitStructInfo_(const distributed::DTensorStructInfoNode* lhs, - const StructInfo& other) final { - auto* rhs = other.as(); + BaseCheckResult VisitType_(const distributed::DTensorTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) return BaseCheckResult::kFailL1; + if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } - BaseCheckResult tensor_sinfo_check_result = - this->VisitStructInfo(lhs->tensor_sinfo, rhs->tensor_sinfo); + BaseCheckResult tensor_ty_check_result = this->VisitType(lhs->tensor_ty, rhs->tensor_ty); BaseCheckResult other_check_result; if (!struct_equal_(lhs->device_mesh, rhs->device_mesh) || !struct_equal_(lhs->placement, rhs->placement)) { @@ -431,24 +428,23 @@ class StructInfoBaseChecker } else { other_check_result = BaseCheckResult::kPass; } - return CombineCheck(tensor_sinfo_check_result, other_check_result); + return CombineCheck(tensor_ty_check_result, other_check_result); } // end-module: distributed - BaseCheckResult VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); + BaseCheckResult VisitType_(const TupleTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) return BaseCheckResult::kFailL1; + if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } return ArrayCheck(lhs->fields, rhs->fields); } - BaseCheckResult VisitStructInfo_(const FuncStructInfoNode* lhs, - const StructInfo& other) override { - auto* rhs = other.as(); + BaseCheckResult VisitType_(const FuncTypeNode* lhs, const Type& other) override { + auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) return BaseCheckResult::kFailL1; + if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } @@ -465,7 +461,7 @@ class StructInfoBaseChecker : BaseCheckResult::kFailL2; } // no derivation function, only depends on ret - return this->VisitStructInfo(lhs->ret, rhs->ret); + return this->VisitType(lhs->ret, rhs->ret); } // Function check is best effort. @@ -486,10 +482,10 @@ class StructInfoBaseChecker // // Given we only do best effort checking in these cases, and such cases // are likely not a primary concern atm, we take this approach here. - if (struct_equal_(ffi::GetRef(lhs), other)) return BaseCheckResult::kPass; + if (struct_equal_(ffi::GetRef(lhs), other)) return BaseCheckResult::kPass; auto param_check = FuncParamsCheck(lhs->params.value(), rhs->params.value()); - auto ret_check = this->VisitStructInfo(lhs->ret, rhs->ret); + auto ret_check = this->VisitType(lhs->ret, rhs->ret); return CombineCheck(param_check, ret_check); } @@ -561,8 +557,8 @@ class StructInfoBaseChecker * \param rhs The right hand params. * \return Check result. */ - virtual BaseCheckResult FuncParamsCheck(const ffi::Array& lhs, - const ffi::Array& rhs) { + virtual BaseCheckResult FuncParamsCheck(const ffi::Array& lhs, + const ffi::Array& rhs) { auto res = ArrayCheck(lhs, rhs); // treat L1 failures in params checking as L2. if (res == BaseCheckResult::kFailL1) res = BaseCheckResult::kFailL2; @@ -593,12 +589,12 @@ class StructInfoBaseChecker * \param lhs The left operand. * \param rhs The right operand. */ - BaseCheckResult ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { + BaseCheckResult ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; BaseCheckResult ret = BaseCheckResult::kPass; for (size_t i = 0; i < lhs.size(); ++i) { - auto cmp_ret = this->VisitStructInfo(lhs[i], rhs[i]); + auto cmp_ret = this->VisitType(lhs[i], rhs[i]); if (ret == BaseCheckResult::kFailL0) return ret; ret = CombineCheck(cmp_ret, ret); } @@ -606,60 +602,58 @@ class StructInfoBaseChecker } }; -BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived) { +BaseCheckResult TypeBaseCheck(const Type& base, const Type& derived) { arith::Analyzer analyzer; - return StructInfoBaseCheck(base, derived, analyzer); + return TypeBaseCheck(base, derived, analyzer); } -BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, - const arith::Analyzer& ana) { - return StructInfoBaseChecker(ana.get())(base, derived); +BaseCheckResult TypeBaseCheck(const Type& base, const Type& derived, const arith::Analyzer& ana) { + return TypeBaseChecker(ana.get())(base, derived); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.analysis.StructInfoBaseCheck", - [](const StructInfo& base, const StructInfo& derived) -> int { - return static_cast(StructInfoBaseCheck(base, derived)); + refl::GlobalDef().def("relax.analysis.TypeBaseCheck", + [](const Type& base, const Type& derived) -> int { + return static_cast(TypeBaseCheck(base, derived)); }); } -bool IsBaseOf(const StructInfo& base, const StructInfo& derived) { +bool IsBaseOf(const Type& base, const Type& derived) { arith::Analyzer analyzer; return IsBaseOf(base, derived, analyzer); } -bool IsBaseOf(const StructInfo& base, const StructInfo& derived, const arith::Analyzer& ana) { - return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; +bool IsBaseOf(const Type& base, const Type& derived, const arith::Analyzer& ana) { + return TypeBaseCheck(base, derived, ana) == BaseCheckResult::kPass; } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "relax.StructInfoIsBaseOf", - [](const StructInfo& base, const StructInfo& derived) { return IsBaseOf(base, derived); }); + refl::GlobalDef().def("relax.TypeIsBaseOf", [](const Type& base, const Type& derived) { + return IsBaseOf(base, derived); + }); } -class StructInfoBasePreconditionCollector - : public StructInfoFunctor { +class TypeBasePreconditionCollector : public TypeFunctor { public: - explicit StructInfoBasePreconditionCollector() {} + explicit TypeBasePreconditionCollector() {} - PrimExpr VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { + PrimExpr VisitType(const Type& lhs, const Type& other) override { if (lhs.same_as(other)) { - // Early bail-out if the StructInfo has reference equality. + // Early bail-out if the Type has reference equality. return IntImm::Bool(true); } else { - return StructInfoFunctor::VisitStructInfo(lhs, other); + return TypeFunctor::VisitType(lhs, other); } } - PrimExpr VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + PrimExpr VisitType_(const ObjectTypeNode* lhs, const Type& other) final { return IntImm::Bool(true); } - PrimExpr VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); + PrimExpr VisitType_(const PrimTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { return IntImm::Bool(false); } @@ -677,8 +671,8 @@ class StructInfoBasePreconditionCollector } } - PrimExpr VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); + PrimExpr VisitType_(const ShapeTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { return IntImm::Bool(false); } @@ -701,8 +695,8 @@ class StructInfoBasePreconditionCollector } } - PrimExpr VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); + PrimExpr VisitType_(const TensorTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { return IntImm::Bool(false); } @@ -752,9 +746,8 @@ class StructInfoBasePreconditionCollector return IntImm::Bool(true); } - PrimExpr VisitStructInfo_(const distributed::DTensorStructInfoNode* lhs, - const StructInfo& other) final { - auto* rhs = other.as(); + PrimExpr VisitType_(const distributed::DTensorTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { return IntImm::Bool(false); } @@ -765,19 +758,19 @@ class StructInfoBasePreconditionCollector return IntImm::Bool(false); } - return this->VisitStructInfo(lhs->tensor_sinfo, rhs->tensor_sinfo); + return this->VisitType(lhs->tensor_ty, rhs->tensor_ty); } - PrimExpr VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); + PrimExpr VisitType_(const TupleTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); if (rhs == nullptr) { return IntImm::Bool(false); } return ArrayCheck(lhs->fields, rhs->fields); } - PrimExpr VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& other) override { - auto* rhs = other.as(); + PrimExpr VisitType_(const FuncTypeNode* lhs, const Type& other) override { + auto* rhs = other.as(); if (rhs == nullptr) { return IntImm::Bool(false); } @@ -794,7 +787,7 @@ class StructInfoBasePreconditionCollector return IntImm::Bool(false); } - PrimExpr all_match = VisitStructInfo(lhs->ret, rhs->ret); + PrimExpr all_match = VisitType(lhs->ret, rhs->ret); PrimExpr param_check; if (lhs->params.defined()) { @@ -803,7 +796,7 @@ class StructInfoBasePreconditionCollector param_check = IntImm::Bool(true); } - PrimExpr ret_check = VisitStructInfo(lhs->ret, rhs->ret); + PrimExpr ret_check = VisitType(lhs->ret, rhs->ret); return param_check && ret_check; } @@ -821,7 +814,7 @@ class StructInfoBasePreconditionCollector return all_equal; } - PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { + PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return IntImm::Bool(false); } @@ -829,33 +822,33 @@ class StructInfoBasePreconditionCollector PrimExpr all_pass = IntImm::Bool(true); for (size_t i = 0; i < lhs.size(); ++i) { - all_pass = all_pass && VisitStructInfo(lhs[i], rhs[i]); + all_pass = all_pass && VisitType(lhs[i], rhs[i]); } return all_pass; } }; -PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const StructInfo& derived) { - StructInfoBasePreconditionCollector visitor; +PrimExpr TypeBaseCheckPrecondition(const Type& base, const Type& derived) { + TypeBasePreconditionCollector visitor; return visitor(base, derived); } //-------------------------- -// DeriveStructInfo +// DeriveType //-------------------------- -// NOTE: we are reusing StructInfoBaseChecker here to populate a mapping +// NOTE: we are reusing TypeBaseChecker here to populate a mapping // from the expressions in arg(rhs) to var in param. -class CallRetStructInfoDeriver : public StructInfoBaseChecker { +class CallRetTypeDeriver : public TypeBaseChecker { public: - explicit CallRetStructInfoDeriver(arith::AnalyzerObj* ana) : StructInfoBaseChecker(ana) {} + explicit CallRetTypeDeriver(arith::AnalyzerObj* ana) : TypeBaseChecker(ana) {} // No short cut, so we can recursively populate all pairs. - BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { - return StructInfoFunctor::VisitStructInfo(lhs, other); + BaseCheckResult VisitType(const Type& lhs, const Type& other) final { + return TypeFunctor::VisitType(lhs, other); } - StructInfo Derive(const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + Type Derive(const FuncType& finfo, const Call& call, const BlockBuilder& ctx) { // opaque derivation if (finfo->IsOpaque()) { if (finfo->derive_func.defined()) { @@ -872,22 +865,22 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { if (params.size() != call->args.size()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Number of arguments and parameters mismatch:" - << " Function " << call->op << " has struct info " << finfo << " and accepts " - << params.size() << " parameters, but was called with " << call->args.size() - << " arguments (" << call->args << ")"; + << " Function " << call->op << " has type " << finfo << " and accepts " << params.size() + << " parameters, but was called with " << call->args.size() << " arguments (" + << call->args << ")"; } // Visit each param arg pair, check and populate the var map for (size_t i = 0; i < params.size(); ++i) { TVM_FFI_VISIT_BEGIN(); - auto arg_sinfo = GetStructInfo(call->args[i]); - BaseCheckResult res = this->VisitStructInfo(params[i], arg_sinfo); + auto arg_ty = GetType(call->args[i]); + BaseCheckResult res = this->VisitType(params[i], arg_ty); // Report error if we find L1 level failure // L2 level is best effort so we don't report. // The behavior of L2 can be customized later. if (res == BaseCheckResult::kFailL0 || res == BaseCheckResult::kFailL1) { TVM_FFI_VISIT_THROW(ValueError, call->args[i]) << "Argument " << i << " type mismatch:" - << " expected " << params[i] << ", given " << arg_sinfo; + << " expected " << params[i] << ", given " << arg_ty; } TVM_FFI_VISIT_END(call->args[i]); } @@ -902,12 +895,12 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { ffi::Map shape_var_map_; ffi::Map var_map_; - using StructInfoBaseChecker::ShapeMatchCheck; + using TypeBaseChecker::ShapeMatchCheck; // Match shape values in between param(lhs) and arg(rhs) BaseCheckResult PrimValueMatchCheck(const PrimExpr& param, const PrimExpr& arg) final { if (!populate_mapping_) { - return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + return TypeBaseChecker::PrimValueMatchCheck(param, arg); } if (auto* ptr = param.as()) { @@ -928,13 +921,13 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { // Do not attempt to do prove when param contains a symbolic expr. // such expression might depends on a later defined var in params created by dyn fusion. // example: f(a: Tensor[(n+1)], s: Shape[(n,)]), the (n+1) case here. - return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + return TypeBaseChecker::PrimValueMatchCheck(param, arg); } } BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) final { if (!populate_mapping_) { - return StructInfoBaseChecker::ShapeMatchCheck(lhs, rhs); + return TypeBaseChecker::ShapeMatchCheck(lhs, rhs); } if (auto* ptr = lhs.as()) { @@ -960,8 +953,7 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); } - BaseCheckResult FuncParamsCheck(const ffi::Array& lhs, - const ffi::Array& rhs) final { + BaseCheckResult FuncParamsCheck(const ffi::Array& lhs, const ffi::Array& rhs) final { // Set populate mapping to false // so we do not pick up symbolic vars in params with function type. // @@ -974,60 +966,58 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { // pick up n in g's signature. bool populate_mapping = false; std::swap(populate_mapping_, populate_mapping); - auto ret = StructInfoBaseChecker::FuncParamsCheck(lhs, rhs); + auto ret = TypeBaseChecker::FuncParamsCheck(lhs, rhs); std::swap(populate_mapping_, populate_mapping); return ret; } }; -StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, - const BlockBuilder& ctx) { +Type DeriveCallRetType(const FuncType& finfo, const Call& call, const BlockBuilder& ctx) { arith::Analyzer analyzer; - return DeriveCallRetStructInfo(finfo, call, ctx, analyzer); + return DeriveCallRetType(finfo, call, ctx, analyzer); } -StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, - const BlockBuilder& ctx, const arith::Analyzer& ana) { +Type DeriveCallRetType(const FuncType& finfo, const Call& call, const BlockBuilder& ctx, + const arith::Analyzer& ana) { // The deriver's TVM_FFI_VISIT_THROW seeds a VisitErrorContext on the error; // the outer pass wrapper catches it and enriches the message with the access // path. Nothing to do here but propagate. - return CallRetStructInfoDeriver(ana.get()).Derive(finfo, call, ctx); + return CallRetTypeDeriver(ana.get()).Derive(finfo, call, ctx); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.analysis.DeriveCallRetStructInfo", - [](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { - return DeriveCallRetStructInfo(finfo, call, ctx); + refl::GlobalDef().def("relax.analysis.DeriveCallRetType", + [](const FuncType& finfo, const Call& call, const BlockBuilder& ctx) { + return DeriveCallRetType(finfo, call, ctx); }); } //-------------------------- // UnifyToLCA //-------------------------- -class StructInfoLCAFinder - : public StructInfoFunctor { +class TypeLCAFinder : public TypeFunctor { public: - explicit StructInfoLCAFinder(arith::AnalyzerObj* ana) : analyzer_(ana) {} + explicit TypeLCAFinder(arith::AnalyzerObj* ana) : analyzer_(ana) {} - StructInfo VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + Type VisitType(const Type& lhs, const Type& other) final { // quick path if (lhs.same_as(other)) return lhs; - return StructInfoFunctor::VisitStructInfo(lhs, other); + return TypeFunctor::VisitType(lhs, other); } // ffi::Object is based of everything, unify to object. - StructInfo VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { - return ffi::GetRef(lhs); + Type VisitType_(const ObjectTypeNode* lhs, const Type& other) final { + return ffi::GetRef(lhs); } - StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); - if (rhs == nullptr) return ObjectStructInfo(lhs->span); + Type VisitType_(const PrimTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectType(lhs->span); if (lhs->dtype != rhs->dtype) { // PrimType will be treated as their boxed(object) values // as a result we can unify to object. - return ObjectStructInfo(lhs->span); + return ObjectType(lhs->span); } if (!lhs->value.defined() || !rhs->value.defined() || !analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) { @@ -1036,18 +1026,18 @@ class StructInfoLCAFinder if (!lhs->value.defined()) { // If the mismatch was due to extra information in the RHS, // prefer to avoid constructing a new object. - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } else { - return PrimStructInfo(lhs->dtype, lhs->span); + return PrimType(lhs->dtype, lhs->span); } } - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } - StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); - if (rhs == nullptr) return ObjectStructInfo(lhs->span); + Type VisitType_(const ShapeTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectType(lhs->span); int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; if (lhs->ndim != rhs->ndim || !lhs->values.defined() || !rhs->values.defined() || @@ -1055,18 +1045,18 @@ class StructInfoLCAFinder ffi::GetRef(analyzer_))) { // prefers return same when possible if (!lhs->values.defined() && lhs->ndim == ndim) { - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } else { - return ShapeStructInfo(ndim, lhs->span); + return ShapeType(ndim, lhs->span); } } // equals to each other - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } - StructInfo VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); - if (rhs == nullptr) return ObjectStructInfo(lhs->span); + Type VisitType_(const TensorTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectType(lhs->span); // find the target dtype, ndim, and vdevice. DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void(); @@ -1084,37 +1074,37 @@ class StructInfoLCAFinder // reuse lhs when possible if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim && (!lhs->vdevice.defined() || vdev.defined())) { - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } else { - return TensorStructInfo(dtype, ndim, vdev, lhs->span); + return TensorType(dtype, ndim, vdev, lhs->span); } } // symbolic shape and vdevice match but dtype mismatch if (lhs->dtype != dtype || (lhs->vdevice.defined() && !vdev.defined())) { - return TensorStructInfo(lhs->shape.value(), dtype, vdev, lhs->span); + return TensorType(lhs->shape.value(), dtype, vdev, lhs->span); } else { - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } } - StructInfo VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); - if (rhs == nullptr) return ObjectStructInfo(lhs->span); - ffi::Optional> fields = UnifyArray(lhs->fields, rhs->fields); + Type VisitType_(const TupleTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectType(lhs->span); + ffi::Optional> fields = UnifyArray(lhs->fields, rhs->fields); // tuple length not the same. - if (!fields.defined()) return ObjectStructInfo(lhs->span); + if (!fields.defined()) return ObjectType(lhs->span); // same length tuple. if (!fields.same_as(lhs->fields)) { - return TupleStructInfo(fields.value(), lhs->span); + return TupleType(fields.value(), lhs->span); } else { - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } } - StructInfo VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& other) final { - auto* rhs = other.as(); - if (rhs == nullptr) return ObjectStructInfo(lhs->span); + Type VisitType_(const FuncTypeNode* lhs, const Type& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectType(lhs->span); // the unified function is pure only if both are pure bool purity = lhs->purity && rhs->purity; @@ -1123,24 +1113,24 @@ class StructInfoLCAFinder if (lhs->IsOpaque()) { if (lhs->derive_func.defined()) { if (lhs->derive_func.same_as(rhs->derive_func)) { - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } else { // Create a new opaque with object return - return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), purity, lhs->span); + return FuncType::OpaqueFunc(ObjectType(), purity, lhs->span); } } else { // no derivation function, only depends on ret - StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); - if (ret.same_as(lhs->ret)) return ffi::GetRef(lhs); - return FuncStructInfo::OpaqueFunc(ret, purity, lhs->span); + Type ret = this->VisitType(lhs->ret, rhs->ret); + if (ret.same_as(lhs->ret)) return ffi::GetRef(lhs); + return FuncType::OpaqueFunc(ret, purity, lhs->span); } } // rhs is opaque, lhs is not if (rhs->IsOpaque()) { // unify ret value, note that rhs's ret is context free(because it is opaque) // so result of the unify is also context-free. - StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); - return FuncStructInfo::OpaqueFunc(ret, purity, lhs->span); + Type ret = this->VisitType(lhs->ret, rhs->ret); + return FuncType::OpaqueFunc(ret, purity, lhs->span); } // Both lhs and rhs are not opaque @@ -1158,21 +1148,21 @@ class StructInfoLCAFinder // // Given we only do best effort checking in these cases, and such cases // are likely not a primary concern atm, we take this approach here. - if (struct_equal_(ffi::GetRef(lhs), ffi::GetRef(rhs))) { - return ffi::GetRef(lhs); + if (struct_equal_(ffi::GetRef(lhs), ffi::GetRef(rhs))) { + return ffi::GetRef(lhs); } auto params = UnifyArray(lhs->params.value(), rhs->params.value()); - auto ret = this->VisitStructInfo(lhs->ret, rhs->ret); + auto ret = this->VisitType(lhs->ret, rhs->ret); if (params.same_as(lhs->params) && ret.same_as(lhs->ret)) { - return ffi::GetRef(lhs); + return ffi::GetRef(lhs); } else { // fail to unify the params if (!params.defined()) { - return FuncStructInfo::OpaqueFunc(ret, purity, lhs->span); + return FuncType::OpaqueFunc(ret, purity, lhs->span); } else { - return FuncStructInfo(params.value(), ret, purity, lhs->span); + return FuncType(params.value(), ret, purity, lhs->span); } } } @@ -1184,36 +1174,35 @@ class StructInfoLCAFinder ffi::StructuralEqual struct_equal_; // check arrays - ffi::Optional> UnifyArray(const ffi::Array& lhs, - const ffi::Array& rhs) { + ffi::Optional> UnifyArray(const ffi::Array& lhs, + const ffi::Array& rhs) { if (lhs.same_as(rhs)) return lhs; if (lhs.size() != rhs.size()) return std::nullopt; size_t index = 0; - return lhs.Map([&](const StructInfo& a) { return this->VisitStructInfo(a, rhs[index++]); }); + return lhs.Map([&](const Type& a) { return this->VisitType(a, rhs[index++]); }); } }; -StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs) { +Type TypeLCA(const Type& lhs, const Type& rhs) { arith::Analyzer analyzer; - return StructInfoLCA(lhs, rhs, analyzer); + return TypeLCA(lhs, rhs, analyzer); } -StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, const arith::Analyzer& ana) { - return StructInfoLCAFinder(ana.get())(lhs, rhs); +Type TypeLCA(const Type& lhs, const Type& rhs, const arith::Analyzer& ana) { + return TypeLCAFinder(ana.get())(lhs, rhs); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "relax.analysis.StructInfoLCA", - [](const StructInfo& lhs, const StructInfo& rhs) { return StructInfoLCA(lhs, rhs); }); + refl::GlobalDef().def("relax.analysis.TypeLCA", + [](const Type& lhs, const Type& rhs) { return TypeLCA(lhs, rhs); }); } //-------------------------- -// TIRVarsInStructInfo +// TIRVarsInType //-------------------------- -class TIRVarsDetector : public StructInfoVisitor { +class TIRVarsDetector : public TypeVisitor { public: enum class VarType { Definition, @@ -1245,21 +1234,21 @@ class TIRVarsDetector : public StructInfoVisitor { } } - void VisitStructInfo_(const PrimStructInfoNode* prim_sinfo) final { - if (prim_sinfo->value.defined()) { - VisitPrimExpr(prim_sinfo->value.value()); + void VisitType_(const PrimTypeNode* prim_ty) final { + if (prim_ty->value.defined()) { + VisitPrimExpr(prim_ty->value.value()); } } - void VisitStructInfo_(const ShapeStructInfoNode* shape_sinfo) final { - if (shape_sinfo->values.defined()) { - VisitShape(shape_sinfo->values.value()); + void VisitType_(const ShapeTypeNode* shape_ty) final { + if (shape_ty->values.defined()) { + VisitShape(shape_ty->values.value()); } } - void VisitStructInfo_(const TensorStructInfoNode* tensor_sinfo) final { - if (tensor_sinfo->shape.defined()) { - VisitStructInfo(GetStructInfo(tensor_sinfo->shape.value())); + void VisitType_(const TensorTypeNode* tensor_ty) final { + if (tensor_ty->shape.defined()) { + VisitType(GetType(tensor_ty->shape.value())); } } @@ -1276,48 +1265,48 @@ class TIRVarsDetector : public StructInfoVisitor { VarType collection_type; }; -ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo) { +ffi::Array TIRVarsInType(const Type& ty) { TIRVarsDetector detector(TIRVarsDetector::VarType::Usage); - detector(sinfo); + detector(ty); return detector.GetTIRVars(); } -ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { +ffi::Array DefinableTIRVarsInType(const Type& ty) { TIRVarsDetector detector(TIRVarsDetector::VarType::Definition); - detector(sinfo); + detector(ty); return detector.GetTIRVars(); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("relax.analysis.TIRVarsInStructInfo", TIRVarsInStructInfo) - .def("relax.analysis.DefinableTIRVarsInStructInfo", DefinableTIRVarsInStructInfo); + .def("relax.analysis.TIRVarsInType", TIRVarsInType) + .def("relax.analysis.DefinableTIRVarsInType", DefinableTIRVarsInType); } -class NonNegativeExpressionCollector : relax::StructInfoVisitor { +class NonNegativeExpressionCollector : relax::TypeVisitor { public: - static ffi::Array Collect(const StructInfo& sinfo) { + static ffi::Array Collect(const Type& ty) { NonNegativeExpressionCollector visitor; - visitor(sinfo); + visitor(ty); return visitor.expressions_; } private: - void VisitStructInfo_(const TensorStructInfoNode* op) override { + void VisitType_(const TensorTypeNode* op) override { if (op->shape.defined()) { - VisitStructInfo(GetStructInfo(op->shape.value())); + VisitType(GetType(op->shape.value())); } } - void VisitStructInfo_(const PrimStructInfoNode* op) override { - // Unlike the expressions in TensorStructInfo or ShapeStructInfo, - // PrimStructInfo may contain negative values. This override - // prevents calling VisitStructInfoExprField from the default - // StructInfoVisitor implementation. + void VisitType_(const PrimTypeNode* op) override { + // Unlike the expressions in TensorType or ShapeType, + // PrimType may contain negative values. This override + // prevents calling VisitTypeExprField from the default + // TypeVisitor implementation. } - void VisitStructInfoExprField(const PrimExpr& size_expr) override { + void VisitTypeExprField(const PrimExpr& size_expr) override { if (auto size_int = size_expr.as(); size_int && size_int->value >= 0) { // Avoid cluttering the result with non-negative integers return; @@ -1333,8 +1322,8 @@ class NonNegativeExpressionCollector : relax::StructInfoVisitor { std::unordered_set dedup_lookup_; }; -ffi::Array CollectNonNegativeExpressions(const StructInfo& sinfo) { - return NonNegativeExpressionCollector::Collect(sinfo); +ffi::Array CollectNonNegativeExpressions(const Type& ty) { + return NonNegativeExpressionCollector::Collect(ty); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1344,7 +1333,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } class SymbolicVarCollector : public relax::ExprVisitor, - public relax::StructInfoVisitor, + public relax::TypeVisitor, public tirx::ExprVisitor { public: static ffi::Array Free(const Expr& expr) { @@ -1392,13 +1381,13 @@ class SymbolicVarCollector : public relax::ExprVisitor, void VisitExpr_(const FunctionNode* op) final { WithMode(VisitMode::kProvideDefinition, [&]() { for (Var param : op->params) { - relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param)); + relax::TypeVisitor::VisitType(GetType(param)); } }); WithMode(VisitMode::kRequireDefinition, [&]() { for (Var param : op->params) { - relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param)); + relax::TypeVisitor::VisitType(GetType(param)); } }); @@ -1407,49 +1396,47 @@ class SymbolicVarCollector : public relax::ExprVisitor, void VisitBinding_(const MatchCastNode* binding) final { WithMode(VisitMode(VisitMode::kProvideDefinition | VisitMode::kRequireDefinition), - [&]() { this->VisitStructInfo(binding->struct_info); }); + [&]() { this->VisitType(binding->ty); }); relax::ExprVisitor::VisitBinding_(binding); } - void VisitExprDepStructInfoField(const StructInfo& struct_info) { - return this->VisitStructInfo(struct_info); - } + void VisitExprDepTypeField(const Type& ty) { return this->VisitType(ty); } - void VisitStructInfo_(const FuncStructInfoNode* op) final { + void VisitType_(const FuncTypeNode* op) final { if (op->params.defined()) { // Visit the parameters once to collect bindings, and another // time to collect usages. Otherwise, a symbolic variable // defined by a later parameter may be treated as undefined when // used by an earlier parameter. WithMode(VisitMode::kProvideDefinition, [&]() { - for (StructInfo param : op->params.value()) { - this->VisitStructInfo(param); + for (Type param : op->params.value()) { + this->VisitType(param); } }); WithMode(VisitMode::kRequireDefinition, [&]() { - for (StructInfo param : op->params.value()) { - this->VisitStructInfo(param); + for (Type param : op->params.value()) { + this->VisitType(param); } }); } - this->VisitStructInfo(op->ret); + this->VisitType(op->ret); } - void VisitStructInfoExprField(const Expr& expr) final { + void VisitTypeExprField(const Expr& expr) final { relax::ExprVisitor::VisitExpr(expr); if (auto* shape = expr.as()) { for (const auto& val : shape->values) { - this->VisitStructInfoExprField(val); + this->VisitTypeExprField(val); } } if (auto prim_value = expr.as()) { - this->VisitStructInfoExprField(prim_value.value()->value); + this->VisitTypeExprField(prim_value.value()->value); } } - void VisitStructInfoExprField(const PrimExpr& expr) final { + void VisitTypeExprField(const PrimExpr& expr) final { if (mode_ & VisitMode::kProvideDefinition) { if (auto var = expr.as()) { defined_symbolic_var_.insert(var.value()); diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 19e8a2fbbfc4..cecc85dcb3c6 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -26,8 +26,8 @@ * with the offending node so the caller can resolve a precise access path. * Use `check_well_formed` for a boolean answer. * This pass will check: - * 1. Each Expr should have `struct_info_` field already populated, when - * `check_struct_info` is true. + * 1. Each Expr should have `ty` field already populated, when + * `check_ty` is true. * 2. GlobalVars are defined before use. And all GlobalVars have different names. * 3. When a Function has a corresponding GlobalVar and a `global_symbol` * attribute, the name of the GlobalVar must equal the value of the @@ -57,12 +57,12 @@ * * The cond field of If nodes * * The op or args fields of Call nodes * * Inside the fields of Tuple nodes - * 13. Expr always has struct_info_ (with the exception of Op). + * 13. Expr always has ty (with the exception of Op). * 14. DataflowBlocks may not contain If nodes. * 15. DataflowBlocks may not contain calls to impure functions or operators - * (only checked if check_struct_info is true). + * (only checked if check_ty is true). * 16. If a function has is_pure set to true and the kForcePure attribute is not set, - * the body may not contain any impure call (only checked if check_struct_info is true). + * the body may not contain any impure call (only checked if check_ty is true). * 17. If the kForcePure attribute is set for a function, * that function's is_pure field must be true. */ @@ -73,7 +73,7 @@ #include #include #include -#include +#include #include #include #include @@ -91,15 +91,14 @@ namespace relax { // /*! \brief Helper to implement well formed check.*/ class WellFormedChecker : public relax::ExprVisitor, - public relax::StructInfoVisitor, + public relax::TypeVisitor, public tirx::ExprVisitor { public: // Throws ffi::Error on the first well-formedness violation, seeded with the // offending node so the caller can resolve an access path. Returns normally // when the object is well-formed. - static void Check(ffi::Variant obj, bool check_struct_info) { - WellFormedChecker well_formed_checker = - WellFormedChecker(obj.as(), check_struct_info); + static void Check(ffi::Variant obj, bool check_ty) { + WellFormedChecker well_formed_checker = WellFormedChecker(obj.as(), check_ty); if (const auto* mod = obj.as()) { for (const auto& it : mod->functions) { @@ -120,8 +119,8 @@ class WellFormedChecker : public relax::ExprVisitor, } private: - WellFormedChecker(ffi::Optional mod, bool check_struct_info) - : mod_(std::move(mod)), check_struct_info_(check_struct_info), cur_visited_func_(nullptr) {} + WellFormedChecker(ffi::Optional mod, bool check_ty) + : mod_(std::move(mod)), check_ty(check_ty), cur_visited_func_(nullptr) {} using relax::ExprVisitor::VisitExpr_; using tirx::ExprVisitor::VisitExpr; @@ -163,8 +162,8 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr(const Expr& expr) final { - if (!expr.as() && !expr->struct_info_.defined()) { - TVM_FFI_VISIT_THROW(TypeError, expr) << "The struct_info_ of Expr " << expr << " is nullptr."; + if (!expr.as() && !expr->ty.defined()) { + TVM_FFI_VISIT_THROW(TypeError, expr) << "The ty of Expr " << expr << " is nullptr."; } relax::ExprVisitor::VisitExpr(expr); } @@ -179,15 +178,14 @@ class WellFormedChecker : public relax::ExprVisitor, } } - if (op->struct_info_.defined()) { - if (!op->struct_info_->IsInstance()) { + if (op->ty.defined()) { + if (!op->ty->IsInstance()) { TVM_FFI_VISIT_THROW(TypeError, var) - << "The struct_info_ of GlobalVar " << ffi::GetRef(op) - << " must be either FuncStructInfo."; + << "The ty of GlobalVar " << ffi::GetRef(op) << " must be either FuncType."; } } - CheckStructInfo(op); + CheckType(op); } void VisitExpr_(const TupleNode* op) final { @@ -202,7 +200,7 @@ class WellFormedChecker : public relax::ExprVisitor, } } - CheckStructInfo(op); + CheckType(op); TVM_FFI_VISIT_END(ffi::GetRef(op)); } @@ -213,7 +211,7 @@ class WellFormedChecker : public relax::ExprVisitor, TVM_FFI_VISIT_THROW(TypeError, ffi::GetRef(op)) << "The tuple value in a TupleGetItem node must be a leaf expression."; } - CheckStructInfo(op); + CheckType(op); } void VisitExpr_(const VarNode* op) final { @@ -221,7 +219,7 @@ class WellFormedChecker : public relax::ExprVisitor, if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) { TVM_FFI_VISIT_THROW(ValueError, var) << "Var " << ffi::GetRef(op) << " is not defined."; } - CheckStructInfo(op); + CheckType(op); } void VisitExpr_(const DataflowVarNode* op) final { @@ -234,7 +232,7 @@ class WellFormedChecker : public relax::ExprVisitor, TVM_FFI_VISIT_THROW(ValueError, var) << "DataflowVar " << ffi::GetRef(op) << " is not defined."; } - CheckStructInfo(op); + CheckType(op); } void VisitExpr_(const FunctionNode* op) final { @@ -258,7 +256,7 @@ class WellFormedChecker : public relax::ExprVisitor, WithMode(VisitMode::kMatchVarDef, [&]() { TVM_FFI_ICHECK(mode_ == VisitMode::kMatchVarDef); for (Var param : op->params) { - relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param)); + relax::TypeVisitor::VisitType(GetType(param)); } }); @@ -282,18 +280,16 @@ class WellFormedChecker : public relax::ExprVisitor, } param_var_func_map_.insert({param, cur_visited_func_}); } - // check function ret_struct_info - if (op->ret_struct_info.defined()) { - this->VisitStructInfo(op->ret_struct_info); + // check function ret_ty + if (op->ret_ty.defined()) { + this->VisitType(op->ret_ty); } else { - TVM_FFI_VISIT_THROW(TypeError, ffi::GetRef(op)) - << "Function must have defined ret_struct_info"; + TVM_FFI_VISIT_THROW(TypeError, ffi::GetRef(op)) << "Function must have defined ret_ty"; } // if we are not forcing purity and the function is annotated as pure, it must not contain an // impure call - if (check_struct_info_ && !op->GetAttr(relax::attr::kForcePure).value_or(false) && - op->is_pure) { + if (check_ty && !op->GetAttr(relax::attr::kForcePure).value_or(false) && op->is_pure) { if (auto impure = FindImpureCall(op->body)) { TVM_FFI_VISIT_THROW(ValueError, ffi::GetRef(op)) << "Function " << op << " is annotated as pure but contains an impure call: " << impure @@ -337,12 +333,12 @@ class WellFormedChecker : public relax::ExprVisitor, } } - for (const StructInfo& sinfo_arg : call->sinfo_args) { - this->VisitStructInfo(sinfo_arg); + for (const Type& ty_arg : call->ty_args) { + this->VisitType(ty_arg); } - CheckStructInfo(call); - if (is_dataflow_ && check_struct_info_) { + CheckType(call); + if (is_dataflow_ && check_ty) { if (auto impure = FindImpureCall(ffi::GetRef(call))) { TVM_FFI_VISIT_THROW(ValueError, ffi::GetRef(call)) << "Impure function call " << impure << " occurs within a dataflow block."; @@ -387,45 +383,44 @@ class WellFormedChecker : public relax::ExprVisitor, } } - if (check_struct_info_ && call->struct_info_.defined()) { - // The `InferStructInfo` method isn't currently exposed by the + if (check_ty && call->ty.defined()) { + // The `InferType` method isn't currently exposed by the // Normalizer, and can only be called indirectly by normalizing - // an expression that does not yet have `StructInfo`. + // an expression that does not yet have `Type`. auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); - Call copied(call->op, call->args, call->attrs, call->sinfo_args); + Call copied(call->op, call->args, call->attrs, call->ty_args); ffi::Optional normalized = std::nullopt; try { normalized = dummy_builder->Normalize(copied); } catch (std::exception& err) { TVM_FFI_VISIT_THROW(TypeError, ffi::GetRef(call)) - << "Each Relax expression must be able to have its StructInfo inferred. " - << "However, inferring the struct info of expression " << ffi::GetRef(call) + << "Each Relax expression must be able to have its Type inferred. " + << "However, inferring the type of expression " << ffi::GetRef(call) << " resulted in the error: \n" << err.what(); } if (normalized.defined()) { - auto inferred_struct_info = GetStructInfo(normalized.value()); - auto current_struct_info = Downcast(call->struct_info_); + auto inferred_ty = GetType(normalized.value()); + auto current_ty = Downcast(call->ty); - // An error should be raised if the annotated StructInfo is + // An error should be raised if the annotated Type is // provably incorrect. This check is done using - // `StructInfoBaseCheck(...) < kFailL1`, because `kFailL1` + // `TypeBaseCheck(...) < kFailL1`, because `kFailL1` // represents cases that are neither provably correct nor // provably incorrect. If this check were replaced with // `!IsBaseOf(...)`, cases that are correct but not provably // so would raise an exception. // - // For example, if a dynamic size in the inferred StructInfo + // For example, if a dynamic size in the inferred Type // is equivalent to the expression used in the annotated - // StructInfo, but the TIR simplifications are not sufficient + // Type, but the TIR simplifications are not sufficient // to prove that the two expressions are equivalent, we should // not raise an error. - if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) < - BaseCheckResult::kFailL1) { + if (TypeBaseCheck(current_ty, inferred_ty) < BaseCheckResult::kFailL1) { TVM_FFI_VISIT_THROW(TypeError, ffi::GetRef(call)) - << "All information in StructInfo annotations must be correct. " + << "All information in Type annotations must be correct. " << "However, while the expression " << ffi::GetRef(call) << " is annotated as " - << current_struct_info << ", the expression outputs " << inferred_struct_info; + << current_ty << ", the expression outputs " << inferred_ty; } } } @@ -454,7 +449,7 @@ class WellFormedChecker : public relax::ExprVisitor, var_set_ = previous_var_set; symbolic_var_set_ = previous_symbolic_var_set; - CheckStructInfo(op); + CheckType(op); TVM_FFI_VISIT_END(ffi::GetRef(op)); } @@ -467,7 +462,7 @@ class WellFormedChecker : public relax::ExprVisitor, << "Shape expressions must be of integer type, but got " << expr.dtype(); } } - CheckStructInfo(op); + CheckType(op); } void VisitExpr_(const SeqExprNode* op) final { @@ -488,7 +483,7 @@ class WellFormedChecker : public relax::ExprVisitor, << "SeqExpr bodies must be leaf expressions."; } this->VisitExpr(op->body); - CheckStructInfo(op); + CheckType(op); TVM_FFI_VISIT_END(ffi::GetRef(op)); } @@ -507,14 +502,13 @@ class WellFormedChecker : public relax::ExprVisitor, this->VisitVarDef(binding->var); - if (check_struct_info_ && binding->var->struct_info_.defined() && - binding->value->struct_info_.defined()) { - auto expr_sinfo = GetStructInfo(binding->value); - auto var_sinfo = GetStructInfo(binding->var); - if (!IsBaseOf(var_sinfo, expr_sinfo)) { + if (check_ty && binding->var->ty.defined() && binding->value->ty.defined()) { + auto expr_ty = GetType(binding->value); + auto var_ty = GetType(binding->var); + if (!IsBaseOf(var_ty, expr_ty)) { TVM_FFI_VISIT_THROW(TypeError, binding->var) - << "Expression of type " << expr_sinfo << " cannot be assigned to a variable of type " - << var_sinfo; + << "Expression of type " << expr_ty << " cannot be assigned to a variable of type " + << var_ty; } } @@ -526,9 +520,9 @@ class WellFormedChecker : public relax::ExprVisitor, void VisitBinding_(const MatchCastNode* binding) final { this->VisitExpr(binding->value); // define the vars - WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitStructInfo(binding->struct_info); }); + WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitType(binding->ty); }); - this->VisitStructInfo(binding->struct_info); + this->VisitType(binding->ty); this->VisitVarDef(binding->var); } @@ -553,7 +547,7 @@ class WellFormedChecker : public relax::ExprVisitor, } // register DataflowVar dataflow_var_set_.insert(lv); - CheckStructInfo(var); + CheckType(var); } void VisitVarDef_(const VarNode* var) final { @@ -563,7 +557,7 @@ class WellFormedChecker : public relax::ExprVisitor, } // register Var var_set_.insert(gv); - CheckStructInfo(var); + CheckType(var); } void VisitExpr_(const tirx::VarNode* op) final { @@ -588,19 +582,19 @@ class WellFormedChecker : public relax::ExprVisitor, symbolic_var_func_map_.insert({var, cur_visited_func_}); } - void VisitStructInfo_(const FuncStructInfoNode* op) final { + void VisitType_(const FuncTypeNode* op) final { if (op->params.defined()) { WithMode(VisitMode::kMatchVarDef, [&]() { TVM_FFI_ICHECK(mode_ == VisitMode::kMatchVarDef); - for (StructInfo param : op->params.value()) { - this->VisitStructInfo(param); + for (Type param : op->params.value()) { + this->VisitType(param); } }); } - this->VisitStructInfo(op->ret); + this->VisitType(op->ret); } - void VisitStructInfoExprField(const Expr& expr) final { + void VisitTypeExprField(const Expr& expr) final { if (mode_ == VisitMode::kMatchVarDef) { // populate symbolic var in first occurrence if (auto* op = expr.as()) { @@ -611,7 +605,7 @@ class WellFormedChecker : public relax::ExprVisitor, } if (auto* shape = expr.as()) { for (auto val : shape->values) { - this->VisitStructInfoExprField(val); + this->VisitTypeExprField(val); } } } else { @@ -619,7 +613,7 @@ class WellFormedChecker : public relax::ExprVisitor, } } - void VisitStructInfoExprField(const PrimExpr& expr) final { + void VisitTypeExprField(const PrimExpr& expr) final { if (mode_ == VisitMode::kMatchVarDef) { // populate symbolic var in first occurrence if (auto* op = expr.as()) { @@ -633,17 +627,18 @@ class WellFormedChecker : public relax::ExprVisitor, } } - void CheckStructInfo(const ExprNode* op) { - if (!check_struct_info_) { + void CheckType(const ExprNode* op) { + if (!check_ty) { return; } - auto* sinfo = op->struct_info_.as(); - if (sinfo != nullptr) { - this->VisitStructInfo(ffi::GetRef(sinfo)); + if (auto* ty = op->ty.as()) { + this->VisitType(ffi::GetRef(ty)); + } else if (auto* ty = op->ty.as()) { + this->VisitType(ffi::GetRef(ty)); } else { TVM_FFI_VISIT_THROW(TypeError, ffi::GetRef(op)) - << "Expr must have struct_info populated. " + << "Expr must have ty populated. " << " Expr.type_key=" << op->GetTypeKey(); } } @@ -657,7 +652,7 @@ class WellFormedChecker : public relax::ExprVisitor, } ffi::Optional mod_; - const bool check_struct_info_; + const bool check_ty; bool is_dataflow_; // Current visited function. const FunctionNode* cur_visited_func_; @@ -677,13 +672,13 @@ class WellFormedChecker : public relax::ExprVisitor, tvm::OpAttrMap op_map_validate_ = Op::GetAttrMap("FValidate"); }; -void WellFormed(ffi::Variant obj, bool check_struct_info) { - WellFormedChecker::Check(obj, check_struct_info); +void WellFormed(ffi::Variant obj, bool check_ty) { + WellFormedChecker::Check(obj, check_ty); } -bool CheckWellFormed(ffi::Variant obj, bool check_struct_info) { +bool CheckWellFormed(ffi::Variant obj, bool check_ty) { try { - WellFormed(obj, check_struct_info); + WellFormed(obj, check_ty); return true; } catch (const ffi::Error&) { return false; @@ -694,9 +689,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.analysis.well_formed", - [](ffi::Variant obj, bool check_struct_info) { - WellFormed(obj, check_struct_info); - }) + [](ffi::Variant obj, bool check_ty) { WellFormed(obj, check_ty); }) .def("relax.analysis.check_well_formed", CheckWellFormed); } diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index 0931eb88337e..b70997c5d7d4 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -22,7 +22,7 @@ * * Texture realization for Adreno GPU targets requires fundamentally follows * Stage 1: Transforming the shapes with inner most dimension being 4 - * Stage 2: Annotate appropriate memory_scope hint in VDevice of StructInfo + * Stage 2: Annotate appropriate memory_scope hint in VDevice of Type * Stage 3: TIR lowering does injects texture load/store builtins looking at this scope * Stage 4: Finally codegen handles appropriate code looking at buffer types and load/store * builtins. @@ -30,7 +30,7 @@ * Stage 1 is generic and straight forward by using convert_layout pass that transforms the * shapes as well as injecting layout_transform ops as needed. * - * Stage 2 This pass is responsible for injeting appropriate VDevice into StructInfo and + * Stage 2 This pass is responsible for injeting appropriate VDevice into Type and * adding any copies if there is a conflict between producer and consuner scopes. * * After convert_layout the mod looks like below @@ -90,10 +90,10 @@ * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): * with R.dataflow(): * lv = R.call_tir(cls.te_layout_transform, (x,), - * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32") + * out_ty=R.Tensor((2, 16, 56, 56, 4), dtype="float32") * ) * lv1 = R.call_tir(cls.te_layout_transform1, (w,), - * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32") + * out_ty=R.Tensor((8, 64, 3, 3, 4), dtype="float32") * ) * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d( * lv, @@ -104,7 +104,7 @@ * out_dtype="float32" * ) * gv = R.call_tir(cls.te_layout_transform2, (lv2,), - * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32") + * out_ty=R.Tensor((2, 32, 54, 54), dtype="float32") * ) * R.output(gv) * return gv @@ -120,13 +120,13 @@ * 2: CollectProducerScopeInfo: Visitor does finalizes the scope for each input and output based * on consumer scope information. It does evaluating mutiple consumer cases and conflicts. * 3: DefineVDevice: Pass does injects hint_on_device for each argument. It also tries to update - * out StructInfo containing VDevice information. This update for tirx calls is straight forward - * as sinfo_args in CallNode is meant for this purpose. This sinfo_args for other calls by - * design is invalid as we do this by "FInferStructInfo". - * Another issue we have with "FInferStructInfo" per op is they can't decide this + * out Type containing VDevice information. This update for tirx calls is straight forward + * as ty_args in CallNode is meant for this purpose. This ty_args for other calls by + * design is invalid as we do this by "FInferType". + * Another issue we have with "FInferType" per op is they can't decide this * memory scope information which is done by this pass based on consumer demand. - * Hence, we are going to use the sinfo_args to indicate this information. - * So, this pass attributes sinfo_args for regumar calls too and FInferStructInfo implmentation + * Hence, we are going to use the ty_args to indicate this information. + * So, this pass attributes ty_args for regumar calls too and FInferType implmentation * do take VDevice information fro this hint. This also solves the issue of mixed VDevice * for arguments of an op. * After these steps the mod looks like @@ -142,7 +142,7 @@ * x, R.device(dev_type=4, dev_id=0), "global" * ) * lv_1 = R.call_tir(cls.te_layout_transform, (lv,), - * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * out_ty=R.Tensor((2, 16, 56, 56, 4), dtype="float32", * vdevice="opencl:0:global.texture-nhwc" * ) * ) @@ -150,7 +150,7 @@ * w, R.device(dev_type=4, dev_id=0), "global" * ) * lv1_1 = R.call_tir(cls.te_layout_transform1, (lv1,), - * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * out_ty=R.Tensor((8, 64, 3, 3, 4), dtype="float32", * vdevice="opencl:2:global.texture-weight" * ) * ) @@ -166,7 +166,7 @@ * lv2, lv3, * data_layout="NCHW4c", kernel_layout="OIHW4o", * out_layout="NCHW4c", out_dtype="float32", - * sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * ty_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", * vdevice="opencl:1:global"), * ) * ) @@ -174,12 +174,12 @@ * vdevice="opencl:1:global" * ) = R.hint_on_device(lv2_1, R.device(dev_type=4, dev_id=0), "global") * gv = R.call_tir(cls.te_layout_transform2, (lv4,), - * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") + * out_ty=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") * ) * R.output(gv) * return gv * - * What we have above is hint_on_device injections and out_sinfo for all calls. + * What we have above is hint_on_device injections and out_ty for all calls. * Now, we apply RealizeVDevice to formalize the hints. Follwed by we also call * CanonicalizeBindings that removes redundant assignments like * @@ -199,12 +199,12 @@ * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): * with R.dataflow(): * lv = R.call_tir(cls.te_layout_transform, (x,), - * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * out_ty=R.Tensor((2, 16, 56, 56, 4), dtype="float32", * vdevice="opencl:0:global.texture-nhwc" * ) * ) * lv1 = R.call_tir(cls.te_layout_transform1, (w,), - * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * out_ty=R.Tensor((8, 64, 3, 3, 4), dtype="float32", * vdevice="opencl:2:global.texture-weight" * ) * ) @@ -214,18 +214,18 @@ * lv2, lv3, * data_layout="NCHW4c", kernel_layout="OIHW4o", * out_layout="NCHW4c", out_dtype="float32", - * sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * ty_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", * vdevice="opencl:1:global"), * ) * ) * gv = R.call_tir(cls.te_layout_transform2, (lv4,), - * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") + * out_ty=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") * ) * R.output(gv) * return gv * * Followed by, the compilation pipeline calls - * - legalization of the remainng ops: This legalization do forwards the annotated out_sinfo + * - legalization of the remainng ops: This legalization do forwards the annotated out_ty * VDevice information to tir_calls * - AnnotateTIROpPattern : TIROp Patterns for newly legalizes ops * - Fusion @@ -258,8 +258,8 @@ namespace adreno { using tvm::tirx::Buffer; -static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { - auto shape = tensor_sinfo->GetShape(); +static ffi::Array GetShapeFromTensorType(const TensorType& tensor_ty) { + auto shape = tensor_ty->GetShape(); TVM_FFI_ICHECK(shape.defined()); return shape.value(); } @@ -357,12 +357,11 @@ class CollectConsumerScopeInfo : public ExprVisitor { ffi::Array arg_scope; for (uint32_t i = 0; i < func_args->fields.size(); ++i) { - auto sinfo = GetStructInfo(func_args->fields[i]); - if (auto tensor_sinfo = sinfo.as()) { + auto ty = GetType(func_args->fields[i]); + if (auto tensor_ty = ty.as()) { bool is_texture = i < is_texture_supported.size() ? is_texture_supported[i] : is_texture_supported[0]; - auto scope = - is_texture ? Scope(GetShapeFromTensorStructInfo(tensor_sinfo.value())) : "global"; + auto scope = is_texture ? Scope(GetShapeFromTensorType(tensor_ty.value())) : "global"; ffi::Map> ent_call; const VarNode* arg_var = func_args->fields[i].as(); if (scope_info.find(ffi::GetRef(arg_var)) != scope_info.end()) { @@ -470,7 +469,7 @@ class CollectConsumerScopeInfo : public ExprVisitor { /* * \brief producer scope information consolidated based on consumer demands. - * \return producer_info which is a map of each call node and corresponding out StructInfo + * \return producer_info which is a map of each call node and corresponding out Type * This pass considers all consumers and their scope demand. * Any mismatches here introduces copies as needed. */ @@ -478,7 +477,7 @@ class CollectProducerScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; - ffi::Map Collect( + ffi::Map Collect( const IRModule& mod, Function func, const ffi::Map>>& scope_info, const Target& target, const BlockBuilder& builder) { @@ -488,26 +487,25 @@ class CollectProducerScopeInfo : public ExprVisitor { builder_ = builder; VisitExpr(func->body); - return producer_sinfo; + return producer_ty; } void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { ExprVisitor::VisitBinding_(binding, call); static const Op& call_tir_op = Op::Get("relax.call_tir"); - StructInfo out_sinfo; + Type out_ty; if (call->op == call_tir_op) { - out_sinfo = call->sinfo_args[0]; + out_ty = call->ty_args[0]; } else { - tvm::OpAttrMap op_map_infer_struct_info_ = - Op::GetAttrMap("FInferStructInfo"); + tvm::OpAttrMap op_map_infer_ty = Op::GetAttrMap("FInferType"); auto* op_ptr = call->op.as(); Op op = ffi::GetRef(op_ptr); - TVM_FFI_ICHECK(op_map_infer_struct_info_.count(op)) - << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; - out_sinfo = op_map_infer_struct_info_[op](ffi::GetRef(call), builder_); + TVM_FFI_ICHECK(op_map_infer_ty.count(op)) + << " Cannot find the FInferType attribute registered to op: " << op->name; + out_ty = op_map_infer_ty[op](ffi::GetRef(call), builder_); } std::unordered_map scope_count; @@ -534,41 +532,40 @@ class CollectProducerScopeInfo : public ExprVisitor { } } // Applying same scope for outputs - StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); - producer_sinfo.Set(ffi::GetRef(call), updated_ret_sinfo); + Type updated_ret_ty = UpdateOutputType(out_ty, {final_scope}); + producer_ty.Set(ffi::GetRef(call), updated_ret_ty); } private: - StructInfo UpdateStructInfo(const StructInfo& out_sinfo, ffi::Array scope) { - if (out_sinfo->IsInstance()) { - auto tensor_sinfo = Downcast(out_sinfo); - auto shape_arr = GetShapeFromTensorStructInfo(tensor_sinfo); - return TensorStructInfo(ShapeExpr(shape_arr), tensor_sinfo->dtype, - VDevice(target_, 0, scope[0])); + Type UpdateOutputType(const Type& out_ty, ffi::Array scope) { + if (out_ty->IsInstance()) { + auto tensor_ty = Downcast(out_ty); + auto shape_arr = GetShapeFromTensorType(tensor_ty); + return TensorType(ShapeExpr(shape_arr), tensor_ty->dtype, VDevice(target_, 0, scope[0])); } - TVM_FFI_ICHECK(out_sinfo->IsInstance()) - << "Expect output struct info of call_tir to be either TupleStructInfo or " - "TensorStructInfo, but got " - << out_sinfo; + TVM_FFI_ICHECK(out_ty->IsInstance()) + << "Expect output type of call_tir to be either TupleType or " + "TensorType, but got " + << out_ty; - const auto& tuple_sinfo = Downcast(out_sinfo); - ffi::Array sinfo_fields; - for (const auto& si : tuple_sinfo->fields) { - TVM_FFI_ICHECK(si->IsInstance()) - << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + const auto& tuple_ty = Downcast(out_ty); + ffi::Array ty_fields; + for (const auto& si : tuple_ty->fields) { + TVM_FFI_ICHECK(si->IsInstance()) + << "Fields of TupleType must be TensorType for call_tir " "output structinfo, but got " << si; - auto sinfo = Downcast(si); - auto shape_arr = GetShapeFromTensorStructInfo(sinfo); - sinfo_fields.push_back( - TensorStructInfo(ShapeExpr(shape_arr), sinfo->dtype, VDevice(target_, 0, scope[0]))); + auto ty = Downcast(si); + auto shape_arr = GetShapeFromTensorType(ty); + ty_fields.push_back( + TensorType(ShapeExpr(shape_arr), ty->dtype, VDevice(target_, 0, scope[0]))); } - return TupleStructInfo(sinfo_fields); + return TupleType(ty_fields); } ffi::Map>> scope_info_; - ffi::Map producer_sinfo; + ffi::Map producer_ty; IRModule mod_; Target target_; BlockBuilder builder_; @@ -576,7 +573,7 @@ class CollectProducerScopeInfo : public ExprVisitor { /* * \brief main pass that injects hint_on_device for each argument based on producer, - * consumer indormations. This also attributes ret StructInfo for each call node. + * consumer indormations. This also attributes ret Type for each call node. * This pass also calls the ReliaseVdevice that formalizes the hints by appropriately injecting * Vdevice copies as needed. */ @@ -597,8 +594,8 @@ class DefineVDevice : ExprMutator { auto info = CollectConsumerScopeInfo().Collect(mod_, Downcast(func), target_); call_scope_info_ = info.first; scope_info_ = info.second; - producer_sinfo_ = CollectProducerScopeInfo().Collect(mod_, Downcast(func), - scope_info_, target_, builder_); + producer_ty_ = CollectProducerScopeInfo().Collect(mod_, Downcast(func), + scope_info_, target_, builder_); relax::Function update_func = Downcast(VisitExpr(func)); updates_->Add(gv, update_func); } @@ -627,7 +624,7 @@ class DefineVDevice : ExprMutator { GlobalVar gv; Tuple func_args; - StructInfo out_sinfo; + Type out_ty; if (call->op == call_tir_op) { gv = Downcast(call->args[0]); @@ -637,51 +634,51 @@ class DefineVDevice : ExprMutator { } ffi::Array new_args; - StructInfo updated_ret_sinfo = producer_sinfo_[ffi::GetRef(call_node)]; - - if (updated_ret_sinfo->IsInstance()) { - auto tensor_sinfo = Downcast(updated_ret_sinfo); - auto shape = tensor_sinfo->shape.value(); - auto dtype = tensor_sinfo->dtype; - if (tensor_sinfo->vdevice.defined()) { - auto vdev = tensor_sinfo->vdevice.value(); + Type updated_ret_ty = producer_ty_[ffi::GetRef(call_node)]; + + if (updated_ret_ty->IsInstance()) { + auto tensor_ty = Downcast(updated_ret_ty); + auto shape = tensor_ty->shape.value(); + auto dtype = tensor_ty->dtype; + if (tensor_ty->vdevice.defined()) { + auto vdev = tensor_ty->vdevice.value(); const VDevice& vdev_global = MakeGlobalVDevice(vdev); - updated_ret_sinfo = TensorStructInfo(shape, dtype, vdev_global); + updated_ret_ty = TensorType(shape, dtype, vdev_global); } } else { - TVM_FFI_ICHECK(updated_ret_sinfo->IsInstance()) - << "Expect output struct info of call_tir to be either TupleStructInfo or " - "TensorStructInfo, but got " - << updated_ret_sinfo; - - const auto& tuple_sinfo = Downcast(updated_ret_sinfo); - ffi::Array sinfo_fields; - for (const auto& si : tuple_sinfo->fields) { - TVM_FFI_ICHECK(si->IsInstance()) - << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + TVM_FFI_ICHECK(updated_ret_ty->IsInstance()) + << "Expect output type of call_tir to be either TupleType or " + "TensorType, but got " + << updated_ret_ty; + + const auto& tuple_ty = Downcast(updated_ret_ty); + ffi::Array ty_fields; + for (const auto& si : tuple_ty->fields) { + TVM_FFI_ICHECK(si->IsInstance()) + << "Fields of TupleType must be TensorType for call_tir " "output structinfo, but got " << si; - auto sinfo = Downcast(si); + auto ty = Downcast(si); - auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + auto shape_arr = GetShapeFromTensorType(ty); - auto shape = sinfo->shape.value(); - auto dtype = sinfo->dtype; - if (sinfo->vdevice.defined()) { - auto vdev = sinfo->vdevice.value(); + auto shape = ty->shape.value(); + auto dtype = ty->dtype; + if (ty->vdevice.defined()) { + auto vdev = ty->vdevice.value(); const VDevice& vdev_global = MakeGlobalVDevice(vdev); - sinfo_fields.push_back(TensorStructInfo(shape, dtype, vdev_global)); + ty_fields.push_back(TensorType(shape, dtype, vdev_global)); } else { - sinfo_fields.push_back(sinfo); + ty_fields.push_back(ty); } } - updated_ret_sinfo = TupleStructInfo(sinfo_fields); + updated_ret_ty = TupleType(ty_fields); } int arg_idx = 0; for (auto arg : func_args->fields) { - auto sinfo = GetStructInfo(arg); - if (auto tensor_sinfo = sinfo.as()) { + auto ty = GetType(arg); + if (auto tensor_ty = ty.as()) { ffi::String scope = "global"; if (call_scope_info_.find(ffi::GetRef(call_node)) != call_scope_info_.end()) { scope = call_scope_info_[ffi::GetRef(call_node)][arg_idx]; @@ -695,9 +692,9 @@ class DefineVDevice : ExprMutator { if (call->op == call_tir_op) { return builder_->Normalize( - Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo})); + Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_ty})); } else { - return builder_->Normalize(Call(call->op, new_args, call->attrs, {updated_ret_sinfo})); + return builder_->Normalize(Call(call->op, new_args, call->attrs, {updated_ret_ty})); } } @@ -717,12 +714,12 @@ class DefineVDevice : ExprMutator { Expr HintArg(const Expr& arg, ffi::String scope) { if (arg->IsInstance()) { - if (auto tsinfo = arg->struct_info_.as()) { - if (!tsinfo->vdevice.defined()) { + if (auto tensor_ty = arg->ty.as()) { + if (!tensor_ty->vdevice.defined()) { const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); - TVM_FFI_ICHECK(tsinfo->shape.defined()) << "Shape not defined for a constant tensor ..!"; - arg->struct_info_ = - TensorStructInfo(tsinfo->shape.value(), tsinfo->dtype, vdev, tsinfo->span); + TVM_FFI_ICHECK(tensor_ty->shape.defined()) + << "Shape not defined for a constant tensor ..!"; + arg->ty = TensorType(tensor_ty->shape.value(), tensor_ty->dtype, vdev, tensor_ty->span); return arg; } } @@ -738,8 +735,8 @@ class DefineVDevice : ExprMutator { return new_arg; } - ffi::Optional GetTarget(const StructInfo& sinfo) { - auto tinfo = sinfo.as(); + ffi::Optional GetTarget(const Type& ty) { + auto tinfo = ty.as(); if (tinfo->vdevice.defined()) { auto vdevice = tinfo->vdevice.value(); if (vdevice->target.defined()) { @@ -755,7 +752,7 @@ class DefineVDevice : ExprMutator { Target target_; ffi::Array vdevices_; ffi::Map>> scope_info_; - ffi::Map producer_sinfo_; + ffi::Map producer_ty_; ffi::Map> call_scope_info_; }; diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index a39d1f7cbcc0..49e5c81f0145 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -69,10 +69,10 @@ std::tuple)>> << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " << "but were instead " << out->attrs << " with type " << out->GetTypeKey(); - const auto* tir_out_sinfo = call_tir->sinfo_args[0].as(); - if (!tir_out_sinfo) return expr; + const auto* tir_out_ty = call_tir->ty_args[0].as(); + if (!tir_out_ty) return expr; - if (!tir_out_sinfo->vdevice.defined()) return expr; + if (!tir_out_ty->vdevice.defined()) return expr; const VarNode* arg_var = out->args[0].as(); if (consumers.find(ffi::GetRef(arg_var)) != consumers.end()) { @@ -82,14 +82,13 @@ std::tuple)>> } } - if ((std::string(tir_out_sinfo->vdevice.value()->memory_scope).find("texture") != + if ((std::string(tir_out_ty->vdevice.value()->memory_scope).find("texture") != std::string::npos) && (vdev_attrs->dst_vdevice->memory_scope == "global")) { - auto shape_arr = tir_out_sinfo->GetShape().value(); - auto new_sinfo = - TensorStructInfo(ShapeExpr(shape_arr), tir_out_sinfo->dtype, vdev_attrs->dst_vdevice); + auto shape_arr = tir_out_ty->GetShape().value(); + auto new_ty = TensorType(ShapeExpr(shape_arr), tir_out_ty->dtype, vdev_attrs->dst_vdevice); - return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_sinfo}); + return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_ty}); } return expr; }; @@ -144,8 +143,8 @@ class CollectConsumerDetails : public ExprVisitor { } for (auto arg : func_args->fields) { - auto sinfo = GetStructInfo(arg); - if (auto tensor_sinfo = sinfo.as()) { + auto ty = GetType(arg); + if (auto tensor_ty = ty.as()) { ffi::Array call_list; const VarNode* arg_var = arg.as(); diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h b/src/relax/backend/contrib/codegen_c/codegen_c.h index 03eb51463409..1a5fb1dd801e 100644 --- a/src/relax/backend/contrib/codegen_c/codegen_c.h +++ b/src/relax/backend/contrib/codegen_c/codegen_c.h @@ -333,9 +333,9 @@ class CodegenCBase { * \return The dtype string. */ std::string GetDtypeString(const Var& var) { - auto tsinfo = var->struct_info_.as(); - TVM_FFI_ICHECK(tsinfo) << "Expect TensorStructInfoNode"; - return GetDtypeString(tsinfo); + auto tensor_ty = var->ty.as(); + TVM_FFI_ICHECK(tensor_ty) << "Expect TensorTypeNode"; + return GetDtypeString(tensor_ty); } /*! @@ -345,24 +345,24 @@ class CodegenCBase { * * \return The dtype string. */ - std::string GetDtypeString(const TensorStructInfoNode* tsinfo) { + std::string GetDtypeString(const TensorTypeNode* tensor_ty) { std::string dtype; - if (runtime::TypeMatch(tsinfo->dtype, kDLFloat, 32)) { + if (runtime::TypeMatch(tensor_ty->dtype, kDLFloat, 32)) { dtype = "float"; - } else if (runtime::TypeMatch(tsinfo->dtype, kDLFloat, 16)) { + } else if (runtime::TypeMatch(tensor_ty->dtype, kDLFloat, 16)) { dtype = "half"; - } else if (runtime::TypeMatch(tsinfo->dtype, kDLBfloat, 16)) { + } else if (runtime::TypeMatch(tensor_ty->dtype, kDLBfloat, 16)) { dtype = "bfloat"; - } else if (runtime::TypeMatch(tsinfo->dtype, kDLInt, 32)) { + } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 32)) { dtype = "int"; - } else if (runtime::TypeMatch(tsinfo->dtype, kDLInt, 64)) { + } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 64)) { dtype = "int64_t"; - } else if (runtime::TypeMatch(tsinfo->dtype, kDLInt, 8)) { + } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 8)) { dtype = "int8_t"; - } else if (runtime::TypeMatch(tsinfo->dtype, kDLUInt, 8)) { + } else if (runtime::TypeMatch(tensor_ty->dtype, kDLUInt, 8)) { dtype = "uint8_t"; } else { - TVM_FFI_THROW(InternalError) << "Unsupported dtype " << tsinfo->dtype; + TVM_FFI_THROW(InternalError) << "Unsupported dtype " << tensor_ty->dtype; } return dtype; diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 34ebdd8e9ec0..535908274b7b 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include @@ -265,7 +265,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { * will flatten it. */ NodeEntries AddNode(JSONGraphObjectPtr node, const Expr& expr) { - auto struct_info = GetStructInfo(expr); + auto ty = GetType(expr); auto node_id = nodes_.size(); nodes_.push_back(node); NodeEntries ret; @@ -273,27 +273,26 @@ class JSONSerializer : public relax::MemoizedExprTranslator { TypeVector dtype; // Flatten tuple node. - if (const auto* tuple_sinfo = struct_info.as()) { - for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { - const auto* tensor_sinfo = tuple_sinfo->fields[i].as(); - TVM_FFI_ICHECK(tensor_sinfo) - << "Expect TensorStructInfo, but received: ." << tuple_sinfo->fields[i]->GetTypeKey(); - TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; - ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); + if (const auto* tuple_ty = ty.as()) { + for (size_t i = 0; i < tuple_ty->fields.size(); ++i) { + const auto* tensor_ty = tuple_ty->fields[i].as(); + TVM_FFI_ICHECK(tensor_ty) << "Expect TensorType, but received: ." + << tuple_ty->fields[i]->GetTypeKey(); + TVM_FFI_ICHECK(tensor_ty->shape.defined()) << "Expect shape to be defined."; + ShapeExpr output_shape = Downcast(tensor_ty->shape.value()); ret.push_back(JSONGraphNodeEntry(node_id, i)); shape.emplace_back(GetIntShape(output_shape->values)); - dtype.emplace_back(DType2String(tensor_sinfo->dtype)); + dtype.emplace_back(DType2String(tensor_ty->dtype)); } - node->SetNumOutput(tuple_sinfo->fields.size()); + node->SetNumOutput(tuple_ty->fields.size()); } else { - const auto* tensor_sinfo = struct_info.as(); - TVM_FFI_ICHECK(tensor_sinfo) - << "Expect TensorStructInfo, but received: " << struct_info->GetTypeKey(); - TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; - ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); + const auto* tensor_ty = ty.as(); + TVM_FFI_ICHECK(tensor_ty) << "Expect TensorType, but received: " << ty->GetTypeKey(); + TVM_FFI_ICHECK(tensor_ty->shape.defined()) << "Expect shape to be defined."; + ShapeExpr output_shape = Downcast(tensor_ty->shape.value()); shape.emplace_back(GetIntShape(output_shape->values)); - dtype.emplace_back(DType2String(tensor_sinfo->dtype)); + dtype.emplace_back(DType2String(tensor_ty->dtype)); ret.push_back(JSONGraphNodeEntry(node_id, 0)); } node->SetShape(shape); diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index a3e23fc71a49..714c38bda46e 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -84,13 +84,13 @@ class CublasJSONSerializer : public JSONSerializer { const CallNode* dequantize_call = backend::GetOpInFunction(fn, "relax.dequantize"); if (dequantize_call->args[1]->IsInstance()) { const auto* const_expr = dequantize_call->args[1].as(); - auto sinfo = Downcast(const_expr->struct_info_); + auto ty = Downcast(const_expr->ty); float alpha = 1.0; - if (sinfo->dtype == DataType::Float(16)) { + if (ty->dtype == DataType::Float(16)) { alpha = __extendXfYf2__( static_cast(const_expr->data->data)[0]); } else { - TVM_FFI_ICHECK(sinfo->dtype == DataType::Float(32)); + TVM_FFI_ICHECK(ty->dtype == DataType::Float(32)); alpha = static_cast(const_expr->data->data)[0]; } diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index 8526b4d03eed..7af44bf0e74d 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -106,12 +106,9 @@ class cuDNNJSONSerializer : public JSONSerializer { "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); const CallNode* root_call = backend::GetOpInFunction(fn, "relax.nn.attention"); - auto q_shape = Downcast( - Downcast(root_call->args[0]->struct_info_.value())->shape.value()); - auto k_shape = Downcast( - Downcast(root_call->args[1]->struct_info_.value())->shape.value()); - auto v_shape = Downcast( - Downcast(root_call->args[2]->struct_info_.value())->shape.value()); + auto q_shape = Downcast(Downcast(root_call->args[0]->ty)->shape.value()); + auto k_shape = Downcast(Downcast(root_call->args[1]->ty)->shape.value()); + auto v_shape = Downcast(Downcast(root_call->args[2]->ty)->shape.value()); int num_heads = q_shape->values[2].as()->value; int num_kv_heads = k_shape->values[2].as()->value; int head_size = q_shape->values[3].as()->value; diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 6de72397dc52..c8c131e04671 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -165,11 +165,11 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, std::vector arg_types, arg_names; for (const auto& arg : ext_func_args_) { - auto sinfo = GetStructInfo(arg); - if (const auto* tensor_sinfo = sinfo.as()) { - arg_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); - } else if (const auto* shape_sinfo = sinfo.as()) { - arg_types.emplace_back(backend::DType2String(shape_sinfo->values.value()[0]->dtype)); + auto ty = GetType(arg); + if (const auto* tensor_ty = ty.as()) { + arg_types.emplace_back(backend::DType2String(tensor_ty->dtype)); + } else if (const auto* shape_ty = ty.as()) { + arg_types.emplace_back(backend::DType2String(shape_ty->values.value()[0]->dtype)); } else { TVM_FFI_THROW(InternalError) << "Unimplemented"; } @@ -298,13 +298,13 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, GenerateBodyOutput GenerateBody(const CallNode* call, const std::string& func_name, const ffi::Map& attrs) { auto func_args = GetArgumentNames(call); - auto struct_info = GetStructInfo(ffi::GetRef(call)); + auto ty = GetType(ffi::GetRef(call)); std::vector out_types; - if (const auto* tensor_sinfo = struct_info.as()) { - out_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + if (const auto* tensor_ty = ty.as()) { + out_types.emplace_back(backend::DType2String(tensor_ty->dtype)); } else { - TVM_FFI_THROW(InternalError) << "Unimplemented sinfo type: " << struct_info; + TVM_FFI_THROW(InternalError) << "Unimplemented ty type: " << ty; } return contrib::GenerateBody(func_name, ext_func_id_, out_types, func_args, attrs, &buf_idx_); diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index 78ed6fbc4e63..3dd7bf323149 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -143,9 +143,9 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { void MaybeFillReduceAxes(const CallNode* call_node) { const auto* attrs = call_node->attrs.as(); if (attrs == nullptr || attrs->axis.has_value()) return; - const auto* tensor_sinfo = GetStructInfo(call_node->args[0]).as(); - if (tensor_sinfo == nullptr || !tensor_sinfo->shape.defined()) return; - const auto* shape = tensor_sinfo->shape.value().as(); + const auto* tensor_ty = GetType(call_node->args[0]).as(); + if (tensor_ty == nullptr || !tensor_ty->shape.defined()) return; + const auto* shape = tensor_ty->shape.value().as(); if (shape == nullptr) return; ffi::Array all_axes; for (size_t i = 0; i < shape->values.size(); ++i) all_axes.push_back(static_cast(i)); diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index e9cb175fdcc7..0e4831af2177 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -142,7 +142,7 @@ class CodeGenVM : public ExprFunctor { } // allocate dst register. - RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister(); + RegName dst_reg = HasVoidType(call) ? Instruction::kVoidRegister : NewRegister(); if (call->op.as()) { if (call_node->op == call_builtin_with_ctx_op_) { // TODO(relax-team) migrate most handling of op to @@ -220,9 +220,9 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const ConstantNode* op) final { auto arg = builder_->ConvertConstant(op->data); - if (auto tsinfo = op->struct_info_.as()) { - if (tsinfo->vdevice.defined()) { - VDevice vdev = tsinfo->vdevice.value(); + if (auto tensor_ty = op->ty.as()) { + if (tensor_ty->vdevice.defined()) { + VDevice vdev = tensor_ty->vdevice.value(); builder_->SaveMemoryScope(arg, vdev->memory_scope); } } diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index e93c2ee199db..508ed1a5fcf3 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -233,7 +233,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (call_node->op == null_value_op_) { return tirx::Call(DataType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}); } - int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); + int64_t dst_reg = HasVoidType(call) ? -1 : NewRegister(); if (call->op.as()) { if (call_node->op == call_builtin_with_ctx_op_) { EmitCallBuiltinWithCtx(call, dst_reg); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 093ddc3c9916..c9e531790d0f 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -118,39 +118,39 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { for (Expr arg : tir_args->fields) { args.push_back(arg); } - return Call(builtin_call_tir_dyn_, args, Attrs(), {void_sinfo_}); + return Call(builtin_call_tir_dyn_, args, Attrs(), {void_ty_}); } Expr Reshape(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 2); - TVM_FFI_ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->ty.defined()); auto arg = call_node->args[1]; - TVM_FFI_CHECK(arg->struct_info_->IsInstance(), TypeError) + TVM_FFI_CHECK(arg->ty->IsInstance(), TypeError) << "VMBuiltinLower expects the shape arg of R.reshape " << "to be a ShapeExpr or VarNode bound to a ShapeExpr. " - << "However, in expression " << call_node << ", the shape argument " << arg - << " has struct info " << arg->struct_info_; + << "However, in expression " << call_node << ", the shape argument " << arg << " has type " + << arg->ty; - return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + return Call(builtin_reshape_, call_node->args, Attrs(), {GetType(call_node)}); } Expr ShapeOf(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 1); - TVM_FFI_ICHECK(call_node->struct_info_.defined()); - return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + TVM_FFI_ICHECK(call_node->ty.defined()); + return Call(builtin_shape_of_, call_node->args, Attrs(), {GetType(call_node)}); } Expr TensorToShape(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 1); - TVM_FFI_ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->ty.defined()); - return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetType(call_node)}); } Expr CallPyFunc(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 2); - TVM_FFI_ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->ty.defined()); // Create tuple with function name and arguments tuple ffi::Array tuple_fields; @@ -159,14 +159,14 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { auto combined_tuple = Tuple(tuple_fields); // Direct call to vm.builtin.call_py_func - return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args, + return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->ty_args, call_node->span); } Expr ToDevice(const Call& call_node) { // TODO(yongwww): replace ToVDeviceAttrs with related Expr TVM_FFI_ICHECK(call_node->args.size() == 1); - TVM_FFI_ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->ty.defined()); auto attrs = call_node->attrs.as(); ffi::Array args; args.push_back(call_node->args[0]); @@ -178,7 +178,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { args.push_back(PrimValue::Int64(dev_type)); args.push_back(PrimValue::Int64(dev_id)); args.push_back(storage_scope); - return Call(builtin_to_device_, args, call_node->attrs, {GetStructInfo(call_node)}); + return Call(builtin_to_device_, args, call_node->attrs, {GetType(call_node)}); } Expr MakeClosure(const Call& call_node) { @@ -195,7 +195,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { args.push_back(arg); } - return Call(builtin_make_closure_, args, Attrs(), {object_sinfo_}); + return Call(builtin_make_closure_, args, Attrs(), {object_ty_}); } Expr InvokeClosure(const Call& call_node) { @@ -213,12 +213,12 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { args.push_back(arg); } return Call(call_builtin_with_ctx_op_, {builtin_invoke_closure_, Tuple(args)}, Attrs(), - {object_sinfo_}); + {object_ty_}); } const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); - const StructInfo object_sinfo_ = ObjectStructInfo(); - const StructInfo void_sinfo_ = TupleStructInfo(ffi::Array({})); + const Type object_ty_ = ObjectType(); + const Type void_ty_ = TupleType(ffi::Array({})); // object to pattern match. const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 36da54849045..9ba34d694581 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -25,8 +25,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include @@ -73,7 +73,7 @@ using PrimExprSlotMap = std::unordered_map; // Collector to collect PrimExprSlotMap -class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { +class PrimExprSlotCollector : public ExprVisitor, public TypeVisitor { public: // collect the PrimExpr slot for a given function static void Collect(Function func, std::vector>* slot_vec, @@ -83,11 +83,11 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { collector.slot_map_ = slot_map; // collect shape declaration in func params for (auto param : func->params) { - collector.VisitStructInfo(GetStructInfo(param)); + collector.VisitType(GetType(param)); collector.VisitExpr(param); } collector.VisitExpr(func->body); - collector.VisitStructInfo(func->ret_struct_info); + collector.VisitType(func->ret_ty); } private: @@ -103,22 +103,22 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { } void VisitBinding_(const MatchCastNode* op) final { - // Visit the match cast struct info so we can define + // Visit the match cast type so we can define // the symbolic variables here. - this->VisitStructInfo(op->struct_info); + this->VisitType(op->ty); } void VisitExpr_(const FunctionNode* op) final { // Do not recurse into function node as it is self-contained } - void VisitStructInfo_(const FuncStructInfoNode* op) final { - // Do not recurse into function struct info as it is self-contained + void VisitType_(const FuncTypeNode* op) final { + // Do not recurse into function type as it is self-contained } - void VisitStructInfoExprField(const PrimExpr& expr) final { VisitPrimExpr(expr); } + void VisitTypeExprField(const PrimExpr& expr) final { VisitPrimExpr(expr); } - void VisitStructInfoExprField(const Expr& expr) final { ExprVisitor::VisitExpr(expr); } + void VisitTypeExprField(const Expr& expr) final { ExprVisitor::VisitExpr(expr); } std::vector>* slot_vec_; PrimExprSlotMap* slot_map_; @@ -145,7 +145,7 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { * * Steps at each matching point: * - Step 0: We call CheckMatchCast, - * which will recursively unpack the StructInfo, and generate static information checks. + * which will recursively unpack the Type, and generate static information checks. * Note that this step only generates functions for checking types and ndim info, but not * the symbolic shape variables. The symbolic shape-matching results will be returned as * vector. This is because symbolic shape matching may not be completed @@ -201,8 +201,8 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { */ class VMShapeLowerMutator : public ExprMutator, - public StructInfoFunctor*)> { + public TypeFunctor*)> { public: static IRModule Lower(IRModule mod, bool emit_err_ctx) { VMShapeLowerMutator mutator(mod, emit_err_ctx); @@ -252,11 +252,11 @@ class VMShapeLowerMutator num_input = static_cast(opt_num_input.value()); } for (size_t i = 0; i < func->params.size(); ++i) { - StructInfo sinfo = GetStructInfo(func->params[i]); + Type ty = GetType(func->params[i]); std::ostringstream err_ctx; err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i - << "], param=" << func->params[i]->name_hint() << ", annotation=" << sinfo << ") "; - this->CheckMatchCast(sinfo, func->params[i], true, i >= num_input, err_ctx.str(), + << "], param=" << func->params[i]->name_hint() << ", annotation=" << ty << ") "; + this->CheckMatchCast(ty, func->params[i], true, i >= num_input, err_ctx.str(), &match_todos); } // insert heap generation logic. @@ -277,11 +277,10 @@ class VMShapeLowerMutator builder_->BeginBindingBlock(); std::ostringstream err_ctx; err_ctx << "ErrorContext(fn=" << gvar->name_hint - << ", loc=return, annotation=" << func->ret_struct_info << ") "; + << ", loc=return, annotation=" << func->ret_ty << ") "; std::vector match_todos; // NOTE: the return value's shape computation must already be defined. - this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, false, err_ctx.str(), - &match_todos); + this->CheckMatchCast(func->ret_ty, body_seq->body, false, false, err_ctx.str(), &match_todos); // NOTE: the return value's shape computation must already be defined. this->RunMatch(match_todos, true); BindingBlock post_block = builder_->EndBlock(); @@ -293,7 +292,7 @@ class VMShapeLowerMutator current_gvar_ = std::nullopt; // create a new function - return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs); + return Function(func->params, new_body, func->ret_ty, func->is_pure, func->attrs); } //------------------------------------------------------- @@ -330,17 +329,17 @@ class VMShapeLowerMutator VarBinding AllocShapeHeapBinding(IntImm heap_size) { if (heap_size->value > 0) { - TensorStructInfo heap_sinfo(ShapeDType(), 1); - Var var("shape_heap", heap_sinfo); + TensorType heap_ty(ShapeDType(), 1); + Var var("shape_heap", heap_ty); // set up the builtin func. Call call(call_builtin_with_ctx_op_, - {builtin_alloc_shape_heap_, Tuple({PrimValue(heap_size)})}, Attrs(), {heap_sinfo}); - UpdateStructInfo(call, heap_sinfo); + {builtin_alloc_shape_heap_, Tuple({PrimValue(heap_size)})}, Attrs(), {heap_ty}); + UpdateType(call, heap_ty); return VarBinding(var, call); } else { - Var var("shape_heap", ObjectStructInfo()); + Var var("shape_heap", ObjectType()); Call call(null_value_op_, {}); - UpdateStructInfo(call, ObjectStructInfo()); + UpdateType(call, ObjectType()); return VarBinding(var, call); } } @@ -386,7 +385,7 @@ class VMShapeLowerMutator args.push_back(value_or_index); // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) - Call call(builtin_make_prim_value_, args, Attrs(), {Downcast(op->struct_info_)}); + Call call(builtin_make_prim_value_, args, Attrs(), {Downcast(op->ty)}); return call; } @@ -409,8 +408,7 @@ class VMShapeLowerMutator } // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) - Call call(builtin_make_shape_, args, Attrs(), - {ShapeStructInfo(static_cast(op->values.size()))}); + Call call(builtin_make_shape_, args, Attrs(), {ShapeType(static_cast(op->values.size()))}); return call; } @@ -418,9 +416,9 @@ class VMShapeLowerMutator Expr value = ExprMutator::VisitExpr(binding->value); std::vector match_todos; std::ostringstream err_ctx; - err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info << ") "; + err_ctx << "ErrorContext(match_cast, ty=" << binding->ty << ") "; // always_check=false - this->CheckMatchCast(binding->struct_info, value, false, false, err_ctx.str(), &match_todos); + this->CheckMatchCast(binding->ty, value, false, false, err_ctx.str(), &match_todos); match_todos = this->RunMatch(match_todos, false); this->EmitOutstandingPrimExprCompute(); @@ -431,12 +429,12 @@ class VMShapeLowerMutator ExprMutator::VisitBinding_(binding); } - // Do not override shape in struct info fields + // Do not override shape in type fields // We only override the shape that are already part of the normal function values // If future passes lift those values out into the values, // then codegen may not be able to handle symbolic values. // Place this pass as last pass before codegen. - StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final { return sinfo; } + Type VisitExprDepTypeField(const Type& ty) final { return ty; } /* \brief Internal utility function used for RunMatch() * @@ -507,7 +505,7 @@ class VMShapeLowerMutator ffi::Array args = {item.input, shape_heap_}; Expr match_op; - if (item.input->struct_info_.as()) { + if (item.input->ty.as()) { match_op = builtin_match_prim_value_; TVM_FFI_ICHECK_EQ(item.pattern.size(), 1); } else { @@ -527,7 +525,7 @@ class VMShapeLowerMutator } args.push_back(GetErrContext(item.err_ctx)); if (!all_nop) { - Call call(match_op, args, Attrs(), {void_sinfo_}); + Call call(match_op, args, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } } @@ -607,53 +605,47 @@ class VMShapeLowerMutator return to_compute.size(); } //------------------------------------------------------- - // StructInfo value match logic + // Type value match logic // // CheckMatchCast is the only function needed by // other code sections //------------------------------------------------------- /*! - * \brief Insert runtime check of the match cast condition(value, struct_info). + * \brief Insert runtime check of the match cast condition(value, ty). * - * \param struct_info The struct info to be matched. + * \param ty The type to be matched. * \param value The input value. * \param always_check Whether we insert runtime check even if we can prove - * that value's struct info already satisfies the condition. + * that value's type already satisfies the condition. * This option is necessary for argument checking per our calling convention. * \param dynamic_only Whether we only check values with dynamic shapes. * \param err_ctx Extra error context to bring more informative error reporting. * \param match_todos List of match shape todo items collected when recursively * visit the match cast. */ - void CheckMatchCast(const StructInfo& struct_info, Expr value, bool always_check, - bool dynamic_only, const ffi::String& err_ctx, - std::vector* match_todos) { - return this->VisitStructInfo(struct_info, value, always_check, dynamic_only, err_ctx, - match_todos); + void CheckMatchCast(const Type& ty, Expr value, bool always_check, bool dynamic_only, + const ffi::String& err_ctx, std::vector* match_todos) { + return this->VisitType(ty, value, always_check, dynamic_only, err_ctx, match_todos); } - void VisitStructInfo(const StructInfo& struct_info, Expr value, bool always_check, - bool dynamic_only, const ffi::String& err_ctx, - std::vector* match_todos) final { - // short-cut, if the struct info already satisfies the + void VisitType(const Type& ty, Expr value, bool always_check, bool dynamic_only, + const ffi::String& err_ctx, std::vector* match_todos) final { + // short-cut, if the type already satisfies the // constraint during match cast, we can skip matching - if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return; - return StructInfoFunctor::VisitStructInfo(struct_info, value, always_check, dynamic_only, - err_ctx, match_todos); + if (!always_check && IsBaseOf(ty, GetType(value))) return; + return TypeFunctor::VisitType(ty, value, always_check, dynamic_only, err_ctx, match_todos); } - void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const ffi::String& err_ctx, - std::vector* match_todos) final {} + void VisitType_(const ObjectTypeNode* op, Expr value, bool always_check, bool dynamic_only, + const ffi::String& err_ctx, std::vector* match_todos) final {} - void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const ffi::String& err_ctx, - std::vector* match_todos) final { + void VisitType_(const PrimTypeNode* op, Expr value, bool always_check, bool dynamic_only, + const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape - if (always_check || !IsBaseOf(PrimStructInfo(op->dtype), GetStructInfo(value))) { + if (always_check || !IsBaseOf(PrimType(op->dtype), GetType(value))) { // check_shape_info(value, ndim, err_ctx) Call call(builtin_check_prim_value_info_, - {value, DataTypeImm(op->dtype), GetErrContext(err_ctx)}, Attrs(), {void_sinfo_}); + {value, DataTypeImm(op->dtype), GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } if (op->value.defined()) { @@ -665,15 +657,13 @@ class VMShapeLowerMutator } } - void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const ffi::String& err_ctx, - std::vector* match_todos) final { + void VisitType_(const ShapeTypeNode* op, Expr value, bool always_check, bool dynamic_only, + const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape - if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), GetStructInfo(value))) { + if (always_check || !IsBaseOf(ShapeType(op->ndim), GetType(value))) { // check_shape_info(value, ndim, err_ctx) Call call(builtin_check_shape_info_, - {value, PrimValue::Int64(op->ndim), GetErrContext(err_ctx)}, Attrs(), - {void_sinfo_}); + {value, PrimValue::Int64(op->ndim), GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } if (op->values.defined()) { @@ -685,9 +675,8 @@ class VMShapeLowerMutator } } - void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const ffi::String& err_ctx, - std::vector* match_todos) final { + void VisitType_(const TensorTypeNode* op, Expr value, bool always_check, bool dynamic_only, + const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape auto* shape_expr = op->shape.as(); if (dynamic_only && @@ -696,11 +685,11 @@ class VMShapeLowerMutator // if we only check dynamic shapes, and the shape is static, we can skip. return; } - if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim), GetStructInfo(value))) { + if (always_check || !IsBaseOf(TensorType(op->dtype, op->ndim), GetType(value))) { // check_tensor_info(value, ndim, dtype, err_ctx) Call call(builtin_check_tensor_info_, {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype), GetErrContext(err_ctx)}, - Attrs(), {void_sinfo_}); + Attrs(), {void_ty_}); builder_->Emit(call, "_"); } @@ -722,25 +711,24 @@ class VMShapeLowerMutator // Internal helper function to make tuple get item. // This function will try to simplify constant tuples - // the return value **always** have struct info. + // the return value **always** have type. Expr MakeTupleGetItem(Expr value, int64_t index) { if (auto* tuple_expr = value.as()) { return tuple_expr->fields[index]; - } else if (GetStructInfoAs(value)) { + } else if (GetTypeAs(value)) { // value is tuple type, it is OK to run tuple get item. return TupleGetItem(value, index); } else { // call runtime tuple get item, and return a object. - Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)}, Attrs(), {object_sinfo_}); - UpdateStructInfo(call, ObjectStructInfo()); + Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)}, Attrs(), {object_ty_}); + UpdateType(call, ObjectType()); return call; } } - void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const ffi::String& err_ctx, - std::vector* match_todos) final { - auto* value_tinfo = GetStructInfoAs(value); + void VisitType_(const TupleTypeNode* op, Expr value, bool always_check, bool dynamic_only, + const ffi::String& err_ctx, std::vector* match_todos) final { + auto* value_tinfo = GetTypeAs(value); if (value_tinfo) { TVM_FFI_CHECK_EQ(value_tinfo->fields.size(), op->fields.size(), TypeError) << err_ctx << " during match-cast we find tuple size mismatch"; @@ -750,23 +738,22 @@ class VMShapeLowerMutator Call call(builtin_check_tuple_info_, {value, PrimValue::Int64(static_cast(op->fields.size())), GetErrContext(err_ctx)}, - Attrs(), {void_sinfo_}); + Attrs(), {void_ty_}); builder_->Emit(call, "_"); } // recursively visit each sub-field and run matching for (size_t i = 0; i < op->fields.size(); ++i) { - this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), always_check, dynamic_only, - err_ctx, match_todos); + this->VisitType(op->fields[i], MakeTupleGetItem(value, i), always_check, dynamic_only, + err_ctx, match_todos); } } - void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const ffi::String& err_ctx, - std::vector* match_todos) final { + void VisitType_(const FuncTypeNode* op, Expr value, bool always_check, bool dynamic_only, + const ffi::String& err_ctx, std::vector* match_todos) final { // we only check function is callable. - if (!always_check && MatchStructInfo(value)) return; + if (!always_check && MatchType(value)) return; // check_func_info(value, err_ctx) - Call call(builtin_check_func_info_, {value, GetErrContext(err_ctx)}, Attrs(), {void_sinfo_}); + Call call(builtin_check_func_info_, {value, GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } @@ -792,9 +779,9 @@ class VMShapeLowerMutator // call builtin cop const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); const Op& null_value_op_ = Op::Get("relax.null_value"); - // common struct info - const StructInfo object_sinfo_ = ObjectStructInfo(); - const StructInfo void_sinfo_ = TupleStructInfo(ffi::Array({})); + // common type + const Type object_ty_ = ObjectType(); + const Type void_ty_ = TupleType(ffi::Array({})); // check function const ExternFunc builtin_alloc_shape_heap_{"vm.builtin.alloc_shape_heap"}; const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"}; diff --git a/src/relax/distributed/axis_group_graph.cc b/src/relax/distributed/axis_group_graph.cc index 1252c46ee6af..d00515f458cb 100644 --- a/src/relax/distributed/axis_group_graph.cc +++ b/src/relax/distributed/axis_group_graph.cc @@ -64,23 +64,23 @@ namespace tvm { namespace relax { namespace distributed { -const TensorStructInfoNode* GetTensorStructInfo(Expr tensor) { - const auto* tensor_sinfo = GetStructInfoAs(tensor); - if (tensor_sinfo) { - return tensor_sinfo; +const TensorTypeNode* GetTensorType(Expr tensor) { + const auto* tensor_ty = GetTypeAs(tensor); + if (tensor_ty) { + return tensor_ty; } - const auto* dtensor_sinfo = GetStructInfoAs(tensor); - if (dtensor_sinfo) { - return dtensor_sinfo->tensor_sinfo.get(); + const auto* dtensor_ty = GetTypeAs(tensor); + if (dtensor_ty) { + return dtensor_ty->tensor_ty.get(); } TVM_FFI_THROW(InternalError) << tensor << " must be either Tensor or DTesor"; throw; } void UnaryOpHelper(ffi::Array tensor_list, distributed::AxisGroupGraph* axis_group_graph) { - int n_dim = GetTensorStructInfo(tensor_list[0])->ndim; + int n_dim = GetTensorType(tensor_list[0])->ndim; for (const auto& tensor : tensor_list) { - TVM_FFI_ICHECK(GetTensorStructInfo(tensor)->ndim == n_dim); + TVM_FFI_ICHECK(GetTensorType(tensor)->ndim == n_dim); } for (int i = 0; i < n_dim; i++) { TVM_FFI_ICHECK(tensor_list.size() <= 2); @@ -104,12 +104,10 @@ void BuildAxisGraphUnary(const Var& output_var, const Call& call, void BuildAxisGraphBinary(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { ffi::Array tensor_list; // vars in param and output - if (call->args[0]->struct_info_.as() || - call->args[0]->struct_info_.as()) { + if (call->args[0]->ty.as() || call->args[0]->ty.as()) { tensor_list.push_back(call->args[0]); } - if (call->args[1]->struct_info_.as() || - call->args[1]->struct_info_.as()) { + if (call->args[1]->ty.as() || call->args[1]->ty.as()) { tensor_list.push_back(call->args[1]); } tensor_list.push_back(output_var); @@ -117,12 +115,12 @@ void BuildAxisGraphBinary(const Var& output_var, const Call& call, UnaryOpHelper(tensor_list, axis_group_graph); return; } - const auto* x1_sinfo = GetTensorStructInfo(tensor_list[0]); - const auto* x2_sinfo = GetTensorStructInfo(tensor_list[1]); - int x1_ndim = x1_sinfo->ndim; - int x2_ndim = x2_sinfo->ndim; - const auto* x1_shape = x1_sinfo->shape.as(); - const auto* x2_shape = x2_sinfo->shape.as(); + const auto* x1_ty = GetTensorType(tensor_list[0]); + const auto* x2_ty = GetTensorType(tensor_list[1]); + int x1_ndim = x1_ty->ndim; + int x2_ndim = x2_ty->ndim; + const auto* x1_shape = x1_ty->shape.as(); + const auto* x2_shape = x2_ty->shape.as(); TVM_FFI_ICHECK(x1_shape && x2_shape); arith::Analyzer analyzer; for (int i = 1; i <= std::min(x1_ndim, x2_ndim); ++i) { @@ -178,7 +176,7 @@ void BuildAxisGraphReduce(const Var& output_var, const Call& call, TVM_FFI_THROW(InternalError) << "Unsupported reduce op: " << call->op; } - int ndim = GetTensorStructInfo(input_tensor)->ndim; + int ndim = GetTensorType(input_tensor)->ndim; std::unordered_set normalized_axes; for (int64_t i : axes) { @@ -212,10 +210,10 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, Expr x1 = call->args[0]; Expr x2 = call->args[1]; Var x3 = output_var; - const auto* x1_sinfo = GetTensorStructInfo(x1); - const auto* x2_sinfo = GetTensorStructInfo(x2); - int x1_ndim = x1_sinfo->ndim; - int x2_ndim = x2_sinfo->ndim; + const auto* x1_ty = GetTensorType(x1); + const auto* x2_ty = GetTensorType(x2); + int x1_ndim = x1_ty->ndim; + int x2_ndim = x2_ty->ndim; TVM_FFI_ICHECK(x1_ndim > 0 && x2_ndim > 0); int x1_prepended = 0; int x2_appended = 0; @@ -227,8 +225,8 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, x2_ndim = 2; x2_appended = 1; } - const auto* x1_shape = x1_sinfo->shape.as(); - const auto* x2_shape = x2_sinfo->shape.as(); + const auto* x1_shape = x1_ty->shape.as(); + const auto* x2_shape = x2_ty->shape.as(); TVM_FFI_ICHECK(x1_shape && x2_shape); ffi::Array x1_shape_prefix{x1_shape->values.begin(), x1_shape->values.end() - 2 + x1_prepended}; @@ -262,7 +260,7 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, } } // join reduction dim - axis_group_graph->JoinAxis({x1.get(), x1_sinfo->ndim - 1}, {x2.get(), x2_ndim - 2}, + axis_group_graph->JoinAxis({x1.get(), x1_ty->ndim - 1}, {x2.get(), x2_ndim - 2}, distributed::AxisGroupGraph::EdgeType::kSimbling); // join lhs_spatial dim and rhs_spatial dim if (!x1_prepended) { @@ -286,7 +284,7 @@ void BuildAxisGraphPermuteDims(const Var& output_var, const Call& call, Expr input_tensor = call->args[0]; const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs); - int ndim = GetTensorStructInfo(input_tensor)->ndim; + int ndim = GetTensorType(input_tensor)->ndim; std::vector normalized_axes; if (attrs->axes.defined()) { for (int64_t i : attrs->axes.value()) { @@ -309,12 +307,12 @@ void BuildAxisGraphPermuteDims(const Var& output_var, const Call& call, void BuildAxisGraphReshape(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { Expr input_tensor = call->args[0]; - const auto* tensor_sinfo = GetTensorStructInfo(input_tensor); - const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); - const auto* old_shape_sinfo = GetStructInfoAs(tensor_sinfo->shape.value()); - TVM_FFI_ICHECK_NOTNULL(old_shape_sinfo); - ffi::Array old_shape_values = old_shape_sinfo->values.value(); - ffi::Array new_shape_values = new_shape_sinfo->values.value(); + const auto* tensor_ty = GetTensorType(input_tensor); + const auto* new_shape_ty = GetTypeAs(call->args[1]); + const auto* old_shape_ty = GetTypeAs(tensor_ty->shape.value()); + TVM_FFI_ICHECK_NOTNULL(old_shape_ty); + ffi::Array old_shape_values = old_shape_ty->values.value(); + ffi::Array new_shape_values = new_shape_ty->values.value(); int i = old_shape_values.size(); int j = new_shape_values.size(); PrimExpr old_shape_product = 1, new_shape_product = 1; @@ -340,9 +338,9 @@ void BuildAxisGraphReshape(const Var& output_var, const Call& call, } inline int GetNumOutput(Call call) { - StructInfo sinfo = call->sinfo_args[0]; - if (const auto* tuple_sinfo = sinfo.as()) { - return tuple_sinfo->fields.size(); + Type output_ty = call->ty_args[0]; + if (const auto* tuple_ty = output_ty.as()) { + return tuple_ty->fields.size(); } else { return 1; } diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index 510927883752..b1b81c401e9a 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -68,16 +68,16 @@ class RedistributeLegalizer : public ExprMutator { if (call->op.same_as(redistribute_op)) { const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs); - const auto* input_sinfo = call->args[0]->struct_info_.as(); - TVM_FFI_ICHECK(input_sinfo); + const auto* input_ty = call->args[0]->ty.as(); + TVM_FFI_ICHECK(input_ty); // As the first step, we only support redistribute in the same device mesh, // and the device mesh must be 1d // todo: extend the ccl ops so that it can support 2d device mesh, and different sharding // dimension - TVM_FFI_ICHECK(ffi::StructuralEqual()(input_sinfo->device_mesh, attrs->device_mesh)); - TVM_FFI_ICHECK(input_sinfo->device_mesh->shape.size() == 1); + TVM_FFI_ICHECK(ffi::StructuralEqual()(input_ty->device_mesh, attrs->device_mesh)); + TVM_FFI_ICHECK(input_ty->device_mesh->shape.size() == 1); // only support "S[x]"-> "R" and "R" -> "S[x]" - PlacementSpec input_spec = input_sinfo->placement->dim_specs[0]; + PlacementSpec input_spec = input_ty->placement->dim_specs[0]; PlacementSpec output_spec = attrs->placement->dim_specs[0]; if (input_spec->kind == PlacementSpecKind::kReplica && output_spec->kind == PlacementSpecKind::kReplica) { diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 222736f8d1f1..30eb2466d45f 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -75,54 +75,54 @@ class DistIRSharder : public ExprMutator { return ShapeExpr(new_tensor_shape_value); } - TensorStructInfo ShardDTensorSinfo(DTensorStructInfo orig_sinfo) { - TensorStructInfo tensor_sinfo = orig_sinfo->tensor_sinfo; - TVM_FFI_ICHECK(tensor_sinfo->shape); - const auto* orig_shape = tensor_sinfo->shape.as(); - auto new_tensor_sinfo = ffi::make_object(*tensor_sinfo.get()); - new_tensor_sinfo->shape = ShardShape(ffi::GetRef(orig_shape), - orig_sinfo->device_mesh, orig_sinfo->placement); - return TensorStructInfo(new_tensor_sinfo); + TensorType ShardDTensorType(DTensorType orig_ty) { + TensorType tensor_ty = orig_ty->tensor_ty; + TVM_FFI_ICHECK(tensor_ty->shape); + const auto* orig_shape = tensor_ty->shape.as(); + auto new_tensor_ty = ffi::make_object(*tensor_ty.get()); + new_tensor_ty->shape = + ShardShape(ffi::GetRef(orig_shape), orig_ty->device_mesh, orig_ty->placement); + return TensorType(new_tensor_ty); } - StructInfo ConvertSinfo(StructInfo orig_sinfo, bool shard_shape) { - if (const auto* dtensor_sinfo = orig_sinfo.as()) { + Type ConvertType(Type orig_ty, bool shard_shape) { + if (const auto* dtensor_ty = orig_ty.as()) { if (shard_shape) { - return ShardDTensorSinfo(ffi::GetRef(dtensor_sinfo)); + return ShardDTensorType(ffi::GetRef(dtensor_ty)); } else { - return dtensor_sinfo->tensor_sinfo; + return dtensor_ty->tensor_ty; } - } else if (const auto* tuple_sinfo = orig_sinfo.as()) { - ffi::Array new_fields; - for (const auto& field_sinfo : tuple_sinfo->fields) { - if (const auto* dtensor_sinfo = field_sinfo.as()) { + } else if (const auto* tuple_ty = orig_ty.as()) { + ffi::Array new_fields; + for (const auto& field_ty : tuple_ty->fields) { + if (const auto* dtensor_ty = field_ty.as()) { if (shard_shape) { - new_fields.push_back(ShardDTensorSinfo(ffi::GetRef(dtensor_sinfo))); + new_fields.push_back(ShardDTensorType(ffi::GetRef(dtensor_ty))); } else { - new_fields.push_back(dtensor_sinfo->tensor_sinfo); + new_fields.push_back(dtensor_ty->tensor_ty); } } else { - new_fields.push_back(field_sinfo); + new_fields.push_back(field_ty); } } - return TupleStructInfo(new_fields); + return TupleType(new_fields); } else { - return orig_sinfo; + return orig_ty; } } Expr ShardInputParamTensorAndConstant(Expr input) { - TVM_FFI_ICHECK(input->struct_info_); - StructInfo old_sinfo = GetStructInfo(input); - StructInfo new_sinfo = ConvertSinfo(old_sinfo, false); + TVM_FFI_ICHECK(input->ty.defined()); + Type old_ty = GetType(input); + Type new_ty = ConvertType(old_ty, false); if (const auto* var = input.as()) { - Var new_param(var->name_hint(), new_sinfo); + Var new_param(var->name_hint(), new_ty); return new_param; } else if (const auto* constant = input.as()) { - for (const auto& spec : Downcast(old_sinfo)->placement->dim_specs) { + for (const auto& spec : Downcast(old_ty)->placement->dim_specs) { TVM_FFI_ICHECK(spec->kind == PlacementSpecKind::kReplica); } - Constant new_constant(constant->data, new_sinfo); + Constant new_constant(constant->data, new_ty); return new_constant; } else { TVM_FFI_THROW(InternalError) << "Cannot shard tensor which is not Var or Constant: " << input; @@ -130,10 +130,10 @@ class DistIRSharder : public ExprMutator { } } - void EmitBroadcastOrScatter(Expr old_expr, Expr new_expr, DTensorStructInfo dtensor_sinfo) { + void EmitBroadcastOrScatter(Expr old_expr, Expr new_expr, DTensorType dtensor_ty) { // FIXME: this is a hack that only works for 1d device mesh - TVM_FFI_ICHECK(dtensor_sinfo->device_mesh->shape.size() == 1); - PlacementSpec sharding_spec = dtensor_sinfo->placement->dim_specs[0]; + TVM_FFI_ICHECK(dtensor_ty->device_mesh->shape.size() == 1); + PlacementSpec sharding_spec = dtensor_ty->placement->dim_specs[0]; if (sharding_spec->kind == PlacementSpecKind::kReplica) { Var new_var = builder_->Emit(broadcast_from_worker0(new_expr)); if (const auto* var = old_expr.as()) { @@ -142,8 +142,8 @@ class DistIRSharder : public ExprMutator { tuple_getitem_remap_[Downcast(old_expr)] = new_var; } } else if (sharding_spec->kind == PlacementSpecKind::kSharding) { - Var scatter_var = builder_->Emit(scatter_from_worker0( - new_expr, dtensor_sinfo->device_mesh->shape[0], sharding_spec->axis)); + Var scatter_var = builder_->Emit( + scatter_from_worker0(new_expr, dtensor_ty->device_mesh->shape[0], sharding_spec->axis)); if (const auto* var = old_expr.as()) { var_remap_[var->vid] = scatter_var; } else { @@ -157,14 +157,13 @@ class DistIRSharder : public ExprMutator { void InputPreprocessing() { for (int i = 0; i < static_cast(func_->params.size()); i++) { Var param = func_->params[i]; - if (const auto* dtensor_sinfo = GetStructInfoAs(param)) { - EmitBroadcastOrScatter(param, new_params_[i], - ffi::GetRef(dtensor_sinfo)); - } else if (const auto* tuple_sinfo = GetStructInfoAs(param)) { - for (int j = 0; j < static_cast(tuple_sinfo->fields.size()); j++) { - if (const auto* dtensor_sinfo = tuple_sinfo->fields[j].as()) { + if (const auto* dtensor_ty = GetTypeAs(param)) { + EmitBroadcastOrScatter(param, new_params_[i], ffi::GetRef(dtensor_ty)); + } else if (const auto* tuple_ty = GetTypeAs(param)) { + for (int j = 0; j < static_cast(tuple_ty->fields.size()); j++) { + if (const auto* dtensor_ty = tuple_ty->fields[j].as()) { EmitBroadcastOrScatter(TupleGetItem(param, j), TupleGetItem(new_params_[i], j), - ffi::GetRef(dtensor_sinfo)); + ffi::GetRef(dtensor_ty)); } } } @@ -217,16 +216,16 @@ class DistIRSharder : public ExprMutator { static Op call_tir_local_view_op = Op::Get("relax.dist.call_tir_local_view"); if (call->op.same_as(reshape_op)) { TVM_FFI_ICHECK(call->args[1].as()); - const auto* out_sinfo = GetStructInfoAs(binding_var); - TVM_FFI_ICHECK(out_sinfo); + const auto* out_ty = GetTypeAs(binding_var); + TVM_FFI_ICHECK(out_ty); auto new_call_node = ffi::make_object(*call); - new_call_node->args.Set(1, ShardShape(Downcast(call->args[1]), - out_sinfo->device_mesh, out_sinfo->placement)); + new_call_node->args.Set(1, ShardShape(Downcast(call->args[1]), out_ty->device_mesh, + out_ty->placement)); return Call(new_call_node); } else if (call->op.same_as(call_tir_local_view_op)) { auto new_call_node = ffi::make_object(*call); new_call_node->op = call_tir_op; - new_call_node->sinfo_args = {ConvertSinfo(GetStructInfo(binding_var), true)}; + new_call_node->ty_args = {ConvertType(GetType(binding_var), true)}; return Call(new_call_node); } else if (call->op.same_as(call_tir_op)) { TVM_FFI_THROW(InternalError) @@ -238,11 +237,11 @@ class DistIRSharder : public ExprMutator { } else if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_view") { new_call_node->op = ExternFunc("vm.builtin.attention_kv_cache_view"); auto orig_shape = Downcast(call->args[1]); - const auto* out_sinfo = GetStructInfoAs(binding_var); - TVM_FFI_ICHECK(out_sinfo); - ShapeExpr new_shape = ShardShape(orig_shape, out_sinfo->device_mesh, out_sinfo->placement); + const auto* out_ty = GetTypeAs(binding_var); + TVM_FFI_ICHECK(out_ty); + ShapeExpr new_shape = ShardShape(orig_shape, out_ty->device_mesh, out_ty->placement); new_call_node->args.Set(1, new_shape); - new_call_node->sinfo_args = {TensorStructInfo(new_shape, out_sinfo->tensor_sinfo->dtype)}; + new_call_node->ty_args = {TensorType(new_shape, out_ty->tensor_ty->dtype)}; } return Call(new_call_node); } diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 6984b00d8101..4442f5c82a42 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -378,18 +378,18 @@ class LowerTIRToLocalView : public ExprMutator { } private: - inline ffi::Array ExtractDTensorStructInfo(Var var) { - if (const auto* dtensor_sinfo = GetStructInfoAs(var)) { - return {ffi::GetRef(dtensor_sinfo)}; - } else if (const auto* tuple_sinfo = GetStructInfoAs(var)) { - ffi::Array ret; - for (const auto& field : tuple_sinfo->fields) { - ret.push_back(Downcast(field)); + inline ffi::Array ExtractDTensorType(Var var) { + if (const auto* dtensor_ty = GetTypeAs(var)) { + return {ffi::GetRef(dtensor_ty)}; + } else if (const auto* tuple_ty = GetTypeAs(var)) { + ffi::Array ret; + for (const auto& field : tuple_ty->fields) { + ret.push_back(Downcast(field)); } return ret; } else { TVM_FFI_THROW(InternalError) - << "The output of a call_tir should be a DTensorStructInfo or TupleStructInfo"; + << "The output of a call_tir should be a DTensorType or TupleType"; } } @@ -402,14 +402,14 @@ class LowerTIRToLocalView : public ExprMutator { std::vector sharding_specs; ffi::Array args = Downcast(val->args[1])->fields; for (const auto& arg : args) { - const auto* sinfo = GetStructInfoAs(arg); - TVM_FFI_ICHECK(sinfo); - sharding_specs.push_back(ShardingSpec(sinfo->device_mesh, sinfo->placement)); + const auto* ty = GetTypeAs(arg); + TVM_FFI_ICHECK(ty); + sharding_specs.push_back(ShardingSpec(ty->device_mesh, ty->placement)); } Var output_var = binding->var; - ffi::Array output_sinfos = ExtractDTensorStructInfo(output_var); - for (const auto& sinfo : output_sinfos) { - sharding_specs.push_back(ShardingSpec(sinfo->device_mesh, sinfo->placement)); + ffi::Array output_tys = ExtractDTensorType(output_var); + for (const auto& ty : output_tys) { + sharding_specs.push_back(ShardingSpec(ty->device_mesh, ty->placement)); } GlobalVar gvar = Downcast(val->args[0]); tirx::PrimFunc prim_func = MatchPrimFunc(builder_->GetContextIRModule(), gvar).value(); diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index e9608b7379d1..3f2b21c66588 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -125,7 +125,7 @@ void CollectAxisGraphForDeviceMesh(const VarBindingNode* binding, const CallNode args = call->args; } for (const auto& arg : args) { - if (arg->struct_info_.as()) { + if (arg->ty.as()) { tensor_list.push_back(arg); } } @@ -171,12 +171,12 @@ class AxisGroupGraphBuilder : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { axis_group_graph_->JoinAxis(Axis(val->tuple.get(), -1, val->index), {binding->var.get(), -1}, distributed::AxisGroupGraph::EdgeType::kDescend); - const auto* tensor_sinfo = GetStructInfoAs(binding->var); - if (!tensor_sinfo) { + const auto* tensor_ty = GetTypeAs(binding->var); + if (!tensor_ty) { ExprVisitor::VisitBinding_(binding, val); return; } - int ndim = tensor_sinfo->ndim; + int ndim = tensor_ty->ndim; for (int i = 0; i < ndim; i++) { axis_group_graph_->JoinAxis(Axis(val->tuple.get(), i, val->index), {binding->var.get(), i}, distributed::AxisGroupGraph::EdgeType::kDescend); @@ -185,20 +185,20 @@ class AxisGroupGraphBuilder : public ExprVisitor { } void VisitBinding_(const VarBindingNode* binding, const VarNode* val) { - ffi::Array tensor_sinfos; - if (const auto* tensor_sinfo = binding->var->struct_info_.as()) { - tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); - } else if (const auto* tuple_sinfo = binding->var->struct_info_.as()) { - TVM_FFI_ICHECK(tuple_sinfo); - for (const auto& sinfo : tuple_sinfo->fields) { - tensor_sinfos.push_back(Downcast(sinfo)); + ffi::Array tensor_tys; + if (const auto* tensor_ty = binding->var->ty.as()) { + tensor_tys.push_back(ffi::GetRef(tensor_ty)); + } else if (const auto* tuple_ty = binding->var->ty.as()) { + TVM_FFI_ICHECK(tuple_ty); + for (const auto& field_ty : tuple_ty->fields) { + tensor_tys.push_back(Downcast(field_ty)); } } else { ExprVisitor::VisitBinding_(binding, val); return; } - for (int idx = 0; idx < static_cast(tensor_sinfos.size()); idx++) { - int ndim = tensor_sinfos[idx]->ndim; + for (int idx = 0; idx < static_cast(tensor_tys.size()); idx++) { + int ndim = tensor_tys[idx]->ndim; for (int i = -1; i < ndim; i++) { axis_group_graph_->JoinAxis({val, i, idx}, {binding->var.get(), i, idx}, distributed::AxisGroupGraph::EdgeType::kDescend); @@ -255,7 +255,7 @@ class ShardingConflictHandler : public ExprVisitor { ShardingConflictHandler handler(axis_group_graph); handler.VisitExpr(function); for (const Var& var : function->params) { - if (GetStructInfoAs(var)) { + if (GetTypeAs(var)) { handler.CheckTensorShardingCompatible(var); } } @@ -267,11 +267,11 @@ class ShardingConflictHandler : public ExprVisitor { : axis_group_graph_(axis_group_graph) {} void CheckTensorShardingCompatible(Var var) { - const auto* sinfo = GetStructInfoAs(var); - TVM_FFI_ICHECK(sinfo); - const auto* shape = sinfo->shape.as(); + const auto* tensor_ty = GetTypeAs(var); + TVM_FFI_ICHECK(tensor_ty); + const auto* shape = tensor_ty->shape.as(); TVM_FFI_ICHECK(shape); - int ndim = sinfo->ndim; + int ndim = tensor_ty->ndim; std::unordered_set sharded_mesh_dim; ffi::Optional device_mesh; for (int i = -1; i < ndim; i++) { @@ -308,8 +308,8 @@ class ShardingConflictHandler : public ExprVisitor { } void CheckConstantNoSharding(Constant constant) { - const auto* sinfo = GetStructInfoAs(constant); - for (int i = 0; i < sinfo->ndim; i++) { + const auto* tensor_ty = GetTypeAs(constant); + for (int i = 0; i < tensor_ty->ndim; i++) { AxisShardingSpec sharding_spec; int has_sharding_spec; std::tie(sharding_spec, has_sharding_spec) = @@ -330,7 +330,7 @@ class ShardingConflictHandler : public ExprVisitor { } void VisitBinding_(const VarBindingNode* binding) final { - if (GetStructInfoAs(binding->var)) { + if (GetTypeAs(binding->var)) { CheckTensorShardingCompatible(binding->var); } ExprVisitor::VisitBinding_(binding); @@ -362,9 +362,8 @@ class DistributedIRBuilder : public ExprMutator { private: using ExprMutator::VisitExpr_; - DTensorStructInfo ConvertToDTensorStructInfo(TensorStructInfo sinfo, Expr expr, - int tuple_idx = 0) { - int ndim = sinfo->ndim; + DTensorType ConvertToDTensorType(TensorType tensor_ty, Expr expr, int tuple_idx = 0) { + int ndim = tensor_ty->ndim; DeviceMesh device_mesh = std::get<0>(axis_group_graph_.GetAxisShardingSpec({expr.get(), -1, tuple_idx})).first; TVM_FFI_ICHECK(device_mesh.defined()) @@ -381,32 +380,31 @@ class DistributedIRBuilder : public ExprMutator { placement_specs.Set(sharding_dim, PlacementSpec::Sharding(i)); } } - return DTensorStructInfo(sinfo, device_mesh, Placement(placement_specs)); + return DTensorType(tensor_ty, device_mesh, Placement(placement_specs)); } Expr RewriteInputTensorAndConstant(Expr tensor) { - StructInfo new_sinfo; - if (tensor->struct_info_.as()) { - new_sinfo = - ConvertToDTensorStructInfo(Downcast(tensor->struct_info_), tensor); - } else if (const auto* tuple = tensor->struct_info_.as()) { - ffi::Array tuple_sinfo_fields; + Type new_ty; + if (tensor->ty.as()) { + new_ty = ConvertToDTensorType(Downcast(tensor->ty), tensor); + } else if (const auto* tuple = tensor->ty.as()) { + ffi::Array tuple_ty_fields; for (int i = 0; i < static_cast(tuple->fields.size()); i++) { - if (tuple->fields[i].as()) { - tuple_sinfo_fields.push_back( - ConvertToDTensorStructInfo(Downcast(tuple->fields[i]), tensor, i)); + if (tuple->fields[i].as()) { + tuple_ty_fields.push_back( + ConvertToDTensorType(Downcast(tuple->fields[i]), tensor, i)); } else { - tuple_sinfo_fields.push_back(tuple->fields[i]); + tuple_ty_fields.push_back(tuple->fields[i]); } } - new_sinfo = TupleStructInfo(tuple_sinfo_fields); + new_ty = TupleType(tuple_ty_fields); } if (const auto* var = tensor.as()) { - Var new_param(var->name_hint(), new_sinfo); + Var new_param(var->name_hint(), new_ty); return new_param; } else if (const auto* constant = tensor.as()) { - Constant new_constant(constant->data, new_sinfo); + Constant new_constant(constant->data, new_ty); return new_constant; } else { TVM_FFI_THROW(InternalError) << "Cannot rewrite tensor which is not a Var or Constant"; @@ -424,7 +422,7 @@ class DistributedIRBuilder : public ExprMutator { // Step 4. Rewrite Function ffi::Array new_params; for (const Var& var : func->params) { - if (GetStructInfoAs(var) || GetStructInfoAs(var)) { + if (GetTypeAs(var) || GetTypeAs(var)) { Var new_param = Downcast(RewriteInputTensorAndConstant(var)); input_tensor_remap_.Set(var, new_param); new_params.push_back(new_param); @@ -455,10 +453,10 @@ class DistributedIRBuilder : public ExprMutator { ffi::ObjectPtr n = ffi::make_object(*new_call.get()); if (new_call->op.same_as(call_tir_op)) { - // do not infer output sinfo when arg size is 0 + // do not infer output type when arg size is 0 if (!args.empty()) { n->args.Set(1, Tuple(args)); - n->sinfo_args = {InferShardingSpec(Call(n), this->builder_, new_call->sinfo_args[0], f)}; + n->ty_args = {InferShardingSpec(Call(n), this->builder_, new_call->ty_args[0], f)}; } } else { n->args = args; @@ -487,64 +485,64 @@ class DistributedIRBuilder : public ExprMutator { return redistribute(expr, device_mesh, placement); } - Call RewriteOutSinfo(Call call, DeviceMesh device_mesh, ffi::Array placements) { - // in cases when infer fails (like arg size is 0), we use propagated sinfo for output + Call RewriteOutType(Call call, DeviceMesh device_mesh, ffi::Array placements) { + // In cases when inference fails (for example, arg size is 0), use the propagated output type. Call new_call = call; static Op call_tir_op = Op::Get("relax.call_tir"); if (const auto* extern_func = call->op.as()) { if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_view") { ffi::ObjectPtr new_call_node = ffi::make_object(*call.get()); - StructInfo new_dtensor_sinfo = DTensorStructInfo( - Downcast(call->sinfo_args[0]), device_mesh, placements[0]); - new_call_node->sinfo_args = {new_dtensor_sinfo}; + Type new_dtensor_ty = + DTensorType(Downcast(call->ty_args[0]), device_mesh, placements[0]); + new_call_node->ty_args = {new_dtensor_ty}; new_call = Call(new_call_node); - new_call->struct_info_ = new_dtensor_sinfo; + new_call->ty = new_dtensor_ty; } } else if (call->op.same_as(call_tir_op)) { - TVM_FFI_ICHECK(call->sinfo_args.size() == 1); - if (!SinfoCompatibleWithDistIR(call->sinfo_args)) { + TVM_FFI_ICHECK(call->ty_args.size() == 1); + if (!TypeCompatibleWithDistIR(call->ty_args)) { ffi::ObjectPtr new_call_node = ffi::make_object(*call.get()); if (placements.size() == 1) { - new_call_node->sinfo_args = {DTensorStructInfo( - Downcast(call->sinfo_args[0]), device_mesh, placements[0])}; + new_call_node->ty_args = { + DTensorType(Downcast(call->ty_args[0]), device_mesh, placements[0])}; } else { - const auto* tuple_sinfo = call->sinfo_args[0].as(); - TVM_FFI_ICHECK(placements.size() == tuple_sinfo->fields.size()); - ffi::Array new_tuple_sinfo_fields; + const auto* tuple_ty = call->ty_args[0].as(); + TVM_FFI_ICHECK(placements.size() == tuple_ty->fields.size()); + ffi::Array new_tuple_ty_fields; for (int i = 0; i < static_cast(placements.size()); i++) { - new_tuple_sinfo_fields.push_back(DTensorStructInfo( - Downcast(tuple_sinfo->fields[i]), device_mesh, placements[i])); + new_tuple_ty_fields.push_back( + DTensorType(Downcast(tuple_ty->fields[i]), device_mesh, placements[i])); } - new_call_node->sinfo_args = {TupleStructInfo(new_tuple_sinfo_fields)}; + new_call_node->ty_args = {TupleType(new_tuple_ty_fields)}; } new_call = Call(new_call_node); - new_call->struct_info_ = new_call_node->sinfo_args[0]; + new_call->ty = new_call_node->ty_args[0]; } } return new_call; } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { - ffi::Array orig_output_tensor_sinfos; - if (const auto* tensor_sinfo = GetStructInfoAs(binding->var)) { - orig_output_tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); - } else if (const auto* tuple_sinfo = GetStructInfoAs(binding->var)) { - for (const auto& sinfo : tuple_sinfo->fields) { - orig_output_tensor_sinfos.push_back(Downcast(sinfo)); + ffi::Array orig_output_tys; + if (const auto* tensor_ty = GetTypeAs(binding->var)) { + orig_output_tys.push_back(ffi::GetRef(tensor_ty)); + } else if (const auto* tuple_ty = GetTypeAs(binding->var)) { + for (const auto& field_ty : tuple_ty->fields) { + orig_output_tys.push_back(Downcast(field_ty)); } } else { ExprMutator::VisitBinding_(binding, val); return; } - // get annotated sinfo from axis group graph + // Get the annotated output type from the axis group graph. DeviceMesh device_mesh = std::get<0>(axis_group_graph_.GetAxisShardingSpec({binding->var.get(), -1})).first; TVM_FFI_ICHECK(device_mesh.defined()); ffi::Array placements; // every tuple element has a placement - for (int idx = 0; idx < static_cast(orig_output_tensor_sinfos.size()); idx++) { + for (int idx = 0; idx < static_cast(orig_output_tys.size()); idx++) { ffi::Array placement_specs( std::vector(device_mesh->shape.size(), PlacementSpec::Replica())); - for (int i = 0; i < orig_output_tensor_sinfos[idx]->ndim; i++) { + for (int i = 0; i < orig_output_tys[idx]->ndim; i++) { AxisShardingSpec sharding_spec; bool has_sharding_spec; std::tie(sharding_spec, has_sharding_spec) = @@ -555,16 +553,16 @@ class DistributedIRBuilder : public ExprMutator { } placements.push_back(Placement(placement_specs)); } - // get inferred sinfo from struct info deduction + // get inferred output type from type deduction Call new_call = Downcast(this->VisitExpr(binding->value)); new_call = - Downcast(builder_->Normalize(RewriteOutSinfo(new_call, device_mesh, placements))); + Downcast(builder_->Normalize(RewriteOutType(new_call, device_mesh, placements))); - if (const auto* inferred_dtensor_sinfo = new_call->struct_info_.as()) { + if (const auto* inferred_dtensor_ty = new_call->ty.as()) { Expr new_value = RemoveAnnotateSharding(new_call); if (!ffi::StructuralEqual()( - DTensorStructInfo(inferred_dtensor_sinfo->tensor_sinfo, device_mesh, placements[0]), - new_call->struct_info_)) { + DTensorType(inferred_dtensor_ty->tensor_ty, device_mesh, placements[0]), + new_call->ty)) { new_value = InsertRedistribute(new_value, device_mesh, placements[0]); } if (const auto* var = new_value.as()) { @@ -573,16 +571,15 @@ class DistributedIRBuilder : public ExprMutator { ReEmitBinding(binding, builder_->Normalize(new_value)); } } else { - const auto* inferred_tuple_sinfo = new_call->struct_info_.as(); - TVM_FFI_ICHECK(inferred_tuple_sinfo) << new_call; + const auto* inferred_tuple_ty = new_call->ty.as(); + TVM_FFI_ICHECK(inferred_tuple_ty) << new_call; Var new_var = builder_->Emit(new_call); var_remap_[binding->var->vid] = new_var; - for (int i = 0; i < static_cast(inferred_tuple_sinfo->fields.size()); i++) { + for (int i = 0; i < static_cast(inferred_tuple_ty->fields.size()); i++) { if (!ffi::StructuralEqual()( - DTensorStructInfo( - Downcast(inferred_tuple_sinfo->fields[i])->tensor_sinfo, - device_mesh, placements[i]), - inferred_tuple_sinfo->fields[i])) { + DTensorType(Downcast(inferred_tuple_ty->fields[i])->tensor_ty, + device_mesh, placements[i]), + inferred_tuple_ty->fields[i])) { Var redistribute_var = builder_->Emit( InsertRedistribute(TupleGetItem(new_var, i), device_mesh, placements[i])); tuple_getitem_remap_[TupleGetItem(binding->var, i)] = redistribute_var; diff --git a/src/relax/distributed/transform/utils.cc b/src/relax/distributed/transform/utils.cc index fb86bc5cae4a..7368a8d03669 100644 --- a/src/relax/distributed/transform/utils.cc +++ b/src/relax/distributed/transform/utils.cc @@ -22,37 +22,37 @@ namespace tvm { namespace relax { namespace distributed { -bool SinfoCompatibleWithDistIR(ffi::Array sinfos) { +bool TypeCompatibleWithDistIR(ffi::Array tys) { bool compatible = true; - for (const auto& sinfo : sinfos) { - if (const auto* tuple_sinfo = sinfo.as()) { - compatible &= SinfoCompatibleWithDistIR(tuple_sinfo->fields); + for (const auto& ty : tys) { + if (const auto* tuple_ty = ty.as()) { + compatible &= TypeCompatibleWithDistIR(tuple_ty->fields); } else { - compatible &= !sinfo->IsInstance(); + compatible &= !ty->IsInstance(); } } return compatible; } -bool SinfoCompatibleWithRelax(ffi::Array sinfos) { +bool TypeCompatibleWithRelax(ffi::Array tys) { bool compatible = true; - for (const auto& sinfo : sinfos) { - if (const auto* tuple_sinfo = sinfo.as()) { - compatible &= SinfoCompatibleWithRelax(tuple_sinfo->fields); + for (const auto& ty : tys) { + if (const auto* tuple_ty = ty.as()) { + compatible &= TypeCompatibleWithRelax(tuple_ty->fields); } else { - compatible &= !sinfo->IsInstance(); + compatible &= !ty->IsInstance(); } } return compatible; } bool IsDistIRFunc(Function func) { - ffi::Array param_sinfos; + ffi::Array param_tys; for (const auto& param : func->params) { - TVM_FFI_ICHECK(param->struct_info_); - param_sinfos.push_back(Downcast(param->struct_info_.value())); + TVM_FFI_ICHECK(param->ty.defined()); + param_tys.push_back(Downcast(param->ty)); } - bool compatible_with_dist_ir = SinfoCompatibleWithDistIR(param_sinfos); - bool compatible_with_relax = SinfoCompatibleWithRelax(param_sinfos); + bool compatible_with_dist_ir = TypeCompatibleWithDistIR(param_tys); + bool compatible_with_relax = TypeCompatibleWithRelax(param_tys); if (compatible_with_relax) { return false; } else if (compatible_with_dist_ir && !compatible_with_relax) { diff --git a/src/relax/distributed/transform/utils.h b/src/relax/distributed/transform/utils.h index f97e17370275..e78dd427e929 100644 --- a/src/relax/distributed/transform/utils.h +++ b/src/relax/distributed/transform/utils.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include namespace tvm { namespace relax { @@ -44,10 +44,10 @@ inline ffi::Optional MatchPrimFunc(const IRModule& mod_, const E return std::nullopt; } /*! - * \brief Check whether the given struct infos can appear in DistIR - * \return Whether the given struct infos can appear in DistIR + * \brief Check whether the given types can appear in DistIR + * \return Whether the given types can appear in DistIR */ -bool SinfoCompatibleWithDistIR(ffi::Array sinfos); +bool TypeCompatibleWithDistIR(ffi::Array tys); /*! * \brief Check whether the given function is a DistIR function diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/type.cc similarity index 84% rename from src/relax/distributed/struct_info.cc rename to src/relax/distributed/type.cc index 42ea8f721aec..895446ff29c3 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/type.cc @@ -18,18 +18,18 @@ */ /*! - * \file src/relax/distributed/struct_info.cc - * \brief Relax dtensor struct info. + * \file src/relax/distributed/type.cc + * \brief Relax DTensor type. */ #include -#include +#include namespace tvm { namespace relax { namespace distributed { TVM_FFI_STATIC_INIT_BLOCK() { - DTensorStructInfoNode::RegisterReflection(); + DTensorTypeNode::RegisterReflection(); PlacementNode::RegisterReflection(); PlacementSpecNode::RegisterReflection(); } @@ -118,19 +118,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // DTensor -DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, - Placement placement, Span span) { +DTensorType::DTensorType(TensorType tensor_ty, DeviceMesh device_mesh, Placement placement, + Span span) { TVM_FFI_CHECK_EQ(device_mesh->shape.size(), placement->dim_specs.size(), ValueError) << "The device mesh and placement must have the same dimension size"; for (auto spec : placement->dim_specs) { if (spec->kind == PlacementSpecKind::kReplica) continue; - TVM_FFI_CHECK_LT(spec->axis, tensor_sinfo->ndim, ValueError) + TVM_FFI_CHECK_LT(spec->axis, tensor_ty->ndim, ValueError) << "Sharding dimension should be smaller than tensor ndim"; } - ffi::ObjectPtr n = ffi::make_object(); + ffi::ObjectPtr n = ffi::make_object(); n->device_mesh = std::move(device_mesh); n->placement = std::move(placement); - n->tensor_sinfo = std::move(tensor_sinfo); + n->tensor_ty = std::move(tensor_ty); n->span = span; data_ = std::move(n); } @@ -138,9 +138,9 @@ DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh d TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.distributed.DTensorStructInfo", - [](TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span) { - return DTensorStructInfo(tensor_sinfo, device_mesh, placement, span); + "relax.distributed.DTensorType", + [](TensorType tensor_ty, DeviceMesh device_mesh, Placement placement, Span span) { + return DTensorType(tensor_ty, device_mesh, placement, span); }); } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index fe6887a97eef..aab0fdf8b453 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -29,10 +29,9 @@ #include #include #include -#include -#include #include #include +#include #include #include @@ -88,18 +87,18 @@ class BlockBuilderImpl : public BlockBuilderNode { } GlobalVar gvar(func_name); - StructInfo finfo; - if (func->struct_info_.defined()) { - finfo = GetStructInfo(func); + Type finfo; + if (func->ty.defined()) { + finfo = GetType(func); } else if (auto* prim_func = func.as()) { - // NOTE: use a slightly different struct info than checked type + // NOTE: use a slightly different type than checked type // in PrimFunc so handle can turn into Tensor. - // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. - finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); + // TODO(relax-team): add fine-grained PrimFunc type signature generation. + finfo = FuncType::OpaqueFunc(TypeFromStaticType(prim_func->ret_type)); } else { - TVM_FFI_THROW(RuntimeError) << "Expect struct_info field to be populated"; + TVM_FFI_THROW(RuntimeError) << "Expect ty field to be populated"; } - UpdateStructInfo(gvar, finfo); + UpdateType(gvar, finfo); context_mod_->Add(gvar, func); @@ -164,11 +163,11 @@ class BlockBuilderImpl : public BlockBuilderNode { void BeginScope(ffi::Optional> params) final { // The current implementation handles the collection of shape var - // defined in parameter struct info annotations. The implementation + // defined in parameter type annotations. The implementation // is correct (since we will simply erase all relax Vars in EraseToWellDefined), // but can be further improved. // - // TODO(relax-team): Add support for relax Var in struct info annotations. + // TODO(relax-team): Add support for relax Var in type annotations. scope_stack_.emplace_back(ScopeFrame()); if (params.defined()) { @@ -194,10 +193,10 @@ class BlockBuilderImpl : public BlockBuilderNode { auto& shape_var_map = CurrentScopeFrame()->shape_var_map; // The current implementation handles the collection of shape var - // defined in parameter struct info annotations. The implementation + // defined in parameter type annotations. The implementation // is correct (since we will simply erase all relax Vars in EraseToWellDefined), // but can be further improved. - ffi::Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + ffi::Map var_map = TypeVarCollector::Collect(GetType(var)); for (const auto& kv : var_map) { const tirx::Var& shape_var = kv.first; const PrimExpr& shape_expr = kv.second; @@ -235,21 +234,20 @@ class BlockBuilderImpl : public BlockBuilderNode { return this->Emit(expr, CurrentBindingBlockFrame()->is_dataflow, name_hint); } - Var EmitMatchCast(Expr value, StructInfo struct_info, ffi::String name_hint) final { + Var EmitMatchCast(Expr value, Type ty, ffi::String name_hint) final { value = this->Normalize(value); - TVM_FFI_ICHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != - BaseCheckResult::kFailL0) - << "It is impossible to match cast any value into the target struct_info. " - "But got value struct info: " - << GetStructInfo(value) << ", given struct info: " << struct_info; + TVM_FFI_ICHECK(TypeBaseCheck(GetType(value), ty) != BaseCheckResult::kFailL0) + << "It is impossible to match cast any value into the target ty. " + "But got value type: " + << GetType(value) << ", given type: " << ty; // NOTE: do match cast checking later in a pass. BindingBlockFrame* cur_frame = CurrentBindingBlockFrame(); Var var = CreateVar(cur_frame->is_dataflow, name_hint); - UpdateStructInfo(var, struct_info); + UpdateType(var, ty); - MatchCast match_cast(var, value, struct_info); + MatchCast match_cast(var, value, ty); cur_frame->bindings.push_back(match_cast); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. @@ -275,8 +273,8 @@ class BlockBuilderImpl : public BlockBuilderNode { << "Cannot emit dataflow var in non-dataflow block"; } // normalized check - TVM_FFI_ICHECK(var_binding->var->struct_info_.defined()); - TVM_FFI_ICHECK(var_binding->value->struct_info_.defined()); + TVM_FFI_ICHECK(var_binding->var->ty.defined()); + TVM_FFI_ICHECK(var_binding->value->ty.defined()); cur_frame->bindings.push_back(binding); binding_table_[var_binding->var->vid] = var_binding->value; } else if (const auto* match_cast = binding.as()) { @@ -285,8 +283,8 @@ class BlockBuilderImpl : public BlockBuilderNode { << "Cannot emit dataflow var in non-dataflow block"; } // normalized check - TVM_FFI_ICHECK(match_cast->var->struct_info_.defined()); - TVM_FFI_ICHECK(match_cast->value->struct_info_.defined()); + TVM_FFI_ICHECK(match_cast->var->ty.defined()); + TVM_FFI_ICHECK(match_cast->value->ty.defined()); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); @@ -390,7 +388,7 @@ class BlockBuilderImpl : public BlockBuilderNode { Var var = CreateVar(is_dataflow, name_hint); // set the values - UpdateStructInfo(var, Downcast(expr->struct_info_.value())); + UpdateType(var, Downcast(expr->ty)); CurrentBindingBlockFrame()->bindings.push_back(VarBinding(var, expr)); @@ -411,8 +409,8 @@ class BlockBuilderImpl : public BlockBuilderNode { name_hint = is_dataflow ? "lv" : "gv"; } Id vid = Id(GetUniqueName(name_hint)); - return is_dataflow ? DataflowVar(vid, /*struct_info_annotation=*/std::nullopt) - : Var(vid, /*struct_info_annotation=*/std::nullopt); + return is_dataflow ? DataflowVar(vid, /*ty_annotation=*/std::nullopt) + : Var(vid, /*ty_annotation=*/std::nullopt); } private: @@ -471,16 +469,16 @@ class BlockBuilderImpl : public BlockBuilderNode { // Collect all the variables that a parameter var can define. // The collector is used to making sure that we record the // shape vars as defined when calling BeginScope(params) - class StructInfoVarCollector : public StructInfoVisitor { + class TypeVarCollector : public TypeVisitor { public: - static ffi::Map Collect(const StructInfo& struct_info) { - StructInfoVarCollector collector; - collector(struct_info); + static ffi::Map Collect(const Type& ty) { + TypeVarCollector collector; + collector(ty); return collector.shape_var_map_; } private: - void VisitStructInfo_(const TensorStructInfoNode* op) final { + void VisitType_(const TensorTypeNode* op) final { if (const auto* shape_expr = op->shape.as()) { for (const PrimExpr& s : shape_expr->values) { // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) @@ -491,7 +489,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } } - void VisitStructInfo_(const ShapeStructInfoNode* op) final { + void VisitType_(const ShapeTypeNode* op) final { for (const PrimExpr& s : op->values.value_or(ffi::Array())) { // Only collect single var defined shape. Ignore something like `R.Shape((m + 1, n + 1)) if (const auto* var = s.as()) { @@ -500,7 +498,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } } - void VisitStructInfo_(const PrimStructInfoNode* op) final { + void VisitType_(const PrimTypeNode* op) final { // Only collect single var defined shape. Ignore something like `R.Prim(value=m + 1)` if (op->value.defined()) { if (auto var = op->value.as()) { @@ -520,14 +518,14 @@ class BlockBuilderImpl : public BlockBuilderNode { #define RELAX_EXPR_NORMALIZER_LEAF(OP) \ Expr VisitExpr_(const OP* op) final { return ffi::GetRef(op); } -// TODO(relax-team): Check normalize logic after struct info. +// TODO(relax-team): Check normalize logic after type. -// Normalizer on struct info: +// Normalizer on type: // // We take benefit of the following invariants(that are checked in constructor): -// - If an expr appears in StructInfo, then it is already normalized. -// As a result, we do not need to peek into StructInfo in Normalization. -// - Constant, ShapeExpr, already have their StructInfo populated in constructing time. +// - If an expr appears in Type, then it is already normalized. +// As a result, we do not need to peek into Type in Normalization. +// - Constant, ShapeExpr, already have their Type populated in constructing time. class Normalizer : public BlockBuilderImpl, private ExprFunctor { public: explicit Normalizer(IRModule context_mod) : BlockBuilderImpl(context_mod) {} @@ -539,11 +537,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorVisitExpr(expr); // Invariant: // After Normalize: an Expr always have - // struct_info (with the exception of Op). + // ty (with the exception of Op). if (!normalized->IsInstance()) { - TVM_FFI_ICHECK(normalized->struct_info_.defined()) - << "The struct_info_ of an Expr except OpNode after " - "normalization must not be nullptr. However, this Expr does not have struct_info_: " + TVM_FFI_ICHECK(normalized->ty.defined()) + << "The ty of an Expr except OpNode after " + "normalization must not be nullptr. However, this Expr does not have ty: " << normalized; } @@ -592,16 +590,15 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor Expr VisitVar_(const typename T::ContainerType* var) { - // Parameters and free-vars must be present with struct info + // Parameters and free-vars must be present with type // Other vars must have already been normalized through binding - TVM_FFI_ICHECK(var->struct_info_.defined()) - << "Var " << var->name_hint() << " does not have struct info."; + TVM_FFI_ICHECK(var->ty.defined()) << "Var " << var->name_hint() << " does not have type."; return ffi::GetRef(var); } Expr VisitExpr_(const VarNode* var_ptr) final { auto var = VisitVar_(var_ptr); - if (HasVoidStructInfo(var)) { + if (HasVoidType(var)) { return VisitExpr(Tuple(ffi::Array{})); } else { return var; @@ -634,12 +631,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(op) : Tuple(new_fields, op->span); // Update tuple fields. - if (!tuple->struct_info_.defined()) { - ffi::Array tuple_sinfo; + if (!tuple->ty.defined()) { + ffi::Array tuple_ty; for (Expr field : tuple->fields) { - tuple_sinfo.push_back(GetStructInfo(field)); + tuple_ty.push_back(GetType(field)); } - UpdateStructInfo(tuple, TupleStructInfo(tuple_sinfo, op->span)); + UpdateType(tuple, TupleType(tuple_ty, op->span)); } return tuple; } @@ -650,7 +647,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorbody)) { return ffi::GetRef(op); } else { - return Function(op->params, new_body, op->ret_struct_info, op->is_pure, op->attrs); + return Function(op->params, new_body, op->ret_ty, op->is_pure, op->attrs); } } @@ -664,12 +661,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorop) && new_args.same_as(op->args)) { call = ffi::GetRef(op); } else { - call = Call(new_op, new_args, op->attrs, op->sinfo_args); + call = Call(new_op, new_args, op->attrs, op->ty_args); } - if (!call->struct_info_.defined()) { - auto inferred_sinfo = InferStructInfo(call); - UpdateStructInfo(call, inferred_sinfo); + if (!call->ty.defined()) { + auto inferred_ty = InferType(call); + UpdateType(call, inferred_ty); } // If the operation has defined a custom normalization @@ -730,8 +727,8 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorstruct_info_.defined()) { - UpdateStructInfo(seq_expr, EraseToWellDefinedInScope(GetStructInfo(seq_expr->body))); + if (!seq_expr->ty.defined()) { + UpdateType(seq_expr, EraseToWellDefinedInScope(GetType(seq_expr->body))); } return seq_expr; } @@ -748,10 +745,10 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorspan); } - if (!if_node->struct_info_.defined()) { - auto true_info = EraseToWellDefinedInScope(GetStructInfo(new_true)); - auto false_info = EraseToWellDefinedInScope(GetStructInfo(new_false)); - UpdateStructInfo(if_node, StructInfoLCA(true_info, false_info)); + if (!if_node->ty.defined()) { + auto true_info = EraseToWellDefinedInScope(GetType(new_true)); + auto false_info = EraseToWellDefinedInScope(GetType(new_false)); + UpdateType(if_node, TypeLCA(true_info, false_info)); } return if_node; } @@ -762,12 +759,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctortuple) ? ffi::GetRef(op) : TupleGetItem(new_tuple, op->index); - if (!node->struct_info_.defined()) { - auto opt = MatchStructInfo(node->tuple); - TVM_FFI_ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo, " - << "but expression " << node->tuple << " has struct info " - << node->tuple->struct_info_; - UpdateStructInfo(node, opt.value()->fields[node->index]); + if (!node->ty.defined()) { + auto opt = MatchType(node->tuple); + TVM_FFI_ICHECK(opt) << "The type of Tuple must be TupleType, " + << "but expression " << node->tuple << " has type " << node->tuple->ty; + UpdateType(node, opt.value()->fields[node->index]); } return node; @@ -788,8 +784,8 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorvalue)) { binding = VarBinding(binding->var, new_value, binding->span); } - if (!binding->var->struct_info_.defined()) { - UpdateStructInfo(binding->var, GetStructInfo(new_value)); + if (!binding->var->ty.defined()) { + UpdateType(binding->var, GetType(new_value)); } return binding; } @@ -797,10 +793,10 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorVisitExpr(binding->value); if (!new_value.same_as(binding->value)) { - binding = MatchCast(binding->var, new_value, binding->struct_info, binding->span); + binding = MatchCast(binding->var, new_value, binding->ty, binding->span); } - if (!binding->var->struct_info_.defined()) { - UpdateStructInfo(binding->var, binding->struct_info); + if (!binding->var->ty.defined()) { + UpdateType(binding->var, binding->ty); } return binding; } @@ -829,41 +825,41 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorop.as()) { - // Case 1: the op field is a primitive op, look up FInferStructInfo attribute + // Case 1: the op field is a primitive op, look up FInferType attribute Op op = ffi::GetRef(op_ptr); bool is_dist_op = false; for (const auto& arg : call->args) { - if (arg->struct_info_.as()) { + if (arg->ty.as()) { is_dist_op = true; break; } } if (is_dist_op) { for (const auto& arg : call->args) { - TVM_FFI_ICHECK(!arg->struct_info_.as()) + TVM_FFI_ICHECK(!arg->ty.as()) << "Distributed operator must take DTensor instead of Tensor as input"; } - TVM_FFI_ICHECK(op_map_dist_infer_struct_info_.count(op)) - << " Cannot find the dist.FInferStructInfo attribute registered to op: " << op->name; - return op_map_dist_infer_struct_info_[op](call, ffi::GetRef(this)); + TVM_FFI_ICHECK(op_map_dist_infer_ty.count(op)) + << " Cannot find the dist.FInferType attribute registered to op: " << op->name; + return op_map_dist_infer_ty[op](call, ffi::GetRef(this)); } - TVM_FFI_ICHECK(op_map_infer_struct_info_.count(op)) - << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; - return op_map_infer_struct_info_[op](call, ffi::GetRef(this)); + TVM_FFI_ICHECK(op_map_infer_ty.count(op)) + << " Cannot find the FInferType attribute registered to op: " << op->name; + return op_map_infer_ty[op](call, ffi::GetRef(this)); } else { // derive using function parameters - TVM_FFI_ICHECK(call->op->struct_info_.defined()); - auto opt = MatchStructInfo(call->op); - TVM_FFI_ICHECK(opt) << "Call->op must contains a function struct info"; - FuncStructInfo finfo = opt.value(); - return DeriveCallRetStructInfo(finfo, call, ffi::GetRef(this), analyzer_); + TVM_FFI_ICHECK(call->op->ty.defined()); + auto opt = MatchType(call->op); + TVM_FFI_ICHECK(opt) << "Call->op must contains a function type"; + FuncType finfo = opt.value(); + return DeriveCallRetType(finfo, call, ffi::GetRef(this), analyzer_); } } // erase to well defined within current scope. - StructInfo EraseToWellDefinedInScope(StructInfo info) { + Type EraseToWellDefinedInScope(Type info) { if (scope_stack_.empty()) { // If no scopes are active, then this fragment does not require // any normalization. @@ -905,7 +901,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorbody))); + UpdateType(seq, EraseToWellDefinedInScope(GetType(seq->body))); ret = seq; } @@ -982,7 +978,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { current.push_back(VarBinding(var_binding->var, seq->body)); } else if (const auto* match_cast = binding.as()) { - current.push_back(MatchCast(match_cast->var, seq->body, match_cast->struct_info)); + current.push_back(MatchCast(match_cast->var, seq->body, match_cast->ty)); } else { TVM_FFI_THROW(InternalError) << "Unknown binding type: " << binding->GetTypeKey(); } @@ -1032,11 +1028,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor op_map_infer_struct_info_ = - Op::GetAttrMap("FInferStructInfo"); - tvm::OpAttrMap op_map_dist_infer_struct_info_ = - Op::GetAttrMap("dist.FInferStructInfo"); + /*! \brief Operator type inference map. */ + tvm::OpAttrMap op_map_infer_ty = Op::GetAttrMap("FInferType"); + tvm::OpAttrMap op_map_dist_infer_ty = Op::GetAttrMap("dist.FInferType"); /*! \brief Operator normalization function */ tvm::OpAttrMap op_map_normalize_ = Op::GetAttrMap("FNormalize"); @@ -1075,8 +1069,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { return builder->Emit(expr, name_hint); }) .def("relax.BlockBuilderEmitMatchCast", - [](BlockBuilder builder, Expr value, StructInfo struct_info, ffi::String name_hint) { - return builder->EmitMatchCast(value, struct_info, name_hint); + [](BlockBuilder builder, Expr value, Type ty, ffi::String name_hint) { + return builder->EmitMatchCast(value, ty, name_hint); }) .def("relax.BlockBuilderEmitOutput", [](BlockBuilder builder, const Expr& output, ffi::String name_hint) { diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index 7344d05ec7e8..1d68fa7adc4e 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 625ae1e76416..cb59f566e613 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include @@ -697,19 +697,19 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { } } - auto sinfo_pattern = GetStructInfo(func_pattern); - auto sinfo_replacement = GetStructInfo(func_replacement); - TVM_FFI_CHECK(ffi::StructuralEqual()(sinfo_pattern, sinfo_replacement), ValueError) + auto ty_pattern = GetType(func_pattern); + auto ty_replacement = GetType(func_replacement); + TVM_FFI_CHECK(ffi::StructuralEqual()(ty_pattern, ty_replacement), ValueError) << "The pattern and replacement must have the same signature, " - << "but the pattern has struct info " << sinfo_pattern - << ", while the replacement has struct info " << sinfo_replacement; + << "but the pattern has type " << ty_pattern << ", while the replacement has type " + << ty_replacement; ffi::Array param_wildcards; ffi::Map pattern_lookup; for (const auto& param : func_pattern->params) { WildcardPattern wildcard; param_wildcards.push_back(wildcard); - pattern_lookup.Set(param, StructInfoPattern(wildcard, GetStructInfo(param))); + pattern_lookup.Set(param, TypePattern(wildcard, GetType(param))); } std::function make_pattern = [&](Expr expr) -> DFPattern { @@ -736,7 +736,7 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { return ExternFuncPattern(func->global_symbol); } else if (auto prim = expr.as()) { - return StructInfoPattern(WildcardPattern(), PrimStructInfo(prim->value)); + return TypePattern(WildcardPattern(), PrimType(prim->value)); } else { TVM_FFI_THROW(TypeError) << "Cannot convert Relax expression of type " << expr->GetTypeKey() @@ -748,7 +748,7 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { for (const auto& binding : block->bindings) { auto value_pattern = make_pattern(GetBoundValue(binding)); if (auto match_cast = binding.as()) { - value_pattern = StructInfoPattern(value_pattern, match_cast->struct_info); + value_pattern = TypePattern(value_pattern, match_cast->ty); } pattern_lookup.Set(binding->var, value_pattern); } @@ -772,10 +772,10 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { // Introduce an intermediate variable, to ensure that the // MatchCast's target will be a Var, even for expressions that // wouldn't normally be normalized into a variable. - Var intermediate_var("intermediate_var", GetStructInfo(matched_expr)); + Var intermediate_var("intermediate_var", GetType(matched_expr)); wildcard_bindings.push_back(VarBinding(intermediate_var, matched_expr)); wildcard_bindings.push_back( - MatchCast(func_replacement->params[i], intermediate_var, GetStructInfo(matched_expr))); + MatchCast(func_replacement->params[i], intermediate_var, GetType(matched_expr))); } new_blocks.push_back(DataflowBlock(wildcard_bindings)); @@ -874,7 +874,7 @@ class PatternMatchingMutator : public ExprMutator { // simplifies the special handling of the SeqExpr's body. ffi::Optional dummy_output_var = std::nullopt; if (!seq->body->IsInstance()) { - dummy_output_var = Var("dummy_output_var", GetStructInfo(seq->body)); + dummy_output_var = Var("dummy_output_var", GetType(seq->body)); VarBinding dummy_binding(dummy_output_var.value(), seq->body); auto last_block = [&]() { @@ -902,7 +902,7 @@ class PatternMatchingMutator : public ExprMutator { auto bindings = orig_bindings.Map([&](Binding binding) -> Binding { if (auto new_expr = rewrites.variable_rewrites.Get(binding->var)) { if (auto match_cast = binding.as()) { - return MatchCast(binding->var, new_expr.value(), match_cast->struct_info); + return MatchCast(binding->var, new_expr.value(), match_cast->ty); } else { return VarBinding(binding->var, new_expr.value()); } @@ -996,7 +996,7 @@ class PatternMatchingMutator : public ExprMutator { if (binding.as()) { builder_->EmitNormalized(VarBinding(binding->var, value)); } else if (auto match_cast = binding.as()) { - builder_->EmitNormalized(MatchCast(binding->var, value, match_cast->struct_info)); + builder_->EmitNormalized(MatchCast(binding->var, value, match_cast->ty)); } else { TVM_FFI_THROW(InternalError) << "Binding must be either VarBinding or MatchCast"; } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index d44e6ae42ce1..08689bd10f0b 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include @@ -428,15 +428,15 @@ bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, cons return false; } -bool DFPatternMatcher::VisitDFPattern_(const StructInfoPatternNode* op, const Expr& expr0) { +bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { if (!VisitDFPattern(op->pattern, expr0)) { return false; } auto expr = UnwrapBindings(expr0, var2val_); - auto expr_struct_info = GetStructInfo(expr); + auto expr_ty = GetType(expr); - PrimExpr new_constraint = StructInfoBaseCheckPrecondition(op->struct_info, expr_struct_info); + PrimExpr new_constraint = TypeBaseCheckPrecondition(op->ty, expr_ty); if (auto* as_int = new_constraint.as()) { return as_int->value; } @@ -490,7 +490,7 @@ static bool ShapeEqual(AnalyzerObj* analyzer, const ffi::Array& lhs, bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) { // no need to jump, as var.shape == value.shape - if (const auto* tinfo = GetStructInfoAs(expr)) { + if (const auto* tinfo = GetTypeAs(expr)) { if (const ShapeExprNode* shape_expr = tinfo->shape.as()) { return ShapeEqual(analyzer_.get(), op->shape, shape_expr->values) && VisitDFPattern(op->pattern, expr); @@ -511,10 +511,10 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( if (auto opt_var = match_state(arg.get())) { auto var = opt_var.value(); auto opt_var_shape = [&]() -> ffi::Optional> { - auto sinfo = GetStructInfo(var); - if (auto tensor = sinfo.as()) { + auto ty = GetType(var); + if (auto tensor = ty.as()) { return tensor->GetShape(); - } else if (auto shape_expr = sinfo.as()) { + } else if (auto shape_expr = ty.as()) { return shape_expr->values; } else { return std::nullopt; @@ -571,9 +571,9 @@ bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) { // no need to jump, as var.dtype == value.dtype - auto expr_sinfo = expr.as()->struct_info_; - if (const TensorStructInfoNode* tensor_sinfo = expr_sinfo.as()) { - return (ffi::StructuralEqual()(op->dtype, tensor_sinfo->dtype)) && + auto expr_ty = expr.as()->ty; + if (const TensorTypeNode* tensor_ty = expr_ty.as()) { + return (ffi::StructuralEqual()(op->dtype, tensor_ty->dtype)) && VisitDFPattern(op->pattern, expr); } return false; diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index e9833d9b297b..b02bf82e177c 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -63,7 +63,7 @@ class DFPatternMatcher : public DFPatternFunctorstream << "*"; }); -StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo struct_info) { - ffi::ObjectPtr n = ffi::make_object(); +TypePattern::TypePattern(DFPattern pattern, Type ty) { + ffi::ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); - n->struct_info = std::move(struct_info); + n->ty = std::move(ty); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.StructInfoPattern", - [](DFPattern pattern, StructInfo struct_info) { - return StructInfoPattern(pattern, struct_info); - }); + refl::GlobalDef().def("relax.dpl.TypePattern", + [](DFPattern pattern, Type ty) { return TypePattern(pattern, ty); }); } -RELAX_PATTERN_PRINTER_DEF(StructInfoPatternNode, [](auto p, auto node) { - p->stream << "StructInfoPattern(" << node->pattern << " has relax StructInfo " - << node->struct_info << ")"; +RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto node) { + p->stream << "TypePattern(" << node->pattern << " has relax Type " << node->ty << ")"; }); ShapePattern::ShapePattern(DFPattern pattern, ffi::Array shape) { @@ -448,8 +445,8 @@ class DFPatternDuplicator : public DFPatternFunctor DFPattern VisitDFPattern_(const ShapePatternNode* op) override { return ShapePattern(op->pattern, op->shape); } - DFPattern VisitDFPattern_(const StructInfoPatternNode* op) override { - return StructInfoPattern(op->pattern, op->struct_info); + DFPattern VisitDFPattern_(const TypePatternNode* op) override { + return TypePattern(op->pattern, op->ty); } DFPattern VisitDFPattern_(const DataflowVarPatternNode* op) override { @@ -476,9 +473,7 @@ NotPattern DFPattern::operator~() const { return NotPattern(*this); } AttrPattern DFPattern::HasAttr(const ffi::Map& attrs) const { return AttrPattern(*this, DictAttrs(attrs)); } -StructInfoPattern DFPattern::HasStructInfo(const StructInfo& struct_info) const { - return StructInfoPattern(*this, struct_info); -} +TypePattern DFPattern::HasType(const Type& ty) const { return TypePattern(*this, ty); } DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { return DataTypePattern(*this, dtype); } diff --git a/src/relax/ir/dataflow_pattern_functor.cc b/src/relax/ir/dataflow_pattern_functor.cc index 7179d6bc83a1..d269196dc32f 100644 --- a/src/relax/ir/dataflow_pattern_functor.cc +++ b/src/relax/ir/dataflow_pattern_functor.cc @@ -96,9 +96,7 @@ void DFPatternVisitor::VisitDFPattern_(const UnorderedTuplePatternNode* op) { } } -void DFPatternVisitor::VisitDFPattern_(const StructInfoPatternNode* op) { - VisitDFPattern(op->pattern); -} +void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } // leaf nodes. void DFPatternVisitor::VisitDFPattern_(const PrimArrPatternNode* op) {} diff --git a/src/relax/ir/dependent_type.cc b/src/relax/ir/dependent_type.cc new file mode 100644 index 000000000000..9a9cccea98a8 --- /dev/null +++ b/src/relax/ir/dependent_type.cc @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dependent_type.cc + * \brief Relax dependent type nodes. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +TVM_FFI_STATIC_INIT_BLOCK() { + DependentTypeNode::RegisterReflection(); + ObjectTypeNode::RegisterReflection(); + PrimTypeNode::RegisterReflection(); + ShapeTypeNode::RegisterReflection(); + TensorTypeNode::RegisterReflection(); + FuncTypeNode::RegisterReflection(); +} + +ObjectType::ObjectType(Span span) { + ffi::ObjectPtr n = ffi::make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ObjectType", [](Span span) { return ObjectType(span); }); +} + +// Prim +PrimType::PrimType(PrimExpr value, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + n->dtype = value->dtype; + n->value = std::move(value); + n->span = span; + data_ = std::move(n); +} + +PrimType::PrimType(DataType dtype, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + n->dtype = dtype; + n->value = std::nullopt; + n->span = span; + data_ = std::move(n); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.PrimTypeFromDtype", + [](DataType dtype, Span span) { return PrimType(dtype, span); }) + .def("relax.PrimTypeFromValue", + [](PrimExpr value, Span span) { return PrimType(value, span); }); +} + +// Shape +ShapeType::ShapeType(ffi::Array values, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + n->ndim = static_cast(values.size()); + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeType can only have dtype of int64"; + return value; + }); + n->span = span; + data_ = std::move(n); +} + +ShapeType::ShapeType(int ndim, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + TVM_FFI_ICHECK_GE(ndim, -1) << "ndim of ShapeType must be >= -1, but got " << ndim; + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.ShapeType", [](ffi::Optional> values, int ndim, Span span) { + if (values.defined()) { + TVM_FFI_CHECK_EQ(ndim, kUnknownNDim, ValueError) << "Cannot both specify values and ndim"; + return ShapeType(values.value(), span); + } else { + return ShapeType(ndim, span); + } + }); +} + +// Tensor +TensorType::TensorType(Expr shape, DataType dtype, ffi::Optional vdevice, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + // assign ndim before move + ffi::Optional shape_ty = MatchType(shape); + TVM_FFI_ICHECK(shape_ty) << "We expect shape to contain pre-set shape type"; + TVM_FFI_ICHECK(shape.defined()) << "Must provide a shape in this constructor"; + TVM_FFI_ICHECK(shape->IsInstance() || shape->IsInstance()) + << "We require shape to be normalized when constructing TensorType"; + n->ndim = shape_ty.value()->ndim; + // assign rest of the fields. + n->shape = std::move(shape); + n->dtype = dtype; + n->vdevice = vdevice; + n->span = span; + data_ = std::move(n); +} + +TensorType::TensorType(DataType dtype, int ndim, ffi::Optional vdevice, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + TVM_FFI_ICHECK_GE(ndim, -1) << "ndim of TensorType must be >= -1, but got " << ndim; + n->ndim = ndim; + n->dtype = dtype; + n->vdevice = vdevice; + n->span = span; + data_ = std::move(n); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "relax.TensorType", [](ffi::Optional shape, ffi::Optional dtype, int ndim, + VDevice vdevice, Span span) { + if (shape.defined()) { + TVM_FFI_CHECK_EQ(ndim, kUnknownNDim, ValueError) << "Cannot both specify shape and ndim"; + return TensorType(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); + } else { + return TensorType(dtype.value_or(DataType::Void()), ndim, vdevice, span); + } + }); +} + +// Func +FuncType::FuncType(ffi::Array params, Type ret, bool purity, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + n->params = std::move(params); + n->ret = std::move(ret); + n->purity = std::move(purity); + n->span = span; + data_ = std::move(n); +} + +FuncType FuncType::OpaqueFunc(TypeDeriveFunc derive_func, bool purity, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + n->derive_func = std::move(derive_func); + n->ret = ObjectType(); + n->purity = std::move(purity); + n->span = span; + return FuncType(n); +} + +FuncType FuncType::OpaqueFunc(Type ret, bool purity, Span span) { + ffi::ObjectPtr n = ffi::make_object(); + n->ret = std::move(ret); + n->purity = std::move(purity); + n->span = span; + return FuncType(n); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.FuncType", [](ffi::Array params, Type ret, bool purity, + Span span) { return FuncType(params, ret, purity, span); }) + .def("relax.FuncTypeOpaqueFunc", [](ffi::Optional ret, + ffi::Optional derive_func, bool purity, + Span span) { + if (derive_func.defined()) { + TVM_FFI_CHECK(!ret.defined(), ValueError) << "Cannot specify both ret and derive_func"; + return FuncType::OpaqueFunc(derive_func.value(), purity, span); + } else { + return FuncType::OpaqueFunc(ret.value_or(ObjectType()), purity, span); + } + }); +} + +// Helper functions +void UpdateType(Expr expr, Type ty) { + TVM_FFI_ICHECK(!expr->ty.defined()) << "To ensure idempotency, " + << "the expression passed to UpdateType " + << "must not have any prior type. " + << "However, expression " << expr << " has type " << expr->ty + << ", which cannot be overwritten with " << ty; + expr->ty = ty; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("relax.UpdateType", [](Expr expr, Type ty) { UpdateType(expr, ty); }) + .def("ir.ExprType", [](Expr expr) { return GetType(expr); }); +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index e8b99a21ddcd..304911c1dca2 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -23,7 +23,7 @@ #include "./emit_te.h" #include -#include +#include #include namespace tvm { @@ -54,18 +54,17 @@ te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std:: n->shape = std::move(shape); return te::PlaceholderOp(n).output(0); } - TVM_FFI_ICHECK(value->struct_info_.defined()) - << "value must be normalized and contain StructInfo"; - auto* tensor_sinfo = GetStructInfoAs(value); - TVM_FFI_ICHECK(tensor_sinfo) << "Value must be a tensor"; - auto* shape_expr = tensor_sinfo->shape.as(); + TVM_FFI_ICHECK(value->ty.defined()) << "value must be normalized and contain Type"; + auto* tensor_ty = GetTypeAs(value); + TVM_FFI_ICHECK(tensor_ty) << "Value must be a tensor"; + auto* shape_expr = tensor_ty->shape.as(); TVM_FFI_CHECK(shape_expr, ValueError) << "Expression does not have an known symbolic shape, please consider use " "match_cast " << "to constrain the shape before passing into te_tensor"; n->shape = shape_expr->values.Map( [&tir_var_map](const PrimExpr& e) { return tirx::Substitute(e, tir_var_map); }); - n->dtype = tensor_sinfo->dtype; + n->dtype = tensor_ty->dtype; return te::PlaceholderOp(n).output(0); } diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index c7bd5061217b..ae292a673dff 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -56,7 +56,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { /*! * \brief Create a TE tensor from relax expression, with TIR variables in the * tensor shape substituted by the given mapping. - * \param value The relax expression, which is required to have TensorStructInfo. + * \param value The relax expression, which is required to have TensorType. * \param tir_var_map The mapping to substitute the TIR variables appeared in the * shape of the input Expr. * \param name The name of the created tensor. diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index a6fd7636f15f..d84ddf2bae58 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -18,8 +18,8 @@ */ #include #include +#include #include -#include #include #include @@ -56,32 +56,29 @@ Id::Id(ffi::String name_hint) { data_ = std::move(n); } -Call::Call(Expr op, ffi::Array args, Attrs attrs, ffi::Array sinfo_args, - Span span) { - TVM_FFI_CHECK(!op->struct_info_.defined() || op->struct_info_->IsInstance(), - ValueError) - << "Call expects its operator to have FuncStructInfo, " - << "but operator " << op << ", which was called with arguments " << args - << ", has struct info " << op->struct_info_; +Call::Call(Expr op, ffi::Array args, Attrs attrs, ffi::Array ty_args, Span span) { + TVM_FFI_CHECK(!op->ty.defined() || op->ty->IsInstance(), ValueError) + << "Call expects its operator to have FuncType, " + << "but operator " << op << ", which was called with arguments " << args << ", has type " + << op->ty; ffi::ObjectPtr n = ffi::make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); - n->sinfo_args = std::move(sinfo_args); + n->ty_args = std::move(ty_args); n->span = std::move(span); data_ = std::move(n); } Call WithFields(Call call, ffi::Optional opt_op, ffi::Optional> opt_args, - ffi::Optional opt_attrs, - ffi::Optional> opt_sinfo_args, + ffi::Optional opt_attrs, ffi::Optional> opt_ty_args, ffi::Optional opt_span) { // Collect new values for fields. Expr op = opt_op.value_or(call->op); ffi::Array args = opt_args.value_or(call->args); Attrs attrs = opt_attrs.value_or(call->attrs); - ffi::Array sinfo_args = opt_sinfo_args.value_or(call->sinfo_args); + ffi::Array ty_args = opt_ty_args.value_or(call->ty_args); Span span = opt_span.value_or(call->span); // Check if anything changed. @@ -96,9 +93,9 @@ Call WithFields(Call call, ffi::Optional opt_op, ffi::Optionalsinfo_args.size()) { - for (size_t i = 0; i < sinfo_args.size(); i++) { - unchanged &= sinfo_args[i].same_as(call->sinfo_args[i]); + if (ty_args.size() == call->ty_args.size()) { + for (size_t i = 0; i < ty_args.size(); i++) { + unchanged &= ty_args[i].same_as(call->ty_args[i]); } } else { unchanged = false; @@ -111,7 +108,7 @@ Call WithFields(Call call, ffi::Optional opt_op, ffi::Optionalop = op; cow_call_node->args = args; cow_call_node->attrs = attrs; - cow_call_node->sinfo_args = sinfo_args; + cow_call_node->ty_args = ty_args; cow_call_node->span = span; } return call; @@ -119,10 +116,9 @@ Call WithFields(Call call, ffi::Optional opt_op, ffi::Optional args, Attrs attrs, - ffi::Array sinfo_args, Span span) { - return Call(op, args, attrs, sinfo_args, span); - }); + refl::GlobalDef().def("relax.Call", + [](Expr op, ffi::Array args, Attrs attrs, ffi::Array ty_args, + Span span) { return Call(op, args, attrs, ty_args, span); }); } If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { @@ -162,22 +158,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { } Tuple::Tuple(tvm::ffi::Array fields, Span span) { - ffi::Optional tuple_sinfo = [&]() -> ffi::Optional { - ffi::Array field_sinfo; + ffi::Optional tuple_ty = [&]() -> ffi::Optional { + ffi::Array field_ty; for (const auto& field : fields) { - if (field->struct_info_.defined()) { - field_sinfo.push_back(GetStructInfo(field)); + if (field->ty.defined()) { + field_ty.push_back(GetType(field)); } else { return std::nullopt; } } - return TupleStructInfo(field_sinfo); + return TupleType(field_ty); }(); ffi::ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); n->span = std::move(span); - n->struct_info_ = tuple_sinfo; + if (tuple_ty.defined()) { + n->ty = tuple_ty.value(); + } data_ = std::move(n); } @@ -215,12 +213,12 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { << " cannot be accessed with negative index " << index; ffi::ObjectPtr n = ffi::make_object(); - if (auto* tuple_info = tuple->struct_info_.as()) { + if (auto* tuple_info = tuple->ty.as()) { TVM_FFI_ICHECK_LT(index, tuple_info->fields.size()) << "Index out of bounds: Tuple " << tuple << " is of size " << tuple_info->fields.size() << ", and cannot be accessed with index " << index; - auto sinfo = tuple_info->fields[index]; - n->struct_info_ = sinfo; + auto ty = tuple_info->fields[index]; + n->ty = ty; } n->tuple = std::move(tuple); n->index = index; @@ -259,11 +257,11 @@ ShapeExpr::ShapeExpr(ffi::Array values, Span span) { return tvm::cast(DataType::Int(64), value); } TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) - << "the value in ShapeStructInfo can only have dtype of int64"; + << "the value in ShapeType can only have dtype of int64"; return value; }); n->span = span; - n->struct_info_ = ShapeStructInfo(values, span); + n->ty = ShapeType(values, span); data_ = std::move(n); } @@ -274,10 +272,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -Var::Var(Id vid, ffi::Optional struct_info_annotation, Span span) { +Var::Var(Id vid, ffi::Optional ty_annotation, Span span) { ffi::ObjectPtr n = ffi::make_object(); n->vid = std::move(vid); - n->struct_info_ = std::move(struct_info_annotation); + if (ty_annotation.defined()) { + n->ty = std::move(ty_annotation.value()); + } n->span = std::move(span); data_ = std::move(n); } @@ -304,17 +304,19 @@ VarNode* Var::CopyOnWrite() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("relax.Var", [](ffi::String name_hint, ffi::Optional struct_info_annotation, - Span span) { return Var(name_hint, struct_info_annotation, span); }) - .def("relax.VarFromId", [](Id vid, ffi::Optional struct_info_annotation, - Span span) { return Var(vid, struct_info_annotation, span); }); + .def("relax.Var", [](ffi::String name_hint, ffi::Optional ty_annotation, + Span span) { return Var(name_hint, ty_annotation, span); }) + .def("relax.VarFromId", [](Id vid, ffi::Optional ty_annotation, Span span) { + return Var(vid, ty_annotation, span); + }); } -DataflowVar::DataflowVar(Id vid, ffi::Optional struct_info_annotation, Span span) { +DataflowVar::DataflowVar(Id vid, ffi::Optional ty_annotation, Span span) { ffi::ObjectPtr n = ffi::make_object(); n->vid = std::move(vid); - n->struct_info_ = std::move(struct_info_annotation); - n->span = std::move(span); + if (ty_annotation.defined()) { + n->ty = std::move(ty_annotation.value()); + } n->span = std::move(span); data_ = std::move(n); } @@ -323,32 +325,30 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.DataflowVar", - [](ffi::String name_hint, ffi::Optional struct_info_annotation, Span span) { - return DataflowVar(name_hint, struct_info_annotation, span); + [](ffi::String name_hint, ffi::Optional ty_annotation, Span span) { + return DataflowVar(name_hint, ty_annotation, span); }) - .def("relax.DataflowVarFromId", - [](Id vid, ffi::Optional struct_info_annotation, Span span) { - return DataflowVar(vid, struct_info_annotation, span); - }); + .def("relax.DataflowVarFromId", [](Id vid, ffi::Optional ty_annotation, Span span) { + return DataflowVar(vid, ty_annotation, span); + }); } -Constant::Constant(runtime::Tensor data, ffi::Optional struct_info_annotation, - Span span) { +Constant::Constant(runtime::Tensor data, ffi::Optional ty_annotation, Span span) { ffi::ObjectPtr n = ffi::make_object(); n->data = std::move(data); n->span = std::move(span); - // set struct info. + // set type. ffi::Array values; auto shape_tuple = n->data.Shape(); for (size_t dim = 0; dim < shape_tuple.size(); ++dim) { values.push_back(IntImm::Int64(shape_tuple[dim])); } - if (struct_info_annotation.defined()) { - n->struct_info_ = struct_info_annotation.value(); + if (ty_annotation.defined()) { + n->ty = ty_annotation.value(); } else { - TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), span); - n->struct_info_ = tinfo; + TensorType tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), span); + n->ty = tinfo; } data_ = std::move(n); @@ -356,15 +356,14 @@ Constant::Constant(runtime::Tensor data, ffi::Optional struct_info_a TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "relax.Constant", - [](runtime::Tensor data, ffi::Optional struct_info_annotation = std::nullopt, - Span span = Span()) { return Constant(data, struct_info_annotation, span); }); + refl::GlobalDef().def("relax.Constant", + [](runtime::Tensor data, ffi::Optional ty_annotation = std::nullopt, + Span span = Span()) { return Constant(data, ty_annotation, span); }); } PrimValue::PrimValue(PrimExpr value, Span span) { ffi::ObjectPtr n = ffi::make_object(); - n->struct_info_ = PrimStructInfo(value); + n->ty = PrimType(value); n->value = std::move(value); n->span = std::move(span); data_ = std::move(n); @@ -384,7 +383,7 @@ StringImm::StringImm(ffi::String value, Span span) { ffi::ObjectPtr n = ffi::make_object(); n->value = std::move(value); n->span = std::move(span); - n->struct_info_ = ObjectStructInfo(); + n->ty = ObjectType(); data_ = std::move(n); } @@ -398,7 +397,7 @@ DataTypeImm::DataTypeImm(DataType value, Span span) { ffi::ObjectPtr n = ffi::make_object(); n->value = std::move(value); n->span = std::move(span); - n->struct_info_ = ObjectStructInfo(); + n->ty = ObjectType(); data_ = std::move(n); } @@ -408,22 +407,21 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](DataType value, Span span) { return DataTypeImm(value, span); }); } -MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { +MatchCast::MatchCast(Var var, Expr value, Type ty, Span span) { ffi::ObjectPtr n = ffi::make_object(); TVM_FFI_ICHECK(var.defined()) << "MatchCast requires var to be defined"; n->var = std::move(var); n->value = std::move(value); - n->struct_info = std::move(struct_info); + n->ty = std::move(ty); n->span = span; data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.MatchCast", - [](Var var, Expr value, StructInfo struct_info, Span span) { - return MatchCast(var, value, struct_info, span); - }); + refl::GlobalDef().def("relax.MatchCast", [](Var var, Expr value, Type ty, Span span) { + return MatchCast(var, value, ty, span); + }); } VarBinding::VarBinding(Var var, Expr value, Span span) { @@ -539,44 +537,42 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -Function::Function(ffi::Array params, Expr body, ffi::Optional ret_struct_info, - bool is_pure, DictAttrs attrs, Span span) { +Function::Function(ffi::Array params, Expr body, ffi::Optional ret_ty, bool is_pure, + DictAttrs attrs, Span span) { // Set the function type. // For function, we take a conservative approach and require the function type // to be known at construction time. - ffi::Array param_sinfo; + ffi::Array param_ty; for (const Var& param : params) { - TVM_FFI_ICHECK(param->struct_info_.defined()) - << "relax.Function requires params to contain struct_info_"; - param_sinfo.push_back(GetStructInfo(param)); + TVM_FFI_ICHECK(param->ty.defined()) << "relax.Function requires params to contain ty"; + param_ty.push_back(GetType(param)); } - ffi::Optional body_sinfo; + ffi::Optional body_ty; - if (body->struct_info_.defined()) { - body_sinfo = GetStructInfo(body); + if (body->ty.defined()) { + body_ty = GetType(body); } - TVM_FFI_ICHECK(body_sinfo.defined() || ret_struct_info.defined()) + TVM_FFI_ICHECK(body_ty.defined() || ret_ty.defined()) << "Function must be constructed with either " - << "an explicit struct info for the return type, " - << "or a normalized body with struct info."; + << "an explicit type for the return type, " + << "or a normalized body with type."; - // Use the body's struct info if there is no explicit return type, + // Use the body's type if there is no explicit return type, // or if the body may provide a more granular return type. - bool use_body_struct_info = - !ret_struct_info.defined() || - (body_sinfo && ret_struct_info && IsBaseOf(ret_struct_info.value(), body_sinfo.value())); + bool use_body_ty = + !ret_ty.defined() || (body_ty && ret_ty && IsBaseOf(ret_ty.value(), body_ty.value())); - if (use_body_struct_info) { + if (use_body_ty) { // MatchCast nodes within the body may introduce new symbolic // variables. These are in-scope for the function body, but not // for the function's return type. When hoisting the body's type // to the function return type, symbolic variables may only be // used if they were defined by the function's parameters. auto f_shape_var_map = [&] { - auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + auto tir_vars = DefinableTIRVarsInType(TupleType(params.Map(GetType))); std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); return [lookup = std::move(lookup)](const tirx::Var& var) -> ffi::Optional { if (lookup.count(var)) { @@ -586,18 +582,18 @@ Function::Function(ffi::Array params, Expr body, ffi::Optional } }; }(); - ret_struct_info = EraseToWellDefined(body_sinfo.value(), f_shape_var_map); + ret_ty = EraseToWellDefined(body_ty.value(), f_shape_var_map); } - FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); + FuncType func_ty(param_ty, ret_ty.value(), is_pure); // set the fields ffi::ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); - n->ret_struct_info = ret_struct_info.value(); + n->ret_ty = ret_ty.value(); n->is_pure = is_pure; - n->struct_info_ = std::move(func_sinfo); + n->ty = std::move(func_ty); n->attrs = std::move(attrs); n->span = std::move(span); data_ = std::move(n); @@ -605,28 +601,27 @@ Function::Function(ffi::Array params, Expr body, ffi::Optional TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.Function", [](ffi::Array params, Expr body, - ffi::Optional ret_struct_info, - bool is_pure, DictAttrs attrs, Span span) { - return Function(params, body, ret_struct_info, is_pure, attrs, span); - }); + refl::GlobalDef().def("relax.Function", + [](ffi::Array params, Expr body, ffi::Optional ret_ty, + bool is_pure, DictAttrs attrs, Span span) { + return Function(params, body, ret_ty, is_pure, attrs, span); + }); } -Function Function::CreateEmpty(ffi::Array params, StructInfo ret_struct_info, bool is_pure, - DictAttrs attrs, Span span) { - ffi::Array param_sinfo; +Function Function::CreateEmpty(ffi::Array params, Type ret_ty, bool is_pure, DictAttrs attrs, + Span span) { + ffi::Array param_ty; for (const Var& param : params) { - TVM_FFI_ICHECK(param->struct_info_.defined()) - << "relax.Function requires params to contain struct_info_."; - param_sinfo.push_back(GetStructInfo(param)); + TVM_FFI_ICHECK(param->ty.defined()) << "relax.Function requires params to contain ty."; + param_ty.push_back(GetType(param)); } - FuncStructInfo finfo(param_sinfo, ret_struct_info, is_pure); + FuncType finfo(param_ty, ret_ty, is_pure); // A dummy body, to ensure that the empty function is still well-formed. Expr body = [&]() -> Expr { - Var output("output", ret_struct_info); - Call expr(ExternFunc("_dummy_function", FuncStructInfo({}, ret_struct_info)), {}); + Var output("output", ret_ty); + Call expr(ExternFunc("_dummy_function", FuncType({}, ret_ty)), {}); return SeqExpr({BindingBlock({VarBinding(output, expr)})}, output); }(); @@ -636,8 +631,8 @@ Function Function::CreateEmpty(ffi::Array params, StructInfo ret_struct_inf n->params = std::move(params); n->body = std::move(body); n->is_pure = is_pure; - n->struct_info_ = std::move(finfo); - n->ret_struct_info = std::move(ret_struct_info); + n->ty = std::move(finfo); + n->ret_ty = std::move(ret_ty); n->attrs = std::move(attrs); n->span = std::move(span); return Function(std::move(n)); @@ -645,79 +640,76 @@ Function Function::CreateEmpty(ffi::Array params, StructInfo ret_struct_inf TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "relax.FunctionCreateEmpty", [](ffi::Array params, StructInfo ret_struct_info, - bool is_pure, DictAttrs attrs, Span span) { - return Function::CreateEmpty(params, ret_struct_info, is_pure, attrs, span); - }); + refl::GlobalDef().def("relax.FunctionCreateEmpty", [](ffi::Array params, Type ret_ty, + bool is_pure, DictAttrs attrs, Span span) { + return Function::CreateEmpty(params, ret_ty, is_pure, attrs, span); + }); } // Special opaque derivation function for ExternFunc -// Take look at sinfo_args to figure out the return StructInfo. +// Take look at ty_args to figure out the return Type. TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tvm.relax.struct_info.infer_by_sinfo_args", - [](const Call& call, const BlockBuilder& ctx) -> StructInfo { - TVM_FFI_ICHECK(call->sinfo_args.defined()) - << "sinfo_args field of CallNode should always be defined"; - if (call->sinfo_args.empty()) { - return ObjectStructInfo(); - } else if (call->sinfo_args.size() == 1) { - return call->sinfo_args[0]; - } else { - return TupleStructInfo(call->sinfo_args); - } - }); + auto infer_by_ty_args = [](const Call& call, const BlockBuilder& ctx) -> Type { + TVM_FFI_ICHECK(call->ty_args.defined()) << "ty_args field of CallNode should always be defined"; + if (call->ty_args.empty()) { + return ObjectType(); + } else if (call->ty_args.size() == 1) { + return call->ty_args[0]; + } else { + return TupleType(call->ty_args); + } + }; + refl::GlobalDef().def("tvm.relax.type.infer_by_ty_args", infer_by_ty_args); } // Get the derive function. -FuncStructInfo GetExternFuncStructInfo() { - EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_sinfo_args"); - StructInfoDeriveFunc derive; +FuncType GetExternFuncType() { + EnvFunc fn = EnvFunc::Get("tvm.relax.type.infer_by_ty_args"); + TypeDeriveFunc derive; derive = fn; - return FuncStructInfo::OpaqueFunc(derive); + return FuncType::OpaqueFunc(derive); } ExternFunc::ExternFunc(ffi::String global_symbol, Span span) - : ExternFunc(global_symbol, GetExternFuncStructInfo(), span) {} + : ExternFunc(global_symbol, GetExternFuncType(), span) {} -ExternFunc::ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span) { - TVM_FFI_ICHECK(struct_info.as()) - << "ExternFunc must have FuncStructInfo, " - << "but declaration of '" << global_symbol << "' received " << struct_info; +ExternFunc::ExternFunc(ffi::String global_symbol, Type ty, Span span) { + TVM_FFI_ICHECK(ty.as()) + << "ExternFunc must have FuncType, " + << "but declaration of '" << global_symbol << "' received " << ty; ffi::ObjectPtr n = ffi::make_object(); n->global_symbol = std::move(global_symbol); n->span = span; - n->struct_info_ = struct_info; + n->ty = ty; data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ExternFunc", [](ffi::String global_symbol, - ffi::Optional struct_info, Span span) { - if (struct_info.defined()) { - return ExternFunc(global_symbol, struct_info.value(), span); - } else { - return ExternFunc(global_symbol, span); - } - }); + refl::GlobalDef().def("relax.ExternFunc", + [](ffi::String global_symbol, ffi::Optional ty, Span span) { + if (ty.defined()) { + return ExternFunc(global_symbol, ty.value(), span); + } else { + return ExternFunc(global_symbol, span); + } + }); } Expr GetShapeOf(const Expr& expr) { // default case, to be normalized. - TVM_FFI_ICHECK(expr->struct_info_.defined()) - << "GetShapeOf can only be applied to normalized expr"; - auto* tinfo = GetStructInfoAs(expr); + TVM_FFI_ICHECK(expr->ty.defined()) << "GetShapeOf can only be applied to normalized expr"; + auto* tinfo = GetTypeAs(expr); - TVM_FFI_ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorStructInfo"; + TVM_FFI_ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorType"; if (tinfo->shape.defined()) return tinfo->shape.value(); static const Op& op = Op::Get("relax.shape_of"); // default case, call shape of, eagerly normalize the expr. relax::Call call_shape_of(op, {expr}, {}, {}); - UpdateStructInfo(call_shape_of, ShapeStructInfo(tinfo->ndim)); + UpdateType(call_shape_of, ShapeType(tinfo->ndim)); return call_shape_of; } diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index e9995fa31d08..7483feb07f57 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -26,7 +26,6 @@ */ #include #include -#include #include #include #include @@ -89,38 +88,46 @@ namespace relax { // ================== // ExprVisitor -void ExprVisitor::VisitExprDepStructInfoField(const StructInfo& struct_info) { - // recurse into struct info in case they depend on value +void ExprVisitor::VisitExprDepTypeField(const Type& ty) { + // recurse into type in case they depend on value // under the current scope. - default_struct_info_field_visitor_.VisitStructInfo(struct_info); + default_tyfield_visitor_.VisitType(ty); } -ExprVisitor::DefaultStructInfoFieldVisitor::DefaultStructInfoFieldVisitor(ExprVisitor* parent) +ExprVisitor::DefaultTypeFieldVisitor::DefaultTypeFieldVisitor(ExprVisitor* parent) : parent_(parent) {} -void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const Expr& expr) { +void ExprVisitor::DefaultTypeFieldVisitor::VisitTypeExprField(const Expr& expr) { parent_->VisitExpr(expr); } -void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const PrimExpr& expr) { +void ExprVisitor::DefaultTypeFieldVisitor::VisitTypeExprField(const PrimExpr& expr) { parent_->VisitPrimExpr(expr); } -void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { - // Do not recurse into function struct info +void ExprVisitor::DefaultTypeFieldVisitor::VisitType_(const FuncTypeNode* op) { + // Do not recurse into function type // as they won't contain ref to values in current scope. } +void VisitExprDepTypeFieldIfNeeded(ExprVisitor* visitor, const Type& ty) { + if (auto* ty_node = ty.as()) { + visitor->VisitExprDepTypeField(ffi::GetRef(ty_node)); + } else if (auto* ty_node = ty.as()) { + visitor->VisitExprDepTypeField(ffi::GetRef(ty_node)); + } +} + void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); - // Constant's StructInfo does not depend on Expr. + // Constant's Type does not depend on Expr. } void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); - // FuncStructInfo is not value-dep + // FuncType is not value-dep } void ExprVisitor::VisitExpr_(const TupleNode* op) { @@ -128,17 +135,13 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) { for (Expr field : op->fields) { this->VisitExpr(field); } - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - } + VisitExprDepTypeFieldIfNeeded(this, op->ty); } // Visit the use-site of a defined Var void ExprVisitor::VisitExpr_(const VarNode* op) { this->VisitSpan(op->span); - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - } + VisitExprDepTypeFieldIfNeeded(this, op->ty); } // Visit the use-site of a defined DataflowVar @@ -153,24 +156,22 @@ void ExprVisitor::VisitExpr_(const FunctionNode* op) { } this->VisitExpr(op->body); - // FuncStructInfo does not depend on Expr. + // FuncType does not depend on Expr. } void ExprVisitor::VisitExpr_(const CallNode* op) { this->VisitSpan(op->span); this->VisitExpr(op->op); - for (StructInfo sinfo_arg : op->sinfo_args) { - this->VisitExprDepStructInfoField(sinfo_arg); + for (Type ty_arg : op->ty_args) { + this->VisitExprDepTypeField(ty_arg); } for (Expr arg : op->args) { this->VisitExpr(arg); } - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - } + VisitExprDepTypeFieldIfNeeded(this, op->ty); } void ExprVisitor::VisitExpr_(const IfNode* op) { @@ -179,9 +180,7 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { this->VisitExpr(op->true_branch); this->VisitExpr(op->false_branch); - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - } + VisitExprDepTypeFieldIfNeeded(this, op->ty); } void ExprVisitor::VisitExpr_(const OpNode* op) { this->VisitSpan(op->span); } @@ -190,9 +189,7 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitSpan(op->span); this->VisitExpr(op->tuple); - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - } + VisitExprDepTypeFieldIfNeeded(this, op->ty); } void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { @@ -201,14 +198,12 @@ void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { } this->VisitSpan(op->span); - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - } + VisitExprDepTypeFieldIfNeeded(this, op->ty); } void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { this->VisitSpan(op->span); - // FuncStructInfo does not depend on Expr. + // FuncType does not depend on Expr. } void ExprVisitor::VisitExpr_(const SeqExprNode* op) { @@ -218,16 +213,12 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) { } this->VisitExpr(op->body); - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - } + VisitExprDepTypeFieldIfNeeded(this, op->ty); } void ExprVisitor::VisitExpr_(const PrimValueNode* op) { this->VisitPrimExpr(op->value); - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - } + VisitExprDepTypeFieldIfNeeded(this, op->ty); this->VisitSpan(op->span); } @@ -260,7 +251,7 @@ RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { this->VisitExpr(binding->value); - this->VisitExprDepStructInfoField(binding->struct_info); + this->VisitExprDepTypeField(binding->ty); this->VisitVarDef(binding->var); } @@ -339,41 +330,38 @@ TVM_FFI_STATIC_INIT_BLOCK() { // ================== // ExprMutatorBase -StructInfo ExprMutatorBase::VisitExprDepStructInfoField(const StructInfo& struct_info) { - // recurse into struct info in case they depend on value +Type ExprMutatorBase::VisitExprDepTypeField(const Type& ty) { + // recurse into type in case they depend on value // under the current scope. - return default_struct_info_field_mutator_.VisitStructInfo(struct_info); + return default_tyfield_mutator_.VisitType(ty); } -ExprMutatorBase::DefaultStructInfoFieldMutator::DefaultStructInfoFieldMutator( - ExprMutatorBase* parent) +ExprMutatorBase::DefaultTypeFieldMutator::DefaultTypeFieldMutator(ExprMutatorBase* parent) : parent_(parent) {} -Expr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField(const Expr& expr) { +Expr ExprMutatorBase::DefaultTypeFieldMutator::VisitTypeExprField(const Expr& expr) { return parent_->VisitExpr(expr); } -PrimExpr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField( - const PrimExpr& expr) { +PrimExpr ExprMutatorBase::DefaultTypeFieldMutator::VisitTypeExprField(const PrimExpr& expr) { return parent_->VisitPrimExpr(expr); } -StructInfo ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfo_( - const FuncStructInfoNode* op) { - // Do not recurse into function struct info +Type ExprMutatorBase::DefaultTypeFieldMutator::VisitType_(const FuncTypeNode* op) { + // Do not recurse into function type // as they won't contain ref to values in current scope. - return ffi::GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { - // Constant' struct info won't be affected by Expr/PrimExpr change. + // Constant' type won't be affected by Expr/PrimExpr change. return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { - // FuncStructInfo won't be affected by Expr/PrimExpr change. + // FuncType won't be affected by Expr/PrimExpr change. return ffi::GetRef(op); } @@ -387,9 +375,9 @@ Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { } if (unchanged) { - // If tuple's struct info change it means that - // one of its fields' struct info will change - // so un-changed already implies that struct info won't change + // If tuple's type change it means that + // one of its fields' type will change + // so un-changed already implies that type won't change return ffi::GetRef(op); } else { // when there is a change return a new tuple node @@ -399,7 +387,7 @@ Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { // Visit the use-site of a defined Var Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { - // struct info of var-use should remain stable + // type of var-use should remain stable // or the var itself will get replaced return ffi::GetRef(op); } @@ -410,14 +398,14 @@ Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) { } Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { - // struct info of function is not value dependent - // so no need to check struct_info field + // type of function is not value dependent + // so no need to check ty field Expr body = this->VisitExpr(op->body); if (body.same_as(op->body)) { return ffi::GetRef(op); } else { - return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); + return Function(op->params, body, op->ret_ty, op->is_pure, op->attrs); } } @@ -425,11 +413,11 @@ Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { Expr new_op = this->VisitExpr(call_node->op); bool unchanged = call_node->op.same_as(new_op); - ffi::Array sinfo_args; - for (StructInfo sinfo_arg : call_node->sinfo_args) { - StructInfo new_sinfo_arg = this->VisitExprDepStructInfoField(sinfo_arg); - sinfo_args.push_back(new_sinfo_arg); - unchanged &= new_sinfo_arg.same_as(sinfo_arg); + ffi::Array ty_args; + for (Type ty_arg : call_node->ty_args) { + Type new_ty_arg = this->VisitExprDepTypeField(ty_arg); + ty_args.push_back(new_ty_arg); + unchanged &= new_ty_arg.same_as(ty_arg); } tvm::ffi::Array call_args; @@ -439,10 +427,10 @@ Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { unchanged &= new_arg.same_as(arg); } - if (unchanged && VisitAndCheckStructInfoFieldUnchanged(call_node->struct_info_)) { + if (unchanged && VisitAndCheckTypeFieldUnchanged(call_node->ty)) { return ffi::GetRef(call_node); } else { - return Call(new_op, call_args, call_node->attrs, sinfo_args, call_node->span); + return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); } } @@ -451,8 +439,7 @@ Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { Expr true_b = this->VisitExpr(op->true_branch); Expr false_b = this->VisitExpr(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && - op->false_branch.same_as(false_b) && - VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + op->false_branch.same_as(false_b) && VisitAndCheckTypeFieldUnchanged(op->ty)) { return ffi::GetRef(op); } else { return If(guard, true_b, false_b, op->span); @@ -464,8 +451,8 @@ Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return ffi::GetRef(op Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { auto t = this->VisitExpr(op->tuple); if (op->tuple.same_as(t)) { - // struct info can be deterministically derived by tuple and index - // if t does not change, then struct info won't change. + // type can be deterministically derived by tuple and index + // if t does not change, then type won't change. return ffi::GetRef(op); } else { return TupleGetItem(t, op->index, op->span); @@ -475,8 +462,8 @@ Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { Expr ExprMutatorBase::VisitExpr_(const PrimValueNode* op) { auto value = this->VisitPrimExpr(op->value); if (op->value.same_as(value)) { - // struct info can be deterministically derived by value - // if value does not change, then struct info won't change. + // type can be deterministically derived by value + // if value does not change, then type won't change. return ffi::GetRef(op); } return PrimValue(value, op->span); @@ -490,7 +477,7 @@ Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); if (values.same_as(op->values)) { - // If values does not change, struct info won't change. + // If values does not change, type won't change. return ffi::GetRef(op); } else { return ShapeExpr(values, op->span); @@ -498,7 +485,7 @@ Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { } Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { - // StructInfo of function remains value independent. + // Type of function remains value independent. return ffi::GetRef(op); } @@ -515,8 +502,7 @@ Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { Expr body = this->VisitExpr(op->body); - if (all_blocks_unchanged && body.same_as(op->body) && - VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + if (all_blocks_unchanged && body.same_as(op->body) && VisitAndCheckTypeFieldUnchanged(op->ty)) { return ffi::GetRef(op); } return SeqExpr(blocks, body); @@ -531,7 +517,7 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { bindings.push_back(VarBinding(var_binding->var, new_value)); } else if (auto match_cast = binding.as()) { Expr new_value = this->VisitExpr(match_cast->value); - bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info)); + bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->ty)); } else { TVM_FFI_THROW(TypeError) << "Invalid type: " << binding->GetTypeKey(); } @@ -589,23 +575,23 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { if (all_params_unchanged && body.same_as(op->body)) { // No changes to the function, return the original object return ffi::GetRef(op); - } else if (IsBaseOf(GetStructInfo(body), op->ret_struct_info)) { + } else if (IsBaseOf(GetType(body), op->ret_ty)) { // If the function was mutated into a form that can no longer // propagate shape information all the way to the return value, we - // may keep the return struct info. This is only allowed when the + // may keep the return type. This is only allowed when the // body produces a return value that is the same as, or more - // specific than, the pre-mutation struct info. For example, if - // the previous return value was `TensorStructInfo(shape=[16,16])` - // but the body only produced `TensorStructInfo(ndim=2)`, we can + // specific than, the pre-mutation type. For example, if + // the previous return value was `TensorType(shape=[16,16])` + // but the body only produced `TensorType(ndim=2)`, we can // keep the more specific information. - return Function(params, body, op->ret_struct_info, op->is_pure, op->attrs); + return Function(params, body, op->ret_ty, op->is_pure, op->attrs); } else { // If the function was mutated such that the body produces an // output that is incompatible with the original return struct - // info, the original return struct info should not be used. For + // info, the original return type should not be used. For // example, if the previous return value was - // `TensorStructInfo(shape=[16,16])`, but the new return value is - // `TensorStructInfo(shape=[8,8])`. + // `TensorType(shape=[16,16])`, but the new return value is + // `TensorType(shape=[8,8])`. return Function(params, body, std::nullopt, op->is_pure, op->attrs); } } @@ -615,8 +601,7 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr true_b = this->VisitWithInnerScope(op->true_branch); Expr false_b = this->VisitWithInnerScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && - op->false_branch.same_as(false_b) && - VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + op->false_branch.same_as(false_b) && VisitAndCheckTypeFieldUnchanged(op->ty)) { return ffi::GetRef(op); } else { return If(guard, true_b, false_b, op->span); @@ -642,8 +627,7 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { all_blocks_unchanged = false; } - if (all_blocks_unchanged && body.same_as(op->body) && - VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + if (all_blocks_unchanged && body.same_as(op->body) && VisitAndCheckTypeFieldUnchanged(op->ty)) { return ffi::GetRef(op); } else { return SeqExpr(blocks, body); @@ -677,14 +661,14 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { return; } - auto new_sinfo = new_value->struct_info_.as(); + auto new_ty = new_value->ty.as(); - TVM_FFI_CHECK(new_sinfo, InternalError) + TVM_FFI_CHECK(new_ty, InternalError) << "In binding of variable " << binding->var << ", the value " << new_value - << " does not have StructInfo. " + << " does not have Type. " << "This typically occurs when ReEmitBinding is called without first calling Normalize."; - Var temp = WithStructInfo(new_var, new_sinfo.value()); + Var temp = WithType(new_var, new_ty.value()); if (!temp.same_as(new_var)) { new_var = temp; } @@ -697,23 +681,23 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { void ExprMutator::VisitBinding_(const MatchCastNode* binding) { Expr new_value = this->VisitExpr(binding->value); - StructInfo new_struct_info = this->VisitExprDepStructInfoField(binding->struct_info); + Type new_ty = this->VisitExprDepTypeField(binding->ty); Var new_var = this->VisitVarDef(binding->var); MatchCast new_binding = [&]() -> MatchCast { if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && - new_struct_info.same_as(binding->struct_info)) { + new_ty.same_as(binding->ty)) { // re-emit old binding if nothing changes return ffi::GetRef(binding); } else { new_value = builder_->NormalizeArgument(new_value); - new_var = WithStructInfo(new_var, new_struct_info); + new_var = WithType(new_var, new_ty); var_remap_[binding->var->vid] = new_var; var_remap_[new_var->vid] = new_var; - return MatchCast(new_var, new_value, new_struct_info, binding->span); + return MatchCast(new_var, new_value, new_ty, binding->span); } }(); @@ -743,18 +727,25 @@ Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { // provide default behavior in subclasses, we may produce a Var // where we should produce a DataflowVar. if (!output->IsInstance()) { - output = DataflowVar(output->vid, GetStructInfo(output), output->span); + output = DataflowVar(output->vid, GetType(output), output->span); } return output; } Var ExprMutator::VisitVarDef_(const VarNode* var) { - if (auto* sinfo = var->struct_info_.as()) { - StructInfo struct_info = this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); - if (struct_info.same_as(var->struct_info_)) { + if (auto* ty_node = var->ty.as()) { + Type ty = this->VisitExprDepTypeField(ffi::GetRef(ty_node)); + if (ty.same_as(var->ty)) { + return ffi::GetRef(var); + } else { + return Var(var->vid, ty, var->span); + } + } else if (auto* ty_node = var->ty.as()) { + Type ty = this->VisitExprDepTypeField(ffi::GetRef(ty_node)); + if (ty.same_as(var->ty)) { return ffi::GetRef(var); } else { - return Var(var->vid, struct_info, var->span); + return Var(var->vid, ty, var->span); } } else { return ffi::GetRef(var); @@ -802,7 +793,7 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, ffi::Optional= 0); } @@ -819,11 +810,11 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, ffi::OptionalEndScope(); - // Normalization (and the resulting StructInfo inference) of the + // Normalization (and the resulting Type inference) of the // expr occurs outside of the body's parameters, but inside the // function signature's scope. This keeps variables that are // inferable based on the function signature, to allow callers to - // propagate StructInfo across the function. + // propagate Type across the function. ret = builder_->Normalize(ret); builder_->EndScope(); return ret; @@ -843,22 +834,21 @@ ffi::Optional ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } -Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { - TVM_FFI_ICHECK(struct_info.defined()); +Var ExprMutator::WithType(Var var, Type ty) { + TVM_FFI_ICHECK(ty.defined()); - // TODO(relax-team) add StructInfoEqual check - if (var->struct_info_.defined()) { + // TODO(relax-team) add TypeEqual check + if (var->ty.defined()) { // use same-as as a quick path - if (var->struct_info_.same_as(struct_info) || - ffi::StructuralEqual()(var->struct_info_, struct_info)) { + if (var->ty.same_as(ty) || ffi::StructuralEqual()(var->ty, ty)) { return var; } else { - Var new_var = var.as() ? DataflowVar(var->vid, struct_info, var->span) - : Var(var->vid, struct_info, var->span); + Var new_var = var.as() ? DataflowVar(var->vid, ty, var->span) + : Var(var->vid, ty, var->span); return new_var; } } else { - UpdateStructInfo(var, struct_info); + UpdateType(var, ty); return var; } } diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 4a36e877c1c0..09d0fd983db4 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -399,7 +399,7 @@ class PyExprMutatorNode : public ffi::Object, public ExprMutator { using ExprMutator::LookupBinding; using ExprMutator::var_remap_; using ExprMutator::VisitWithNewScope; - using ExprMutator::WithStructInfo; + using ExprMutator::WithType; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -654,10 +654,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](PyExprMutator mutator, const Expr& expr) { return mutator->VisitWithNewScope(expr); }) .def("relax.PyExprMutatorLookupBinding", [](PyExprMutator mutator, const Var& var) { return mutator->LookupBinding(var); }) - .def("relax.PyExprMutatorWithStructInfo", - [](PyExprMutator mutator, Var var, StructInfo sinfo) { - return mutator->WithStructInfo(var, sinfo); - }) + .def("relax.PyExprMutatorWithType", + [](PyExprMutator mutator, Var var, Type ty) { return mutator->WithType(var, ty); }) .def("relax.PyExprMutatorSetVarRemap", [](PyExprMutator mutator, Id id, Var var) { return mutator->var_remap_[id] = var; }) .def("relax.PyExprMutatorGetVarRemap", diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc deleted file mode 100644 index 0f9c0366cefe..000000000000 --- a/src/relax/ir/struct_info.cc +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relax/ir/struct_info.cc - * \brief Relax struct info. - */ -#include -#include -#include -#include -#include - -namespace tvm { -namespace relax { - -TVM_FFI_STATIC_INIT_BLOCK() { - StructInfoNode::RegisterReflection(); - ObjectStructInfoNode::RegisterReflection(); - PrimStructInfoNode::RegisterReflection(); - ShapeStructInfoNode::RegisterReflection(); - TensorStructInfoNode::RegisterReflection(); - TupleStructInfoNode::RegisterReflection(); - FuncStructInfoNode::RegisterReflection(); -} - -ObjectStructInfo::ObjectStructInfo(Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->span = span; - data_ = std::move(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ObjectStructInfo", [](Span span) { return ObjectStructInfo(span); }); -} - -// Prim -PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->dtype = value->dtype; - n->value = std::move(value); - n->span = span; - data_ = std::move(n); -} - -PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->dtype = dtype; - n->value = std::nullopt; - n->span = span; - data_ = std::move(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("relax.PrimStructInfoFromDtype", - [](DataType dtype, Span span) { return PrimStructInfo(dtype, span); }) - .def("relax.PrimStructInfoFromValue", - [](PrimExpr value, Span span) { return PrimStructInfo(value, span); }); -} - -// Shape -ShapeStructInfo::ShapeStructInfo(ffi::Array values, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->ndim = static_cast(values.size()); - n->values = values.Map([](PrimExpr value) { - if (value->IsInstance()) { - return tvm::cast(DataType::Int(64), value); - } - TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) - << "the value in ShapeStructInfo can only have dtype of int64"; - return value; - }); - n->span = span; - data_ = std::move(n); -} - -ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - TVM_FFI_ICHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; - n->ndim = ndim; - n->span = span; - data_ = std::move(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "relax.ShapeStructInfo", [](ffi::Optional> values, int ndim, Span span) { - if (values.defined()) { - TVM_FFI_CHECK_EQ(ndim, kUnknownNDim, ValueError) << "Cannot both specify values and ndim"; - return ShapeStructInfo(values.value(), span); - } else { - return ShapeStructInfo(ndim, span); - } - }); -} - -// Tensor -TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, ffi::Optional vdevice, - Span span) { - ffi::ObjectPtr n = ffi::make_object(); - // assign ndim before move - ffi::Optional sinfo = MatchStructInfo(shape); - TVM_FFI_ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; - TVM_FFI_ICHECK(shape.defined()) << "Must provide a shape in this constructor"; - TVM_FFI_ICHECK(shape->IsInstance() || shape->IsInstance()) - << "We require shape to be normalized when constructing TensorStructInfo"; - n->ndim = sinfo.value()->ndim; - // assign rest of the fields. - n->shape = std::move(shape); - n->dtype = dtype; - n->vdevice = vdevice; - n->span = span; - data_ = std::move(n); -} - -TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, ffi::Optional vdevice, - Span span) { - ffi::ObjectPtr n = ffi::make_object(); - TVM_FFI_ICHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; - n->ndim = ndim; - n->dtype = dtype; - n->vdevice = vdevice; - n->span = span; - data_ = std::move(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "relax.TensorStructInfo", [](ffi::Optional shape, ffi::Optional dtype, - int ndim, VDevice vdevice, Span span) { - if (shape.defined()) { - TVM_FFI_CHECK_EQ(ndim, kUnknownNDim, ValueError) << "Cannot both specify shape and ndim"; - return TensorStructInfo(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); - } else { - return TensorStructInfo(dtype.value_or(DataType::Void()), ndim, vdevice, span); - } - }); -} - -// Tuple -TupleStructInfo::TupleStructInfo(ffi::Array fields, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->fields = std::move(fields); - n->span = span; - data_ = std::move(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.TupleStructInfo", [](ffi::Array fields, Span span) { - return TupleStructInfo(fields, span); - }); -} - -// Func -FuncStructInfo::FuncStructInfo(ffi::Array params, StructInfo ret, bool purity, - Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->params = std::move(params); - n->ret = std::move(ret); - n->purity = std::move(purity); - n->span = span; - data_ = std::move(n); -} - -FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity, - Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->derive_func = std::move(derive_func); - n->ret = ObjectStructInfo(); - n->purity = std::move(purity); - n->span = span; - return FuncStructInfo(n); -} - -FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, bool purity, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->ret = std::move(ret); - n->purity = std::move(purity); - n->span = span; - return FuncStructInfo(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("relax.FuncStructInfo", - [](ffi::Array params, StructInfo ret, bool purity, Span span) { - return FuncStructInfo(params, ret, purity, span); - }) - .def("relax.FuncStructInfoOpaqueFunc", [](ffi::Optional ret, - ffi::Optional derive_func, - bool purity, Span span) { - if (derive_func.defined()) { - TVM_FFI_CHECK(!ret.defined(), ValueError) << "Cannot specify both ret and derive_func"; - return FuncStructInfo::OpaqueFunc(derive_func.value(), purity, span); - } else { - return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), purity, span); - } - }); -} - -// Helper functions -void UpdateStructInfo(Expr expr, StructInfo struct_info) { - TVM_FFI_ICHECK(!expr->struct_info_.defined()) - << "To ensure idempotency, " - << "the expression passed to UpdateStructInfo " - << "must not have any prior StructInfo. " - << "However, expression " << expr << " has struct info " << expr->struct_info_ - << ", which cannot be overwritten with " << struct_info; - expr->struct_info_ = struct_info; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("relax.UpdateStructInfo", - [](Expr expr, StructInfo struct_info) { UpdateStructInfo(expr, struct_info); }) - .def("ir.ExprStructInfo", [](Expr expr) { return GetStructInfo(expr); }); -} - -} // namespace relax -} // namespace tvm diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc deleted file mode 100644 index 7dc6843f8402..000000000000 --- a/src/relax/ir/struct_info_functor.cc +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file struct_info_functor.cc - * \brief Implementations of struct info functors. - */ -#include -#include - -namespace tvm { -namespace relax { - -void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {} - -void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) { - if (op->value.defined()) { - this->VisitStructInfoExprField(op->value.value()); - } -} - -void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) { - if (op->values.defined()) { - for (PrimExpr value : op->values.value()) { - this->VisitStructInfoExprField(value); - } - } -} - -void StructInfoVisitor::VisitStructInfo_(const TensorStructInfoNode* op) { - if (op->shape.defined()) { - this->VisitStructInfoExprField(op->shape.value()); - } -} - -void StructInfoVisitor::VisitStructInfo_(const distributed::DTensorStructInfoNode* op) { - this->VisitStructInfo(op->tensor_sinfo); -} - -void StructInfoVisitor::VisitStructInfo_(const TupleStructInfoNode* op) { - for (StructInfo field : op->fields) { - this->VisitStructInfo(field); - } -} - -void StructInfoVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { - if (op->params.defined()) { - for (StructInfo param : op->params.value()) { - this->VisitStructInfo(param); - } - } - this->VisitStructInfo(op->ret); -} - -StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) { - return ffi::GetRef(op); -} - -StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) { - if (!op->value.defined()) { - return ffi::GetRef(op); - } - - auto new_expr = VisitStructInfoExprField(op->value.value()); - if (new_expr.same_as(op->value)) { - return ffi::GetRef(op); - } else { - return PrimStructInfo(new_expr); - } -} - -StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { - ffi::Optional> values; - - if (op->values.defined()) { - // if no changes are made the original array will be returned. - values = op->values.value().Map( - [this](const PrimExpr& expr) { return this->VisitStructInfoExprField(expr); }); - } - - if (values.same_as(op->values)) { - return ffi::GetRef(op); - } else { - return ShapeStructInfo(values.value(), op->span); - } -} - -StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { - ffi::Optional shape; - - if (op->shape.defined()) { - shape = this->VisitStructInfoExprField(op->shape.value()); - } - - VDevice vdev = op->vdevice.value_or(VDevice()); - - if (shape.same_as(op->shape)) { - return ffi::GetRef(op); - } else { - return TensorStructInfo(shape.value(), op->dtype, vdev, op->span); - } -} - -StructInfo StructInfoMutator::VisitStructInfo_(const distributed::DTensorStructInfoNode* op) { - TensorStructInfo tensor_sinfo = - Downcast(this->VisitStructInfo(op->tensor_sinfo)); - return distributed::DTensorStructInfo(tensor_sinfo, op->device_mesh, op->placement); -} - -StructInfo StructInfoMutator::VisitStructInfo_(const TupleStructInfoNode* op) { - ffi::Array fields = - op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); - - if (fields.same_as(op->fields)) { - return ffi::GetRef(op); - } else { - return TupleStructInfo(fields, op->span); - } -} - -StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { - ffi::Optional> params; - - if (op->params.defined()) { - params = op->params.value().Map( - [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); - } - - StructInfo ret = this->VisitStructInfo(op->ret); - - if (params.same_as(op->params) && ret.same_as(op->ret)) { - return ffi::GetRef(op); - } else { - TVM_FFI_ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; - return FuncStructInfo(params.value(), ret, op->purity, op->span); - } -} - -} // namespace relax -} // namespace tvm diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index 58e29560c65b..00c4ffc32877 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -28,8 +28,8 @@ #include #include #include -#include #include +#include #include namespace tvm { @@ -224,7 +224,7 @@ class DataflowBlockMutator : public ExprMutator { for (const Binding& binding : n->bindings) { Var var = binding->var; if (const auto* match_cast = binding.as()) { - auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + auto collected_vars = SymbolicVarCollector::Collect(match_cast->ty); for (const tirx::VarNode* var : collected_vars) { symbolic_vars.Set(var->name_hint, ffi::GetRef(var)); } @@ -242,7 +242,7 @@ class DataflowBlockMutator : public ExprMutator { for (const Binding& binding : updated_block->bindings) { Var var = binding->var; if (const auto* match_cast = binding.as()) { - auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + auto collected_vars = SymbolicVarCollector::Collect(match_cast->ty); for (const tirx::VarNode* var : collected_vars) { if (symbolic_vars.count(var->name_hint) > 0) { tirx::Var old_var = symbolic_vars[var->name_hint]; @@ -265,16 +265,16 @@ class DataflowBlockMutator : public ExprMutator { } private: - class SymbolicVarCollector : public StructInfoVisitor { + class SymbolicVarCollector : public TypeVisitor { public: - static std::unordered_set Collect(const StructInfo& info) { + static std::unordered_set Collect(const Type& info) { SymbolicVarCollector collector; - collector.VisitStructInfo(info); + collector.VisitType(info); return std::move(collector.symbolic_vars_); } private: - void VisitStructInfoExprField(const PrimExpr& expr) final { + void VisitTypeExprField(const PrimExpr& expr) final { if (const tirx::VarNode* sym_var = expr.as()) { symbolic_vars_.insert(sym_var); } diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 37cd541d6887..297e71e30cdc 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -28,59 +28,7 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK() { - ShapeTypeNode::RegisterReflection(); - TensorTypeNode::RegisterReflection(); - ObjectTypeNode::RegisterReflection(); - PackedFuncTypeNode::RegisterReflection(); -} - -ShapeType::ShapeType(int ndim, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->ndim = ndim; - n->span = span; - data_ = std::move(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ShapeType", - [](int ndim, Span span) { return ShapeType(ndim, span); }); -} - -ObjectType::ObjectType(Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->span = span; - data_ = std::move(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ObjectType", [](Span span) { return ObjectType(span); }); -} - -TensorType::TensorType(int ndim, DataType dtype, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->ndim = std::move(ndim); - n->dtype = std::move(dtype); - n->span = span; - data_ = std::move(n); -} - -TensorType TensorType::CreateUnknownNDim(DataType dtype, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->ndim = -1; - n->dtype = std::move(dtype); - n->span = std::move(span); - return TensorType(std::move(n)); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.TensorType", [](int ndim, DataType dtype, Span span) { - return TensorType(ndim, dtype, span); - }); -} +TVM_FFI_STATIC_INIT_BLOCK() { PackedFuncTypeNode::RegisterReflection(); } PackedFuncType::PackedFuncType(Span span) { ffi::ObjectPtr n = ffi::make_object(); diff --git a/src/relax/ir/type_functor.cc b/src/relax/ir/type_functor.cc new file mode 100644 index 000000000000..4f92d4feb6e2 --- /dev/null +++ b/src/relax/ir/type_functor.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file type_functor.cc + * \brief Implementations of Relax type functors. + */ +#include +#include + +namespace tvm { +namespace relax { + +void TypeVisitor::VisitType_(const ObjectTypeNode* op) {} + +void TypeVisitor::VisitType_(const PrimTypeNode* op) { + if (op->value.defined()) { + this->VisitTypeExprField(op->value.value()); + } +} + +void TypeVisitor::VisitType_(const ShapeTypeNode* op) { + if (op->values.defined()) { + for (PrimExpr value : op->values.value()) { + this->VisitTypeExprField(value); + } + } +} + +void TypeVisitor::VisitType_(const TensorTypeNode* op) { + if (op->shape.defined()) { + this->VisitTypeExprField(op->shape.value()); + } +} + +void TypeVisitor::VisitType_(const distributed::DTensorTypeNode* op) { + this->VisitType(op->tensor_ty); +} + +void TypeVisitor::VisitType_(const TupleTypeNode* op) { + for (Type field : op->fields) { + this->VisitType(field); + } +} + +void TypeVisitor::VisitType_(const FuncTypeNode* op) { + if (op->params.defined()) { + for (Type param : op->params.value()) { + this->VisitType(param); + } + } + this->VisitType(op->ret); +} + +Type TypeMutator::VisitType_(const ObjectTypeNode* op) { return ffi::GetRef(op); } + +Type TypeMutator::VisitType_(const PrimTypeNode* op) { + if (!op->value.defined()) { + return ffi::GetRef(op); + } + + auto new_expr = VisitTypeExprField(op->value.value()); + if (new_expr.same_as(op->value)) { + return ffi::GetRef(op); + } else { + return PrimType(new_expr); + } +} + +Type TypeMutator::VisitType_(const ShapeTypeNode* op) { + ffi::Optional> values; + + if (op->values.defined()) { + // if no changes are made the original array will be returned. + values = op->values.value().Map( + [this](const PrimExpr& expr) { return this->VisitTypeExprField(expr); }); + } + + if (values.same_as(op->values)) { + return ffi::GetRef(op); + } else { + return ShapeType(values.value(), op->span); + } +} + +Type TypeMutator::VisitType_(const TensorTypeNode* op) { + ffi::Optional shape; + + if (op->shape.defined()) { + shape = this->VisitTypeExprField(op->shape.value()); + } + + VDevice vdev = op->vdevice.value_or(VDevice()); + + if (shape.same_as(op->shape)) { + return ffi::GetRef(op); + } else { + return TensorType(shape.value(), op->dtype, vdev, op->span); + } +} + +Type TypeMutator::VisitType_(const distributed::DTensorTypeNode* op) { + TensorType tensor_ty = Downcast(this->VisitType(op->tensor_ty)); + return distributed::DTensorType(tensor_ty, op->device_mesh, op->placement); +} + +Type TypeMutator::VisitType_(const TupleTypeNode* op) { + ffi::Array fields = op->fields.Map([this](const Type& ty) { return this->VisitType(ty); }); + + if (fields.same_as(op->fields)) { + return ffi::GetRef(op); + } else { + return TupleType(fields, op->span); + } +} + +Type TypeMutator::VisitType_(const FuncTypeNode* op) { + ffi::Optional> params; + + if (op->params.defined()) { + params = op->params.value().Map([this](const Type& ty) { return this->VisitType(ty); }); + } + + Type ret = this->VisitType(op->ret); + + if (params.same_as(op->params) && ret.same_as(op->ret)) { + return ffi::GetRef(op); + } else { + TVM_FFI_ICHECK(ret.defined()) << "FuncType that contains params must contain ret"; + return FuncType(params.value(), ret, op->purity, op->span); + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index b8faeb9164bd..dd67f65dea09 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -21,6 +21,7 @@ #include #include +#include #include @@ -49,16 +50,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.ccl.allreduce", allreduce); } -StructInfo InferStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - return input_sinfo; +Type InferTypeAllReduce(const Call& call, const BlockBuilder& ctx) { + TensorType input_ty = GetUnaryInputTensorType(call, ctx); + return input_ty; } TVM_REGISTER_OP("relax.ccl.allreduce") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "Input to which allreduce will be applied.") - .set_attr("FInferStructInfo", InferStructInfoAllReduce) + .set_attr("FInferType", InferTypeAllReduce) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -78,26 +79,26 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.ccl.allgather", allgather); } -StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeAllGather(const Call& call, const BlockBuilder& ctx) { + TensorType input_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; - DataType output_dtype = input_sinfo->dtype; - auto input_shape = input_sinfo->GetShape(); + DataType output_dtype = input_ty->dtype; + auto input_shape = input_ty->GetShape(); if (!input_shape.defined()) { - return input_sinfo; + return input_ty; } ffi::Array output_shape = input_shape.value(); output_shape.Set(0, floor(output_shape[0] * num_workers)); - return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); + return TensorType(ShapeExpr(output_shape), output_dtype, input_ty->vdevice); } TVM_REGISTER_OP("relax.ccl.allgather") .set_num_inputs(1) .add_argument("x", "Tensor", "Input to which allgather will be applied.") - .set_attr("FInferStructInfo", InferStructInfoAllGather) + .set_attr("FInferType", InferTypeAllGather) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -112,15 +113,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.ccl.broadcast_from_worker0", broadcast_from_worker0); } -StructInfo InferStructInfoBroadcastFromZero(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - return input_sinfo; +Type InferTypeBroadcastFromZero(const Call& call, const BlockBuilder& ctx) { + TensorType input_ty = GetUnaryInputTensorType(call, ctx); + return input_ty; } TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0") .set_num_inputs(1) .add_argument("x", "Tensor", "Input to be broadcast.") - .set_attr("FInferStructInfo", InferStructInfoBroadcastFromZero) + .set_attr("FInferType", InferTypeBroadcastFromZero) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("FPurity", true); @@ -140,15 +141,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.ccl.scatter_from_worker0", scatter_from_worker0); } -StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - DataType output_dtype = input_sinfo->dtype; +Type InferTypeScatter(const Call& call, const BlockBuilder& ctx) { + TensorType input_ty = GetUnaryInputTensorType(call, ctx); + DataType output_dtype = input_ty->dtype; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; arith::Analyzer analyzer = ctx->GetAnalyzer(); - auto input_shape = input_sinfo->GetShape(); + auto input_shape = input_ty->GetShape(); TVM_FFI_ICHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should have defined shape."; @@ -161,7 +162,7 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { ffi::Array output_shape = input_shape.value(); output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers)); - return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); + return TensorType(ShapeExpr(output_shape), output_dtype, input_ty->vdevice); } TVM_REGISTER_OP("relax.ccl.scatter_from_worker0") @@ -169,7 +170,7 @@ TVM_REGISTER_OP("relax.ccl.scatter_from_worker0") .add_argument("x", "Tensor", "The buffer to be divided into equal parts and sent to each worker accordingly.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoScatter) + .set_attr("FInferType", InferTypeScatter) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/distributed/binary.cc b/src/relax/op/distributed/binary.cc index 1e7fa8172718..0a6d283a1b44 100644 --- a/src/relax/op/distributed/binary.cc +++ b/src/relax/op/distributed/binary.cc @@ -23,15 +23,15 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoBroadcastArith(const Call& call, const BlockBuilder& ctx) { - return InferDistStructInfoBroadcast(call, ctx, InferBinaryArithOpOutDtype); +Type InferDistTypeBroadcastArith(const Call& call, const BlockBuilder& ctx) { + return InferDistTypeBroadcast(call, ctx, InferBinaryArithOpOutDtype); } -StructInfo InferDistStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx) { - return InferDistStructInfoBroadcast( +Type InferDistTypeBroadcastCMP(const Call& call, const BlockBuilder& ctx) { + return InferDistTypeBroadcast( call, ctx, - [](const Call& call, const BlockBuilder& ctx, const TensorStructInfo& x1_sinfo, - const TensorStructInfo& x2_sinfo) { return DataType::Bool(); }); + [](const Call& call, const BlockBuilder& ctx, const TensorType& x1_ty, + const TensorType& x2_ty) { return DataType::Bool(); }); } /***************** Arithmetic operators *****************/ diff --git a/src/relax/op/distributed/binary.h b/src/relax/op/distributed/binary.h index b5181d770c0a..0bf7e390d72e 100644 --- a/src/relax/op/distributed/binary.h +++ b/src/relax/op/distributed/binary.h @@ -19,7 +19,7 @@ /*! * \file binary.h - * \brief The functions to infer struct info for distributed binary operator + * \brief The functions to infer type for distributed binary operator */ #ifndef TVM_RELAX_OP_DISTRIBUTED_BINARY_H_ @@ -36,53 +36,51 @@ namespace relax { namespace distributed { template -StructInfo InferDistStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, - FType f_compute_out_dtype) { - ffi::Array input_dtensor_sinfos = - GetInputDTensorStructInfo(call, ctx); - TensorStructInfo x1_sinfo, x2_sinfo; - x1_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; - x2_sinfo = input_dtensor_sinfos[1]->tensor_sinfo; +Type InferDistTypeBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TensorType x1_ty, x2_ty; + x1_ty = input_dtensor_tys[0]->tensor_ty; + x2_ty = input_dtensor_tys[1]->tensor_ty; // DateType - DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo); + DataType output_dtype = f_compute_out_dtype(call, ctx, x1_ty, x2_ty); // ndims - TVM_FFI_ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) + TVM_FFI_ICHECK(!x1_ty->IsUnknownNdim() && !x2_ty->IsUnknownNdim()) << "Unknown ndim is not supported for distributed operators."; - int output_ndim = std::max(x1_sinfo->ndim, x2_sinfo->ndim); + int output_ndim = std::max(x1_ty->ndim, x2_ty->ndim); - const auto* x1_shape = x1_sinfo->shape.as(); - const auto* x2_shape = x2_sinfo->shape.as(); - TensorStructInfo output_tensor_sinfo; + const auto* x1_shape = x1_ty->shape.as(); + const auto* x2_shape = x2_ty->shape.as(); + TensorType output_tensor_ty; // Shapes and ndims if (x1_shape && x2_shape) { // If all inputs have shapes, directly infer shapes ffi::Optional> output_shape = InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); if (!output_shape.defined()) { - output_tensor_sinfo = TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + output_tensor_ty = TensorType(output_dtype, /*ndim=*/output_ndim); } else { TVM_FFI_ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); - output_tensor_sinfo = TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + output_tensor_ty = TensorType(ShapeExpr(output_shape.value()), output_dtype); } } else { TVM_FFI_VISIT_THROW(InternalError, call) << "Cannot infer shape for binary broadcast operator."; } - return InferShardingSpec(call, ctx, output_tensor_sinfo, distributed::BuildAxisGraphBinary); + return InferShardingSpec(call, ctx, output_tensor_ty, distributed::BuildAxisGraphBinary); } -StructInfo InferDistStructInfoBroadcastArith(const Call& call, const BlockBuilder& ctx); +Type InferDistTypeBroadcastArith(const Call& call, const BlockBuilder& ctx); -StructInfo InferDistStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx); +Type InferDistTypeBroadcastCMP(const Call& call, const BlockBuilder& ctx); #define RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(OpName) \ TVM_REGISTER_OP("relax." #OpName) \ - .set_attr("dist.FInferStructInfo", InferDistStructInfoBroadcastArith) + .set_attr("dist.FInferType", InferDistTypeBroadcastArith) #define RELAX_REGISTER_CMP_DIST_INFER_STRUCT_INFO(OpName) \ TVM_REGISTER_OP("relax." #OpName) \ - .set_attr("dist.FInferStructInfo", InferDistStructInfoBroadcastCMP) + .set_attr("dist.FInferType", InferDistTypeBroadcastCMP) } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/ccl.cc b/src/relax/op/distributed/ccl.cc index cb48ff38aa4f..4f899cdec4be 100644 --- a/src/relax/op/distributed/ccl.cc +++ b/src/relax/op/distributed/ccl.cc @@ -18,25 +18,27 @@ */ #include "tvm/relax/attrs/ccl.h" +#include + #include "utils.h" namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); - TVM_FFI_ICHECK(input_dtensor_sinfos.size() == 1); - DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; - TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; - DeviceMesh device_mesh = input_dtensor_sinfo->device_mesh; +Type InferDistTypeAllReduce(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TVM_FFI_ICHECK(input_dtensor_tys.size() == 1); + DTensorType input_dtensor_ty = input_dtensor_tys[0]; + TensorType tensor_ty = input_dtensor_ty->tensor_ty; + DeviceMesh device_mesh = input_dtensor_ty->device_mesh; // FIXME: this is a hack where there's only 1d mesh - return DTensorStructInfo(tensor_sinfo, device_mesh, - Placement::FromText(std::string(device_mesh->shape.size(), 'R'))); + return DTensorType(tensor_ty, device_mesh, + Placement::FromText(std::string(device_mesh->shape.size(), 'R'))); } TVM_REGISTER_OP("relax.ccl.allreduce") - .set_attr("dist.FInferStructInfo", InferDistStructInfoAllReduce); + .set_attr("dist.FInferType", InferDistTypeAllReduce); } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index d76fa6213d64..b009630070cd 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -57,15 +57,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.dist.annotate_sharding", annotate_sharding); } -StructInfo InferStructInfoAnnotateSharding(const Call& call, const BlockBuilder& ctx) { - return GetStructInfo(call->args[0]); +Type InferTypeAnnotateSharding(const Call& call, const BlockBuilder& ctx) { + return GetType(call->args[0]); } TVM_REGISTER_OP("relax.dist.annotate_sharding") .set_num_inputs(1) .add_argument("input", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoAnnotateSharding) - .set_attr("dist.FInferStructInfo", InferStructInfoAnnotateSharding) + .set_attr("FInferType", InferTypeAnnotateSharding) + .set_attr("dist.FInferType", InferTypeAnnotateSharding) .set_attr("FPurity", true); /* relax.dist.redistribute */ @@ -85,29 +85,28 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.dist.redistribute", redistribute); } -StructInfo InferDistStructInfoRedistribute(const Call& call, const BlockBuilder& ctx) { +Type InferDistTypeRedistribute(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); - const auto* sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(sinfo); - return distributed::DTensorStructInfo(sinfo->tensor_sinfo, attrs->device_mesh, attrs->placement); + const auto* ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(ty); + return distributed::DTensorType(ty->tensor_ty, attrs->device_mesh, attrs->placement); } TVM_REGISTER_OP("relax.dist.redistribute") .set_num_inputs(1) .add_argument("input", "Tensor", "The input tensor.") - .set_attr("dist.FInferStructInfo", InferDistStructInfoRedistribute) + .set_attr("dist.FInferType", InferDistTypeRedistribute) .set_attr("FPurity", true); -StructInfo InferStructInfoCallTIRLocalView(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.size() != 1) { - TVM_FFI_VISIT_THROW(InternalError, call) - << "sinfo_args should have exactly 1 output struct info."; +Type InferTypeCallTIRLocalView(const Call& call, const BlockBuilder& ctx) { + if (call->ty_args.size() != 1) { + TVM_FFI_VISIT_THROW(InternalError, call) << "ty_args should have exactly 1 output type."; } TVM_FFI_ICHECK(call->args[0]->IsInstance()) << "call_tir_local_view expects the first argument to be a GlobalVar referring to a TIR " "PrimFunc. " << "However, gets " << call->args[0]; - return call->sinfo_args[0]; + return call->ty_args[0]; } TVM_REGISTER_OP("relax.dist.call_tir_local_view") @@ -117,34 +116,33 @@ TVM_REGISTER_OP("relax.dist.call_tir_local_view") .add_argument("packed_ints", "Expr", "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") - .set_attr("FInferStructInfo", InferStructInfoCallTIRLocalView) + .set_attr("FInferType", InferTypeCallTIRLocalView) .set_attr("FPurity", true); -Expr MakeCallTIRLocalView(Expr func, Tuple args, - ffi::Array out_sinfo_list, +Expr MakeCallTIRLocalView(Expr func, Tuple args, ffi::Array out_ty_list, ffi::Optional packed_ints) { - for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) { - const auto* shape = sinfo->tensor_sinfo->shape.as(); + for (const distributed::DTensorType& ty : out_ty_list) { + const auto* shape = ty->tensor_ty->shape.as(); TVM_FFI_ICHECK(shape != nullptr) - << "out_sinfo of call_tir_local_view should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + << "out_ty of call_tir_local_view should have defined ShapeExpr as shape. " + "However, one given type information is " + << ty; } - StructInfo out_sinfo{nullptr}; - if (out_sinfo_list.size() == 1) { - out_sinfo = out_sinfo_list[0]; + Type out_ty{nullptr}; + if (out_ty_list.size() == 1) { + out_ty = out_ty_list[0]; } else { - out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + out_ty = TupleType({out_ty_list.begin(), out_ty_list.end()}); } static const Op& op = Op::Get("relax.dist.call_tir_local_view"); Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, {}, {out_sinfo}); + call = Call(op, {func, args}, {}, {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); + call = Call(op, {func, args, packed_ints.value()}, {}, {out_ty}); } return call; } @@ -154,15 +152,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.dist.call_tir_local_view", MakeCallTIRLocalView); } -StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - DataType output_dtype = input_sinfo->dtype; +Type InferTypeRtoS(const Call& call, const BlockBuilder& ctx) { + TensorType input_ty = GetUnaryInputTensorType(call, ctx); + DataType output_dtype = input_ty->dtype; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; arith::Analyzer analyzer = ctx->GetAnalyzer(); - auto input_shape = input_sinfo->GetShape(); + auto input_shape = input_ty->GetShape(); TVM_FFI_ICHECK(input_shape.defined()) << "input tensor of redistribute_replica_to_shard should have defined shape."; @@ -178,19 +176,19 @@ StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { ffi::Array output_shape = input_shape.value(); output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers)); - return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); + return TensorType(ShapeExpr(output_shape), output_dtype, input_ty->vdevice); } -StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { +Type InferDistTypeRtoS(const Call& call, const BlockBuilder& ctx) { using namespace distributed; - ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); - TVM_FFI_ICHECK(input_dtensor_sinfos.size() == 1); - DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; - TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TVM_FFI_ICHECK(input_dtensor_tys.size() == 1); + DTensorType input_dtensor_ty = input_dtensor_tys[0]; + TensorType tensor_ty = input_dtensor_ty->tensor_ty; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; arith::Analyzer analyzer = ctx->GetAnalyzer(); - auto input_shape = tensor_sinfo->GetShape(); + auto input_shape = tensor_ty->GetShape(); TVM_FFI_ICHECK(input_shape.defined()) << "input tensor of redistribute_replica_to_shard should have defined shape."; @@ -204,12 +202,12 @@ StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { << " while num_workers is " << num_workers; } - DeviceMesh device_mesh = input_dtensor_sinfo->device_mesh; + DeviceMesh device_mesh = input_dtensor_ty->device_mesh; // FIXME: this is a hack where there's only 1d mesh TVM_FFI_ICHECK(device_mesh->shape.size() == 1); - TVM_FFI_ICHECK(input_dtensor_sinfo->placement->dim_specs[0]->kind == PlacementSpecKind::kReplica); - return DTensorStructInfo(tensor_sinfo, device_mesh, - Placement::FromText("S[" + std::to_string(attrs->axis) + "]")); + TVM_FFI_ICHECK(input_dtensor_ty->placement->dim_specs[0]->kind == PlacementSpecKind::kReplica); + return DTensorType(tensor_ty, device_mesh, + Placement::FromText("S[" + std::to_string(attrs->axis) + "]")); } Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) { @@ -231,8 +229,8 @@ TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard") .set_num_inputs(1) .add_argument("input", "Tensor", "The buffer to be sliced.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoRtoS) - .set_attr("dist.FInferStructInfo", InferDistStructInfoRtoS) + .set_attr("FInferType", InferTypeRtoS) + .set_attr("dist.FInferType", InferDistTypeRtoS) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/distributed/linear_algebra.cc b/src/relax/op/distributed/linear_algebra.cc index 5bcc5888e711..b1c9903cbe12 100644 --- a/src/relax/op/distributed/linear_algebra.cc +++ b/src/relax/op/distributed/linear_algebra.cc @@ -26,26 +26,25 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_dtensor_sinfos = - GetInputDTensorStructInfo(call, ctx); - TensorStructInfo x1_sinfo, x2_sinfo; - x1_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; - x2_sinfo = input_dtensor_sinfos[1]->tensor_sinfo; +Type InferDistTypeMatmul(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TensorType x1_ty, x2_ty; + x1_ty = input_dtensor_tys[0]->tensor_ty; + x2_ty = input_dtensor_tys[1]->tensor_ty; const auto* attrs = call->attrs.as(); DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo) + ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) : attrs->out_dtype; - if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + if (x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Matmul requires both inputs to have known ndim. However, " - << (x1_sinfo->IsUnknownNdim() ? "x1" : "x2") << " has unknown ndim."; + << (x1_ty->IsUnknownNdim() ? "x1" : "x2") << " has unknown ndim."; } - int x1_ndim = x1_sinfo->ndim; - int x2_ndim = x2_sinfo->ndim; + int x1_ndim = x1_ty->ndim; + int x2_ndim = x2_ty->ndim; if (x1_ndim == 0 || x2_ndim == 0) { TVM_FFI_VISIT_THROW(ValueError, call) << "Matmul requires both inputs to have at least 1 dimension. However, " @@ -64,8 +63,8 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) } int output_ndim = std::max(x1_ndim, x2_ndim) - x1_prepended - x2_appended; - const auto* x1_shape = x1_sinfo->shape.as(); - const auto* x2_shape = x2_sinfo->shape.as(); + const auto* x1_shape = x1_ty->shape.as(); + const auto* x2_shape = x2_ty->shape.as(); if (x1_shape == nullptr || x2_shape == nullptr) { TVM_FFI_VISIT_THROW(ValueError, call) << "input of distributed operator must have shape"; } @@ -78,7 +77,7 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); TVM_FFI_ICHECK(output_shape_prefix.defined()) << "Failed to infer output shape of Matmul"; arith::Analyzer analyzer = ctx->GetAnalyzer(); - PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1]; + PrimExpr x1_reduction_length = x1_shape->values[x1_ty->ndim - 1]; PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2]; if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -95,11 +94,10 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) output_shape.push_back(x2_shape->values[x2_ndim - 1]); } TVM_FFI_ICHECK_EQ(static_cast(output_shape.size()), output_ndim); - TensorStructInfo output_tensor_sinfo(ShapeExpr(output_shape), out_dtype); - return InferShardingSpec(call, ctx, output_tensor_sinfo, distributed::BuildAxisGraphMatmul); + TensorType output_tensor_ty(ShapeExpr(output_shape), out_dtype); + return InferShardingSpec(call, ctx, output_tensor_ty, distributed::BuildAxisGraphMatmul); } -TVM_REGISTER_OP("relax.matmul") - .set_attr("dist.FInferStructInfo", InferDistStructInfoMatmul); +TVM_REGISTER_OP("relax.matmul").set_attr("dist.FInferType", InferDistTypeMatmul); } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/linear_algebra.h b/src/relax/op/distributed/linear_algebra.h index b7aace8e2469..046f3dfcc204 100644 --- a/src/relax/op/distributed/linear_algebra.h +++ b/src/relax/op/distributed/linear_algebra.h @@ -19,7 +19,7 @@ /*! * \file linear_algebra.h - * \brief The functions to infer struct info for distributed linear algebra operator + * \brief The functions to infer type for distributed linear algebra operator */ #ifndef TVM_RELAX_OP_DISTRIBUTED_LINEAR_ALGEBRA_H_ @@ -30,7 +30,7 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx); +Type InferDistTypeMatmul(const Call& call, const BlockBuilder& ctx); } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/manipulate.cc b/src/relax/op/distributed/manipulate.cc index bb98b2fa64e0..bdff1b00ba97 100644 --- a/src/relax/op/distributed/manipulate.cc +++ b/src/relax/op/distributed/manipulate.cc @@ -30,109 +30,105 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_dtensor_sinfos = - GetInputDTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; +Type InferDistTypePermuteDims(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TensorType data_ty = input_dtensor_tys[0]->tensor_ty; const auto* attrs = call->attrs.as(); // Todo(relax-team): revisit here for better check on if the input tensor has // ndim same as the number of input axes. - if (!attrs->axes.defined() && data_sinfo->IsUnknownNdim()) { + if (!attrs->axes.defined() && data_ty->IsUnknownNdim()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Input of distributed operator must have known ndim"; } if (attrs->axes.defined()) { int n_axis = attrs->axes.value().size(); - if (!data_sinfo->IsUnknownNdim() && n_axis != data_sinfo->ndim) { + if (!data_ty->IsUnknownNdim() && n_axis != data_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "PermuteDims expects the number of input axes to equal the ndim of the " "input tensor. However, the tensor ndim is " - << data_sinfo->ndim << " while the given number of axes is " << n_axis; + << data_ty->ndim << " while the given number of axes is " << n_axis; } } std::vector axes; if (attrs->axes.defined()) { - axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes.value()); + axes = NormalizeAxes(call, ctx, data_ty->ndim, attrs->axes.value()); } else { // Construct the reverse permutation via std::iota - axes.resize(data_sinfo->ndim); + axes.resize(data_ty->ndim); std::iota(axes.rbegin(), axes.rend(), 0); } if (IsIdentityPermutation(axes)) { - return input_dtensor_sinfos[0]; + return input_dtensor_tys[0]; } - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { TVM_FFI_VISIT_THROW(ValueError, call) << "Input of distributed operator must have known shape"; } std::vector new_shape; - new_shape.reserve(data_sinfo->ndim); - for (int i = 0; i < data_sinfo->ndim; ++i) { + new_shape.reserve(data_ty->ndim); + for (int i = 0; i < data_ty->ndim; ++i) { new_shape.push_back(data_shape->values[axes[i]]); } - TensorStructInfo output_tensor_sinfo(ShapeExpr(new_shape), data_sinfo->dtype); - return InferShardingSpec(call, ctx, output_tensor_sinfo, distributed::BuildAxisGraphPermuteDims); + TensorType output_tensor_ty(ShapeExpr(new_shape), data_ty->dtype); + return InferShardingSpec(call, ctx, output_tensor_ty, distributed::BuildAxisGraphPermuteDims); } TVM_REGISTER_OP("relax.permute_dims") - .set_attr("dist.FInferStructInfo", InferDistStructInfoPermuteDims); + .set_attr("dist.FInferType", InferDistTypePermuteDims); -StructInfo InferDistStructInfoReshape(const Call& call, const BlockBuilder& ctx) { +Type InferDistTypeReshape(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Reshape op should take 2 arguments"; } - ffi::Array input_dtensor_sinfos = - GetInputDTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TensorType data_ty = input_dtensor_tys[0]->tensor_ty; - const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); - if (!data_sinfo.defined()) { + const auto* new_shape_ty = GetTypeAs(call->args[1]); + if (!data_ty.defined()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Reshape requires the input data to be Tensor. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (new_shape_sinfo == nullptr) { + if (new_shape_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Reshape requires the input new shape to be Shape. However, the given one is " - << call->args[1]->struct_info_->GetTypeKey(); + << call->args[1]->ty->GetTypeKey(); } ffi::Optional> old_shape_values; - if (data_sinfo->shape.defined()) { - const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); - TVM_FFI_ICHECK_NOTNULL(old_shape_sinfo); - old_shape_values = old_shape_sinfo->values; + if (data_ty->shape.defined()) { + const auto* old_shape_ty = GetTypeAs(data_ty->shape.value()); + TVM_FFI_ICHECK_NOTNULL(old_shape_ty); + old_shape_values = old_shape_ty->values; } - if (new_shape_sinfo->values.defined() && old_shape_values.defined()) { - PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_sinfo->values.value()); + if (new_shape_ty->values.defined() && old_shape_values.defined()) { + PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_ty->values.value()); PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value()); if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) { TVM_FFI_VISIT_THROW(ValueError, call) << "Reshape expects the new shape to be convertible from the old shape. " "However, the old shape is " - << data_sinfo->shape << ", with product " << old_shape_prod << ", while the new shape is " + << data_ty->shape << ", with product " << old_shape_prod << ", while the new shape is " << call->args[1] << ", with product " << new_shape_prod; } } Expr target_shape = call->args[1]; - TensorStructInfo output_tensor_sinfo; + TensorType output_tensor_ty; // If shape values are defined, use them - if (target_shape->IsInstance() && new_shape_sinfo->values.defined()) { - output_tensor_sinfo = - TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype); + if (target_shape->IsInstance() && new_shape_ty->values.defined()) { + output_tensor_ty = TensorType(ShapeExpr(new_shape_ty->values.value()), data_ty->dtype); } else { - output_tensor_sinfo = TensorStructInfo(target_shape, data_sinfo->dtype); + output_tensor_ty = TensorType(target_shape, data_ty->dtype); } - return InferShardingSpec(call, ctx, output_tensor_sinfo, distributed::BuildAxisGraphReshape); + return InferShardingSpec(call, ctx, output_tensor_ty, distributed::BuildAxisGraphReshape); } -TVM_REGISTER_OP("relax.reshape") - .set_attr("dist.FInferStructInfo", InferDistStructInfoReshape); +TVM_REGISTER_OP("relax.reshape").set_attr("dist.FInferType", InferDistTypeReshape); } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/manipulate.h b/src/relax/op/distributed/manipulate.h index 10b7bdcb6a33..db05a03d44f0 100644 --- a/src/relax/op/distributed/manipulate.h +++ b/src/relax/op/distributed/manipulate.h @@ -19,7 +19,7 @@ /*! * \file manipulate.h - * \brief The functions to infer struct info for distributed manipulate operator + * \brief The functions to infer type for distributed manipulate operator */ #ifndef TVM_RELAX_OP_DISTRIBUTED_MANIPULATE_H_ @@ -30,9 +30,9 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx); +Type InferDistTypePermuteDims(const Call& call, const BlockBuilder& ctx); -StructInfo InferDistStructInfoReshape(const Call& call, const BlockBuilder& ctx); +Type InferDistTypeReshape(const Call& call, const BlockBuilder& ctx); } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/nn.cc b/src/relax/op/distributed/nn.cc index 36b77282a012..0890362e449d 100644 --- a/src/relax/op/distributed/nn.cc +++ b/src/relax/op/distributed/nn.cc @@ -25,28 +25,26 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_dtensor_sinfos = - GetInputDTensorStructInfo(call, ctx); - TVM_FFI_ICHECK(input_dtensor_sinfos.size() == 1); - TensorStructInfo input_tensor_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; +Type InferDistTypeSoftmax(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TVM_FFI_ICHECK(input_dtensor_tys.size() == 1); + TensorType input_tensor_ty = input_dtensor_tys[0]->tensor_ty; - if (input_tensor_sinfo->IsUnknownNdim()) { + if (input_tensor_ty->IsUnknownNdim()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Input of distributed operator must have known ndim"; } - if (!input_tensor_sinfo->IsUnknownDtype() && !input_tensor_sinfo->dtype.is_float()) { + if (!input_tensor_ty->IsUnknownDtype() && !input_tensor_ty->dtype.is_float()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Softmax requires the input tensor to have float " "dtype. However, the given input dtype is " - << input_tensor_sinfo->dtype; + << input_tensor_ty->dtype; } const auto* attrs = call->attrs.as(); - NormalizeAxis(call, ctx, input_tensor_sinfo->ndim, attrs->axis); + NormalizeAxis(call, ctx, input_tensor_ty->ndim, attrs->axis); - return InferShardingSpec(call, ctx, input_tensor_sinfo, distributed::BuildAxisGraphReduce); + return InferShardingSpec(call, ctx, input_tensor_ty, distributed::BuildAxisGraphReduce); } -TVM_REGISTER_OP("relax.nn.softmax") - .set_attr("dist.FInferStructInfo", InferDistStructInfoSoftmax); +TVM_REGISTER_OP("relax.nn.softmax").set_attr("dist.FInferType", InferDistTypeSoftmax); /* relax.nn.relu */ RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(nn.relu, /*require_float_dtype=*/false); diff --git a/src/relax/op/distributed/nn.h b/src/relax/op/distributed/nn.h index 9ecb378fdb49..1e7dee0d9bd4 100644 --- a/src/relax/op/distributed/nn.h +++ b/src/relax/op/distributed/nn.h @@ -19,7 +19,7 @@ /*! * \file nn.h - * \brief The functions to infer struct info for distributed nn operator + * \brief The functions to infer type for distributed nn operator */ #ifndef TVM_RELAX_OP_DISTRIBUTED_NN_H_ @@ -31,7 +31,7 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoSoftmax(const Call& call, const BlockBuilder& ctx); +Type InferDistTypeSoftmax(const Call& call, const BlockBuilder& ctx); } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/op.cc b/src/relax/op/distributed/op.cc index 71000390e523..d4bc049d90ad 100644 --- a/src/relax/op/distributed/op.cc +++ b/src/relax/op/distributed/op.cc @@ -25,29 +25,27 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.size() != 1) { - TVM_FFI_VISIT_THROW(InternalError, call) - << "sinfo_args should have exact 1 output struct info."; +Type InferDistTypeCallTIR(const Call& call, const BlockBuilder& ctx) { + if (call->ty_args.size() != 1) { + TVM_FFI_VISIT_THROW(InternalError, call) << "ty_args should have exact 1 output type."; } TVM_FFI_ICHECK(call->args[0]->IsInstance()) << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " << "However, gets " << call->args[0]; - return call->sinfo_args[0]; + return call->ty_args[0]; } -TVM_REGISTER_OP("relax.call_tir") - .set_attr("dist.FInferStructInfo", InferDistStructInfoCallTIR); +TVM_REGISTER_OP("relax.call_tir").set_attr("dist.FInferType", InferDistTypeCallTIR); -StructInfo InferDistStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { +Type InferDistTypeStopLiftParams(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "stop_lift_params should have exact 1 arg."; } - return Downcast(call->args[0]->struct_info_.value()); + return Downcast(call->args[0]->ty); } TVM_REGISTER_OP("relax.builtin.stop_lift_params") - .set_attr("dist.FInferStructInfo", InferDistStructInfoStopLiftParams); + .set_attr("dist.FInferType", InferDistTypeStopLiftParams); } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/statistical.cc b/src/relax/op/distributed/statistical.cc index 5384219f884b..95d4c1026eea 100644 --- a/src/relax/op/distributed/statistical.cc +++ b/src/relax/op/distributed/statistical.cc @@ -26,27 +26,26 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_dtensor_sinfos = - GetInputDTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; +Type InferDistTypeStatistical(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TensorType data_ty = input_dtensor_tys[0]->tensor_ty; const auto* attrs = call->attrs.as(); std::vector axes; - if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { - axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + if (!data_ty->IsUnknownNdim() && attrs->axis.defined()) { + axes = NormalizeAxes(call, ctx, data_ty->ndim, attrs->axis.value()); } int out_ndim = 0; if (attrs->keepdims) { - out_ndim = data_sinfo->ndim; + out_ndim = data_ty->ndim; } else if (!attrs->axis.defined()) { out_ndim = 0; - } else if (data_sinfo->IsUnknownNdim()) { + } else if (data_ty->IsUnknownNdim()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Input of distributed operator must be known ndim"; } else { - out_ndim = data_sinfo->ndim - axes.size(); + out_ndim = data_ty->ndim - axes.size(); TVM_FFI_ICHECK_GE(out_ndim, 0); } @@ -57,14 +56,14 @@ StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& // - axes is not None, keepdims is false -> the returned shape does not contain the input axes. // - axes is not None, keepdims is true -> the returned shape has value 1 at the positions of the // input axes - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { TVM_FFI_VISIT_THROW(ValueError, call) << "Input of distributed operator must be known shape"; } ffi::Array out_shape; out_shape.reserve(out_ndim); - for (int i = 0; i < data_sinfo->ndim; ++i) { + for (int i = 0; i < data_ty->ndim; ++i) { if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { out_shape.push_back(data_shape->values[i]); } else if (attrs->keepdims) { @@ -72,9 +71,9 @@ StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& } } TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); - TensorStructInfo output_tensor_sinfo = TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); + TensorType output_tensor_ty = TensorType(ShapeExpr(out_shape), data_ty->dtype); - return InferShardingSpec(call, ctx, output_tensor_sinfo, distributed::BuildAxisGraphReduce); + return InferShardingSpec(call, ctx, output_tensor_ty, distributed::BuildAxisGraphReduce); } RELAX_REGISTER_STATISTICAL_DIST_INFER_STRUCT_INFO(max); RELAX_REGISTER_STATISTICAL_DIST_INFER_STRUCT_INFO(mean); diff --git a/src/relax/op/distributed/statistical.h b/src/relax/op/distributed/statistical.h index c54818335b2b..f993f651dcf0 100644 --- a/src/relax/op/distributed/statistical.h +++ b/src/relax/op/distributed/statistical.h @@ -19,7 +19,7 @@ /*! * \file statistical.h - * \brief The functions to infer struct info for distributed statistical operator + * \brief The functions to infer type for distributed statistical operator */ #ifndef TVM_RELAX_OP_DISTRIBUTED_STATISTICAL_H_ @@ -31,11 +31,11 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& ctx); +Type InferDistTypeStatistical(const Call& call, const BlockBuilder& ctx); #define RELAX_REGISTER_STATISTICAL_DIST_INFER_STRUCT_INFO(OpName) \ TVM_REGISTER_OP("relax." #OpName) \ - .set_attr("dist.FInferStructInfo", InferDistStructInfoStatistical) + .set_attr("dist.FInferType", InferDistTypeStatistical) } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/unary.cc b/src/relax/op/distributed/unary.cc index 4e62d93eecb4..0934d7aafa5a 100644 --- a/src/relax/op/distributed/unary.cc +++ b/src/relax/op/distributed/unary.cc @@ -23,9 +23,9 @@ namespace tvm { namespace relax { namespace distributed { -StructInfo InferDistStructInfoUnaryCheck(const Call& call, const BlockBuilder& ctx) { - return InferDistStructInfoUnary( - call, ctx, [](const TensorStructInfo& input_sinfo) { return DataType::Bool(); }); +Type InferDistTypeUnaryCheck(const Call& call, const BlockBuilder& ctx) { + return InferDistTypeUnary(call, ctx, + [](const TensorType& input_ty) { return DataType::Bool(); }); } RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(abs, /*require_float_dtype=*/false); diff --git a/src/relax/op/distributed/unary.h b/src/relax/op/distributed/unary.h index acf61f5998d3..b4b831e79b02 100644 --- a/src/relax/op/distributed/unary.h +++ b/src/relax/op/distributed/unary.h @@ -19,7 +19,7 @@ /*! * \file unary.h - * \brief The functions to infer struct info for distributed unary operator + * \brief The functions to infer type for distributed unary operator */ #ifndef TVM_RELAX_OP_DISTRIBUTED_UNARY_H_ @@ -34,44 +34,40 @@ namespace relax { namespace distributed { template -StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx, - FType f_compute_out_dtype) { - ffi::Array input_dtensor_sinfos = - GetInputDTensorStructInfo(call, ctx); - TVM_FFI_ICHECK(input_dtensor_sinfos.size() == 1); - distributed::DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; - TensorStructInfo input_tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; +Type InferDistTypeUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + TVM_FFI_ICHECK(input_dtensor_tys.size() == 1); + distributed::DTensorType input_dtensor_ty = input_dtensor_tys[0]; + TensorType input_tensor_ty = input_dtensor_ty->tensor_ty; - if (require_float_dtype && !input_tensor_sinfo->IsUnknownDtype() && - !input_tensor_sinfo->dtype.is_float()) { + if (require_float_dtype && !input_tensor_ty->IsUnknownDtype() && + !input_tensor_ty->dtype.is_float()) { TVM_FFI_VISIT_THROW(TypeError, call) << call->op << " requires the input tensor to have float dtype. However, the given input dtype is " - << input_tensor_sinfo->dtype; + << input_tensor_ty->dtype; } - auto output_sinfo = ffi::make_object(*input_tensor_sinfo.get()); - output_sinfo->dtype = f_compute_out_dtype(input_tensor_sinfo); - TensorStructInfo out_tensor_sinfo(output_sinfo); - return distributed::DTensorStructInfo(out_tensor_sinfo, input_dtensor_sinfo->device_mesh, - input_dtensor_sinfo->placement); + auto output_ty = ffi::make_object(*input_tensor_ty.get()); + output_ty->dtype = f_compute_out_dtype(input_tensor_ty); + TensorType out_tensor_ty(output_ty); + return distributed::DTensorType(out_tensor_ty, input_dtensor_ty->device_mesh, + input_dtensor_ty->placement); } template -StructInfo InferDistStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) { - return InferDistStructInfoUnary( - call, ctx, [](const TensorStructInfo& input_sinfo) { return input_sinfo->dtype; }); +Type InferDistTypeUnaryArith(const Call& call, const BlockBuilder& ctx) { + return InferDistTypeUnary( + call, ctx, [](const TensorType& input_ty) { return input_ty->dtype; }); } -StructInfo InferDistStructInfoUnaryCheck(const Call& call, const BlockBuilder& ctx); +Type InferDistTypeUnaryCheck(const Call& call, const BlockBuilder& ctx); #define RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(OpName, RequireFloatDtype) \ TVM_REGISTER_OP("relax." #OpName) \ - .set_attr("dist.FInferStructInfo", \ - InferDistStructInfoUnaryArith) + .set_attr("dist.FInferType", InferDistTypeUnaryArith) #define RELAX_REGISTER_UNARY_CHECK_DIST_INFER_STRUCT_INFO(OpName) \ - TVM_REGISTER_OP("relax." #OpName) \ - .set_attr("dist.FInferStructInfo", InferDistStructInfoUnaryCheck) + TVM_REGISTER_OP("relax." #OpName).set_attr("dist.FInferType", InferDistTypeUnaryCheck) } // namespace distributed } // namespace relax diff --git a/src/relax/op/distributed/utils.cc b/src/relax/op/distributed/utils.cc index 7a293711e8b8..a3a773e6ea94 100644 --- a/src/relax/op/distributed/utils.cc +++ b/src/relax/op/distributed/utils.cc @@ -26,65 +26,63 @@ namespace tvm { namespace relax { namespace distributed { -ffi::Array GetInputDTensorStructInfo(const Call& call, - const BlockBuilder& ctx) { +ffi::Array GetInputDTensorType(const Call& call, + const BlockBuilder& ctx) { Op op = Downcast(call->op); ffi::Array args = GetCallArgs(call); - ffi::Array input_tensor_sinfo; - input_tensor_sinfo.reserve(args.size()); + ffi::Array input_tensor_ty; + input_tensor_ty.reserve(args.size()); for (const Expr& arg : args) { - const auto* sinfo = GetStructInfoAs(arg); - if (sinfo != nullptr) { - input_tensor_sinfo.push_back(ffi::GetRef(sinfo)); + const auto* ty = GetTypeAs(arg); + if (ty != nullptr) { + input_tensor_ty.push_back(ffi::GetRef(ty)); } } - return input_tensor_sinfo; + return input_tensor_ty; } -StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, - const StructInfo& orig_output_sinfo, - distributed::FBuildAxisGraph f_build_graph) { - ffi::Array input_dtensor_sinfos = - GetInputDTensorStructInfo(call, ctx); - for (int i = 1; i < static_cast(input_dtensor_sinfos.size()); i++) { - TVM_FFI_ICHECK(ffi::StructuralEqual()(input_dtensor_sinfos[0]->device_mesh, - input_dtensor_sinfos[i]->device_mesh)); +Type InferShardingSpec(const Call& call, const BlockBuilder& ctx, const Type& orig_output_ty, + distributed::FBuildAxisGraph f_build_graph) { + ffi::Array input_dtensor_tys = GetInputDTensorType(call, ctx); + for (int i = 1; i < static_cast(input_dtensor_tys.size()); i++) { + TVM_FFI_ICHECK(ffi::StructuralEqual()(input_dtensor_tys[0]->device_mesh, + input_dtensor_tys[i]->device_mesh)); } - distributed::DeviceMesh device_mesh = input_dtensor_sinfos[0]->device_mesh; - Var output_var("output", orig_output_sinfo); + distributed::DeviceMesh device_mesh = input_dtensor_tys[0]->device_mesh; + Var output_var("output", orig_output_ty); distributed::AxisGroupGraph axis_group_graph; f_build_graph(output_var, call, &axis_group_graph); ffi::Array args = GetCallArgs(call); - int n_input_var = input_dtensor_sinfos.size(); + int n_input_var = input_dtensor_tys.size(); for (int i = 0; i < n_input_var; i++) { - distributed::DTensorStructInfo dtensor_sinfo = input_dtensor_sinfos[i]; + distributed::DTensorType dtensor_ty = input_dtensor_tys[i]; Expr input_tensor = args[i]; for (int j = 0; j < static_cast(device_mesh->shape.size()); j++) { - distributed::PlacementSpec placement_spec = dtensor_sinfo->placement->dim_specs[j]; + distributed::PlacementSpec placement_spec = dtensor_ty->placement->dim_specs[j]; if (placement_spec->kind != distributed::PlacementSpecKind::kSharding) { continue; } axis_group_graph.AddSrcShardingPoint({input_tensor.get(), placement_spec->axis}, - {dtensor_sinfo->device_mesh, j}); + {dtensor_ty->device_mesh, j}); } } axis_group_graph.PropagateShardingSpec(); - ffi::Array orig_output_tensor_sinfos; - if (const auto* tensor_sinfo = orig_output_sinfo.as()) { - orig_output_tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); + ffi::Array orig_output_tensor_tys; + if (const auto* tensor_ty = orig_output_ty.as()) { + orig_output_tensor_tys.push_back(ffi::GetRef(tensor_ty)); } else { - const auto* tuple_sinfo = orig_output_sinfo.as(); - TVM_FFI_ICHECK(tuple_sinfo); - for (const auto& sinfo : tuple_sinfo->fields) { - orig_output_tensor_sinfos.push_back(Downcast(sinfo)); + const auto* tuple_ty = orig_output_ty.as(); + TVM_FFI_ICHECK(tuple_ty); + for (const auto& ty : tuple_ty->fields) { + orig_output_tensor_tys.push_back(Downcast(ty)); } } - ffi::Array new_output_dtensor_sinfos; - for (int idx = 0; idx < static_cast(orig_output_tensor_sinfos.size()); idx++) { + ffi::Array new_output_dtensor_tys; + for (int idx = 0; idx < static_cast(orig_output_tensor_tys.size()); idx++) { ffi::Array output_placement_specs( std::vector(device_mesh->shape.size(), distributed::PlacementSpec::Replica())); - for (int i = 0; i < orig_output_tensor_sinfos[idx]->ndim; i++) { + for (int i = 0; i < orig_output_tensor_tys[idx]->ndim; i++) { distributed::AxisShardingSpec sharding_spec; bool has_sharding_spec; std::tie(sharding_spec, has_sharding_spec) = @@ -93,13 +91,12 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, output_placement_specs.Set(sharding_spec.second, distributed::PlacementSpec::Sharding(i)); } } - new_output_dtensor_sinfos.push_back( - DTensorStructInfo(orig_output_tensor_sinfos[idx], device_mesh, - distributed::Placement(output_placement_specs))); + new_output_dtensor_tys.push_back(DTensorType(orig_output_tensor_tys[idx], device_mesh, + distributed::Placement(output_placement_specs))); } - return new_output_dtensor_sinfos.size() == 1 ? new_output_dtensor_sinfos[0] - : TupleStructInfo(new_output_dtensor_sinfos); + return new_output_dtensor_tys.size() == 1 ? new_output_dtensor_tys[0] + : TupleType(new_output_dtensor_tys); } } // namespace distributed diff --git a/src/relax/op/distributed/utils.h b/src/relax/op/distributed/utils.h index 125a2d242ba5..78ac15755811 100644 --- a/src/relax/op/distributed/utils.h +++ b/src/relax/op/distributed/utils.h @@ -19,7 +19,7 @@ /*! * \file utils.h - * \brief The util function for dtensor infer struct info + * \brief The util function for dtensor infer type */ #ifndef TVM_RELAX_OP_DISTRIBUTED_UTILS_H_ @@ -36,28 +36,26 @@ namespace relax { namespace distributed { /*! - * \brief Get the dtensor struct info of the operator input. + * \brief Get the dtensor type of the operator input. * \param call The context Call to the operator. * \param ctx The error reporting context. - * \return The dtensor struct info of each input. + * \return The dtensor type of each input. * \note This function require every input tensor to be DTensor. */ -ffi::Array GetInputDTensorStructInfo(const Call& call, - const BlockBuilder& ctx); +ffi::Array GetInputDTensorType(const Call& call, const BlockBuilder& ctx); /*! * \brief Perform a local sharding spec propagation to infer the output dtensor - struct info or tuple of dtensor struct info. + type or tuple of dtensor type. * * \param call The context Call to the operator. * \param ctx The error reporting context. - * \param output_sinfo The original output struct info + * \param output_ty The original output type * \param f_build_graph The function to build axis graph - * \return The inferred output struct info + * \return The inferred output type */ -StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, - const StructInfo& orig_output_sinfo, - distributed::FBuildAxisGraph f_build_graph); +Type InferShardingSpec(const Call& call, const BlockBuilder& ctx, const Type& orig_output_ty, + distributed::FBuildAxisGraph f_build_graph); } // namespace distributed } // namespace relax diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index eab2f6849956..beb89af08777 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -62,30 +62,30 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.image.resize2d", resize2d); } -StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { +Type InferTypeResize2D(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Resize2D expects 2 arguments, while the given number of arguments is " << call->args.size(); } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* size_sinfo = GetStructInfoAs(call->args[1]); + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* size_ty = GetTypeAs(call->args[1]); const auto* size_value = call->args[1].as(); - if (data_sinfo == nullptr) { + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Resize2D expects the input data to be a Tensor, while the given data is " << call->args[0]->GetTypeKey(); } - if (size_sinfo == nullptr) { + if (size_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Resize2D expects the given output image size to be a Shape, while the given one is " << call->args[1]->GetTypeKey(); } - if (size_sinfo->ndim != 2) { + if (size_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Resize2D expects the given output image size to " "be a 2-dim shape, while the given one has ndim " - << size_sinfo->ndim; + << size_ty->ndim; } const auto* attrs = call->attrs.as(); @@ -93,12 +93,12 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCHW", // /*tensor_name=*/"data"); - DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype; + DataType out_dtype = attrs->out_dtype.is_void() ? data_ty->dtype : attrs->out_dtype; - ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape( - call, ctx, ffi::GetRef(data_sinfo), data_layout); + ffi::Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); if (!data_shape.defined() || size_value == nullptr) { - return TensorStructInfo(out_dtype, data_layout.ndim(), data_sinfo->vdevice); + return TensorType(out_dtype, data_layout.ndim(), data_ty->vdevice); } ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); @@ -107,7 +107,7 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape.Set(3, size_value->values[1]); ffi::Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutResize2d( @@ -145,7 +145,7 @@ TVM_REGISTER_OP("relax.image.resize2d") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("size", "Shape", "The output image shape.") - .set_attr("FInferStructInfo", InferStructInfoResize2D) + .set_attr("FInferType", InferTypeResize2D) .set_attr("FRelaxInferLayout", InferLayoutResize2d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -176,30 +176,30 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.image.resize3d", resize3d); } -StructInfo InferStructInfoResize3D(const Call& call, const BlockBuilder& ctx) { +Type InferTypeResize3D(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Resize3D expects 2 arguments, while the given number of arguments is " << call->args.size(); } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* size_sinfo = GetStructInfoAs(call->args[1]); + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* size_ty = GetTypeAs(call->args[1]); const auto* size_value = call->args[1].as(); - if (data_sinfo == nullptr) { + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Resize3D expects the input data to be a Tensor, while the given data is " << call->args[0]->GetTypeKey(); } - if (size_sinfo == nullptr) { + if (size_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Resize3D expects the given output image size to be a Shape, while the given one is " << call->args[1]->GetTypeKey(); } - if (size_sinfo->ndim != 3) { + if (size_ty->ndim != 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "Resize3D expects the given output image size to " "be a 3-dim shape, while the given one has ndim " - << size_sinfo->ndim; + << size_ty->ndim; } const auto* attrs = call->attrs.as(); @@ -207,12 +207,12 @@ StructInfo InferStructInfoResize3D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCDHW", // /*tensor_name=*/"data"); - DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype; + DataType out_dtype = attrs->out_dtype.is_void() ? data_ty->dtype : attrs->out_dtype; - ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape( - call, ctx, ffi::GetRef(data_sinfo), data_layout); + ffi::Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); if (!data_shape.defined() || size_value == nullptr) { - return TensorStructInfo(out_dtype, data_layout.ndim(), data_sinfo->vdevice); + return TensorType(out_dtype, data_layout.ndim(), data_ty->vdevice); } ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); @@ -222,7 +222,7 @@ StructInfo InferStructInfoResize3D(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape.Set(4, size_value->values[2]); ffi::Array out_shape = data2NCDHW.BackwardShape(out_NCDHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutResize3d( @@ -257,7 +257,7 @@ TVM_REGISTER_OP("relax.image.resize3d") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("size", "Shape", "The output image shape.") - .set_attr("FInferStructInfo", InferStructInfoResize3D) + .set_attr("FInferType", InferTypeResize3D) .set_attr("FRelaxInferLayout", InferLayoutResize3d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -283,22 +283,22 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.image.grid_sample", grid_sample); } -StructInfo InferStructInfoGridSample(const Call& call, const BlockBuilder& ctx) { +Type InferTypeGridSample(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "GridSample expects two arguments, while the given number of arguments is " << call->args.size(); } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* grid_sinfo = GetStructInfoAs(call->args[1]); + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* grid_ty = GetTypeAs(call->args[1]); - if (data_sinfo == nullptr) { + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "GridSample expects the input data to be a Tensor, while the given data is " << call->args[0]->GetTypeKey(); } - if (grid_sinfo == nullptr) { + if (grid_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "GridSample expects the grid to be a Tensor, while the given grid is " << call->args[1]->GetTypeKey(); @@ -315,14 +315,14 @@ StructInfo InferStructInfoGridSample(const Call& call, const BlockBuilder& ctx) /*tgt_layout=*/is_ncdhw ? "NCDHW" : "NCHW", /*tensor_name=*/"data"); - DataType out_dtype = data_sinfo->dtype; + DataType out_dtype = data_ty->dtype; - ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape( - call, ctx, ffi::GetRef(data_sinfo), data_layout); - const auto* grid_shape = grid_sinfo->shape.as(); + ffi::Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); + const auto* grid_shape = grid_ty->shape.as(); if (!data_shape.defined() || grid_shape == nullptr) { - return TensorStructInfo(out_dtype, data_layout.ndim(), data_sinfo->vdevice); + return TensorType(out_dtype, data_layout.ndim(), data_ty->vdevice); } ffi::Array data_tgt_shape = data2tgt.ForwardShape(data_shape.value()->values); @@ -340,7 +340,7 @@ StructInfo InferStructInfoGridSample(const Call& call, const BlockBuilder& ctx) } ffi::Array out_shape = data2tgt.BackwardShape(out_tgt_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.image.grid_sample") @@ -348,7 +348,7 @@ TVM_REGISTER_OP("relax.image.grid_sample") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("grid", "Tensor", "The grid tensor for sampling.") - .set_attr("FInferStructInfo", InferStructInfoGridSample) + .set_attr("FInferType", InferTypeGridSample) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -364,42 +364,42 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.image.affine_grid", affine_grid); } -StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder& ctx) { +Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects two arguments, while the given number of arguments is " << call->args.size(); } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* size_sinfo = GetStructInfoAs(call->args[1]); + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* size_ty = GetTypeAs(call->args[1]); const auto* size_value = call->args[1].as(); - if (data_sinfo == nullptr) { + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "AffineGrid expects the input data to be a Tensor, while the given data is " << call->args[0]->GetTypeKey(); } - if (size_sinfo == nullptr) { + if (size_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "AffineGrid expects the target size to be a Shape, while the given one is " << call->args[1]->GetTypeKey(); } - if (size_sinfo->ndim != 2) { + if (size_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the target size to be a 2-dim shape, while the given " "one has ndim " - << size_sinfo->ndim; + << size_ty->ndim; } // data should be 3-D: [batch, 2, 3] - if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) { + if (data_ty->ndim != -1 && data_ty->ndim != 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got ndim " - << data_sinfo->ndim; + << data_ty->ndim; } - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape != nullptr) { // Check that the affine matrix has shape [batch, 2, 3] if (data_shape->values.size() >= 2) { @@ -418,10 +418,10 @@ StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder& ctx) } } - DataType out_dtype = data_sinfo->dtype; + DataType out_dtype = data_ty->dtype; if (data_shape == nullptr || size_value == nullptr) { - return TensorStructInfo(out_dtype, /*ndim=*/4, data_sinfo->vdevice); + return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice); } // Output shape: [batch, 2, target_height, target_width] @@ -431,14 +431,14 @@ StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder& ctx) out_shape.push_back(size_value->values[0]); // target_height out_shape.push_back(size_value->values[1]); // target_width - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.image.affine_grid") .set_num_inputs(2) .add_argument("data", "Tensor", "The input affine matrix tensor.") .add_argument("size", "Shape", "The target output shape (H, W).") - .set_attr("FInferStructInfo", InferStructInfoAffineGrid) + .set_attr("FInferType", InferTypeAffineGrid) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 10b42f8002c2..1b21432b8d7f 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -49,7 +49,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.memory.view", view); } -StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { +Type InferTypeView(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 4) { TVM_FFI_VISIT_THROW(ValueError, call) << "Operator " << call->op << " should receive 4 arguments, " @@ -60,37 +60,37 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { Expr arg_dtype = call->args[2]; Expr arg_relative_byte_offset = call->args[3]; - TensorStructInfo data_sinfo = [&]() -> TensorStructInfo { - StructInfo sinfo = GetStructInfo(arg_data); - if (auto opt = sinfo.as()) { + TensorType data_ty = [&]() -> TensorType { + Type ty = GetType(arg_data); + if (auto opt = ty.as()) { return opt.value(); } else { TVM_FFI_THROW(TypeError) << "Operator " << call->op << " expects first argument to be a tensor, " - << "but received " << arg_data << " with type " << sinfo; + << "but received " << arg_data << " with type " << ty; } }(); - auto view_shape_sinfo = [&]() -> const ShapeStructInfoNode* { - StructInfo sinfo = GetStructInfo(arg_shape); - if (HasVoidStructInfo(arg_shape)) { + auto view_shape_ty = [&]() -> const ShapeTypeNode* { + Type ty = GetType(arg_shape); + if (HasVoidType(arg_shape)) { // No shape change is applied. The input tensor's shape is // kept as-is. return nullptr; - } else if (auto ptr = sinfo.as()) { + } else if (auto ptr = ty.as()) { // The `R.view` operation returns a different shape. return ptr; } else { TVM_FFI_THROW(TypeError) << "Operator " << call->op << " expects second argument to be a ShapeExpr, " << "or a void-type (empty relax tuple), " - << "but received " << arg_shape << " with type " << sinfo; + << "but received " << arg_shape << " with type " << ty; } }(); auto view_dtype = [&]() -> std::optional { - StructInfo sinfo = GetStructInfo(arg_dtype); + Type ty = GetType(arg_dtype); - if (HasVoidStructInfo(arg_dtype)) { + if (HasVoidType(arg_dtype)) { // No datatype change is applied. The input tensor's dtype is // kept as-is. return std::nullopt; @@ -105,40 +105,40 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } } - // In general, StructInfo inference should only depend on the - // StructInfo of the arguments, and not on the arguments + // In general, Type inference should only depend on the + // Type of the arguments, and not on the arguments // themselves. However, `relax::DataTypeImm` uses - // `ObjectStructInfo`, so we need to inspect the argument itself + // `ObjectType`, so we need to inspect the argument itself // in this case. if (auto dtype_imm = arg_value.as()) { // We know the datatype for the view. return dtype_imm->value; - } else if (sinfo.as()) { + } else if (ty.as()) { // The view changes the datatype, but we don't know what it is // being changed into. return DataType::Void(); } else { TVM_FFI_THROW(TypeError) << "Operator " << call->op << " expects the dtype argument to be a relax::DataTypeImm, " - << "but received " << arg_dtype << " with type " << sinfo; + << "but received " << arg_dtype << " with type " << ty; } }(); auto view_relative_byte_offset = [&]() -> ffi::Optional { - StructInfo sinfo = GetStructInfo(arg_relative_byte_offset); + Type ty = GetType(arg_relative_byte_offset); - if (HasVoidStructInfo(arg_relative_byte_offset)) { + if (HasVoidType(arg_relative_byte_offset)) { // No byte offset is specified, so no change is applied. return IntImm::Int64(0); - } else if (auto prim_sinfo = sinfo.as()) { - TVM_FFI_CHECK_EQ(prim_sinfo->dtype, DataType::Int(64), TypeError) + } else if (auto prim_ty = ty.as()) { + TVM_FFI_CHECK_EQ(prim_ty->dtype, DataType::Int(64), TypeError) << "Operator " << call->op << " expects the relative_byte_offset to be a 64-bit integer, but received " - << arg_relative_byte_offset << ", which has type " << sinfo; - if (prim_sinfo->value.defined()) { + << arg_relative_byte_offset << ", which has type " << ty; + if (prim_ty->value.defined()) { // An offset of known value is applied. The known value may // be dynamic. - return prim_sinfo->value.value(); + return prim_ty->value.value(); } else { // An offset of unknown value is applied. return std::nullopt; @@ -149,25 +149,25 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { << "to be a Relax PrimValue. " << "However, expression " << call << " provides relative_byte_offset of " << arg_relative_byte_offset - << ", which has type " << sinfo; + << ", which has type " << ty; } }(); - ffi::Optional> input_shape = data_sinfo->GetShape(); + ffi::Optional> input_shape = data_ty->GetShape(); ffi::Optional> output_shape = std::nullopt; int output_ndim = kUnknownNDim; - if (view_shape_sinfo && view_shape_sinfo->values.defined()) { - output_shape = view_shape_sinfo->values.value(); - } else if (view_shape_sinfo) { - output_ndim = view_shape_sinfo->ndim; + if (view_shape_ty && view_shape_ty->values.defined()) { + output_shape = view_shape_ty->values.value(); + } else if (view_shape_ty) { + output_ndim = view_shape_ty->ndim; } else if (input_shape) { output_shape = input_shape; } else { - output_ndim = data_sinfo->ndim; + output_ndim = data_ty->ndim; } - DataType output_dtype = view_dtype.value_or(data_sinfo->dtype); + DataType output_dtype = view_dtype.value_or(data_ty->dtype); // Helper function, returns the number of bytes per vectorized // element. Cannot use `DataType::bytes`, as it returns the @@ -199,7 +199,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { ffi::Optional input_nelements = get_num_elements(input_shape); ffi::Optional output_nelements = get_num_elements(output_shape); - ffi::Optional input_element_size = get_size_bytes(data_sinfo->dtype); + ffi::Optional input_element_size = get_size_bytes(data_ty->dtype); ffi::Optional output_element_size = get_size_bytes(output_dtype); if (input_nelements && output_nelements && input_element_size && output_element_size && @@ -217,9 +217,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { TVM_FFI_THROW(ValueError) << "Views into an array must not exceed the bounds of the array being viewed. " << "However, expression " << call << " attempted to create view of type " - << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype) + << TensorType(ShapeExpr(output_shape.value()), output_dtype) << " with relative byte offset " << view_relative_byte_offset - << ", viewing into the array " << arg_data << " of type " << data_sinfo << ". " + << ", viewing into the array " << arg_data << " of type " << data_ty << ". " << "The end of the view would occur at byte " << view_end << ", relative to the start of array " << arg_data << ", but " << arg_data << " is only " << input_nbytes << " long."; @@ -239,13 +239,13 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { TVM_FFI_THROW(ValueError) << "Views into an array must not exceed the bounds of the array being viewed. " << "However, expression " << call << " attempted to create view of type " - << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype) - << " from input array of type " << data_sinfo << ". " + << TensorType(ShapeExpr(output_shape.value()), output_dtype) + << " from input array of type " << data_ty << ". " << "This view would increase the size from " << output_nbytes << " bytes to " << output_nbytes << " bytes."; } - } else if (input_element_size && output_element_size && !view_shape_sinfo) { + } else if (input_element_size && output_element_size && !view_shape_ty) { // The output view has a known dtype, which is different from the // known dtype of the input array. Because the view's shape is // the same as the original array, when counted in number of @@ -256,9 +256,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { ValueError) << "Operator " << call->op << " may not produce a view that exceeds the bounds of the original array. " - << "In expression " << call << " the data type is changed from " << data_sinfo->dtype - << " to " << view_dtype.value() << ", increasing the size per element from " - << input_element_size << " bytes to " << output_element_size << " bytes. " + << "In expression " << call << " the data type is changed from " << data_ty->dtype << " to " + << view_dtype.value() << ", increasing the size per element from " << input_element_size + << " bytes to " << output_element_size << " bytes. " << "Consider providing a new shape for the R.view."; } else if (input_nelements && output_nelements && !view_dtype) { // The shape is being updated, while keeping the datatype the @@ -275,7 +275,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { << " (shape = " << input_shape << ", " << input_nelements << " elements) as shape " << output_shape << " with " << output_nelements << " elements."; } - } else if (view_relative_byte_offset && !view_shape_sinfo && !view_dtype) { + } else if (view_relative_byte_offset && !view_shape_ty && !view_dtype) { // The byte_offset is being updated, but neither the shape nor the // dtype is changing. Any non-zero offset will cause the view to // overrun the bounds of the original array. @@ -290,15 +290,15 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } if (output_shape.defined()) { - return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(output_shape.value()), output_dtype, data_ty->vdevice); } else { - return TensorStructInfo(output_dtype, output_ndim, data_sinfo->vdevice); + return TensorType(output_dtype, output_ndim, data_ty->vdevice); } } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tvm.relax.struct_info.infer_view_sinfo", InferStructInfoView); + refl::GlobalDef().def("tvm.relax.type.infer_view_ty", InferTypeView); } Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { @@ -307,8 +307,7 @@ Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr dtype = call->args[2]; Expr relative_byte_offset = call->args[3]; - if (HasVoidStructInfo(shape) && HasVoidStructInfo(dtype) && - HasVoidStructInfo(relative_byte_offset)) { + if (HasVoidType(shape) && HasVoidType(dtype) && HasVoidType(relative_byte_offset)) { // Special-case, no change is required by the view. return data; } @@ -318,37 +317,37 @@ Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { // a pass updates the input `data` tensor. However, when we // legalize the `R.view`, we must provide an explicit parameters. - if (HasVoidStructInfo(shape)) { - auto data_shape = data->struct_info_.as().value()->GetShape(); + if (HasVoidType(shape)) { + auto data_shape = data->ty.as().value()->GetShape(); TVM_FFI_ICHECK(data_shape.defined()) << "Legalization of " << call->op << " requires that either the output shape be explicitly specified, " << "or the input shape is known. " << "However, in expression " << call << ", no output shape is specified, " - << "and the input " << data << " of type " << data->struct_info_ << " has unknown shape."; + << "and the input " << data << " of type " << data->ty << " has unknown shape."; shape = ShapeExpr(data_shape.value()); } - if (HasVoidStructInfo(dtype)) { - auto data_dtype = data->struct_info_.as().value()->dtype; + if (HasVoidType(dtype)) { + auto data_dtype = data->ty.as().value()->dtype; TVM_FFI_ICHECK(!data_dtype.is_void()) << "Legalization of " << call->op << " requires that either the output dtype be explicitly specified, " << "or the input dtype is known. " << "However, in expression " << call << ", no output dtype is specified, " - << "and the input " << data << " of type " << data->struct_info_ << " has unknown dtype."; + << "and the input " << data << " of type " << data->ty << " has unknown dtype."; dtype = relax::DataTypeImm(data_dtype); } - if (HasVoidStructInfo(relative_byte_offset)) { + if (HasVoidType(relative_byte_offset)) { relative_byte_offset = relax::PrimValue::Int64(0); } - StructInfoDeriveFunc infer_sinfo_env_func; - infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); - auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); + TypeDeriveFunc infer_ty_env_func; + infer_ty_env_func = EnvFunc::Get("tvm.relax.type.infer_view_ty"); + auto runtime_view_ty = FuncType::OpaqueFunc(infer_ty_env_func, true); - ExternFunc runtime_view_func("runtime.TVMTensorCreateView", runtime_view_sinfo); + ExternFunc runtime_view_func("runtime.TVMTensorCreateView", runtime_view_ty); return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); } @@ -361,7 +360,7 @@ TVM_REGISTER_OP("relax.memory.view") .add_argument("relative_byte_offset", "Prim(\"int64\")", "The view's byte offset, relative to the input tensor's byte offset.") .set_attr("RequiresArgumentShapes", false) - .set_attr("FInferStructInfo", InferStructInfoView) + .set_attr("FInferType", InferTypeView) .set_attr("FPurity", true) .set_attr("FLowerBuiltin", LowerBuiltinView); @@ -375,25 +374,25 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.memory.ensure_zero_offset", ensure_zero_offset); } -StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { +Type InferTypeEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "Operator " << call->op << " should receive 1 argument, " << "but received " << call->args; } - return GetStructInfo(call->args[0]); + return GetType(call->args[0]); } Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) { const ExternFunc builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"}; - return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetStructInfo(call)}); + return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetType(call)}); } TVM_REGISTER_OP("relax.memory.ensure_zero_offset") .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") .set_attr("RequiresArgumentShapes", false) - .set_attr("FInferStructInfo", InferStructInfoEnsureZeroOffset) + .set_attr("FInferType", InferTypeEnsureZeroOffset) .set_attr("FPurity", true) .set_attr("FLowerBuiltin", LowerBuiltinEnsureZeroOffset); diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 5b11cb517bef..58035c75ac2f 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -66,24 +66,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("relax.op.nn.attention_var_len", attention_var_len); } -StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo q_sinfo = input_sinfo[0]; - TensorStructInfo k_sinfo = input_sinfo[1]; - TensorStructInfo v_sinfo = input_sinfo[2]; - auto diag_dim = [&](TensorStructInfo sinfo, ffi::String name) { - if (sinfo->ndim != 4) { +Type InferTypeAttention(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType q_ty = input_ty[0]; + TensorType k_ty = input_ty[1]; + TensorType v_ty = input_ty[2]; + auto diag_dim = [&](TensorType ty, ffi::String name) { + if (ty->ndim != 4) { TVM_FFI_VISIT_THROW(ValueError, call) << "The " << name << " should have 4 dimension, namely " << "[batch size, sequence length, number of heads, dimension of heads]."; } }; - diag_dim(q_sinfo, "query"); - diag_dim(k_sinfo, "key"); - diag_dim(v_sinfo, "value"); - const ShapeExprNode* q_shape = q_sinfo->shape.as(); - const ShapeExprNode* k_shape = k_sinfo->shape.as(); - const ShapeExprNode* v_shape = v_sinfo->shape.as(); + diag_dim(q_ty, "query"); + diag_dim(k_ty, "key"); + diag_dim(v_ty, "value"); + const ShapeExprNode* q_shape = q_ty->shape.as(); + const ShapeExprNode* k_shape = k_ty->shape.as(); + const ShapeExprNode* v_shape = v_ty->shape.as(); PrimExpr num_batches = q_shape->values[0]; PrimExpr num_queries = q_shape->values[1]; PrimExpr num_heads = q_shape->values[2]; @@ -116,13 +116,13 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { diag_equal(num_keys, v_shape->values[1], "key", "value", "sequence length"); diag_equal(head_dim, k_shape->values[3], "query", "key", "dimension of heads"); - if (input_sinfo.size() == 4) { - TensorStructInfo bias_sinfo = input_sinfo[3]; - const ShapeExprNode* bias_shape = bias_sinfo->shape.as(); - if (bias_sinfo->ndim != 4) { + if (input_ty.size() == 4) { + TensorType bias_ty = input_ty[3]; + const ShapeExprNode* bias_shape = bias_ty->shape.as(); + if (bias_ty->ndim != 4) { TVM_FFI_VISIT_THROW(ValueError, call) << "The bias should have 4 dimensions." - << "However, the bias input has " << bias_sinfo->ndim << " dimensions."; + << "However, the bias input has " << bias_ty->ndim << " dimensions."; } auto diag_equal_or_broadcast = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, ffi::String dim) { @@ -140,7 +140,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { } ffi::Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; - return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype, q_sinfo->vdevice); + return TensorType(ShapeExpr(output_shape), q_ty->dtype, q_ty->vdevice); } Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) { @@ -156,7 +156,7 @@ TVM_REGISTER_OP("relax.nn.attention") .add_argument("value", "Tensor", "The input values tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) - .set_attr("FInferStructInfo", InferStructInfoAttention) + .set_attr("FInferType", InferTypeAttention) .set_attr("FPurity", true); TVM_REGISTER_OP("relax.nn.attention_bias") @@ -168,7 +168,7 @@ TVM_REGISTER_OP("relax.nn.attention_bias") .add_argument("bias", "Tensor", "The input bias tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) - .set_attr("FInferStructInfo", InferStructInfoAttention) + .set_attr("FInferType", InferTypeAttention) .set_attr("FPurity", true); TVM_REGISTER_OP("relax.nn.attention_var_len") @@ -183,7 +183,7 @@ TVM_REGISTER_OP("relax.nn.attention_var_len") .add_argument("max_seqlen_k", "Tensor", "The maximum key sequence length in the batch.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) - .set_attr("FInferStructInfo", InferStructInfoAttention) + .set_attr("FInferType", InferTypeAttention) .set_attr("FPurity", true); TVM_FFI_STATIC_INIT_BLOCK() { AttentionAttrs::RegisterReflection(); } diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 9fe4a8e84da1..bc354b6804f3 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -70,10 +70,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.conv1d", conv1d); } -StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo weight_sinfo = input_sinfo[1]; +Type InferTypeConv1d(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType data_ty = input_ty[0]; + TensorType weight_ty = input_ty[1]; const auto* attrs = call->attrs.as(); auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->data_layout, // @@ -87,17 +87,16 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); ffi::Optional weight_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) : attrs->out_dtype; - ffi::Optional vdevice = - InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { - return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); + return TensorType(out_dtype, out_layout.ndim(), vdevice); } ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); @@ -141,7 +140,7 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { analyzer->Simplify(floordiv(numerator_w, IntImm::Int32(attrs->strides[0])) + 1); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, vdevice); } InferLayoutOutput InferLayoutConv1d( @@ -200,7 +199,7 @@ TVM_REGISTER_OP("relax.nn.conv1d") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoConv1d) + .set_attr("FInferType", InferTypeConv1d) .set_attr("FRelaxInferLayout", InferLayoutConv1d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionConv1d) @@ -240,10 +239,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.conv2d", conv2d); } -StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo weight_sinfo = input_sinfo[1]; +Type InferTypeConv2d(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType data_ty = input_ty[0]; + TensorType weight_ty = input_ty[1]; const auto* attrs = call->attrs.as(); auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->data_layout, // @@ -257,17 +256,16 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); ffi::Optional weight_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) : attrs->out_dtype; - ffi::Optional vdevice = - InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { - return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); + return TensorType(out_dtype, out_layout.ndim(), vdevice); } ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); @@ -318,7 +316,7 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { analyzer->Simplify(floordiv(numerator_w, IntImm::Int32(attrs->strides[1])) + 1); ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, vdevice); } InferLayoutOutput InferLayoutConv2d( @@ -355,14 +353,14 @@ InferLayoutOutput InferLayoutConv2d( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } else { // Layout Transform - auto data_si = GetStructInfo(call->args[0]); - auto kernel_si = GetStructInfo(call->args[1]); - TensorStructInfo data_sinfo = data_si.as().value(); - TensorStructInfo kernel_sinfo = kernel_si.as().value(); + auto data_si = GetType(call->args[0]); + auto kernel_si = GetType(call->args[1]); + TensorType data_ty = data_si.as().value(); + TensorType kernel_ty = kernel_si.as().value(); ffi::Optional data_shape = - ffi::GetRef(data_sinfo->shape.as()); + ffi::GetRef(data_ty->shape.as()); ffi::Optional kernel_shape = - ffi::GetRef(kernel_sinfo->shape.as()); + ffi::GetRef(kernel_ty->shape.as()); bool can_data_proved = CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); @@ -411,7 +409,7 @@ TVM_REGISTER_OP("relax.nn.conv2d") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoConv2d) + .set_attr("FInferType", InferTypeConv2d) .set_attr("FRelaxInferLayout", InferLayoutConv2d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionConv2d) @@ -453,10 +451,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.conv3d", conv3d); } -StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo weight_sinfo = input_sinfo[1]; +Type InferTypeConv3d(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType data_ty = input_ty[0]; + TensorType weight_ty = input_ty[1]; const auto* attrs = call->attrs.as(); auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, attrs->data_layout, // @@ -470,17 +468,16 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); ffi::Optional weight_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) : attrs->out_dtype; - ffi::Optional vdevice = - InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { - return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); + return TensorType(out_dtype, out_layout.ndim(), vdevice); } ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); @@ -538,7 +535,7 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { analyzer->Simplify(floordiv(numerator_w, IntImm::Int32(attrs->strides[2])) + 1); ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, vdevice); } InferLayoutOutput InferLayoutConv3d( @@ -597,7 +594,7 @@ TVM_REGISTER_OP("relax.nn.conv3d") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoConv3d) + .set_attr("FInferType", InferTypeConv3d) .set_attr("FRelaxInferLayout", InferLayoutConv3d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionConv3d) @@ -643,10 +640,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.conv1d_transpose", conv1d_transpose); } -StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo weight_sinfo = input_sinfo[1]; +Type InferTypeConv1dTranspose(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType data_ty = input_ty[0]; + TensorType weight_ty = input_ty[1]; const auto* attrs = call->attrs.as(); auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->data_layout, // @@ -659,17 +656,16 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& /*tgt_layout=*/"NCW", // /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); ffi::Optional weight_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) : attrs->out_dtype; - ffi::Optional vdevice = - InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { - return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); + return TensorType(out_dtype, out_layout.ndim(), vdevice); } ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); @@ -721,7 +717,7 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& out_NCW_shape[2] = analyzer->Simplify(out_w); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, vdevice); } InferLayoutOutput InferLayoutConv1dTranspose( @@ -777,7 +773,7 @@ TVM_REGISTER_OP("relax.nn.conv1d_transpose") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoConv1dTranspose) + .set_attr("FInferType", InferTypeConv1dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv1dTranspose) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionConv1dTranspose) @@ -834,10 +830,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.conv2d_transpose", conv2d_transpose); } -StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo weight_sinfo = input_sinfo[1]; +Type InferTypeConv2dTranspose(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType data_ty = input_ty[0]; + TensorType weight_ty = input_ty[1]; const auto* attrs = call->attrs.as(); auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->data_layout, // @@ -851,17 +847,16 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); ffi::Optional weight_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) : attrs->out_dtype; - ffi::Optional vdevice = - InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { - return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); + return TensorType(out_dtype, out_layout.ndim(), vdevice); } ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); @@ -918,7 +913,7 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& out_NCHW_shape[3] = analyzer->Simplify(out_w); ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, vdevice); } InferLayoutOutput InferLayoutConv2dTranspose( @@ -951,14 +946,14 @@ InferLayoutOutput InferLayoutConv2dTranspose( new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } else { - auto data_si = GetStructInfo(call->args[0]); - auto kernel_si = GetStructInfo(call->args[1]); - TensorStructInfo data_sinfo = data_si.as().value(); - TensorStructInfo kernel_sinfo = kernel_si.as().value(); + auto data_si = GetType(call->args[0]); + auto kernel_si = GetType(call->args[1]); + TensorType data_ty = data_si.as().value(); + TensorType kernel_ty = kernel_si.as().value(); ffi::Optional data_shape = - ffi::GetRef(data_sinfo->shape.as()); + ffi::GetRef(data_ty->shape.as()); ffi::Optional kernel_shape = - ffi::GetRef(kernel_sinfo->shape.as()); + ffi::GetRef(kernel_ty->shape.as()); bool can_data_proved = CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); @@ -1006,7 +1001,7 @@ TVM_REGISTER_OP("relax.nn.conv2d_transpose") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoConv2dTranspose) + .set_attr("FInferType", InferTypeConv2dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv2dTranspose) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionConv2dTranspose) @@ -1066,10 +1061,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.conv3d_transpose", conv3d_transpose); } -StructInfo InferStructInfoConv3dTranspose(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo weight_sinfo = input_sinfo[1]; +Type InferTypeConv3dTranspose(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType data_ty = input_ty[0]; + TensorType weight_ty = input_ty[1]; const auto* attrs = call->attrs.as(); auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, attrs->data_layout, // @@ -1083,17 +1078,16 @@ StructInfo InferStructInfoConv3dTranspose(const Call& call, const BlockBuilder& /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); ffi::Optional weight_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) : attrs->out_dtype; - ffi::Optional vdevice = - InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { - return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); + return TensorType(out_dtype, out_layout.ndim(), vdevice); } ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); @@ -1158,7 +1152,7 @@ StructInfo InferStructInfoConv3dTranspose(const Call& call, const BlockBuilder& out_NCDHW_shape[4] = analyzer->Simplify(out_w); ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, vdevice); } InferLayoutOutput InferLayoutConv3dTranspose( @@ -1191,14 +1185,14 @@ InferLayoutOutput InferLayoutConv3dTranspose( new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } else { - auto data_si = GetStructInfo(call->args[0]); - auto kernel_si = GetStructInfo(call->args[1]); - TensorStructInfo data_sinfo = data_si.as().value(); - TensorStructInfo kernel_sinfo = kernel_si.as().value(); + auto data_si = GetType(call->args[0]); + auto kernel_si = GetType(call->args[1]); + TensorType data_ty = data_si.as().value(); + TensorType kernel_ty = kernel_si.as().value(); ffi::Optional data_shape = - ffi::GetRef(data_sinfo->shape.as()); + ffi::GetRef(data_ty->shape.as()); ffi::Optional kernel_shape = - ffi::GetRef(kernel_sinfo->shape.as()); + ffi::GetRef(kernel_ty->shape.as()); bool can_data_proved = CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); @@ -1246,7 +1240,7 @@ TVM_REGISTER_OP("relax.nn.conv3d_transpose") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoConv3dTranspose) + .set_attr("FInferType", InferTypeConv3dTranspose) .set_attr("FRelaxInferLayout", InferLayoutConv3dTranspose) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionConv3dTranspose) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 0363da159b03..f1ac80bcddb1 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -78,8 +78,7 @@ TVM_REGISTER_OP("relax.nn.leakyrelu") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", - InferStructInfoUnaryArith) + .set_attr("FInferType", InferTypeUnaryArith) .set_attr("FPurity", true); /* relax.nn.softplus */ @@ -101,8 +100,7 @@ TVM_REGISTER_OP("relax.nn.softplus") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", - InferStructInfoUnaryArith) + .set_attr("FInferType", InferTypeUnaryArith) .set_attr("FPurity", true); /* relax.nn.prelu */ @@ -119,20 +117,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.prelu", prelu); } -StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (data_sinfo->IsUnknownNdim()) { - return data_sinfo; +Type InferTypePRelu(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); + if (data_ty->IsUnknownNdim()) { + return data_ty; } - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Prelu requires the input tensor to have float " "dtype. However, the given input dtype is " - << data_sinfo->dtype; + << data_ty->dtype; } const auto* attrs = call->attrs.as(); - NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + NormalizeAxis(call, ctx, data_ty->ndim, attrs->axis); - return data_sinfo; + return data_ty; } InferLayoutOutput InferLayoutPRelu( @@ -146,10 +144,10 @@ InferLayoutOutput InferLayoutPRelu( // TODO(Siva): We could handle if the axis is not the sub indexed one. if (layout->layout.ndim() != layout->layout.ndim_primal()) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - int ndim = tensor_sinfo->ndim; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_ty->ndim; layout = LayoutDecision(InitialLayout(ndim)); } @@ -165,7 +163,7 @@ TVM_REGISTER_OP("relax.nn.prelu") .add_argument("data", "Tensor", "The input tensor.") .add_argument("alpha", "Tensor", "The channel-wise learnable slope.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPRelu) + .set_attr("FInferType", InferTypePRelu) .set_attr("FRelaxInferLayout", InferLayoutPRelu) .set_attr("FPurity", true); @@ -183,21 +181,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.softmax", softmax); } -StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (data_sinfo->IsUnknownNdim()) { - return data_sinfo; +Type InferTypeSoftmax(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); + if (data_ty->IsUnknownNdim()) { + return data_ty; } - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() && - !data_sinfo->dtype.is_bfloat()) { + if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float() && !data_ty->dtype.is_bfloat()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Softmax requires the input tensor to have float " "dtype. However, the given input dtype is " - << data_sinfo->dtype; + << data_ty->dtype; } const auto* attrs = call->attrs.as(); - NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + NormalizeAxis(call, ctx, data_ty->ndim, attrs->axis); - return data_sinfo; + return data_ty; } InferLayoutOutput InferLayoutSoftmax( @@ -211,10 +208,10 @@ InferLayoutOutput InferLayoutSoftmax( // TODO(Siva): We could handle if the axis is not the sub indexed one. if (layout->layout.ndim() != layout->layout.ndim_primal()) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - int ndim = tensor_sinfo->ndim; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_ty->ndim; layout = LayoutDecision(InitialLayout(ndim)); } @@ -227,7 +224,7 @@ TVM_REGISTER_OP("relax.nn.softmax") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoSoftmax) + .set_attr("FInferType", InferTypeSoftmax) .set_attr("FRelaxInferLayout", InferLayoutSoftmax) .set_attr("FPurity", true); @@ -248,7 +245,7 @@ TVM_REGISTER_OP("relax.nn.log_softmax") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoSoftmax) + .set_attr("FInferType", InferTypeSoftmax) .set_attr("FPurity", true); /* relax.nn.pad */ @@ -267,17 +264,17 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.pad", pad); } -StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); +Type InferTypePad(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - int ndim = input_sinfo[0]->ndim; + int ndim = input_ty[0]->ndim; ffi::Array pad_width = attrs->pad_width; TVM_FFI_ICHECK(static_cast(pad_width.size()) == 2 * ndim) << "Illegal pad_width"; ffi::Array out_shape; - if (input_sinfo[0]->shape.defined()) { + if (input_ty[0]->shape.defined()) { // Compute output shape by adding corresponding pad width to each axis. - const auto* data_shape = input_sinfo[0]->shape.as(); + const auto* data_shape = input_ty[0]->shape.as(); for (int i = 0; i < ndim; i++) { // Sum pad width for this axis. PrimExpr added_width = IntImm::Int64(pad_width[2 * i] + pad_width[(2 * i) + 1]); @@ -286,16 +283,16 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { } } else { // Shape isnt defined, best we can do is return ndim and dtype. - return TensorStructInfo(input_sinfo[0]->dtype, ndim); + return TensorType(input_ty[0]->dtype, ndim); } - return TensorStructInfo(ShapeExpr(out_shape), input_sinfo[0]->dtype); + return TensorType(ShapeExpr(out_shape), input_ty[0]->dtype); } TVM_REGISTER_OP("relax.nn.pad") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPad) + .set_attr("FInferType", InferTypePad) .set_attr("FPurity", true); /* relax.nn.pixel_shuffle */ @@ -312,18 +309,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.pixel_shuffle", pixel_shuffle); } -StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); +Type InferTypePixelShuffle(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); int r = attrs->upscale_factor; TVM_FFI_ICHECK_GT(r, 0) << "Upscale factor must be positive"; - const TensorStructInfo& input = input_sinfo[0]; + const TensorType& input = input_ty[0]; int ndim = input->ndim; TVM_FFI_ICHECK_GE(ndim, 3) << "PixelShuffle requires at least 3D input tensor"; if (!input->shape.defined()) { - return TensorStructInfo(input->dtype, ndim); + return TensorType(input->dtype, ndim); } const auto* shape = input->shape.as(); @@ -360,54 +357,52 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx } } - return TensorStructInfo(ShapeExpr(out_shape), input->dtype); + return TensorType(ShapeExpr(out_shape), input->dtype); } TVM_REGISTER_OP("relax.nn.pixel_shuffle") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPixelShuffle) + .set_attr("FInferType", InferTypePixelShuffle) .set_attr("FPurity", true); /* relax.nn.batchnorm */ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, - const ffi::Array& input_sinfo, - ffi::Array axes) { + const ffi::Array& input_ty, ffi::Array axes) { Op op = Downcast(call->op); int n_input = op->arguments.size(); - TensorStructInfo data_sinfo = input_sinfo[0]; + TensorType data_ty = input_ty[0]; std::vector axes_non_neg; - if (!data_sinfo->IsUnknownNdim()) { - axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes); + if (!data_ty->IsUnknownNdim()) { + axes_non_neg = NormalizeAxes(call, ctx, data_ty->ndim, axes); } int n_axis = axes.size(); - if (!data_sinfo->IsUnknownDtype() && - (!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) { + if (!data_ty->IsUnknownDtype() && (!data_ty->dtype.is_float() && !data_ty->dtype.is_bfloat())) { TVM_FFI_VISIT_THROW(TypeError, call) << op << " requires the input data to have float dtype. However, the given data dtype is " - << data_sinfo->dtype; + << data_ty->dtype; } for (int i = 1; i < n_input; ++i) { - if (input_sinfo[i]->dtype != data_sinfo->dtype) { + if (input_ty[i]->dtype != data_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << op << " requires all the input tensors to have the same dtype. However, the " - << op->arguments[i]->name << " has dtype " << input_sinfo[i]->dtype - << " which is other than the input data's dtype " << data_sinfo->dtype; - } else if (input_sinfo[i]->ndim != n_axis) { + << op->arguments[i]->name << " has dtype " << input_ty[i]->dtype + << " which is other than the input data's dtype " << data_ty->dtype; + } else if (input_ty[i]->ndim != n_axis) { TVM_FFI_VISIT_THROW(ValueError, call) << op << " requires the input " << op->arguments[i]->name << " to have as many dimensions as the length of input axes. However, the " "given one has ndim " - << input_sinfo[i]->ndim << ", which is other than the length of axes " << n_axis; + << input_ty[i]->ndim << ", which is other than the length of axes " << n_axis; } } std::vector> axis_lengths; axis_lengths.reserve(n_input); - if (const auto* data_shape = data_sinfo->shape.as()) { + if (const auto* data_shape = data_ty->shape.as()) { std::vector lengths; lengths.reserve(n_axis); for (int d = 0; d < n_axis; ++d) { @@ -416,7 +411,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, axis_lengths.push_back(lengths); } for (int i = 1; i < n_input; ++i) { - if (const auto* shape = input_sinfo[i]->shape.as()) { + if (const auto* shape = input_ty[i]->shape.as()) { axis_lengths.push_back(shape->values); } } @@ -461,20 +456,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.batch_norm", batch_norm); } -StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); +Type InferTypeBatchNorm(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, {attrs->axis}); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_ty, {attrs->axis}); - DataType dtype = input_sinfo[0]->dtype; + DataType dtype = input_ty[0]->dtype; if (unknown_shape) { - auto vdev = input_sinfo[0]->vdevice; - return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim, vdev), - TensorStructInfo(dtype, /*ndim=*/1, vdev), - TensorStructInfo(dtype, /*ndim=*/1, vdev)}); + auto vdev = input_ty[0]->vdevice; + return TupleType({TensorType(dtype, input_ty[0]->ndim, vdev), + TensorType(dtype, /*ndim=*/1, vdev), TensorType(dtype, /*ndim=*/1, vdev)}); } else { - return TupleStructInfo({input_sinfo[0], input_sinfo[3], input_sinfo[4]}); + return TupleType({input_ty[0], input_ty[3], input_ty[4]}); } } @@ -484,10 +478,10 @@ InferLayoutOutput InferLayoutBatchNorm( TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 5; ++i) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; - initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + const auto* tensor_ty = GetTypeAs(call->args[i]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_ty->ndim)); } const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -496,8 +490,8 @@ InferLayoutOutput InferLayoutBatchNorm( // While dealing with sub layouts, its adviced to deal with batchnorm // on other ways like decomposing or fusion methods. // This handling is fail safe fallback. - const auto* input_sinfo = GetStructInfoAs(call->args[0]); - int ndim = input_sinfo->ndim; + const auto* input_ty = GetTypeAs(call->args[0]); + int ndim = input_ty->ndim; if (layout->layout.ndim() != layout->layout.ndim_primal()) { layout = LayoutDecision(InitialLayout(ndim)); } @@ -517,7 +511,7 @@ TVM_REGISTER_OP("relax.nn.batch_norm") .add_argument("beta", "Tensor", "The beta offset factor.") .add_argument("moving_mean", "Tensor", "Running mean of input.") .add_argument("moving_var", "Tensor", "Running variance of input.") - .set_attr("FInferStructInfo", InferStructInfoBatchNorm) + .set_attr("FInferType", InferTypeBatchNorm) .set_attr("FRelaxInferLayout", InferLayoutBatchNorm) .set_attr("FPurity", true); @@ -540,15 +534,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.layer_norm", layer_norm); } -StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); +Type InferTypeLayerNorm(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_ty, attrs->axes); - return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim, - input_sinfo[0]->vdevice) - : input_sinfo[0]; + return unknown_shape ? TensorType(input_ty[0]->dtype, input_ty[0]->ndim, input_ty[0]->vdevice) + : input_ty[0]; } InferLayoutOutput InferLayoutLayerNorm( @@ -557,18 +550,18 @@ InferLayoutOutput InferLayoutLayerNorm( TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; - initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + const auto* tensor_ty = GetTypeAs(call->args[i]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_ty->ndim)); } const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ffi::ObjectPtr new_attrs = ffi::make_object(*attrs); - const auto* input_sinfo = GetStructInfoAs(call->args[0]); - int ndim = input_sinfo->ndim; + const auto* input_ty = GetTypeAs(call->args[0]); + int ndim = input_ty->ndim; std::vector new_axis; for (int64_t axis : attrs->axes) { new_axis.push_back(FindAxis(layout->layout, (axis + ndim) % ndim)); @@ -584,7 +577,7 @@ TVM_REGISTER_OP("relax.nn.layer_norm") .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") - .set_attr("FInferStructInfo", InferStructInfoLayerNorm) + .set_attr("FInferType", InferTypeLayerNorm) .set_attr("FRelaxInferLayout", InferLayoutLayerNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -610,16 +603,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.group_norm", group_norm); } -StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { +Type InferTypeGroupNorm(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_ty = GetInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - TensorStructInfo data_sinfo = input_sinfo[0]; + TensorType data_ty = input_ty[0]; int channel_axis = -1; - if (!data_sinfo->IsUnknownNdim()) { - channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis); - std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + if (!data_ty->IsUnknownNdim()) { + channel_axis = NormalizeAxis(call, ctx, data_ty->ndim, attrs->channel_axis); + std::vector axes = NormalizeAxes(call, ctx, data_ty->ndim, attrs->axes); // channel_axis must be in axes. if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -627,12 +620,12 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { << channel_axis << ", axes: " << attrs->axes; } } - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float()) { TVM_FFI_VISIT_THROW(TypeError, call) - << op << " expects that data must be float, but got " << data_sinfo->dtype; + << op << " expects that data must be float, but got " << data_ty->dtype; } arith::Analyzer analyzer = ctx->GetAnalyzer(); - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape != nullptr && channel_axis != -1 && analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -640,15 +633,15 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { << ", but got " << data_shape->values[channel_axis]; } for (int i = 1; i < static_cast(op->arguments.size()); ++i) { - if (input_sinfo[i]->dtype != data_sinfo->dtype) { + if (input_ty[i]->dtype != data_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << op << " expects that all inputs must have the same dtype, but got " - << input_sinfo[i]->dtype << " and " << data_sinfo->dtype; - } else if (input_sinfo[i]->ndim != 1) { + << input_ty[i]->dtype << " and " << data_ty->dtype; + } else if (input_ty[i]->ndim != 1) { TVM_FFI_VISIT_THROW(ValueError, call) - << op << " expects that all inputs must have ndim=1, but got " << input_sinfo[i]->ndim; + << op << " expects that all inputs must have ndim=1, but got " << input_ty[i]->ndim; } else if (channel_axis != -1) { - const auto* shape = input_sinfo[i]->shape.as(); + const auto* shape = input_ty[i]->shape.as(); if (shape != nullptr && data_shape != nullptr) { PrimExpr channel_size = data_shape->values[channel_axis]; PrimExpr input_size = shape->values[0]; @@ -661,7 +654,7 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { } } } - return data_sinfo; + return data_ty; } InferLayoutOutput InferLayoutGroupNorm( @@ -670,10 +663,10 @@ InferLayoutOutput InferLayoutGroupNorm( TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; - initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + const auto* tensor_ty = GetTypeAs(call->args[i]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_ty->ndim)); } const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -696,7 +689,7 @@ TVM_REGISTER_OP("relax.nn.group_norm") .add_argument("data", "Tensor", "Input to which group_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") - .set_attr("FInferStructInfo", InferStructInfoGroupNorm) + .set_attr("FInferType", InferTypeGroupNorm) .set_attr("FRelaxInferLayout", InferLayoutGroupNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -721,17 +714,17 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.instance_norm", instance_norm); } -StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx) { +Type InferTypeInstanceNorm(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_ty = GetInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; - TensorStructInfo data_sinfo = input_sinfo[0]; + TensorType data_ty = input_ty[0]; int channel_axis = -1; - if (!data_sinfo->IsUnknownNdim()) { - channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis); - std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + if (!data_ty->IsUnknownNdim()) { + channel_axis = NormalizeAxis(call, ctx, data_ty->ndim, attrs->channel_axis); + std::vector axes = NormalizeAxes(call, ctx, data_ty->ndim, attrs->axes); // channel_axis must not be in axes. if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -739,18 +732,18 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx << channel_axis << ", axes: " << attrs->axes; } } - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); arith::Analyzer analyzer = ctx->GetAnalyzer(); for (int i = 1; i < static_cast(op->arguments.size()); ++i) { - if (input_sinfo[i]->dtype != data_sinfo->dtype) { + if (input_ty[i]->dtype != data_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << op << " expects that all inputs must have the same dtype, but got " - << input_sinfo[i]->dtype << " and " << data_sinfo->dtype; - } else if (input_sinfo[i]->ndim != 1) { + << input_ty[i]->dtype << " and " << data_ty->dtype; + } else if (input_ty[i]->ndim != 1) { TVM_FFI_VISIT_THROW(ValueError, call) - << op << " expects that all inputs must have ndim=1, but got " << input_sinfo[i]->ndim; + << op << " expects that all inputs must have ndim=1, but got " << input_ty[i]->ndim; } - const auto* shape = input_sinfo[i]->shape.as(); + const auto* shape = input_ty[i]->shape.as(); if (shape != nullptr && data_shape != nullptr) { PrimExpr channel_size = data_shape->values[channel_axis]; PrimExpr input_size = shape->values[0]; @@ -762,7 +755,7 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx } } } - return data_sinfo; + return data_ty; } InferLayoutOutput InferLayoutInstanceNorm( @@ -771,10 +764,10 @@ InferLayoutOutput InferLayoutInstanceNorm( TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; - initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + const auto* tensor_ty = GetTypeAs(call->args[i]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_ty->ndim)); } const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -797,7 +790,7 @@ TVM_REGISTER_OP("relax.nn.instance_norm") .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") - .set_attr("FInferStructInfo", InferStructInfoInstanceNorm) + .set_attr("FInferType", InferTypeInstanceNorm) .set_attr("FRelaxInferLayout", InferLayoutInstanceNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -817,15 +810,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.rms_norm", rms_norm); } -StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); +Type InferTypeRMSNorm(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_ty, attrs->axes); - return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim, - input_sinfo[0]->vdevice) - : input_sinfo[0]; + return unknown_shape ? TensorType(input_ty[0]->dtype, input_ty[0]->ndim, input_ty[0]->vdevice) + : input_ty[0]; } InferLayoutOutput InferLayoutRMSNorm( @@ -834,10 +826,10 @@ InferLayoutOutput InferLayoutRMSNorm( TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 2; ++i) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; - initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + const auto* tensor_ty = GetTypeAs(call->args[i]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_ty->ndim)); } const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -857,7 +849,7 @@ TVM_REGISTER_OP("relax.nn.rms_norm") .set_num_inputs(2) .add_argument("data", "Tensor", "Input to which rms_norm will be applied.") .add_argument("weight", "Tensor", "The scale factor.") - .set_attr("FInferStructInfo", InferStructInfoRMSNorm) + .set_attr("FInferType", InferTypeRMSNorm) .set_attr("FRelaxInferLayout", InferLayoutRMSNorm) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -877,49 +869,48 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.dropout", dropout); } -StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - return TupleStructInfo({data_sinfo, data_sinfo}); +Type InferTypeDropout(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); + return TupleType({data_ty, data_ty}); } TVM_REGISTER_OP("relax.nn.dropout") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "Input to which dropout will be applied.") - .set_attr("FInferStructInfo", InferStructInfoDropout) + .set_attr("FInferType", InferTypeDropout) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); /* relax.nn.cross_entropy_with_logits */ -StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo pred_sinfo = input_sinfo[0]; - TensorStructInfo label_sinfo = input_sinfo[1]; +Type InferTypeCrossEntropy(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType pred_ty = input_ty[0]; + TensorType label_ty = input_ty[1]; // infer dtype - DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo); + DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_ty, label_ty); // infer vdevice - ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_sinfo, label_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_ty, label_ty); // infer ndim - if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() && - pred_sinfo->ndim != label_sinfo->ndim) { + if (!pred_ty->IsUnknownNdim() && !label_ty->IsUnknownNdim() && pred_ty->ndim != label_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "CrossEntropy requires predictions and labels to have the same ndim. " "However, the ndim of predictions is " - << pred_sinfo->ndim << " while the ndim of labels is " << label_sinfo->ndim; + << pred_ty->ndim << " while the ndim of labels is " << label_ty->ndim; } ffi::Optional> pred_shape_value; - if (pred_sinfo->shape.defined()) { - pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; + if (pred_ty->shape.defined()) { + pred_shape_value = GetTypeAs(pred_ty->shape.value())->values; } ffi::Optional> label_shape_value; - if (label_sinfo->shape.defined()) { - label_shape_value = GetStructInfoAs(label_sinfo->shape.value())->values; + if (label_ty->shape.defined()) { + label_shape_value = GetTypeAs(label_ty->shape.value())->values; } if (pred_shape_value.defined() && label_shape_value.defined()) { @@ -934,7 +925,7 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx } } } - return TensorStructInfo(ShapeExpr(ffi::Array()), dtype, vdevice); + return TensorType(ShapeExpr(ffi::Array()), dtype, vdevice); } Expr cross_entropy_with_logits(Expr predictions, Expr labels) { @@ -951,7 +942,7 @@ TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") .set_num_inputs(2) .add_argument("predictions", "Tensor", "The predictions.") .add_argument("labels", "Tensor", "The labels.") - .set_attr("FInferStructInfo", InferStructInfoCrossEntropy) + .set_attr("FInferType", InferTypeCrossEntropy) .set_attr("FPurity", true); /* relax.nn.nll_loss */ @@ -982,66 +973,66 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.nll_loss", nll_loss); } -StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { +Type InferTypeNLLLoss(const Call& call, const BlockBuilder& ctx) { if (call->args.size() < 2 || call->args.size() > 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "NLLLoss op should take 2 or 3 arguments"; } - const auto* pred_sinfo = GetStructInfoAs(call->args[0]); - const auto* tgt_sinfo = GetStructInfoAs(call->args[1]); - const TensorStructInfoNode* wgt_sinfo = nullptr; + const auto* pred_ty = GetTypeAs(call->args[0]); + const auto* tgt_ty = GetTypeAs(call->args[1]); + const TensorTypeNode* wgt_ty = nullptr; if (call->args.size() == 3) { - wgt_sinfo = GetStructInfoAs(call->args[2]); - if (wgt_sinfo == nullptr) { + wgt_ty = GetTypeAs(call->args[2]); + if (wgt_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "NLLLoss requires the argument weights to be Tensor. However, the given one is " - << call->args[2]->struct_info_->GetTypeKey(); + << call->args[2]->ty->GetTypeKey(); } } - if (pred_sinfo == nullptr) { + if (pred_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "NLLLoss requires the argument preditions to be Tensor. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (tgt_sinfo == nullptr) { + if (tgt_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "NLLLoss requires the argument targets to be Tensor. However, the given one is " - << call->args[1]->struct_info_->GetTypeKey(); + << call->args[1]->ty->GetTypeKey(); } // infer dtype, vdevice DataType output_dtype; ffi::Optional vdevice; - if (wgt_sinfo != nullptr) { - output_dtype = InferBinaryArithOpOutDtype(call, ctx, ffi::GetRef(pred_sinfo), - ffi::GetRef(wgt_sinfo)); - vdevice = InferBinaryArithOpOutVDevice(call, ctx, ffi::GetRef(pred_sinfo), - ffi::GetRef(wgt_sinfo)); + if (wgt_ty != nullptr) { + output_dtype = InferBinaryArithOpOutDtype(call, ctx, ffi::GetRef(pred_ty), + ffi::GetRef(wgt_ty)); + vdevice = InferBinaryArithOpOutVDevice(call, ctx, ffi::GetRef(pred_ty), + ffi::GetRef(wgt_ty)); } else { - output_dtype = pred_sinfo->dtype; - vdevice = pred_sinfo->vdevice; + output_dtype = pred_ty->dtype; + vdevice = pred_ty->vdevice; } // the type of targets must be int/uint. - if (!tgt_sinfo->IsUnknownDtype() && !tgt_sinfo->dtype.is_int() && !tgt_sinfo->dtype.is_uint()) { + if (!tgt_ty->IsUnknownDtype() && !tgt_ty->dtype.is_int() && !tgt_ty->dtype.is_uint()) { TVM_FFI_VISIT_THROW(TypeError, call) << "NLLLoss expects the dtype of targets to be int/uint. However, the dtype of targets is " - << tgt_sinfo->dtype; + << tgt_ty->dtype; } // infer ndim int K = kUnknownNDim; // k dim - if (!pred_sinfo->IsUnknownNdim()) { - if (pred_sinfo->ndim < 1) { + if (!pred_ty->IsUnknownNdim()) { + if (pred_ty->ndim < 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "NLLLoss expects the ndim of predictions >= 1. However, the ndim of predictions is " - << pred_sinfo->ndim; + << pred_ty->ndim; } - K = pred_sinfo->ndim <= 2 ? 0 : pred_sinfo->ndim - 2; + K = pred_ty->ndim <= 2 ? 0 : pred_ty->ndim - 2; } - if (!tgt_sinfo->IsUnknownNdim()) { - int K_tgt = tgt_sinfo->ndim <= 1 ? 0 : tgt_sinfo->ndim - 1; + if (!tgt_ty->IsUnknownNdim()) { + int K_tgt = tgt_ty->ndim <= 1 ? 0 : tgt_ty->ndim - 1; if (K != kUnknownNDim && K != K_tgt) { TVM_FFI_VISIT_THROW(ValueError, call) << "NLLLoss expects number of dimensions K inferred from different " @@ -1049,10 +1040,10 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { << K << " while K from targets is " << K_tgt; } } - if (wgt_sinfo != nullptr && !wgt_sinfo->IsUnknownNdim() && wgt_sinfo->ndim != 1) { + if (wgt_ty != nullptr && !wgt_ty->IsUnknownNdim() && wgt_ty->ndim != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "NLLLoss expects the ndim of weights == 1. However, the ndim of weights is " - << wgt_sinfo->ndim; + << wgt_ty->ndim; } arith::Analyzer analyzer = ctx->GetAnalyzer(); @@ -1061,18 +1052,18 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { ffi::Array output_shape; // N, d1, d2, ..., dk ffi::Optional> pred_shape_value; - if (pred_sinfo->shape.defined()) { - pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; + if (pred_ty->shape.defined()) { + pred_shape_value = GetTypeAs(pred_ty->shape.value())->values; } if (pred_shape_value.defined()) { if (pred_shape_value.value().size() == 1) { // (C,) - TVM_FFI_ICHECK(pred_sinfo->ndim == 1); + TVM_FFI_ICHECK(pred_ty->ndim == 1); C = pred_shape_value.value()[0]; } else { // (N, C, d1, d2, ..., dk) TVM_FFI_ICHECK(pred_shape_value.value().size() >= 2); - TVM_FFI_ICHECK(pred_sinfo->ndim == static_cast(pred_shape_value.value().size())); + TVM_FFI_ICHECK(pred_ty->ndim == static_cast(pred_shape_value.value().size())); N = pred_shape_value.value()[0]; C = pred_shape_value.value()[1]; output_shape = ffi::Array(); @@ -1084,13 +1075,13 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } ffi::Optional> tgt_shape_value; - if (tgt_sinfo->shape.defined()) { - tgt_shape_value = GetStructInfoAs(tgt_sinfo->shape.value())->values; + if (tgt_ty->shape.defined()) { + tgt_shape_value = GetTypeAs(tgt_ty->shape.value())->values; } if (tgt_shape_value.defined()) { if (tgt_shape_value.value().empty()) { // () - TVM_FFI_ICHECK(tgt_sinfo->ndim == 0); + TVM_FFI_ICHECK(tgt_ty->ndim == 0); if (N.defined()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Shape mismatch for NLLLoss. Predictions shape is " "(N, C, ...) while targets is a scalar"; @@ -1113,12 +1104,12 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (tgt_shape_value.value().size() == 1) { // (N,) - TVM_FFI_ICHECK(tgt_sinfo->IsUnknownNdim() || tgt_sinfo->ndim == 1); + TVM_FFI_ICHECK(tgt_ty->IsUnknownNdim() || tgt_ty->ndim == 1); } else { // (N, d1, d2, ..., dk) TVM_FFI_ICHECK(tgt_shape_value.value().size() >= 2); - TVM_FFI_ICHECK(tgt_sinfo->IsUnknownNdim() || - tgt_sinfo->ndim == static_cast(tgt_shape_value.value().size())); + TVM_FFI_ICHECK(tgt_ty->IsUnknownNdim() || + tgt_ty->ndim == static_cast(tgt_shape_value.value().size())); if (pred_shape_value.defined()) { // check (d1, d2, ..., dk) @@ -1135,14 +1126,14 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } } - if (wgt_sinfo != nullptr) { + if (wgt_ty != nullptr) { ffi::Optional> wgt_shape_value; - if (wgt_sinfo->shape.defined()) { - wgt_shape_value = GetStructInfoAs(wgt_sinfo->shape.value())->values; + if (wgt_ty->shape.defined()) { + wgt_shape_value = GetTypeAs(wgt_ty->shape.value())->values; } if (wgt_shape_value.defined()) { TVM_FFI_ICHECK(wgt_shape_value.value().size() == 1); - TVM_FFI_ICHECK(wgt_sinfo->IsUnknownNdim() || wgt_sinfo->ndim == 1); + TVM_FFI_ICHECK(wgt_ty->IsUnknownNdim() || wgt_ty->ndim == 1); const PrimExpr& C_wgt = wgt_shape_value.value()[0]; if (C.defined() && analyzer->CanProve(C.value() != C_wgt)) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -1158,15 +1149,15 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (reduction == "none") { // () or (N,) or (N, d1, d2, ..., dk) - if (pred_sinfo->shape.as()) { - return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdevice); + if (pred_ty->shape.as()) { + return TensorType(ShapeExpr(output_shape), output_dtype, vdevice); } else { - int output_ndim = pred_sinfo->ndim == kUnknownNDim ? kUnknownNDim : pred_sinfo->ndim - 1; - return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice); + int output_ndim = pred_ty->ndim == kUnknownNDim ? kUnknownNDim : pred_ty->ndim - 1; + return TensorType(output_dtype, /*ndim=*/output_ndim, vdevice); } } else { // sum or mean. output is scalar - return TensorStructInfo(/*shape=*/ShapeExpr(ffi::Array()), output_dtype, vdevice); + return TensorType(/*shape=*/ShapeExpr(ffi::Array()), output_dtype, vdevice); } } @@ -1176,7 +1167,7 @@ TVM_REGISTER_OP("relax.nn.nll_loss") .add_argument("predictions", "Tensor", "The prediction tensor.") .add_argument("targets", "Tensor", "The target tensor.") .add_argument("weights", "ffi::Optional", "The weight of each target values.") - .set_attr("FInferStructInfo", InferStructInfoNLLLoss) + .set_attr("FInferType", InferTypeNLLLoss) .set_attr("FPurity", true); /* relax.nn.batch_flatten */ @@ -1191,26 +1182,26 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.batch_flatten", batch_flatten); } -StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeBatchFlatten(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); - if (data_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, /*ndim=*/2, data_ty->vdevice); } - if (data_sinfo->ndim < 2) { + if (data_ty->ndim < 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "batch_flatten expects input tensor to have at least 2 dimensions, " - << "but got " << data_sinfo->ndim; + << "but got " << data_ty->ndim; } - if (data_sinfo->ndim == 2) { - return data_sinfo; + if (data_ty->ndim == 2) { + return data_ty; } - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2, data_sinfo->vdevice); + return TensorType(data_ty->dtype, /*ndim=*/2, data_ty->vdevice); } PrimExpr batch_dim = data_shape->values[0]; @@ -1219,13 +1210,13 @@ StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder& ctx flat_dim = flat_dim * data_shape->values[i]; } - return TensorStructInfo(ShapeExpr({batch_dim, flat_dim}), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr({batch_dim, flat_dim}), data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.nn.batch_flatten") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoBatchFlatten) + .set_attr("FInferType", InferTypeBatchFlatten) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 65dac4b15381..068a1462e9ff 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -42,8 +42,8 @@ namespace relax { */ #define RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(OpName, OpRegName, RequireFloatDtype) \ RELAX_UNARY_OP_INTERFACE(OpName, OpRegName) \ - RELAX_REGISTER_UNARY_OP(OpRegName).set_attr( \ - "FInferStructInfo", InferStructInfoUnaryArith) + RELAX_REGISTER_UNARY_OP(OpRegName).set_attr("FInferType", \ + InferTypeUnaryArith) /*! \brief Rectified linear unit. */ Expr relu(Expr data); diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index ca963010c3b6..856cd75c5902 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -79,8 +79,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.max_pool1d", max_pool1d); } -StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypePool1D(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->layout, @@ -91,9 +91,9 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); if (!data_shape.defined()) { - return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); + return TensorType(data_ty->dtype, out_layout.ndim(), data_ty->vdevice); } ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); @@ -123,16 +123,16 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { } ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutPool1d( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); - const auto* tensor_sinfo = GetStructInfoAs(call); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 3) << "Unsupported initial layout"; + const auto* tensor_ty = GetTypeAs(call); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_ty->ndim, 3) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -147,7 +147,7 @@ TVM_REGISTER_OP("relax.nn.max_pool1d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPool1D) + .set_attr("FInferType", InferTypePool1D) .set_attr("FRelaxInferLayout", InferLayoutPool1d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -203,8 +203,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.max_pool2d", max_pool2d); } -StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypePool2D(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, @@ -215,9 +215,9 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); if (!data_shape.defined()) { - return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); + return TensorType(data_ty->dtype, out_layout.ndim(), data_ty->vdevice); } ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); @@ -258,16 +258,16 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { } ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutPool2d( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); - const auto* tensor_sinfo = GetStructInfoAs(call); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout"; + const auto* tensor_ty = GetTypeAs(call); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_ty->ndim, 4) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -277,10 +277,10 @@ InferLayoutOutput InferLayoutPool2d( if (layout->layout.ndim() != layout->layout.ndim_primal()) { tirx::SLayout in_layout(attrs->layout, DataType::Int(64)); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); - auto data_si = GetStructInfo(call->args[0]); - TensorStructInfo data_sinfo = data_si.as().value(); + auto data_si = GetType(call->args[0]); + TensorType data_ty = data_si.as().value(); ffi::Optional data_shape = - ffi::GetRef(data_sinfo->shape.as()); + ffi::GetRef(data_ty->shape.as()); if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { // Not handling out_layout being different from in_layout now. Any use case ? new_attrs->layout = desired_layout.name(); @@ -300,7 +300,7 @@ TVM_REGISTER_OP("relax.nn.max_pool2d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPool2D) + .set_attr("FInferType", InferTypePool2D) .set_attr("FRelaxInferLayout", InferLayoutPool2d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -359,8 +359,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.max_pool3d", max_pool3d); } -StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypePool3D(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, attrs->layout, @@ -371,9 +371,9 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); if (!data_shape.defined()) { - return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); + return TensorType(data_ty->dtype, out_layout.ndim(), data_ty->vdevice); } ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); @@ -425,16 +425,16 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { } ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutPool3d( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); - const auto* tensor_sinfo = GetStructInfoAs(call); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 5) << "Unsupported initial layout"; + const auto* tensor_ty = GetTypeAs(call); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_ty->ndim, 5) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -449,7 +449,7 @@ TVM_REGISTER_OP("relax.nn.max_pool3d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPool3D) + .set_attr("FInferType", InferTypePool3D) .set_attr("FRelaxInferLayout", InferLayoutPool3d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -471,7 +471,7 @@ TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPool1D) + .set_attr("FInferType", InferTypePool1D) .set_attr("FRelaxInferLayout", InferLayoutPool1d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -493,7 +493,7 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPool2D) + .set_attr("FInferType", InferTypePool2D) .set_attr("FRelaxInferLayout", InferLayoutPool2d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -515,7 +515,7 @@ TVM_REGISTER_OP("relax.nn.avg_pool3d") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoPool3D) + .set_attr("FInferType", InferTypePool3D) .set_attr("FRelaxInferLayout", InferLayoutPool3d) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -544,8 +544,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool1d", adaptive_avg_pool1d); } -StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeAdaptiveAvgPool1D(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->layout, @@ -556,13 +556,13 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); if (!data_shape.defined()) { - if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && + if (data_ty->shape.defined() && attrs->out_layout == attrs->layout && !attrs->output_size.defined()) { - return data_sinfo; + return data_ty; } else { - return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); + return TensorType(data_ty->dtype, out_layout.ndim(), data_ty->vdevice); } } @@ -573,16 +573,16 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder } ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutAdaptiveAvgPool1D( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); - const auto* tensor_sinfo = GetStructInfoAs(call); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 3) << "Unsupported initial layout"; + const auto* tensor_ty = GetTypeAs(call); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_ty->ndim, 3) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -597,7 +597,7 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool1d") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") - .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool1D) + .set_attr("FInferType", InferTypeAdaptiveAvgPool1D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool1D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -629,8 +629,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool2d", adaptive_avg_pool2d); } -StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, @@ -641,13 +641,13 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); if (!data_shape.defined()) { - if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && + if (data_ty->shape.defined() && attrs->out_layout == attrs->layout && !attrs->output_size.defined()) { - return data_sinfo; + return data_ty; } else { - return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); + return TensorType(data_ty->dtype, out_layout.ndim(), data_ty->vdevice); } } @@ -659,16 +659,16 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder } ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutAdaptiveAvgPool2D( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); - const auto* tensor_sinfo = GetStructInfoAs(call); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout"; + const auto* tensor_ty = GetTypeAs(call); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_ty->ndim, 4) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -677,10 +677,10 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D( if (layout->layout.ndim() != layout->layout.ndim_primal()) { tirx::SLayout in_layout(attrs->layout, DataType::Int(64)); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); - auto data_si = GetStructInfo(call->args[0]); - TensorStructInfo data_sinfo = data_si.as().value(); + auto data_si = GetType(call->args[0]); + TensorType data_ty = data_si.as().value(); ffi::Optional data_shape = - ffi::GetRef(data_sinfo->shape.as()); + ffi::GetRef(data_ty->shape.as()); if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { // Not handling out_layout being different from in_layout now. Any use case ? new_attrs->layout = desired_layout.name(); @@ -699,7 +699,7 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") - .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool2D) + .set_attr("FInferType", InferTypeAdaptiveAvgPool2D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool2D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -731,8 +731,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool3d", adaptive_avg_pool3d); } -StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeAdaptiveAvgPool3D(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, attrs->layout, @@ -743,13 +743,13 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder /*tensor_name=*/"output"); ffi::Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + CheckNdimPerLayoutAndGetShape(call, ctx, data_ty, data_layout); if (!data_shape.defined()) { - if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && + if (data_ty->shape.defined() && attrs->out_layout == attrs->layout && !attrs->output_size.defined()) { - return data_sinfo; + return data_ty; } else { - return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); + return TensorType(data_ty->dtype, out_layout.ndim(), data_ty->vdevice); } } @@ -762,16 +762,16 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder } ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutAdaptiveAvgPool3D( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); - const auto* tensor_sinfo = GetStructInfoAs(call); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 5) << "Unsupported initial layout"; + const auto* tensor_ty = GetTypeAs(call); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_ty->ndim, 5) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs) << "Invalid Call"; @@ -786,7 +786,7 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool3d") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") - .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool3D) + .set_attr("FInferType", InferTypeAdaptiveAvgPool3D) .set_attr("FRelaxInferLayout", InferLayoutAdaptiveAvgPool3D) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 71df381bb72a..2e1fa02591bc 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include @@ -59,36 +59,34 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { return false; } -StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) { - return TupleStructInfo(ffi::Array()); +Type ReturnVoidType(const Call& call, const BlockBuilder& ctx) { + return TupleType(ffi::Array()); } -StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { - return ObjectStructInfo(); -} +Type ReturnObjectType(const Call& call, const BlockBuilder& ctx) { return ObjectType(); } -StructInfo InferStructInfoShapeOf(const Call& call, const BlockBuilder& ctx) { - // use the StructInfo of the argument - auto arg_sinfo = GetStructInfo(call->args[0]); - auto* tensor_sinfo = GetStructInfo(call->args[0]).as(); - TVM_FFI_ICHECK(tensor_sinfo) << "shape_of expects a tensor input, but received " << arg_sinfo - << "; use MatchCast if necessary"; - if (tensor_sinfo->ndim == kUnknownNDim) { - return ShapeStructInfo(kUnknownNDim); +Type InferTypeShapeOf(const Call& call, const BlockBuilder& ctx) { + // use the Type of the argument + auto arg_ty = GetType(call->args[0]); + auto* tensor_ty = GetType(call->args[0]).as(); + TVM_FFI_ICHECK(tensor_ty) << "shape_of expects a tensor input, but received " << arg_ty + << "; use MatchCast if necessary"; + if (tensor_ty->ndim == kUnknownNDim) { + return ShapeType(kUnknownNDim); } // if the tensor shape is a Relax var or omitted, do not try to construct a shape expr from it - if (!tensor_sinfo->shape.defined() || tensor_sinfo->shape.as()) { - return ShapeStructInfo(tensor_sinfo->ndim); + if (!tensor_ty->shape.defined() || tensor_ty->shape.as()) { + return ShapeType(tensor_ty->ndim); } // otherwise, copy over the values from the tensor shape - auto* tensor_shape = tensor_sinfo->shape.as(); + auto* tensor_shape = tensor_ty->shape.as(); TVM_FFI_ICHECK(tensor_shape); - return ShapeStructInfo(tensor_shape->values); + return ShapeType(tensor_shape->values); } // call_pure_packed -StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder& ctx) { +Type InferTypeCallPurePacked(const Call& call, const BlockBuilder& ctx) { if (call->args.size() < 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "call_pure_packed must be called with at least one argument"; @@ -97,14 +95,14 @@ StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder& c // the callee must be an opaque function auto callee = call->args[0]; TVM_FFI_ICHECK(!callee.as()) << "call_pure_packed cannot be used with an op node"; - auto opt = MatchStructInfo(callee); - TVM_FFI_ICHECK(opt) << "Callee must have a function struct info"; - FuncStructInfo finfo = opt.value(); + auto opt = MatchType(callee); + TVM_FFI_ICHECK(opt) << "Callee must have a function type"; + FuncType finfo = opt.value(); TVM_FFI_ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque function, but " << callee << " is not opaque"; - // same logic as from DeriveCallRetStructInfo for ordinary calls + // same logic as from DeriveCallRetType for ordinary calls if (finfo->derive_func.defined()) { // derive using custom derivation function. return finfo->derive_func.value()(call, ctx); @@ -119,17 +117,17 @@ TVM_REGISTER_OP("relax.call_pure_packed") .add_argument("args", "ffi::Array", "The first argument is the function being called. The rest are the " "arguments to that function.") - .set_attr("FInferStructInfo", InferStructInfoCallPurePacked) + .set_attr("FInferType", InferTypeCallPurePacked) .set_attr("FPurity", true); Expr MakeCallPurePacked(const Expr& callee, ffi::Array args, const Attrs& attrs, - ffi::Array sinfo_args) { + ffi::Array ty_args) { static const Op& op = Op::Get("relax.call_pure_packed"); ffi::Array call_args = {callee}; for (auto arg : args) { call_args.push_back(arg); } - return Call(op, call_args, attrs, sinfo_args); + return Call(op, call_args, attrs, ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -139,7 +137,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // call_inplace_packed -StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder& ctx) { +Type InferTypeCallInplacePacked(const Call& call, const BlockBuilder& ctx) { if (call->args.size() <= 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "call_inplace_packed must be called with at least two arguments" @@ -150,9 +148,9 @@ StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder // the callee must be an opaque function auto callee = call->args[0]; TVM_FFI_ICHECK(!callee.as()) << "call_pure_packed cannot be used with an op node"; - auto opt = MatchStructInfo(callee); - TVM_FFI_ICHECK(opt) << "Callee must have a function struct info"; - FuncStructInfo finfo = opt.value(); + auto opt = MatchType(callee); + TVM_FFI_ICHECK(opt) << "Callee must have a function type"; + FuncType finfo = opt.value(); TVM_FFI_ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque function, but " << callee << " is not opaque"; @@ -181,8 +179,8 @@ StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder "-1 (or else simply use call_pure_packed)"; } - // same logic as from DeriveCallRetStructInfo for ordinary calls - StructInfo ret; + // same logic as from DeriveCallRetType for ordinary calls + Type ret; if (finfo->derive_func.defined()) { // derive using custom derivation function. ret = finfo->derive_func.value()(call, ctx); @@ -191,35 +189,33 @@ StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder ret = finfo->ret; } - // make sure that the derived return struct info matches that of the in-place args + // make sure that the derived return type matches that of the in-place args // (note: arg 0 is the packed func, so we add 1 to the arg index) if (attrs->inplace_indices.size() == 1) { auto arg_idx = attrs->inplace_indices[0] + 1; - auto arg_sinfo = GetStructInfo(call->args[arg_idx]); - if (!IsBaseOf(ret, arg_sinfo, ctx->GetAnalyzer())) { + auto arg_ty = GetType(call->args[arg_idx]); + if (!IsBaseOf(ret, arg_ty, ctx->GetAnalyzer())) { TVM_FFI_VISIT_THROW(ValueError, call) - << "The derived return StructInfo does not match that for " - << "the in-place argument at index " << (arg_idx - 1) << ": " << ret << " vs " - << arg_sinfo; + << "The derived return Type does not match that for " + << "the in-place argument at index " << (arg_idx - 1) << ": " << ret << " vs " << arg_ty; } } else { - auto* tup_info = ret.as(); + auto* tup_info = ret.as(); if (!tup_info) { TVM_FFI_VISIT_THROW(ValueError, call) << "Multiple outputs given via the inplace indices " - "but the derived StructInfo is not a tuple"; + "but the derived Type is not a tuple"; } for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { if (attrs->inplace_indices[i] == -1) { continue; } auto arg_idx = attrs->inplace_indices[i] + 1; - auto arg_sinfo = GetStructInfo(call->args[arg_idx]); - auto ret_sinfo = tup_info->fields[i]; - if (!IsBaseOf(ret_sinfo, arg_sinfo, ctx->GetAnalyzer())) { - TVM_FFI_VISIT_THROW(ValueError, call) - << "The derived return StructInfo does not match that for " - << "the in-place argument at index " << (arg_idx - 1) << ": " << ret_sinfo << " vs " - << arg_sinfo; + auto arg_ty = GetType(call->args[arg_idx]); + auto ret_ty = tup_info->fields[i]; + if (!IsBaseOf(ret_ty, arg_ty, ctx->GetAnalyzer())) { + TVM_FFI_VISIT_THROW(ValueError, call) << "The derived return Type does not match that for " + << "the in-place argument at index " << (arg_idx - 1) + << ": " << ret_ty << " vs " << arg_ty; } } } @@ -233,7 +229,7 @@ TVM_REGISTER_OP("relax.call_inplace_packed") .add_argument("args", "ffi::Array", "The first argument is the function being called. The rest are the " "arguments to that function.") - .set_attr("FInferStructInfo", InferStructInfoCallInplacePacked) + .set_attr("FInferType", InferTypeCallInplacePacked) // Warning: considered pure, but it has the potential to create visible effects! // This should only be used if it has been *checked* that it is safe (no aliases, in-place // arguments will no longer be live) and the user believes the packed func to have no @@ -241,14 +237,14 @@ TVM_REGISTER_OP("relax.call_inplace_packed") .set_attr("FPurity", true); Expr MakeCallInplacePacked(Expr func, ffi::Array args, ffi::Array inplace_indices, - ffi::Array sinfo_args) { + ffi::Array ty_args) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->inplace_indices = ffi::Array(inplace_indices.begin(), inplace_indices.end()); static const Op& op = Op::Get("relax.call_inplace_packed"); ffi::Array call_args = {func}; call_args.insert(call_args.end(), args.begin(), args.end()); - return Call(op, call_args, Attrs(attrs), sinfo_args); + return Call(op, call_args, Attrs(attrs), ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -258,9 +254,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { // call_tir -/* If possible, infer a legal value of `arg_sinfo` +/* If possible, infer a legal value of `arg_ty` * - * The `R.call_tir` operator and its variants accept an `arg_sinfo` + * The `R.call_tir` operator and its variants accept an `arg_ty` * parameter, which specifies the shape of the tensor or tensors * returned by a PrimFunc. This output shape must be compatible with * the shape defined by the PrimFunc's signature. @@ -274,38 +270,38 @@ TVM_FFI_STATIC_INIT_BLOCK() { * If the arguments provided are not compatible with the PrimFunc's * signature, an error will be raised. If the arguments are * compatible with the PrimFunc's signature, but are not sufficient to - * determine the output's StructInfo, then `std::nullopt` will be returned. + * determine the output's Type, then `std::nullopt` will be returned. * - * \param func_sinfo The StructInfo of the TIR callee. - * \param arg_sinfo The StructInfo of the argument tuple. - * \param packed_ints_sinfo The StructInfo of the ffi::Shape argument, + * \param func_ty The Type of the TIR callee. + * \param arg_ty The Type of the argument tuple. + * \param packed_ints_ty The Type of the ffi::Shape argument, * if present. * \param opt_inplace_indices For `R.call_tir_inplace`, an array of * indices indicating which outputs are constructed from in-place * mutation of the inputs. See * `CallTIRInplaceAttrs::inplace_indices` for more details. * - * \return The `arg_sinfo`, if it can be inferred from the arguments. + * \return The `arg_ty`, if it can be inferred from the arguments. * Otherwise, std::nullopt. */ -static ffi::Optional InferCallTIROutputStructInfoFromArguments( - StructInfo func_sinfo, StructInfo arg_sinfo, ffi::Optional packed_ints_sinfo, +static ffi::Optional InferCallTIROutputTypeFromArguments( + Type func_ty, Type arg_ty, ffi::Optional packed_ints_ty, ffi::Optional> opt_inplace_indices) { - auto opt_callee_sinfo = func_sinfo.as(); - TVM_FFI_CHECK(opt_callee_sinfo, TypeError) + auto opt_callee_ty = func_ty.as(); + TVM_FFI_CHECK(opt_callee_ty, TypeError) << "The first argument to `R.call_tir` must be a function, " - << "but instead received argument of type " << func_sinfo; - auto callee_sinfo = opt_callee_sinfo.value(); + << "but instead received argument of type " << func_ty; + auto callee_ty = opt_callee_ty.value(); - TVM_FFI_CHECK(callee_sinfo->params.defined(), ValueError) + TVM_FFI_CHECK(callee_ty->params.defined(), ValueError) << "The first argument to `R.call_tir` must be a function " << "with known argument types. " - << "However, the first argument was of type " << callee_sinfo; - auto callee_params = callee_sinfo->params.value(); + << "However, the first argument was of type " << callee_ty; + auto callee_params = callee_ty->params.value(); - const TupleStructInfoNode* args = arg_sinfo.as(); + const TupleTypeNode* args = arg_ty.as(); TVM_FFI_CHECK(args, TypeError) << "The second argument to `R.call_tir` must be a tuple, " - << "but instead received expression of type " << arg_sinfo; + << "but instead received expression of type " << arg_ty; // R.call_tir expects the PrimFunc to have three groups of arguments. // @@ -318,15 +314,15 @@ static ffi::Optional InferCallTIROutputStructInfoFromArguments( // identify the PrimFunc arguments that will be in group (2). size_t num_input_arguments = args->fields.size(); size_t num_trailing_int_arguments = 0; - const ShapeStructInfoNode* packed_tuple_sinfo = nullptr; - if (packed_ints_sinfo) { - auto packed_sinfo = packed_ints_sinfo.value(); - packed_tuple_sinfo = packed_sinfo.as(); - TVM_FFI_CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim(), TypeError) + const ShapeTypeNode* packed_tuple_ty = nullptr; + if (packed_ints_ty) { + auto packed_ty = packed_ints_ty.value(); + packed_tuple_ty = packed_ty.as(); + TVM_FFI_CHECK(packed_tuple_ty && !packed_tuple_ty->IsUnknownNdim(), TypeError) << "The third argument to `R.call_tir`, if present, " << "must be a ffi::Shape with known dimensionality. " - << "However, the argument received was of type " << packed_sinfo; - num_trailing_int_arguments = packed_tuple_sinfo->ndim; + << "However, the argument received was of type " << packed_ty; + num_trailing_int_arguments = packed_tuple_ty->ndim; } else { num_trailing_int_arguments = 0; } @@ -341,50 +337,50 @@ static ffi::Optional InferCallTIROutputStructInfoFromArguments( // current implementation does not support determining the output // shape for `R.dist.call_tir` calls, as it depends on the lowering // of DistIR into regular Relax. - std::function contains_dtensor = [&contains_dtensor](StructInfo sinfo) -> bool { - if (sinfo.as()) { + std::function contains_dtensor = [&contains_dtensor](Type ty) -> bool { + if (ty.as()) { return true; - } else if (auto tuple = sinfo.as()) { + } else if (auto tuple = ty.as()) { return std::any_of(tuple->fields.begin(), tuple->fields.end(), contains_dtensor); } else { return false; } }; - if (contains_dtensor(arg_sinfo)) { + if (contains_dtensor(arg_ty)) { return std::nullopt; } // At this point, the return types are known. However, the shapes // in `callee_params` may contain dynamic shape parameters that are - // not present in the caller's scope. The `DeriveCallRetStructInfo` + // not present in the caller's scope. The `DeriveCallRetType` // utility can infer the value of dynamic parameters in - // `FuncStructInfoNode::ret` based on definitions in - // `FuncStructInfoNode::params`, inferring the correct values in the + // `FuncTypeNode::ret` based on definitions in + // `FuncTypeNode::params`, inferring the correct values in the // caller's scope. // // Since the callee of `R.call_tir` is provided with output - // arguments, where `DeriveCallRetStructInfo` requires a callee that + // arguments, where `DeriveCallRetType` requires a callee that // produces its own outputs, a dummy function signature and // arguments are used. - auto dummy_callee_sinfo = [&]() -> FuncStructInfo { - ffi::Array dummy_params(callee_params.begin(), - callee_params.begin() + num_input_arguments); + auto dummy_callee_ty = [&]() -> FuncType { + ffi::Array dummy_params(callee_params.begin(), + callee_params.begin() + num_input_arguments); for (size_t i = callee_params.size() - num_trailing_int_arguments; i < callee_params.size(); i++) { dummy_params.push_back(callee_params[i]); } - ffi::Array dummy_ret(callee_params.begin() + num_input_arguments, - callee_params.end() - num_trailing_int_arguments); + ffi::Array dummy_ret(callee_params.begin() + num_input_arguments, + callee_params.end() - num_trailing_int_arguments); if (opt_inplace_indices) { // For R.call_tir_inplace, the `inplace_indices` are used to - // indicate which elements of the `out_sinfo` will be generated + // indicate which elements of the `out_ty` will be generated // as in-place mutation from an input. For any in-place - // mutation, the parameter's StructInfo must be inserted into - // `out_sinfo`. + // mutation, the parameter's Type must be inserted into + // `out_ty`. auto inplace_indices = opt_inplace_indices.value(); for (size_t i = 0; i < inplace_indices.size(); i++) { int64_t inplace_input_index = inplace_indices[i]; @@ -394,56 +390,55 @@ static ffi::Optional InferCallTIROutputStructInfoFromArguments( } } - auto dummy_out_sinfo = [&]() -> StructInfo { + auto dummy_out_ty = [&]() -> Type { if (dummy_ret.size() == 1) { return dummy_ret[0]; } else { - return TupleStructInfo(dummy_ret); + return TupleType(dummy_ret); } }(); - return FuncStructInfo(dummy_params, dummy_out_sinfo); + return FuncType(dummy_params, dummy_out_ty); }(); auto dummy_args = [&]() -> ffi::Array { - ffi::Array dummy_args = args->fields.Map( - [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg", sinfo); }); + ffi::Array dummy_args = + args->fields.Map([](const Type& ty) -> Expr { return Var("dummy_leading_arg", ty); }); for (size_t i = 0; i < num_trailing_int_arguments; i++) { - TVM_FFI_ICHECK(packed_tuple_sinfo); - PrimStructInfo dummy_arg_sinfo = [&]() { - if (packed_tuple_sinfo->values) { - return PrimStructInfo(packed_tuple_sinfo->values.value()[i]); + TVM_FFI_ICHECK(packed_tuple_ty); + PrimType dummy_arg_ty = [&]() { + if (packed_tuple_ty->values) { + return PrimType(packed_tuple_ty->values.value()[i]); } else { - return PrimStructInfo(DataType::Int(64)); + return PrimType(DataType::Int(64)); } }(); - dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_sinfo)); + dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_ty)); } return dummy_args; }(); - auto derived_ret_sinfo = DeriveCallRetStructInfo( - dummy_callee_sinfo, Call(Var("dummy_callee", dummy_callee_sinfo), dummy_args), - BlockBuilder::Create(std::nullopt)); + auto derived_ret_ty = + DeriveCallRetType(dummy_callee_ty, Call(Var("dummy_callee", dummy_callee_ty), dummy_args), + BlockBuilder::Create(std::nullopt)); - return derived_ret_sinfo; + return derived_ret_ty; } -StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.size() != 1) { - TVM_FFI_VISIT_THROW(InternalError, call) - << "sinfo_args should have exactly 1 output struct info."; +Type InferTypeCallTIR(const Call& call, const BlockBuilder& ctx) { + if (call->ty_args.size() != 1) { + TVM_FFI_VISIT_THROW(InternalError, call) << "ty_args should have exactly 1 output type."; } TVM_FFI_ICHECK(call->args[0]->IsInstance()) << "R.call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " << "However, the argument " << call->args[0] << " instead has type " << call->args[0]->GetTypeKey(); - StructInfo explicit_sinfo = call->sinfo_args[0]; + Type explicit_ty = call->ty_args[0]; - return explicit_sinfo; + return explicit_ty; } Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { @@ -458,16 +453,15 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { << "but " << call << " has " << call->args.size() << " arguments."; auto callee = call->args[0]; - TVM_FFI_ICHECK(callee->struct_info_.as()) + TVM_FFI_ICHECK(callee->ty.as()) << "Operation " << call->op << " expects the first argument to be a TIR callee. " - << "However, the first argument " << callee << " has struct info " << callee->struct_info_; + << "However, the first argument " << callee << " has type " << callee->ty; Expr arg_tuple = call->args[1]; - TVM_FFI_ICHECK(arg_tuple->struct_info_.as()) + TVM_FFI_ICHECK(arg_tuple->ty.as()) << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " - << "However, the second argument " << arg_tuple << " has struct info " - << arg_tuple->struct_info_ << "."; + << "However, the second argument " << arg_tuple << " has type " << arg_tuple->ty << "."; TVM_FFI_ICHECK(arg_tuple.as() || arg_tuple.as()) << "Operation " << call->op << " must hold its arguments as an in-line tuple. " @@ -477,15 +471,14 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { if (call->args.size() > 2) { Expr packed_ints = call->args[2]; - TVM_FFI_ICHECK(packed_ints->struct_info_.as()) + TVM_FFI_ICHECK(packed_ints->ty.as()) << "Operation " << call->op << " expects the optional third argument, " << "if present, to be a ffi::Shape. " - << "However, the third argument " << packed_ints << " has struct info " - << packed_ints->struct_info_; + << "However, the third argument " << packed_ints << " has type " << packed_ints->ty; } - TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1) - << "R.call_tir should have exactly one `sinfo_args` parameter, " + TVM_FFI_ICHECK_EQ(call->ty_args.size(), 1) + << "R.call_tir should have exactly one `ty_args` parameter, " << "which defines the output of the PrimFunc."; auto unwrap_binding = [&ctx](Expr expr) -> ffi::Optional { @@ -520,7 +513,7 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { // example, if a relax function accepted a tuple as an parameter, // then provided that same tuple as an argument to call_tir. ffi::Array tuple_elements; - size_t num_fields = Downcast(arg_tuple->struct_info_)->fields.size(); + size_t num_fields = Downcast(arg_tuple->ty)->fields.size(); for (size_t i = 0; i < num_fields; i++) { tuple_elements.push_back(TupleGetItem(arg_tuple, i)); } @@ -546,11 +539,11 @@ void ValidateCallTIR(Call call) { auto callee = call->args[0]; Expr arg_tuple = call->args[1]; - auto packed_int_sinfo = [&]() -> ffi::Optional { + auto packed_int_ty = [&]() -> ffi::Optional { if (call->args.size() <= 2) { return std::nullopt; } else { - return GetStructInfo(call->args[2]); + return GetType(call->args[2]); } }(); @@ -562,14 +555,14 @@ void ValidateCallTIR(Call call) { } }(); - StructInfo explicit_sinfo = call->sinfo_args[0]; - auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( - GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); - if (inferred_sinfo.defined()) { - TVM_FFI_CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo), TypeError) - << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " - << "However, the PrimFunc's signature implies that the output should be " << inferred_sinfo - << ", but the `out_sinfo` argument was " << explicit_sinfo; + Type explicit_ty = call->ty_args[0]; + auto inferred_ty = InferCallTIROutputTypeFromArguments(GetType(callee), GetType(arg_tuple), + packed_int_ty, opt_inplace_indices); + if (inferred_ty.defined()) { + TVM_FFI_CHECK(IsBaseOf(inferred_ty.value(), explicit_ty), TypeError) + << "The `out_ty` argument for R.call_tir must be compatible with the PrimFunc. " + << "However, the PrimFunc's signature implies that the output should be " << inferred_ty + << ", but the `out_ty` argument was " << explicit_ty; } } @@ -580,35 +573,35 @@ TVM_REGISTER_OP("relax.call_tir") .add_argument("packed_ints", "Expr", "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") - .set_attr("FInferStructInfo", InferStructInfoCallTIR) + .set_attr("FInferType", InferTypeCallTIR) .set_attr("FNormalize", NormalizeCallTIR) .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", true); -Expr MakeCallTIR(Expr func, Tuple args, ffi::Array out_sinfo_list, +Expr MakeCallTIR(Expr func, Tuple args, ffi::Array out_ty_list, ffi::Optional packed_ints) { - for (const TensorStructInfo& sinfo : out_sinfo_list) { - const auto* shape = sinfo->shape.as(); + for (const TensorType& ty : out_ty_list) { + const auto* shape = ty->shape.as(); TVM_FFI_ICHECK(shape != nullptr) - << "out_sinfo of call_tir should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + << "out_ty of call_tir should have defined ShapeExpr as shape. " + "However, one given type information is " + << ty; } - StructInfo out_sinfo{nullptr}; - if (out_sinfo_list.size() == 1) { - out_sinfo = out_sinfo_list[0]; + Type out_ty{nullptr}; + if (out_ty_list.size() == 1) { + out_ty = out_ty_list[0]; } else { - out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + out_ty = TupleType({out_ty_list.begin(), out_ty_list.end()}); } static const Op& op = Op::Get("relax.call_tir"); Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, {}, {out_sinfo}); + call = Call(op, {func, args}, {}, {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); + call = Call(op, {func, args, packed_ints.value()}, {}, {out_ty}); } return call; } @@ -628,27 +621,27 @@ TVM_REGISTER_OP("relax.call_tir_with_grad") .add_argument("packed_ints", "Expr", "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") - .set_attr("FInferStructInfo", InferStructInfoCallTIR) + .set_attr("FInferType", InferTypeCallTIR) .set_attr("FNormalize", NormalizeCallTIR) .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", true); -Expr MakeCallTIRWithGrad(Expr func, Tuple args, ffi::Array out_sinfo_list, +Expr MakeCallTIRWithGrad(Expr func, Tuple args, ffi::Array out_ty_list, ffi::String te_grad_name, ffi::Map te_grad_kwargs, ffi::Optional packed_ints) { - for (const TensorStructInfo& sinfo : out_sinfo_list) { - const auto* shape = sinfo->shape.as(); + for (const TensorType& ty : out_ty_list) { + const auto* shape = ty->shape.as(); TVM_FFI_ICHECK(shape != nullptr) - << "out_sinfo of call_tir_with_grad should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + << "out_ty of call_tir_with_grad should have defined ShapeExpr as shape. " + "However, one given type information is " + << ty; } - StructInfo out_sinfo{nullptr}; - if (out_sinfo_list.size() == 1) { - out_sinfo = out_sinfo_list[0]; + Type out_ty{nullptr}; + if (out_ty_list.size() == 1) { + out_ty = out_ty_list[0]; } else { - out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + out_ty = TupleType({out_ty_list.begin(), out_ty_list.end()}); } ffi::ObjectPtr attrs = ffi::make_object(); @@ -659,9 +652,9 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, ffi::Array out Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, args}, Attrs(attrs), {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_ty}); } return call; } @@ -679,19 +672,19 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { // may result in an error if performed before normalization. call = Downcast(NormalizeCallTIR(ctx, std::move(call))); - ffi::Array sinfo_outputs = [&]() -> ffi::Array { - auto out_sinfo = call->sinfo_args[0]; - if (auto* tuple_output = out_sinfo.as()) { + ffi::Array ty_outputs = [&]() -> ffi::Array { + auto out_ty = call->ty_args[0]; + if (auto* tuple_output = out_ty.as()) { return tuple_output->fields; } else { - return {out_sinfo}; + return {out_ty}; } }(); // there must be an inplace index for each output const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs); - if (attrs->inplace_indices.size() != sinfo_outputs.size()) { + if (attrs->inplace_indices.size() != ty_outputs.size()) { TVM_FFI_VISIT_THROW(ValueError, call) << "There must be an in-place index specified for each output"; } @@ -730,18 +723,18 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { continue; } - auto sinfo_output = sinfo_outputs[i_output]; - auto tinfo_output = sinfo_output.as(); + auto ty_output = ty_outputs[i_output]; + auto tinfo_output = ty_output.as(); if (!tinfo_output || !tinfo_output->shape.defined() || tinfo_output->IsUnknownDtype()) { TVM_FFI_VISIT_THROW(ValueError, call) - << "The output struct info for an in-place mutation must be a tensor " + << "The output type for an in-place mutation must be a tensor " << "with a defined shape and dtype, " - << "but output " << i_output << " has struct info " << sinfo_output; + << "but output " << i_output << " has type " << ty_output; } - auto sinfo_input = GetStructInfo(call_args->fields[i_input]); - auto tinfo_input = sinfo_input.as(); + auto ty_input = GetType(call_args->fields[i_input]); + auto tinfo_input = ty_input.as(); if (!tinfo_input || (tinfo_output->IsUnknownDtype() || tinfo_output->dtype != tinfo_input->dtype) || @@ -751,9 +744,9 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { TVM_FFI_VISIT_THROW(ValueError, call) << "The input used for an in-place mutation must be " << "a tensor with identical shape and dtype as the output. " - << "However, output " << i_output << " with struct info " << sinfo_output - << " is specified as an in-place mutation of input " << i_input << " with struct info " - << sinfo_input; + << "However, output " << i_output << " with type " << ty_output + << " is specified as an in-place mutation of input " << i_input << " with type " + << ty_input; } } @@ -768,7 +761,7 @@ TVM_REGISTER_OP("relax.call_tir_inplace") .add_argument("packed_ints", "Expr", "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") - .set_attr("FInferStructInfo", InferStructInfoCallTIR) + .set_attr("FInferType", InferTypeCallTIR) .set_attr("FNormalize", NormalizeCallTIRInPlace) .set_attr("FValidate", ValidateCallTIR) // Warning: considered pure, but it has the potential to create visible effects! @@ -777,33 +770,32 @@ TVM_REGISTER_OP("relax.call_tir_inplace") .set_attr("FPurity", true); Expr MakeCallTIRInplace(Expr func, Tuple args, ffi::Array inplace_indices, - ffi::Array out_sinfo_list, - ffi::Optional packed_ints) { - for (const TensorStructInfo& sinfo : out_sinfo_list) { - const auto* shape = sinfo->shape.as(); + ffi::Array out_ty_list, ffi::Optional packed_ints) { + for (const TensorType& ty : out_ty_list) { + const auto* shape = ty->shape.as(); TVM_FFI_ICHECK(shape != nullptr) - << "out_sinfo of call_tir should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + << "out_ty of call_tir should have defined ShapeExpr as shape. " + "However, one given type information is " + << ty; } ffi::ObjectPtr attrs = ffi::make_object(); attrs->inplace_indices = ffi::Array(inplace_indices.begin(), inplace_indices.end()); - StructInfo out_sinfo{nullptr}; - if (out_sinfo_list.size() == 1) { - out_sinfo = out_sinfo_list[0]; + Type out_ty{nullptr}; + if (out_ty_list.size() == 1) { + out_ty = out_ty_list[0]; } else { - out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + out_ty = TupleType({out_ty_list.begin(), out_ty_list.end()}); } static const Op& op = Op::Get("relax.call_tir_inplace"); Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, args}, Attrs(attrs), {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_ty}); } return call; } @@ -815,41 +807,40 @@ TVM_FFI_STATIC_INIT_BLOCK() { // call_dps_packed -StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.size() != 1) { - TVM_FFI_VISIT_THROW(InternalError, call) - << "sinfo_args should have exact 1 output struct info."; +Type InferTypeCallDPSPacked(const Call& call, const BlockBuilder& ctx) { + if (call->ty_args.size() != 1) { + TVM_FFI_VISIT_THROW(InternalError, call) << "ty_args should have exact 1 output type."; } - return call->sinfo_args[0]; + return call->ty_args[0]; } TVM_REGISTER_OP("relax.call_dps_packed") .set_num_inputs(2) .add_argument("func", "Expr", "The destination-passing-style function.") .add_argument("args", "Tuple", "The input arguments.") - .set_attr("FInferStructInfo", InferStructInfoCallDPSPacked) + .set_attr("FInferType", InferTypeCallDPSPacked) // technically, an impure op could be used with this, but there is // little reason to use DPS with an impure op .set_attr("FPurity", true); -Expr MakeCallDPSPacked(Expr func, Tuple args, ffi::Array out_sinfo_list) { - for (const TensorStructInfo& sinfo : out_sinfo_list) { - const auto* shape = sinfo->shape.as(); +Expr MakeCallDPSPacked(Expr func, Tuple args, ffi::Array out_ty_list) { + for (const TensorType& ty : out_ty_list) { + const auto* shape = ty->shape.as(); TVM_FFI_ICHECK(shape != nullptr) - << "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + << "out_ty of call_dps_packed should have defined ShapeExpr as shape. " + "However, one given type information is " + << ty; } - StructInfo out_sinfo{nullptr}; - if (out_sinfo_list.size() == 1) { - out_sinfo = out_sinfo_list[0]; + Type out_ty{nullptr}; + if (out_ty_list.size() == 1) { + out_ty = out_ty_list[0]; } else { - out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + out_ty = TupleType({out_ty_list.begin(), out_ty_list.end()}); } static const Op& op = Op::Get("relax.call_dps_packed"); - return Call(op, {func, args}, {}, {out_sinfo}); + return Call(op, {func, args}, {}, {out_ty}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -859,12 +850,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { // call_py_func -StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.size() != 1) { - TVM_FFI_VISIT_THROW(InternalError, call) - << "sinfo_args should have exact 1 output struct info."; +Type InferTypeCallPyFunc(const Call& call, const BlockBuilder& ctx) { + if (call->ty_args.size() != 1) { + TVM_FFI_VISIT_THROW(InternalError, call) << "ty_args should have exact 1 output type."; } - return call->sinfo_args[0]; + return call->ty_args[0]; } void ValidateCallPyFunc(Call call) { @@ -877,10 +867,9 @@ void ValidateCallPyFunc(Call call) { // Validate that args is a tuple Expr arg_tuple = call->args[1]; - TVM_FFI_ICHECK(arg_tuple->struct_info_.as()) + TVM_FFI_ICHECK(arg_tuple->ty.as()) << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " - << "However, the second argument " << arg_tuple << " has struct info " - << arg_tuple->struct_info_ << "."; + << "However, the second argument " << arg_tuple << " has type " << arg_tuple->ty << "."; TVM_FFI_ICHECK(arg_tuple.as() || arg_tuple.as()) << "Operation " << call->op << " must hold its arguments as an in-line tuple. " @@ -893,28 +882,28 @@ TVM_REGISTER_OP("relax.call_py_func") .set_num_inputs(2) .add_argument("func_name", "StringImm", "The name of the Python function to call.") .add_argument("args", "Tuple", "The input arguments.") - .set_attr("FInferStructInfo", InferStructInfoCallPyFunc) + .set_attr("FInferType", InferTypeCallPyFunc) .set_attr("FValidate", ValidateCallPyFunc) .set_attr("FPurity", true); -Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_sinfo_list) { - for (const TensorStructInfo& sinfo : out_sinfo_list) { - const auto* shape = sinfo->shape.as(); +Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_ty_list) { + for (const TensorType& ty : out_ty_list) { + const auto* shape = ty->shape.as(); TVM_FFI_ICHECK(shape != nullptr) - << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + << "out_ty of call_py_func should have defined ShapeExpr as shape. " + "However, one given type information is " + << ty; } - StructInfo out_sinfo{nullptr}; - if (out_sinfo_list.size() == 1) { - out_sinfo = out_sinfo_list[0]; + Type out_ty{nullptr}; + if (out_ty_list.size() == 1) { + out_ty = out_ty_list[0]; } else { - out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + out_ty = TupleType({out_ty_list.begin(), out_ty_list.end()}); } static const Op& op = Op::Get("relax.call_py_func"); - return Call(op, {func_name, args}, {}, {out_sinfo}); + return Call(op, {func_name, args}, {}, {out_ty}); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -923,13 +912,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // call builtin -StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.size() == 0) { +Type InferTypeCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { + if (call->ty_args.size() == 0) { // by default return void. - return TupleStructInfo(ffi::Array()); + return TupleType(ffi::Array()); } else { - TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1); - return call->sinfo_args[0]; + TVM_FFI_ICHECK_EQ(call->ty_args.size(), 1); + return call->ty_args[0]; } } @@ -937,13 +926,13 @@ TVM_REGISTER_OP("relax.call_builtin_with_ctx") .set_num_inputs(4) .add_argument("func", "Expr", "The builtin packed func.") .add_argument("args", "Tuple", "The input arguments.") - .set_attr("FInferStructInfo", InferStructInfoCallBuiltinWithCtx) + .set_attr("FInferType", InferTypeCallBuiltinWithCtx) // Most builtins are pure, but some are not, like `vm.builtin.attention_kv_cache_append` .set_attr("FPurity", false); -Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, ffi::Array sinfo_args) { +Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, ffi::Array ty_args) { static const Op& op = Op::Get("relax.call_builtin_with_ctx"); - return Call(op, {func, args}, Attrs(), sinfo_args); + return Call(op, {func, args}, Attrs(), ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -953,7 +942,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.null_value") .set_num_inputs(0) - .set_attr("FInferStructInfo", ReturnObjectStructInfo) + .set_attr("FInferType", ReturnObjectType) .set_attr("FPurity", true); Expr MakeCallNullValue() { @@ -973,7 +962,7 @@ TVM_REGISTER_OP("relax.print") .add_argument("vals", "ffi::Array", "The first value is Python-style format string to use to print. The others " "are values to print") - .set_attr("FInferStructInfo", ReturnVoidStructInfo) + .set_attr("FInferType", ReturnVoidType) .set_attr("FCallPacked", "relax.run.print") .set_attr("FPurity", false); @@ -996,7 +985,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // can't actually name it assert or else Python will consider it a syntax error -StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { +Type InferAssertType(const Call& call, const BlockBuilder& ctx) { // Ensure that the condition argument is a boolean scalar. // Also permitted is a tensor with unknown shape and unknown dtype // (checked dynamically in that case). Returns void. @@ -1004,12 +993,12 @@ StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { TVM_FFI_VISIT_THROW(ValueError, call) << "Assert must have at least one argument (the condition)."; } - StructInfo arg_struct_info = GetStructInfo(call->args[0]); - if (!IsBoolStructInfo(arg_struct_info)) { + Type arg_ty = GetType(call->args[0]); + if (!IsBoolType(arg_ty)) { TVM_FFI_VISIT_THROW(TypeError, call) - << "The argument to assert must be a boolean scalar, but received " << arg_struct_info; + << "The argument to assert must be a boolean scalar, but received " << arg_ty; } - return ReturnVoidStructInfo(call, ctx); + return ReturnVoidType(call, ctx); } TVM_REGISTER_OP("relax.assert_op") @@ -1018,7 +1007,7 @@ TVM_REGISTER_OP("relax.assert_op") "The first value is used as the assertion condition. The second value is " "Python-style format string to use for displaying an error message, if the " "assert fails. The others are used as format arguments if there is an error.") - .set_attr("FInferStructInfo", InferAssertStructInfo) + .set_attr("FInferType", InferAssertType) .set_attr("FCallPacked", "relax.run.assert_op") .set_attr("FPurity", false); @@ -1043,7 +1032,7 @@ TVM_REGISTER_OP("relax.make_closure") .set_num_inputs(2) .add_argument("func", "Expr", "The closure.") .add_argument("args", "Tuple", "The captured variables.") - .set_attr("FInferStructInfo", ReturnObjectStructInfo) + .set_attr("FInferType", ReturnObjectType) .set_attr("FPurity", true); Expr MakeClosure(Expr func, Tuple args) { @@ -1058,13 +1047,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { // invoke_closure -StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.empty()) { - return ObjectStructInfo(); - } else if (call->sinfo_args.size() == 1) { - return call->sinfo_args[0]; +Type InferTypeInvokeClosure(const Call& call, const BlockBuilder& ctx) { + if (call->ty_args.empty()) { + return ObjectType(); + } else if (call->ty_args.size() == 1) { + return call->ty_args[0]; } else { - return TupleStructInfo(call->sinfo_args); + return TupleType(call->ty_args); } } @@ -1072,13 +1061,13 @@ TVM_REGISTER_OP("relax.invoke_closure") .set_num_inputs(2) .add_argument("closure", "Expr", "The VMClosure.") .add_argument("args", "Tuple", "The captured variables.") - .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) + .set_attr("FInferType", InferTypeInvokeClosure) // Not all closures are pure. Use invoke_pure_closure for specifying purity .set_attr("FPurity", false); -Expr InvokeClosure(Expr closure, Tuple args, ffi::Array sinfo_args) { +Expr InvokeClosure(Expr closure, Tuple args, ffi::Array ty_args) { static const Op& op = Op::Get("relax.invoke_closure"); - return Call(op, {closure, args}, {}, sinfo_args); + return Call(op, {closure, args}, {}, ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1092,12 +1081,12 @@ TVM_REGISTER_OP("relax.invoke_pure_closure") .set_num_inputs(2) .add_argument("closure", "Expr", "The VMClosure.") .add_argument("args", "Tuple", "The captured variables.") - .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) + .set_attr("FInferType", InferTypeInvokeClosure) .set_attr("FPurity", true); -Expr InvokePureClosure(Expr closure, Tuple args, ffi::Array sinfo_args) { +Expr InvokePureClosure(Expr closure, Tuple args, ffi::Array ty_args) { static const Op& op = Op::Get("relax.invoke_pure_closure"); - return Call(op, {closure, args}, {}, sinfo_args); + return Call(op, {closure, args}, {}, ty_args); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -1110,7 +1099,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.shape_of") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") - .set_attr("FInferStructInfo", InferStructInfoShapeOf) + .set_attr("FInferType", InferTypeShapeOf) .set_attr("FPurity", true); Expr MakeShapeOf(Expr expr) { @@ -1125,18 +1114,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { // size -StructInfo InferStructInfoSize(const Call& call, const BlockBuilder& ctx) { - auto arg_sinfo = GetStructInfo(call->args[0]); - auto* tensor_sinfo = GetStructInfo(call->args[0]).as(); - TVM_FFI_ICHECK(tensor_sinfo) << "size expects a tensor input, but received " << arg_sinfo - << "; use MatchCast if necessary"; - return TensorStructInfo(ShapeExpr(ffi::Array{}), DataType::Int(64)); +Type InferTypeSize(const Call& call, const BlockBuilder& ctx) { + auto arg_ty = GetType(call->args[0]); + auto* tensor_ty = GetType(call->args[0]).as(); + TVM_FFI_ICHECK(tensor_ty) << "size expects a tensor input, but received " << arg_ty + << "; use MatchCast if necessary"; + return TensorType(ShapeExpr(ffi::Array{}), DataType::Int(64)); } TVM_REGISTER_OP("relax.size") .set_num_inputs(1) .add_argument("input", "Expr", "The input tensor") - .set_attr("FInferStructInfo", InferStructInfoSize) + .set_attr("FInferType", InferTypeSize) .set_attr("FPurity", true); Expr MakeSize(Expr expr) { @@ -1151,29 +1140,29 @@ TVM_FFI_STATIC_INIT_BLOCK() { // tensor_to_shape -StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& ctx) { +Type ReturnTensorToShapeType(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args.size() == 1); - TVM_FFI_ICHECK(call->args[0]->struct_info_.defined()); - const auto* tsinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tsinfo); - TVM_FFI_ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be 1-d, " - << "but " << call << " has argument " << call->args[0] - << " with struct info " << call->args[0]->struct_info_; - - if (tsinfo->shape.defined()) { - ShapeExpr shape_expr = Downcast(tsinfo->shape.value()); + TVM_FFI_ICHECK(call->args[0]->ty.defined()); + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty); + TVM_FFI_ICHECK_EQ(tensor_ty->ndim, 1) + << "relax.tensor_to_shape expected argument to be 1-d, " + << "but " << call << " has argument " << call->args[0] << " with type " << call->args[0]->ty; + + if (tensor_ty->shape.defined()) { + ShapeExpr shape_expr = Downcast(tensor_ty->shape.value()); const IntImmNode* ndim = shape_expr->values[0].as(); if (ndim) { - return ShapeStructInfo(ndim->value); + return ShapeType(ndim->value); } } - return ShapeStructInfo(kUnknownNDim); + return ShapeType(kUnknownNDim); } TVM_REGISTER_OP("relax.tensor_to_shape") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") - .set_attr("FInferStructInfo", ReturnTensorToShapeStructInfo) + .set_attr("FInferType", ReturnTensorToShapeType) .set_attr("FPurity", true); Expr MakeTensorToShape(Expr expr) { @@ -1187,19 +1176,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // shape_to_tensor -StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& ctx) { +Type ReturnShapeToTensorType(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args.size() == 1); - TVM_FFI_ICHECK(call->args[0]->struct_info_.defined()); - const auto* sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(sinfo); - int32_t ndim = sinfo->ndim; - return TensorStructInfo(ShapeExpr({PrimExpr(ndim)}), DataType::Int(64)); + TVM_FFI_ICHECK(call->args[0]->ty.defined()); + const auto* ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(ty); + int32_t ndim = ty->ndim; + return TensorType(ShapeExpr({PrimExpr(ndim)}), DataType::Int(64)); } TVM_REGISTER_OP("relax.shape_to_tensor") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") - .set_attr("FInferStructInfo", ReturnShapeToTensorStructInfo) + .set_attr("FInferType", ReturnShapeToTensorType) .set_attr("FCallPacked", "relax.run.shape_to_tensor") .set_attr("FPurity", true); @@ -1215,7 +1204,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // alloc_tensor -StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& ctx) { +Type InferTypeAllocateTensor(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args[0].as()) << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); TVM_FFI_ICHECK(call->args[1].as()) @@ -1232,9 +1221,9 @@ StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& c auto vdevice = GetGlobalVDevice(ctx->GetContextIRModule(), vdevice_index); if (vdevice.defined()) { - return TensorStructInfo(call->args[0], out_dtype, vdevice.value()); + return TensorType(call->args[0], out_dtype, vdevice.value()); } - return TensorStructInfo(call->args[0], out_dtype); + return TensorType(call->args[0], out_dtype); } TVM_REGISTER_OP("relax.builtin.alloc_tensor") @@ -1246,7 +1235,7 @@ TVM_REGISTER_OP("relax.builtin.alloc_tensor") "allocated at runtime. Index -1 is reserved for the host device.") .add_argument("storage_scope", "StringImm", "The storage scope of the storage to allocate. Default is global.") - .set_attr("FInferStructInfo", InferStructInfoAllocateTensor) + .set_attr("FInferType", InferTypeAllocateTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) .set_attr("TAllocator", true); @@ -1274,7 +1263,7 @@ TVM_REGISTER_OP("relax.memory.alloc_storage") .add_argument("storage_scope", "StringImm", "The storage scope of the storage to allocate. Default is global.") .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") - .set_attr("FInferStructInfo", ReturnObjectStructInfo) + .set_attr("FInferType", ReturnObjectType) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) .set_attr("TAllocator", true); @@ -1292,9 +1281,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { // memory planning alloc_tensor -StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& ctx) { - TVM_FFI_ICHECK(GetStructInfoAs(call->args[2])) - << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); +Type InferTypeMemAllocTensor(const Call& call, const BlockBuilder& ctx) { + TVM_FFI_ICHECK(GetTypeAs(call->args[2])) + << "must be a Expr of ShapeType, but got " << call->args[1]->GetTypeKey(); DataType out_dtype; if (const auto* dtype_node = call->args[3].as()) { const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); @@ -1308,11 +1297,11 @@ StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& c } auto vdevice = GetGlobalVDevice(ctx->GetContextIRModule(), vdevice_index); if (vdevice.defined()) { - return TensorStructInfo(call->args[2], out_dtype, vdevice.value()); + return TensorType(call->args[2], out_dtype, vdevice.value()); } } - return TensorStructInfo(call->args[2], out_dtype); + return TensorType(call->args[2], out_dtype); } TVM_REGISTER_OP("relax.memory.alloc_tensor") @@ -1324,7 +1313,7 @@ TVM_REGISTER_OP("relax.memory.alloc_tensor") .add_argument("runtime_device_index", "PrimValue", "The device index indicating on which device the tensor is to be " "allocated at runtime. Index -1 is reserved for the host device.") - .set_attr("FInferStructInfo", InferStructInfoMemAllocTensor) + .set_attr("FInferType", InferTypeMemAllocTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) .set_attr("TAllocator", true); @@ -1356,7 +1345,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.memory.kill_storage") .set_num_inputs(1) .add_argument("storage", "Expr", "The storage to be killed.") - .set_attr("FInferStructInfo", ReturnVoidStructInfo) + .set_attr("FInferType", ReturnVoidType) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1375,7 +1364,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.memory.kill_tensor") .set_num_inputs(1) .add_argument("tensor", "Expr", "The tensor to be killed.") - .set_attr("FInferStructInfo", ReturnVoidStructInfo) + .set_attr("FInferType", ReturnVoidType) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1400,7 +1389,7 @@ TVM_REGISTER_OP("relax.vm.alloc_storage") "to be allocated at runtime.") .add_argument("storage_scope", "StringImm", "The storage scope of the storage to allocate. Default is global.") - .set_attr("FInferStructInfo", ReturnObjectStructInfo) + .set_attr("FInferType", ReturnObjectType) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) .set_attr("TAllocator", true); @@ -1418,7 +1407,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // vm alloc_tensor -StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ctx) { +Type InferTypeVMAllocTensor(const Call& call, const BlockBuilder& ctx) { DataType out_dtype; if (const auto* dtype_node = call->args[3].as()) { const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); @@ -1431,15 +1420,15 @@ StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ct auto vdevice = GetGlobalVDevice(ctx->GetContextIRModule(), vdevice_index); if (const auto* output_shape = call->args[2].as()) { - return TensorStructInfo(ffi::GetRef(output_shape), out_dtype, vdevice); - } else if (const auto* shape_sinfo = GetStructInfoAs(call->args[2])) { - if (shape_sinfo->values.defined()) { - return TensorStructInfo(ShapeExpr(shape_sinfo->values.value()), out_dtype, vdevice); + return TensorType(ffi::GetRef(output_shape), out_dtype, vdevice); + } else if (const auto* shape_ty = GetTypeAs(call->args[2])) { + if (shape_ty->values.defined()) { + return TensorType(ShapeExpr(shape_ty->values.value()), out_dtype, vdevice); } else { - return TensorStructInfo(out_dtype, shape_sinfo->ndim, vdevice); + return TensorType(out_dtype, shape_ty->ndim, vdevice); } } - return TensorStructInfo(out_dtype, kUnknownNDim, vdevice); + return TensorType(out_dtype, kUnknownNDim, vdevice); } TVM_REGISTER_OP("relax.vm.alloc_tensor") @@ -1451,7 +1440,7 @@ TVM_REGISTER_OP("relax.vm.alloc_tensor") .add_argument("runtime_device_index", "PrimValue", "The device index indicating on which device the tensor is " "to be allocated at runtime.") - .set_attr("FInferStructInfo", InferStructInfoVMAllocTensor) + .set_attr("FInferType", InferTypeVMAllocTensor) // memory allocation isn't considered a "visible effect" as far as purity is concerned .set_attr("FPurity", true) .set_attr("TAllocator", true); @@ -1481,7 +1470,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.vm.kill_object") .set_num_inputs(1) .add_argument("obj", "Expr", "The object to be killed.") - .set_attr("FInferStructInfo", ReturnVoidStructInfo) + .set_attr("FInferType", ReturnVoidType) // We mark this as impure so it wouldn't be removed by "remove_all_unused" .set_attr("FPurity", false); @@ -1502,7 +1491,7 @@ TVM_REGISTER_OP("relax.vm.call_tir_dyn") .add_argument("func", "Expr", "The destination-passing-style function.") .add_argument("args", "Tuple", "The input arguments (list of tensors and last argument is ShapeExpr)") - .set_attr("FInferStructInfo", ReturnVoidStructInfo) + .set_attr("FInferType", ReturnVoidType) // "relax.vm.call_tir_dyn" works in an in-place way, which is impure. .set_attr("FPurity", false); @@ -1517,14 +1506,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // builtin stop_lift_params -StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { - return InferStructInfoUnaryArith(call, ctx); +Type InferTypeStopLiftParams(const Call& call, const BlockBuilder& ctx) { + return InferTypeUnaryArith(call, ctx); } TVM_REGISTER_OP("relax.builtin.stop_lift_params") .set_num_inputs(1) .add_argument("x", "Expr", "The input data") - .set_attr("FInferStructInfo", InferStructInfoStopLiftParams) + .set_attr("FInferType", InferTypeStopLiftParams) .set_attr("FPurity", true); Expr MakeStopLiftParams(Expr x) { @@ -1539,23 +1528,23 @@ TVM_FFI_STATIC_INIT_BLOCK() { // to_vdevice -StructInfo InferToVDeviceStructInfo(const Call& call, const BlockBuilder& ctx) { +Type InferToVDeviceType(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args.size() == 1); - TVM_FFI_ICHECK(call->args[0]->struct_info_.defined()); - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + TVM_FFI_ICHECK(call->args[0]->ty.defined()); + TensorType data_ty = GetUnaryInputTensorType(call, ctx); auto attrs = call->attrs.as(); VDevice vdev = attrs->dst_vdevice; - if (data_sinfo->shape.defined()) { - return TensorStructInfo(data_sinfo->shape.value(), data_sinfo->dtype, vdev, data_sinfo->span); + if (data_ty->shape.defined()) { + return TensorType(data_ty->shape.value(), data_ty->dtype, vdev, data_ty->span); } - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, vdev, data_sinfo->span); + return TensorType(data_ty->dtype, data_ty->ndim, vdev, data_ty->span); } TVM_REGISTER_OP("relax.to_vdevice") .set_num_inputs(1) .set_attrs_type() .add_argument("data", "Expr", "The input expression to be copied") - .set_attr("FInferStructInfo", InferToVDeviceStructInfo) + .set_attr("FInferType", InferToVDeviceType) .set_attr("FPurity", true); Expr MakeToVDevice(Expr data, VDevice dst_vdev) { @@ -1572,18 +1561,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { // hint_on_device -StructInfo InferHintOnDeviceStructInfo(const Call& call, const BlockBuilder& ctx) { +Type InferHintOnDeviceType(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(call->args.size() == 1); - TVM_FFI_ICHECK(call->args[0]->struct_info_.defined()); - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - return data_sinfo; + TVM_FFI_ICHECK(call->args[0]->ty.defined()); + TensorType data_ty = GetUnaryInputTensorType(call, ctx); + return data_ty; } TVM_REGISTER_OP("relax.hint_on_device") .set_num_inputs(1) .set_attrs_type() .add_argument("data", "Expr", "The input expression") - .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) + .set_attr("FInferType", InferHintOnDeviceType) .set_attr("FPurity", true); Expr MakeHintOnDevice(Expr data, Device device, ffi::String memory_scope = "global") { diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 49a58db3bfea..f2d9a5ad986c 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -49,7 +49,7 @@ void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { } } -TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) { +TensorType GetInputTensorType(const Call& call, size_t i_arg, const BlockBuilder& ctx) { Op op = Downcast(call->op); TVM_FFI_ICHECK_EQ(op->arguments.size(), call->args.size()) @@ -58,54 +58,54 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const TVM_FFI_ICHECK_LT(i_arg, op->arguments.size()); auto arg = call->args[i_arg]; - auto sinfo = GetStructInfo(arg); + auto ty = GetType(arg); - if (auto tensor_sinfo = sinfo.as()) { - return tensor_sinfo.value(); + if (auto tensor_ty = ty.as()) { + return tensor_ty.value(); } else { TVM_FFI_VISIT_THROW(TypeError, call) << "Operator " << op << " requires argument " << i_arg << " (" << op->arguments[i_arg]->name << ") to be a tensor. " - << "However, the argument " << arg << " is instead of type " << sinfo; + << "However, the argument " << arg << " is instead of type " << ty; // Unreachable, but [[noreturn]] attribute on virtual function // `ReportFatal` is insufficient to silence -Wreturn-type, as // child class might not be [[noreturn]]. - return TensorStructInfo(); + return TensorType(); } } -ffi::Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { +ffi::Array GetInputTensorType(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); Op op = Downcast(call->op); - ffi::Array input_tensor_sinfo; + ffi::Array input_tensor_ty; for (size_t i = 0; i < call->args.size(); ++i) { - input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx)); + input_tensor_ty.push_back(GetInputTensorType(call, i, ctx)); } - return input_tensor_sinfo; + return input_tensor_ty; } -ffi::Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, - const Expr& tup) { - const auto* tuple_sinfo = GetStructInfoAs(tup); - if (tuple_sinfo == nullptr) { +ffi::Array GetTensorTypeFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& tup) { + const auto* tuple_ty = GetTypeAs(tup); + if (tuple_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " - << tup->struct_info_->GetTypeKey(); + << tup->ty->GetTypeKey(); } - ffi::Array tensor_sinfo; - tensor_sinfo.reserve(tuple_sinfo->fields.size()); - for (StructInfo field_sinfo : tuple_sinfo->fields) { - const auto* field_tensor_sinfo = field_sinfo.as(); - if (field_tensor_sinfo == nullptr) { + ffi::Array tensor_ty; + tensor_ty.reserve(tuple_ty->fields.size()); + for (Type field_ty : tuple_ty->fields) { + const auto* field_tensor_ty = field_ty.as(); + if (field_tensor_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " - << tup->struct_info_; + << tup->ty; } - tensor_sinfo.push_back(ffi::GetRef(field_tensor_sinfo)); + tensor_ty.push_back(ffi::GetRef(field_tensor_ty)); } - return tensor_sinfo; + return tensor_ty; } BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::AnalyzerObj* analyzer, diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 771bdad9a223..3cc389e53812 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -42,7 +42,7 @@ namespace tvm { namespace relax { -/************ Op input struct info getter ************/ +/************ Op input type getter ************/ /*! * \brief Check that the operator has @@ -57,76 +57,74 @@ namespace relax { void CheckNumArguments(const Call& call, const BlockBuilder& ctx); /*! - * \brief Get the tensor struct info of the operator input. + * \brief Get the tensor type of the operator input. * \param call The context Call to the operator. * \param i_arg The index of the argument to check * \param ctx The error reporting context. - * \return The tensor struct info of the argument + * \return The tensor type of the argument */ -TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx); +TensorType GetInputTensorType(const Call& call, size_t i_arg, const BlockBuilder& ctx); /*! - * \brief Get the tensor struct info of the operator input. + * \brief Get the tensor type of the operator input. * \param call The context Call to the operator. * \param ctx The error reporting context. - * \return The tensor struct info of each input. + * \return The tensor type of each input. * \note This function require every input to be Tensor. The number of call arguments is required * to match the number of inputs of the op being called. */ -ffi::Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); +ffi::Array GetInputTensorType(const Call& call, const BlockBuilder& ctx); /*! - * \brief Get the tensor struct info of the unary operator input. + * \brief Get the tensor type of the unary operator input. * \param call The context Call to the operator. * \param ctx The error reporting context. - * \return The tensor struct info of the unary operator input. - * \throw Throw exception if the number of input is not one, or the struct info of the input is not - * a tensor struct info. + * \return The tensor type of the unary operator input. + * \throw Throw exception if the number of input is not one, or the type of the input is not + * a tensor type. */ -inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { - return GetInputTensorStructInfo(call, ctx)[0]; +inline TensorType GetUnaryInputTensorType(const Call& call, const BlockBuilder& ctx) { + return GetInputTensorType(call, ctx)[0]; } /*! - * \brief Get the tensor struct info of tuple input. + * \brief Get the tensor type of tuple input. * \param call The context Call to the operator. * \param ctx The error reporting context. * \param tup The input tuple. * \return The tensor struct infos of tuple input. * \throw Throw exception if input expression is not a tuple. */ -ffi::Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, - const Expr& tup); +ffi::Array GetTensorTypeFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& tup); namespace detail { -/*! \brief Implementation helper for GetArgStructInfo */ +/*! \brief Implementation helper for GetArgType */ template -ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, const BlockBuilder& ctx, - size_t index) { - if (!call->args[index]->struct_info_.defined()) { +ArgType GetArgTypeByIndex(const Call& call, const Op& op, const BlockBuilder& ctx, size_t index) { + if (!call->args[index]->ty.defined()) { TVM_FFI_VISIT_THROW(InternalError, call) - << op << " op should have arguments with defined StructInfo. " - << "However, args[" << index << "] has undefined struct info."; + << op << " op should have arguments with defined Type. " + << "However, args[" << index << "] has undefined type."; } - auto sinfo = GetStructInfo(call->args[index]); - auto typed_sinfo = sinfo.as(); + auto ty = GetType(call->args[index]); + auto typed_ty = ty.as(); - if (!typed_sinfo.has_value()) { + if (!typed_ty.has_value()) { TVM_FFI_VISIT_THROW(TypeError, call) << op << " requires that args[" << index << "] be a " << ArgType::ContainerType::_type_key - << ", but was instead " << sinfo << " of type " << sinfo->GetTypeKey(); + << ", but was instead " << ty << " of type " << ty->GetTypeKey(); } - return typed_sinfo.value(); + return typed_ty.value(); } -/*! \brief Implementation helper for GetArgStructInfo */ +/*! \brief Implementation helper for GetArgType */ template -std::tuple GetArgStructInfoHelper(const Call& call, const Op& op, - const BlockBuilder& ctx, - std::index_sequence) { - return std::tuple{GetArgStructInfoByIndex(call, op, ctx, Indices)...}; +std::tuple GetArgTypeHelper(const Call& call, const Op& op, const BlockBuilder& ctx, + std::index_sequence) { + return std::tuple{GetArgTypeByIndex(call, op, ctx, Indices)...}; } } // namespace detail @@ -140,7 +138,7 @@ std::tuple GetArgStructInfoHelper(const Call& call, const Op& op, * \throw Throw exception if input expression is not a tuple. */ template -std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& ctx) { +std::tuple GetArgType(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); size_t n_input = op->arguments.size(); @@ -150,10 +148,10 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c TVM_FFI_ICHECK_EQ(n_input, sizeof...(ArgTypes)) << "Internal error: " << op << " op defines " << n_input << " arguments in its TVM_REGISTER_OP() call, " - << "but GetArgStructInfo was given " << sizeof...(ArgTypes) << " template arguments."; + << "but GetArgType was given " << sizeof...(ArgTypes) << " template arguments."; - return detail::GetArgStructInfoHelper( - call, op, ctx, std::make_index_sequence()); + return detail::GetArgTypeHelper(call, op, ctx, + std::make_index_sequence()); } /************ Op registration macro ************/ @@ -189,50 +187,48 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c /************ Utilities ************/ /*! - * \brief Infer the struct info for unary elementwise ops. + * \brief Infer the type for unary elementwise ops. * \param call The context Call to the operator. * \param ctx The error reporting context. * \param f_compute_out_dtype The function to compute the output dtype, with - * signature DataType f_compute_out_dtype(const TensorStructInfo& input_sinfo). + * signature DataType f_compute_out_dtype(const TensorType& input_ty). * \tparam require_float_dtype whether this op requires the input dtype to be float * \tparam Ftype the type of f_compute_out_dtype - * \return The inferred struct info. + * \return The inferred type. */ template -inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, - FType f_compute_out_dtype) { - TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (require_float_dtype && !input_sinfo->IsUnknownDtype() && - (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) { +inline Type InferTypeUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { + TensorType input_ty = GetUnaryInputTensorType(call, ctx); + if (require_float_dtype && !input_ty->IsUnknownDtype() && + (!input_ty->dtype.is_float() && !input_ty->dtype.is_bfloat())) { TVM_FFI_VISIT_THROW(TypeError, call) << call->op << " requires the input tensor to have float dtype. However, the given input dtype is " - << input_sinfo->dtype; + << input_ty->dtype; } - auto output_sinfo = ffi::make_object(*input_sinfo.get()); - output_sinfo->dtype = f_compute_out_dtype(input_sinfo); - if (call->sinfo_args.size() > 0) { - auto defined_sinfo = call->sinfo_args[0].as(); - TVM_FFI_ICHECK(defined_sinfo); - auto shape = output_sinfo->GetShape(); + auto output_ty = ffi::make_object(*input_ty.get()); + output_ty->dtype = f_compute_out_dtype(input_ty); + if (call->ty_args.size() > 0) { + auto defined_ty = call->ty_args[0].as(); + TVM_FFI_ICHECK(defined_ty); + auto shape = output_ty->GetShape(); TVM_FFI_ICHECK(shape.defined()); - TVM_FFI_ICHECK(defined_sinfo->vdevice.has_value()); - return TensorStructInfo(ShapeExpr(shape.value()), output_sinfo->dtype, - defined_sinfo->vdevice.value()); + TVM_FFI_ICHECK(defined_ty->vdevice.has_value()); + return TensorType(ShapeExpr(shape.value()), output_ty->dtype, defined_ty->vdevice.value()); } else { - return TensorStructInfo(output_sinfo); + return TensorType(output_ty); } } /*! - * \brief Infer the struct info by returning the struct info of the input argument. + * \brief Infer the type by returning the type of the input argument. * \param call The context Call to the operator. * \param ctx The error reporting context. * \tparam arg_index The index of the argument to infer the output dtype from. - * \return The inferred struct info. + * \return The inferred type. */ template -StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { +Type ReturnTypeFromArg(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); int n_input = op->arguments.size(); if (static_cast(call->args.size()) != n_input) { @@ -243,21 +239,21 @@ StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { << op << " op has only " << n_input << "arguments, but try to get the arg with index " << arg_index; } - return GetStructInfo(call->args[arg_index]); + return GetType(call->args[arg_index]); } /*! - * \brief Infer the struct info for unary arithmetic elementwise ops. It's also + * \brief Infer the type for unary arithmetic elementwise ops. It's also * used in some NN operators. * \param call The context Call to the operator. * \param ctx The error reporting context. * \tparam require_float_dtype whether this op requires the input dtype to be float - * \return The inferred struct info. + * \return The inferred type. */ template -StructInfo InferStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) { - return InferStructInfoUnary( - call, ctx, [](const TensorStructInfo& input_sinfo) { return input_sinfo->dtype; }); +Type InferTypeUnaryArith(const Call& call, const BlockBuilder& ctx) { + return InferTypeUnary( + call, ctx, [](const TensorType& input_ty) { return input_ty->dtype; }); } /*! @@ -272,22 +268,22 @@ InferLayoutOutput InferLayoutUnaryEwise( const VarLayoutMap& var_layout_map); /*! - * \brief Get the element dtype from StructInfo + * \brief Get the element dtype from Type * - * \param sinfo The StructInfo to expect + * \param ty The Type to expect * \return The inferred element dtype. - * \throw Throw exception if the StructInfo doesn't have an element type. + * \throw Throw exception if the Type doesn't have an element type. */ -inline std::optional GetElementDType(const StructInfo& sinfo) { - if (const auto* prim = sinfo.as()) { +inline std::optional GetElementDType(const Type& ty) { + if (const auto* prim = ty.as()) { return prim->dtype; - } else if (const auto* tensor = sinfo.as()) { + } else if (const auto* tensor = ty.as()) { return tensor->dtype; } else { return std::nullopt; - TVM_FFI_THROW(TypeError) << "Only PrimStructInfo and TensorStructInfo " + TVM_FFI_THROW(TypeError) << "Only PrimType and TensorType " << "have an associated data type. " - << "Cannot determine element type of " << sinfo; + << "Cannot determine element type of " << ty; } } @@ -295,31 +291,30 @@ inline std::optional GetElementDType(const StructInfo& sinfo) { * \brief Infer the output datatype for binary arithmetic operators. * \param call The context Call to the operator. * \param ctx The error reporting context. - * \param lhs_sinfo The struct info of the first operand - * \param rhs_sinfo The struct info of the second operand + * \param lhs_ty The type of the first operand + * \param rhs_ty The type of the second operand * \return The inferred output dtype. - * \throw Throw exception if the dtype of two input TensorStructInfo don’t match + * \throw Throw exception if the dtype of two input TensorType don’t match */ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, - const StructInfo& lhs_sinfo, - const StructInfo& rhs_sinfo) { - auto opt_lhs_dtype = GetElementDType(lhs_sinfo); + const Type& lhs_ty, const Type& rhs_ty) { + auto opt_lhs_dtype = GetElementDType(lhs_ty); if (!opt_lhs_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "Binary operators must have the same datatype for both operands. " - << "However, " << call << " has argument " << call->args[0] - << " on the LHS, with struct info " << lhs_sinfo << ". This is of type " - << lhs_sinfo->GetTypeKey() << ", which does not have a datatype."; + << "However, " << call << " has argument " << call->args[0] << " on the LHS, with type " + << lhs_ty << ". This is of type " << lhs_ty->GetTypeKey() + << ", which does not have a datatype."; } auto lhs_dtype = opt_lhs_dtype.value(); - auto opt_rhs_dtype = GetElementDType(rhs_sinfo); + auto opt_rhs_dtype = GetElementDType(rhs_ty); if (!opt_rhs_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "Binary operators must have the same datatype for both operands. " - << "However, " << call << " has argument " << call->args[1] - << " on the RHS, with struct info " << rhs_sinfo << ". This is of type " - << rhs_sinfo->GetTypeKey() << ", which does not have a datatype."; + << "However, " << call << " has argument " << call->args[1] << " on the RHS, with type " + << rhs_ty << ". This is of type " << rhs_ty->GetTypeKey() + << ", which does not have a datatype."; } auto rhs_dtype = opt_rhs_dtype.value(); @@ -328,9 +323,8 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& } else if (lhs_dtype != rhs_dtype && !lhs_dtype.is_bool() && !rhs_dtype.is_bool()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Binary operators must have the same datatype for both operands. " - << "However, " << call << " uses datatype " << lhs_dtype << " on the LHS (StructInfo of " - << lhs_sinfo << "), and datatype " << rhs_dtype << " on the RHS (StructInfo of " - << rhs_sinfo << ")."; + << "However, " << call << " uses datatype " << lhs_dtype << " on the LHS (Type of " + << lhs_ty << "), and datatype " << rhs_dtype << " on the RHS (Type of " << rhs_ty << ")."; } return lhs_dtype; } @@ -339,17 +333,16 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& * \brief Infer the output virtual device for binary arithmetic operators. * \param call The context Call to the operator. * \param ctx The error reporting context. - * \param lhs_sinfo The struct info of the first operand - * \param rhs_sinfo The struct info of the second operand + * \param lhs_ty The type of the first operand + * \param rhs_ty The type of the second operand * \return The inferred output vdevice. - * \throw Throw exception if the vdevice of two input TensorStructInfo don’t match + * \throw Throw exception if the vdevice of two input TensorType don’t match */ inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx, - const StructInfo& lhs_sinfo, - const StructInfo& rhs_sinfo) { - auto get_vdevice = [&](const StructInfo& sinfo) -> ffi::Optional { - if (const auto* tensor = sinfo.as()) { + const Type& lhs_ty, const Type& rhs_ty) { + auto get_vdevice = [&](const Type& ty) -> ffi::Optional { + if (const auto* tensor = ty.as()) { return tensor->vdevice; } else { return std::nullopt; @@ -361,12 +354,12 @@ inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, * Like targets that supports mixed VDevices (like differed by memory_scope for Adreno) * and have specialized derivation for output VDevice. */ - if (call->sinfo_args.size() > 0) { - return get_vdevice(call->sinfo_args[0]); + if (call->ty_args.size() > 0) { + return get_vdevice(call->ty_args[0]); } - auto lhs_vdevice = get_vdevice(lhs_sinfo); - auto rhs_vdevice = get_vdevice(rhs_sinfo); + auto lhs_vdevice = get_vdevice(lhs_ty); + auto rhs_vdevice = get_vdevice(rhs_ty); if (!lhs_vdevice.defined() || !lhs_vdevice.value()->target.defined()) { return rhs_vdevice; @@ -578,24 +571,24 @@ inline std::pair CheckTensorLayout( } /*! - * \brief Check if the given tensor struct info has expected ndim per the given layout (or the ndim + * \brief Check if the given tensor type has expected ndim per the given layout (or the ndim * is unknown), and try to cast the shape to ShapeExpr. * \param call The context Call to the operator. * \param ctx The error reporting context. - * \param sinfo The input tensor struct info to be checked. + * \param ty The input tensor type to be checked. * \param layout The layout that the given tensor is expected to have. * \return The shape of the input tensor in ShapeExpr, or `std::nullopt` if the shape is unknown. */ inline ffi::Optional CheckNdimPerLayoutAndGetShape(const Call& call, const BlockBuilder& ctx, - const TensorStructInfo& sinfo, + const TensorType& ty, const tirx::SLayout& layout) { - if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { + if (!ty->IsUnknownNdim() && ty->ndim != static_cast(layout.ndim())) { TVM_FFI_VISIT_THROW(ValueError, call) << "In " << call->op << ", layout " << layout << " requires the input to be " - << layout.ndim() << "-dim tensor. However, the given input has ndim " << sinfo->ndim; + << layout.ndim() << "-dim tensor. However, the given input has ndim " << ty->ndim; } - if (const auto* shape_expr = sinfo->shape.as()) { + if (const auto* shape_expr = ty->shape.as()) { return ffi::GetRef(shape_expr); } return std::nullopt; diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 85c71641f4f1..844dd2cfefb6 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -33,8 +33,7 @@ namespace tvm { namespace relax { template -StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, - FType f_compute_out_dtype) { +Type InferTypeBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { Op op = Downcast(call->op); size_t n_input = op->arguments.size(); if (call->args.size() != n_input) { @@ -42,34 +41,30 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, << call->op << " op should have " << n_input << " arguments"; } - auto lhs_sinfo = GetStructInfo(call->args[0]); - auto rhs_sinfo = GetStructInfo(call->args[1]); + auto lhs_ty = GetType(call->args[0]); + auto rhs_ty = GetType(call->args[1]); - TVM_FFI_CHECK(lhs_sinfo.as() || lhs_sinfo.as(), - TypeError) + TVM_FFI_CHECK(lhs_ty.as() || lhs_ty.as(), TypeError) << "Arguments to binary operators must be either R.Tensor or R.Prim types, " - << "but expression " << call << " has LHS " << call->args[0] << ", which has StructInfo " - << lhs_sinfo; - TVM_FFI_CHECK(rhs_sinfo.as() || rhs_sinfo.as(), - TypeError) + << "but expression " << call << " has LHS " << call->args[0] << ", which has Type " << lhs_ty; + TVM_FFI_CHECK(rhs_ty.as() || rhs_ty.as(), TypeError) << "Arguments to binary operators must be either R.Tensor or R.Prim types, " - << "but expression " << call << " has RHS " << call->args[1] << ", which has StructInfo " - << rhs_sinfo; + << "but expression " << call << " has RHS " << call->args[1] << ", which has Type " << rhs_ty; // DateType - DataType output_dtype = f_compute_out_dtype(call, ctx, lhs_sinfo, rhs_sinfo); + DataType output_dtype = f_compute_out_dtype(call, ctx, lhs_ty, rhs_ty); - if (lhs_sinfo.as() && rhs_sinfo.as()) { - return PrimStructInfo(output_dtype); + if (lhs_ty.as() && rhs_ty.as()) { + return PrimType(output_dtype); } // VDevice - ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, lhs_sinfo, rhs_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, lhs_ty, rhs_ty); - auto get_ndim = [&](const StructInfo& sinfo) -> int { - if (sinfo.as()) { + auto get_ndim = [&](const Type& ty) -> int { + if (ty.as()) { return 1; - } else if (const auto* tensor = sinfo.as()) { + } else if (const auto* tensor = ty.as()) { return tensor->ndim; } else { return kUnknownNDim; @@ -78,8 +73,8 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, // ndims int output_ndim = [&]() { - int lhs_ndim = get_ndim(lhs_sinfo); - int rhs_ndim = get_ndim(rhs_sinfo); + int lhs_ndim = get_ndim(lhs_ty); + int rhs_ndim = get_ndim(rhs_ty); if (lhs_ndim == kUnknownNDim || rhs_ndim == kUnknownNDim) { return kUnknownNDim; } else { @@ -89,10 +84,10 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, // Shapes - auto get_shape = [](const StructInfo& sinfo) -> ffi::Optional> { - if (sinfo.as()) { + auto get_shape = [](const Type& ty) -> ffi::Optional> { + if (ty.as()) { return ffi::Array{IntImm::Int64(1)}; - } else if (const auto* tensor = sinfo.as()) { + } else if (const auto* tensor = ty.as()) { return tensor->GetShape(); } else { return std::nullopt; @@ -101,19 +96,19 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, // If both inputs have a known shape, directly infer the shape of // the output. - auto lhs_shape = get_shape(lhs_sinfo); - auto rhs_shape = get_shape(rhs_sinfo); + auto lhs_shape = get_shape(lhs_ty); + auto rhs_shape = get_shape(rhs_ty); if (lhs_shape && rhs_shape) { ffi::Optional> output_shape = InferBinaryBroadcastShape(call, ctx, lhs_shape.value(), rhs_shape.value()); if (output_shape.defined()) { TVM_FFI_ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); - return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdevice); + return TensorType(ShapeExpr(output_shape.value()), output_dtype, vdevice); } } - auto get_shape_expr = [](const StructInfo& sinfo) -> ffi::Optional { - if (const auto* tensor = sinfo.as()) { + auto get_shape_expr = [](const Type& ty) -> ffi::Optional { + if (const auto* tensor = ty.as()) { return tensor->shape; } else { return std::nullopt; @@ -121,28 +116,27 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, }; // If the input shape is unknown, but both inputs have the same - // `ShapeStructInfo`variable for their shape, then propagate that + // `ShapeType`variable for their shape, then propagate that // variable to the output. - auto lhs_shape_expr = get_shape_expr(lhs_sinfo); - auto rhs_shape_expr = get_shape_expr(rhs_sinfo); + auto lhs_shape_expr = get_shape_expr(lhs_ty); + auto rhs_shape_expr = get_shape_expr(rhs_ty); if (lhs_shape_expr.defined() && lhs_shape_expr.same_as(rhs_shape_expr)) { - return TensorStructInfo(lhs_shape_expr.value(), output_dtype, vdevice); + return TensorType(lhs_shape_expr.value(), output_dtype, vdevice); } // If neither of those cases holds, then fall back to an unknown // shape with `output_ndim` dimensionality. - return TensorStructInfo(output_dtype, output_ndim, vdevice); + return TensorType(output_dtype, output_ndim, vdevice); } -StructInfo InferStructInfoBroadcastArith(const Call& call, const BlockBuilder& ctx) { - return InferStructInfoBroadcast(call, ctx, InferBinaryArithOpOutDtype); +Type InferTypeBroadcastArith(const Call& call, const BlockBuilder& ctx) { + return InferTypeBroadcast(call, ctx, InferBinaryArithOpOutDtype); } -StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx) { - return InferStructInfoBroadcast( - call, ctx, - [](const Call& call, const BlockBuilder& ctx, const StructInfo& lhs_sinfo, - const StructInfo& rhs_sinfo) { return DataType::Bool(); }); +Type InferTypeBroadcastCMP(const Call& call, const BlockBuilder& ctx) { + return InferTypeBroadcast(call, ctx, + [](const Call& call, const BlockBuilder& ctx, const Type& lhs_ty, + const Type& rhs_ty) { return DataType::Bool(); }); } InferLayoutOutput InferLayoutBinaryEwise( @@ -152,14 +146,14 @@ InferLayoutOutput InferLayoutBinaryEwise( LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[0]); LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[1]); - auto* x1_sinfo = GetStructInfoAs(call->args[0]); - auto* x2_sinfo = GetStructInfoAs(call->args[1]); + auto* x1_ty = GetTypeAs(call->args[0]); + auto* x2_ty = GetTypeAs(call->args[1]); - TVM_FFI_ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) + TVM_FFI_ICHECK(!x1_ty->IsUnknownNdim() && !x2_ty->IsUnknownNdim()) << "Unknown dim tensors should not be handled by this function"; - ffi::Optional shape1 = ffi::GetRef(x1_sinfo->shape.as()); - ffi::Optional shape2 = ffi::GetRef(x2_sinfo->shape.as()); + ffi::Optional shape1 = ffi::GetRef(x1_ty->shape.as()); + ffi::Optional shape2 = ffi::GetRef(x2_ty->shape.as()); // Lets handle sub indexing as long as primal dims are matching if ((layout1->layout.ndim() != layout1->layout.ndim_primal()) || (layout2->layout.ndim() != layout2->layout.ndim_primal())) { @@ -178,19 +172,19 @@ InferLayoutOutput InferLayoutBinaryEwise( } } - if (x1_sinfo->ndim <= x2_sinfo->ndim) { - if (x1_sinfo->ndim == 0) { + if (x1_ty->ndim <= x2_ty->ndim) { + if (x1_ty->ndim == 0) { LayoutDecision out_layout = layout2; return InferLayoutOutput({LayoutDecision(""), layout2}, {out_layout}, Attrs(call->attrs)); } - LayoutDecision out_layout = FollowDecision(layout1, x2_sinfo->ndim); + LayoutDecision out_layout = FollowDecision(layout1, x2_ty->ndim); return InferLayoutOutput({layout1, out_layout}, {out_layout}, Attrs(call->attrs)); } else { - if (x2_sinfo->ndim == 0) { + if (x2_ty->ndim == 0) { LayoutDecision out_layout = layout1; return InferLayoutOutput({layout1, LayoutDecision("")}, {out_layout}, Attrs(call->attrs)); } - LayoutDecision out_layout = FollowDecision(layout2, x1_sinfo->ndim); + LayoutDecision out_layout = FollowDecision(layout2, x1_ty->ndim); return InferLayoutOutput({out_layout, layout2}, {out_layout}, Attrs(call->attrs)); } diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index a234a30bc221..aadbc5c70ad0 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -53,13 +53,13 @@ namespace relax { .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) \ .set_attr("FPurity", true) -#define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName) \ - RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ - "FInferStructInfo", InferStructInfoBroadcastArith) +#define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr("FInferType", \ + InferTypeBroadcastArith) -#define RELAX_REGISTER_CMP_OP_AND_IMPL(OpName) \ - RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ - "FInferStructInfo", InferStructInfoBroadcastCMP) +#define RELAX_REGISTER_CMP_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr("FInferType", \ + InferTypeBroadcastCMP) /***************** Arithmetic operators *****************/ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index a47cf9716b01..e5e56916cbc0 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -69,26 +69,26 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.full", full); } -StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { +Type InferTypeFull(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Full op should have 2 arguments"; } - const auto* shape_sinfo = GetStructInfoAs(call->args[0]); - const auto* fill_value_sinfo = GetStructInfoAs(call->args[1]); - if (shape_sinfo == nullptr) { + const auto* shape_ty = GetTypeAs(call->args[0]); + const auto* fill_value_ty = GetTypeAs(call->args[1]); + if (shape_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Full requires the input shape to be a Shape. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (fill_value_sinfo == nullptr || fill_value_sinfo->ndim != 0) { + if (fill_value_ty == nullptr || fill_value_ty->ndim != 0) { TVM_FFI_VISIT_THROW(ValueError, call) << "Full requires the input fill value to be zero rank Tensor. However, the given one is " - << call->args[1]->struct_info_; + << call->args[1]->ty; } const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->dtype.is_void() ? fill_value_sinfo->dtype : attrs->dtype; - return TensorStructInfo(/*shape=*/call->args[0], out_dtype, fill_value_sinfo->vdevice); + DataType out_dtype = attrs->dtype.is_void() ? fill_value_ty->dtype : attrs->dtype; + return TensorType(/*shape=*/call->args[0], out_dtype, fill_value_ty->vdevice); } TVM_REGISTER_OP("relax.full") @@ -96,7 +96,7 @@ TVM_REGISTER_OP("relax.full") .set_num_inputs(2) .add_argument("shape", "Shape", "The shape of the created tensor.") .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") - .set_attr("FInferStructInfo", InferStructInfoFull) + .set_attr("FInferType", InferTypeFull) .set_attr("RequiresArgumentShapes", false) .set_attr("FDataDependent", true) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) @@ -115,23 +115,23 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.full_like", full_like); } -StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo fill_value_sinfo = input_sinfo[1]; - if (fill_value_sinfo->ndim != 0) { +Type InferTypeFullLike(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType data_ty = input_ty[0]; + TensorType fill_value_ty = input_ty[1]; + if (fill_value_ty->ndim != 0) { TVM_FFI_VISIT_THROW(ValueError, call) << "FullLike requires the input fill value to be zero " "rank Tensor. However, the given one has ndim" - << fill_value_sinfo->ndim; + << fill_value_ty->ndim; } const auto* attrs = call->attrs.as(); if (attrs->dtype.is_void()) { - return data_sinfo; + return data_ty; } else { - auto output_sinfo = ffi::make_object(*data_sinfo.get()); - output_sinfo->dtype = attrs->dtype; - return TensorStructInfo(output_sinfo); + auto output_ty = ffi::make_object(*data_ty.get()); + output_ty->dtype = attrs->dtype; + return TensorType(output_ty); } } @@ -140,36 +140,36 @@ TVM_REGISTER_OP("relax.full_like") .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("fill_value", "Tensor", "The scalar value to fill.") - .set_attr("FInferStructInfo", InferStructInfoFullLike) + .set_attr("FInferType", InferTypeFullLike) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); // Structure info inference for ones and zeros -StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx) { +Type InferTypeOnesZeros(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "Ones/Zeros should have 1 argument"; } - const auto* shape_sinfo = GetStructInfoAs(call->args[0]); - if (shape_sinfo == nullptr) { + const auto* shape_ty = GetTypeAs(call->args[0]); + if (shape_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Ones/Zeros requires the input shape to be a Shape. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } const auto* attrs = call->attrs.as(); - return TensorStructInfo(/*shape=*/call->args[0], attrs->dtype); + return TensorType(/*shape=*/call->args[0], attrs->dtype); } // Structure info inference for ones_like and zeros_like -StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); if (attrs->dtype.is_void()) { - return data_sinfo; + return data_ty; } else { - auto output_sinfo = ffi::make_object(*data_sinfo.get()); - output_sinfo->dtype = attrs->dtype; - return TensorStructInfo(output_sinfo); + auto output_ty = ffi::make_object(*data_ty.get()); + output_ty->dtype = attrs->dtype; + return TensorType(output_ty); } } @@ -199,7 +199,7 @@ TVM_REGISTER_OP("relax.ones") .set_attrs_type() .set_num_inputs(1) .add_argument("shape", "Shape", "The shape of the created tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesZeros) + .set_attr("FInferType", InferTypeOnesZeros) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -207,7 +207,7 @@ TVM_REGISTER_OP("relax.ones_like") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) + .set_attr("FInferType", InferTypeOnesLikeZerosLike) .set_attr("FPurity", true); /* relax.zeros & relax.zeros_like */ @@ -236,7 +236,7 @@ TVM_REGISTER_OP("relax.zeros") .set_attrs_type() .set_num_inputs(1) .add_argument("shape", "Shape", "The shape of the created tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesZeros) + .set_attr("FInferType", InferTypeOnesZeros) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -244,7 +244,7 @@ TVM_REGISTER_OP("relax.zeros_like") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) + .set_attr("FInferType", InferTypeOnesLikeZerosLike) .set_attr("FPurity", true); /* relax.eye & relax.eye_like */ @@ -267,7 +267,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.eye", eye).def("relax.op.eye_like", eye_like); } -StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { +Type InferTypeEye(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "Eye op should have 3 arguments: n, m, and k, but got " << call->args.size() << " arguments"; @@ -285,32 +285,32 @@ StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { PrimExpr m = get_prim_value(call->args[1], "m"); DataType dtype = call->attrs.as()->dtype; - return TensorStructInfo(ShapeExpr({n, m}), dtype); + return TensorType(ShapeExpr({n, m}), dtype); } -StructInfo InferStructInfoEyeLike(const Call& call, const BlockBuilder& ctx) { +Type InferTypeEyeLike(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Eye_like op should have 2 arguments: x and k, but got " << call->args.size() << " arguments"; } - const auto* x_sinfo = GetStructInfoAs(call->args[0]); - if (x_sinfo == nullptr) { + const auto* x_ty = GetTypeAs(call->args[0]); + if (x_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Eye_like expects the input `x` to be a Tensor, but got " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (x_sinfo->ndim != 2 && x_sinfo->ndim != kUnknownNDim) { + if (x_ty->ndim != 2 && x_ty->ndim != kUnknownNDim) { TVM_FFI_VISIT_THROW(ValueError, call) - << "Eye_like expects the input tensor to be 2-dimensional, but got " << x_sinfo->ndim + << "Eye_like expects the input tensor to be 2-dimensional, but got " << x_ty->ndim << " dimensions"; } const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->dtype.is_void() ? x_sinfo->dtype : attrs->dtype; + DataType out_dtype = attrs->dtype.is_void() ? x_ty->dtype : attrs->dtype; - return TensorStructInfo(x_sinfo->shape.value(), out_dtype, x_sinfo->vdevice); + return TensorType(x_ty->shape.value(), out_dtype, x_ty->vdevice); } TVM_REGISTER_OP("relax.eye") @@ -319,7 +319,7 @@ TVM_REGISTER_OP("relax.eye") .add_argument("n", "PrimValue", "Number of rows in the output.") .add_argument("m", "PrimValue", "Number of columns in the output.") .add_argument("k", "PrimValue", "Index of the diagonal.") - .set_attr("FInferStructInfo", InferStructInfoEye) + .set_attr("FInferType", InferTypeEye) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -328,7 +328,7 @@ TVM_REGISTER_OP("relax.eye_like") .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("k", "PrimValue", "Index of the diagonal.") - .set_attr("FInferStructInfo", InferStructInfoEyeLike) + .set_attr("FInferType", InferTypeEyeLike) .set_attr("FPurity", true); /* relax.arange */ @@ -344,7 +344,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.arange", arange); } -StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { +Type InferTypeArange(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "Arange should have 3 arguments, which are `start`, `end` and `step`, but got " @@ -371,7 +371,7 @@ StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { } arith::Analyzer analyzer; num_elem = analyzer->Simplify(num_elem); - return TensorStructInfo(ShapeExpr({num_elem}), dtype); + return TensorType(ShapeExpr({num_elem}), dtype); } TVM_REGISTER_OP("relax.arange") @@ -380,7 +380,7 @@ TVM_REGISTER_OP("relax.arange") .add_argument("start", "PrimValue", "The starting value for the set of points.") .add_argument("end", "PrimValue", "The ending value for the set of points.") .add_argument("step", "PrimValue", "The gap between each pair of adjacent points.") - .set_attr("FInferStructInfo", InferStructInfoArange) + .set_attr("FInferType", InferTypeArange) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -399,7 +399,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.hamming_window", hamming_window); } -StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ctx) { +Type InferTypeHammingWindow(const Call& call, const BlockBuilder& ctx) { DataType dtype = call->attrs.as()->dtype; if (dtype.is_int() || dtype.is_uint() || dtype.is_uint()) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -421,7 +421,7 @@ StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ct << window_size; } window_size = analyzer->Simplify(window_size); - return TensorStructInfo(ShapeExpr({window_size}), dtype); + return TensorType(ShapeExpr({window_size}), dtype); } TVM_REGISTER_OP("relax.hamming_window") @@ -433,7 +433,7 @@ TVM_REGISTER_OP("relax.hamming_window") "symmetric window") .add_argument("alpha", "PrimValue", "The coefficient alpha") .add_argument("beta", "PrimValue", "The coefficient beta") - .set_attr("FInferStructInfo", InferStructInfoHammingWindow) + .set_attr("FInferType", InferTypeHammingWindow) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -460,30 +460,30 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("relax.op.triu", static_cast(triu)); } -StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { - auto [data_sinfo, offset] = GetArgStructInfo(call, ctx); +Type InferTypeTrilTriu(const Call& call, const BlockBuilder& ctx) { + auto [data_ty, offset] = GetArgType(call, ctx); - if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim < 2) { + if (!data_ty->IsUnknownNdim() && data_ty->ndim < 2) { TVM_FFI_VISIT_THROW(ValueError, call) << call->op << " requires the input tensor to have at least two " "dimensions. However, the given input has " - << data_sinfo->ndim << " dimension(s)."; + << data_ty->ndim << " dimension(s)."; } - return data_sinfo; + return data_ty; } TVM_REGISTER_OP("relax.tril") .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("k", "PrimValue", "The offset of the diagonal.") - .set_attr("FInferStructInfo", InferStructInfoTrilTriu) + .set_attr("FInferType", InferTypeTrilTriu) .set_attr("FPurity", true); TVM_REGISTER_OP("relax.triu") .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("k", "PrimValue", "The offset of the diagonal.") - .set_attr("FInferStructInfo", InferStructInfoTrilTriu) + .set_attr("FInferType", InferTypeTrilTriu) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 50624355c8fe..907dffb0b3f3 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -51,20 +51,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.astype", astype); } -StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeAstype(const Call& call, const BlockBuilder& ctx) { + TensorType ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - ffi::ObjectPtr new_sinfo = - ffi::make_object(*sinfo.get()); - new_sinfo->dtype = attrs->dtype; - return TensorStructInfo(new_sinfo); + ffi::ObjectPtr new_ty = ffi::make_object(*ty.get()); + new_ty->dtype = attrs->dtype; + return TensorType(new_ty); } TVM_REGISTER_OP("relax.astype") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor") - .set_attr("FInferStructInfo", InferStructInfoAstype) + .set_attr("FInferType", InferTypeAstype) .set_attr("FRelaxInferLayout", InferLayoutUnaryEwise) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -84,20 +83,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.wrap_param", MakeWrapParam); } -StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeWrapParam(const Call& call, const BlockBuilder& ctx) { + TensorType ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - ffi::ObjectPtr new_sinfo = - ffi::make_object(*sinfo.get()); - new_sinfo->dtype = attrs->dtype; - return TensorStructInfo(new_sinfo); + ffi::ObjectPtr new_ty = ffi::make_object(*ty.get()); + new_ty->dtype = attrs->dtype; + return TensorType(new_ty); } TVM_REGISTER_OP("relax.wrap_param") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") - .set_attr("FInferStructInfo", InferStructInfoWrapParam) + .set_attr("FInferType", InferTypeWrapParam) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index 35504360dfb1..ba788fb5860e 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -43,14 +43,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.grad.no_grad", no_grad); } -StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { - return GetStructInfo(call->args[0]); -} +Type InferTypeNoGrad(const Call& call, const BlockBuilder& ctx) { return GetType(call->args[0]); } TVM_REGISTER_OP("relax.grad.no_grad") .set_num_inputs(1) .add_argument("x", "Expr", "The corresponding input tensor.") - .set_attr("FInferStructInfo", InferStructInfoNoGrad) + .set_attr("FInferType", InferTypeNoGrad) .set_attr("FPurity", true); /* relax.grad.start_checkpoint */ @@ -64,18 +62,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.grad.start_checkpoint", start_checkpoint); } -StructInfo InferStructInfoStartCheckpoint(const Call& call, const BlockBuilder& ctx) { +Type InferTypeStartCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { TVM_FFI_VISIT_THROW(TypeError, call) << "The argument of relax.op.grad.start_checkpoint should be a Var."; } - return GetStructInfo(call->args[0]); + return GetType(call->args[0]); } TVM_REGISTER_OP("relax.grad.start_checkpoint") .set_num_inputs(1) .add_argument("x", "Expr", "The tensor marking the input of the checkpoint stage.") - .set_attr("FInferStructInfo", InferStructInfoStartCheckpoint) + .set_attr("FInferType", InferTypeStartCheckpoint) .set_attr("FPurity", true); /* relax.grad.end_checkpoint */ @@ -89,18 +87,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.grad.end_checkpoint", end_checkpoint); } -StructInfo InferStructInfoEndCheckpoint(const Call& call, const BlockBuilder& ctx) { +Type InferTypeEndCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { TVM_FFI_VISIT_THROW(TypeError, call) << "The argument of relax.op.grad.end_checkpoint should be a Var."; } - return GetStructInfo(call->args[0]); + return GetType(call->args[0]); } TVM_REGISTER_OP("relax.grad.end_checkpoint") .set_num_inputs(1) .add_argument("x", "Expr", "The output of the checkpoint stage.") - .set_attr("FInferStructInfo", InferStructInfoEndCheckpoint) + .set_attr("FInferType", InferTypeEndCheckpoint) .set_attr("FPurity", true); /* relax.grad.nll_loss_backward */ @@ -127,8 +125,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.grad.nll_loss_backward", nll_loss_backward); } -StructInfo InferStructInfoNLLLossBackward(const Call& call, const BlockBuilder& ctx) { - return GetStructInfo(call->args[1]); +Type InferTypeNLLLossBackward(const Call& call, const BlockBuilder& ctx) { + return GetType(call->args[1]); } TVM_REGISTER_OP("relax.grad.nll_loss_backward") @@ -138,7 +136,7 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward") .add_argument("predictions", "Tensor", "The prediction tensor.") .add_argument("targets", "Tensor", "The target tensor.") .add_argument("weights", "ffi::Optional", "The weight of each target values.") - .set_attr("FInferStructInfo", InferStructInfoNLLLossBackward) + .set_attr("FInferType", InferTypeNLLLossBackward) .set_attr("FPurity", true); /* relax.grad.max_pool2d_backward */ @@ -164,8 +162,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.grad.max_pool2d_backward", max_pool2d_backward); } -StructInfo InferStructInfoMaxPool2DBackward(const Call& call, const BlockBuilder& ctx) { - return GetStructInfo(call->args[1]); +Type InferTypeMaxPool2DBackward(const Call& call, const BlockBuilder& ctx) { + return GetType(call->args[1]); } TVM_REGISTER_OP("relax.grad.max_pool2d_backward") @@ -173,7 +171,7 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward") .add_argument("output_grad", "Tensor", "The output gradient.") .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoMaxPool2DBackward) + .set_attr("FInferType", InferTypeMaxPool2DBackward) .set_attr("FPurity", true); /* relax.grad.avg_pool2d_backward */ @@ -199,8 +197,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.grad.avg_pool2d_backward", avg_pool2d_backward); } -StructInfo InferStructInfoAvgPool2DBackward(const Call& call, const BlockBuilder& ctx) { - return GetStructInfo(call->args[1]); +Type InferTypeAvgPool2DBackward(const Call& call, const BlockBuilder& ctx) { + return GetType(call->args[1]); } TVM_REGISTER_OP("relax.grad.avg_pool2d_backward") @@ -208,7 +206,7 @@ TVM_REGISTER_OP("relax.grad.avg_pool2d_backward") .add_argument("output_grad", "Tensor", "The output gradient.") .add_argument("data", "Tensor", "The input tensor") .set_attrs_type() - .set_attr("FInferStructInfo", InferStructInfoAvgPool2DBackward) + .set_attr("FInferType", InferTypeAvgPool2DBackward) .set_attr("FPurity", true); /* relax.grad.take_backward */ @@ -226,8 +224,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.grad.take_backward", take_backward); } -StructInfo InferStructInfoTakeBackward(const Call& call, const BlockBuilder& ctx) { - return GetStructInfo(call->args[1]); +Type InferTypeTakeBackward(const Call& call, const BlockBuilder& ctx) { + return GetType(call->args[1]); } TVM_REGISTER_OP("relax.grad.take_backward") @@ -236,7 +234,7 @@ TVM_REGISTER_OP("relax.grad.take_backward") .add_argument("output_grad", "Tensor", "The output gradient.") .add_argument("x", "Tensor", "The source tensor.") .add_argument("indices", "Tensor", "The indices of the values to extract.") - .set_attr("FInferStructInfo", InferStructInfoTakeBackward) + .set_attr("FInferType", InferTypeTakeBackward) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index c09dc1050107..ad9648c0d9fd 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -60,72 +60,70 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.take", take); } -StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { +Type InferTypeTake(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); - TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + TensorType data_ty = GetInputTensorType(call, 0, ctx); - // StructInfo inference when the index is a PrimValue is equivalent + // Type inference when the index is a PrimValue is equivalent // to that of a scalar (0-d) tensor. - TensorStructInfo indices_sinfo = [&]() { + TensorType indices_ty = [&]() { auto arg = call->args[1]; - auto sinfo = GetStructInfo(arg); - if (auto tensor_sinfo = sinfo.as()) { - return tensor_sinfo.value(); - } else if (auto prim_sinfo = sinfo.as()) { - return TensorStructInfo(ShapeExpr(ffi::Array{}), prim_sinfo->dtype); + auto ty = GetType(arg); + if (auto tensor_ty = ty.as()) { + return tensor_ty.value(); + } else if (auto prim_ty = ty.as()) { + return TensorType(ShapeExpr(ffi::Array{}), prim_ty->dtype); } else { TVM_FFI_VISIT_THROW(TypeError, call) << "Operator " << call->op << " requires the indices argument to be " << "either a tensor or a scalar value. " - << "However, argument " << arg << " has struct info " << sinfo; + << "However, argument " << arg << " has type " << ty; // Unreachable, but [[noreturn]] attribute on virtual function // `ReportFatal` is insufficient to silence -Wreturn-type, as // child class might not be [[noreturn]]. - return TensorStructInfo(); + return TensorType(); } }(); - if (indices_sinfo->IsUnknownDtype()) { + if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { TVM_FFI_VISIT_THROW(TypeError, call) << "Take op requires the input indices to have integer dtype. However, the " "given indices dtype is " - << indices_sinfo->dtype; + << indices_ty->dtype; } const auto* attrs = call->attrs.as(); - if (!attrs->axis.has_value() && data_sinfo->ndim != 1) { + if (!attrs->axis.has_value() && data_ty->ndim != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "Take op expects the input data to be 1-dimensional tensor when the axis " "is not specified. However, the given data tensor has ndim " - << data_sinfo->ndim; + << data_ty->ndim; } - if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim() || indices_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } int axis = 0; if (attrs->axis.has_value()) { - axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()); + axis = NormalizeAxis(call, ctx, data_ty->ndim, attrs->axis.value()); } - const auto* data_shape = data_sinfo->shape.as(); - const auto* indices_shape = indices_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); + const auto* indices_shape = indices_ty->shape.as(); if (data_shape == nullptr || indices_shape == nullptr) { - return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1, - data_sinfo->vdevice); + return TensorType(data_ty->dtype, indices_ty->ndim + data_ty->ndim - 1, data_ty->vdevice); } ffi::Array output_shape; - for (int i = 0; i < data_sinfo->ndim; i++) { + for (int i = 0; i < data_ty->ndim; i++) { if (i == axis) { - for (int j = 0; j < indices_sinfo->ndim; j++) - output_shape.push_back(indices_shape->values[j]); + for (int j = 0; j < indices_ty->ndim; j++) output_shape.push_back(indices_shape->values[j]); } else { output_shape.push_back(data_shape->values[i]); } } - return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(output_shape), data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.take") @@ -133,7 +131,7 @@ TVM_REGISTER_OP("relax.take") .set_num_inputs(2) .add_argument("x", "Tensor", "The source tensor.") .add_argument("indices", "Tensor", "The indices of the values to extract.") - .set_attr("FInferStructInfo", InferStructInfoTake) + .set_attr("FInferType", InferTypeTake) .set_attr("FPurity", true); /* relax.strided_slice */ @@ -141,8 +139,8 @@ TVM_REGISTER_OP("relax.take") Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, ffi::Optional strides, bool assume_inbound) { // Initial validation of the arguments. A more complete validation - // will be done when inferring the StructInfo, but that requires the - // StructInfo of all arguments to be populated. + // will be done when inferring the Type, but that requires the + // Type of all arguments to be populated. std::optional> known_length; auto check_tuple = [&known_length](const char* name, Expr expr) { @@ -189,35 +187,35 @@ TVM_FFI_STATIC_INIT_BLOCK() { * * A `relax::Tuple` may be provided to an operator as an in-line * expression, as a variable bound to known tuple within the current - * function, as a function argument, etc. The StructInfo of the tuple + * function, as a function argument, etc. The Type of the tuple * tracks the known values of any `PrimValue` elements, but it can be * tedious to extract. This utility extracts the `PrimExpr` contents * of a `relax::Tuple`. * - * If the StructInfo cannot contain a tuple of the type specified, + * If the Type cannot contain a tuple of the type specified, * this function will throw an exception. (e.g. Attempting to extract - * a tuple from a `TensorStructInfo`.) + * a tuple from a `TensorType`.) * * \tparam PrimType The subtype of PrimExpr to extract. For example, * extracting an `ffi::Array` * - * \param sinfo The StructInfo to inspect + * \param ty The Type to inspect * * \returns An array of the `PrimType`, if it can be extracted. * Otherwise, `std::nullopt`. */ template >> -ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional sinfo) { - if (!sinfo) return std::nullopt; +ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional ty) { + if (!ty) return std::nullopt; - // An ObjectStructInfo may contain a tuple of the desired type, but + // An ObjectType may contain a tuple of the desired type, but // it isn't yet known whether it does. Return early, as we cannot // provide a known `ffi::Array` to the caller. - if (sinfo.as()) return std::nullopt; + if (ty.as()) return std::nullopt; - auto tuple = sinfo.as(); - TVM_FFI_CHECK(tuple, TypeError) << "The struct info " << sinfo + auto tuple = ty.as(); + TVM_FFI_CHECK(tuple, TypeError) << "The type " << ty << " cannot contain a tuple whose elements are " << PrimType::ContainerType::_type_key; @@ -225,17 +223,16 @@ ffi::Optional> UnpackTupleOfPrimValue(ffi::Optionalfields.size(); i++) { auto field = tuple->fields[i]; - if (field.as()) return std::nullopt; + if (field.as()) return std::nullopt; - auto prim_sinfo = field.as(); - TVM_FFI_CHECK(prim_sinfo, TypeError) - << "The struct info " << sinfo << " cannot contain a tuple whose elements are " - << PrimType::ContainerType::_type_key << ", because element " << i << " has struct info " - << field; + auto prim_ty = field.as(); + TVM_FFI_CHECK(prim_ty, TypeError) + << "The type " << ty << " cannot contain a tuple whose elements are " + << PrimType::ContainerType::_type_key << ", because element " << i << " has type " << field; - if (!prim_sinfo->value.defined()) return std::nullopt; + if (!prim_ty->value.defined()) return std::nullopt; - ffi::Optional element = prim_sinfo->value.as(); + ffi::Optional element = prim_ty->value.as(); if (!element) return std::nullopt; output.push_back(element.value()); @@ -247,14 +244,14 @@ ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional` @@ -268,13 +265,13 @@ template >> ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional expr) { if (expr) { - return UnpackTupleOfPrimValue(GetStructInfo(expr.value())); + return UnpackTupleOfPrimValue(GetType(expr.value())); } else { return std::nullopt; } } -StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx) { +Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) { size_t n_args = call->args.size(); TVM_FFI_ICHECK(4 <= n_args && n_args <= 5) << "Operator " << call->op << " accepts either three arguments (data, axes, begin, end) " @@ -293,46 +290,45 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx } }(); - auto axes_sinfo = GetStructInfo(call->args[1]); - auto begin_sinfo = GetStructInfo(call->args[2]); - auto end_sinfo = GetStructInfo(call->args[3]); - auto strides_sinfo = [&]() -> ffi::Optional { + auto axes_ty = GetType(call->args[1]); + auto begin_ty = GetType(call->args[2]); + auto end_ty = GetType(call->args[3]); + auto strides_ty = [&]() -> ffi::Optional { if (n_args > 4) { - return GetStructInfo(call->args[4]); + return GetType(call->args[4]); } else { return std::nullopt; } }(); - TVM_FFI_ICHECK( - IsBaseOf(relax::TensorStructInfo(DataType::Void(), kUnknownNDim), GetStructInfo(data))) + TVM_FFI_ICHECK(IsBaseOf(relax::TensorType(DataType::Void(), kUnknownNDim), GetType(data))) << "Operator " << call->op << " requires the first argument to be a tensor. " - << "However, in expression " << call << ", the first argument " << data << " has struct info " - << GetStructInfo(data); + << "However, in expression " << call << ", the first argument " << data << " has type " + << GetType(data); // TODO(Lunderberg): Implement this check using `IsBaseOf`. Doing - // so will require a way to represent a `relax::TupleStructInfo` of - // unknown length, where each element has the same `StructInfo`. - auto is_base_of_tuple_of_int64 = [&](const StructInfo& sinfo) -> bool { - if (sinfo.as()) { + // so will require a way to represent a `relax::TupleType` of + // unknown length, where each element has the same `Type`. + auto is_base_of_tuple_of_int64 = [&](const Type& ty) -> bool { + if (ty.as()) { return true; } - const auto* tuple = sinfo.as(); + const auto* tuple = ty.as(); if (!tuple) return false; - return std::all_of(tuple->fields.begin(), tuple->fields.end(), [](const StructInfo& field) { - return IsBaseOf(relax::PrimStructInfo(DataType::Int(64)), field); + return std::all_of(tuple->fields.begin(), tuple->fields.end(), [](const Type& field) { + return IsBaseOf(relax::PrimType(DataType::Int(64)), field); }); }; auto check_tuple = [&](const char* name, Expr expr) { - auto sinfo = GetStructInfo(expr); + auto ty = GetType(expr); - TVM_FFI_ICHECK(is_base_of_tuple_of_int64(sinfo)) + TVM_FFI_ICHECK(is_base_of_tuple_of_int64(ty)) << "Operator " << call->op << " requires the " << name << " argument to be a tuple of int64 PrimValues. " << "However, in expression " << call << ", the " << name << " argument " << expr - << " has struct info " << sinfo; + << " has type " << ty; }; check_tuple("axes", call->args[1]); check_tuple("begin", call->args[2]); @@ -341,20 +337,20 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx check_tuple("strides", call->args[4]); } - const auto* data_sinfo = data->struct_info_.as(); + const auto* data_ty = data->ty.as(); DataType dtype = DataType::Void(); ffi::Optional vdevice = std::nullopt; int ndim = kUnknownNDim; - if (data_sinfo) { - dtype = data_sinfo->dtype; - vdevice = data_sinfo->vdevice; - ndim = data_sinfo->ndim; + if (data_ty) { + dtype = data_ty->dtype; + vdevice = data_ty->vdevice; + ndim = data_ty->ndim; } ffi::Optional shape = [&]() -> ffi::Optional { - if (!data_sinfo) return std::nullopt; - if (!data_sinfo->shape) return std::nullopt; + if (!data_ty) return std::nullopt; + if (!data_ty->shape) return std::nullopt; auto opt_axes_tuple = UnpackTupleOfPrimValue(axes); if (!opt_axes_tuple) return std::nullopt; @@ -397,10 +393,10 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx << "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple << ") and " << strides_tuple.size() << " strides specified (" << strides_tuple << ")"; - auto opt_data_shape = data_sinfo->GetShape(); + auto opt_data_shape = data_ty->GetShape(); if (axes_tuple.empty() && !opt_data_shape.defined()) { - return data_sinfo->shape.value(); + return data_ty->shape.value(); } else if (!opt_data_shape.defined()) { return std::nullopt; } @@ -408,10 +404,10 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx ffi::Array axes_tuple_i64; axes_tuple_i64.reserve(axes_tuple.size()); for (const IntImm& v : axes_tuple) axes_tuple_i64.push_back(v->value); - std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, axes_tuple_i64); + std::vector axes = NormalizeAxes(call, ctx, data_ty->ndim, axes_tuple_i64); auto attrs = call->attrs.as(); - ffi::Array output_shape = data_sinfo->GetShape().value(); + ffi::Array output_shape = data_ty->GetShape().value(); for (size_t i = 0; i < axes.size(); i++) { size_t axis = axes[i]; PrimExpr input_dim = output_shape[axis]; @@ -435,9 +431,9 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx }(); if (shape.defined()) { - return TensorStructInfo(shape.value(), dtype, vdevice); + return TensorType(shape.value(), dtype, vdevice); } else { - return TensorStructInfo(dtype, ndim, vdevice); + return TensorType(dtype, ndim, vdevice); } } @@ -449,19 +445,19 @@ InferLayoutOutput InferLayoutStridedSlice( const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Layout inference only supports known dimensionality, " << "but expression " << call << " has argument " << call->args[0] << " of unknown dimensionality."; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); // Can't handle sub indexed layouts. if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { - existing_layout = LayoutDecision(InitialLayout(tensor_sinfo->ndim)); + existing_layout = LayoutDecision(InitialLayout(tensor_ty->ndim)); } - auto opt_axes_tuple = UnpackTupleOfPrimValue(GetStructInfo(call->args[1])); + auto opt_axes_tuple = UnpackTupleOfPrimValue(GetType(call->args[1])); TVM_FFI_ICHECK(opt_axes_tuple) << "Layout inference of " << call->op << " requires slices to be along static axes. " << "However, expression " << call @@ -482,7 +478,7 @@ TVM_REGISTER_OP("relax.strided_slice") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The source tensor to be sliced.") - .set_attr("FInferStructInfo", InferStructInfoStridedSlice) + .set_attr("FInferType", InferTypeStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutStridedSlice) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -501,32 +497,32 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.dynamic_strided_slice", dynamic_strided_slice); } -StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& ctx) { - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* begin_sinfo = GetStructInfoAs(call->args[1]); - const auto* end_sinfo = GetStructInfoAs(call->args[2]); - const auto* strides_sinfo = GetStructInfoAs(call->args[3]); +Type InferTypeDynStridedSlice(const Call& call, const BlockBuilder& ctx) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* begin_ty = GetTypeAs(call->args[1]); + const auto* end_ty = GetTypeAs(call->args[2]); + const auto* strides_ty = GetTypeAs(call->args[3]); - TVM_FFI_ICHECK(data_sinfo); - if (data_sinfo->IsUnknownNdim()) { + TVM_FFI_ICHECK(data_ty); + if (data_ty->IsUnknownNdim()) { LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes begin/end/strides " "tensors are well-formed. It could produce runtime error when this assumption " "turns out to be wrong."; - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } - if (data_sinfo->IsUnknownDtype()) { + if (data_ty->IsUnknownDtype()) { LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes to have a valid " "dtype. It could produce runtime error when this assumption " "turns out to be wrong."; } - int n_axis = data_sinfo->ndim; - auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name) { - TVM_FFI_ICHECK(sinfo) << "Dynamic strided slice requires the input " << name - << " to be have the struct info. Please try normalizing the inputs."; - TVM_FFI_ICHECK_EQ(sinfo->ndim, 1) + int n_axis = data_ty->ndim; + auto diag_def = [&](const TensorTypeNode* ty, ffi::String name) { + TVM_FFI_ICHECK(ty) << "Dynamic strided slice requires the input " << name + << " to be have the type. Please try normalizing the inputs."; + TVM_FFI_ICHECK_EQ(ty->ndim, 1) << "Dynamic strided slice requires " << name << " to be 1d tensor (list of values)."; - const auto* shape = sinfo->shape.as(); + const auto* shape = ty->shape.as(); TVM_FFI_ICHECK(shape) << "Dynamic strided slice requires the input " << name << " to have well-defined shape."; // NOTE(tvm-team): This strong restriction seems necessary for now until we have a generic @@ -537,23 +533,23 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& TVM_FFI_ICHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the number of indices in " << name << " to equal the number of axes."; - if (sinfo->IsUnknownDtype()) { + if (ty->IsUnknownDtype()) { LOG(WARNING) << "Dynamic strided slice assumes " << name << " to be int64 when it is not specified."; } else { - TVM_FFI_ICHECK(sinfo->dtype == DataType::Int(64)) + TVM_FFI_ICHECK(ty->dtype == DataType::Int(64)) << "Dynamic strided_slice expects the input " << name - << "values to be all int64. However, " << name << " has dtype " << sinfo->dtype << "."; + << "values to be all int64. However, " << name << " has dtype " << ty->dtype << "."; } }; - diag_def(begin_sinfo, "begin"); - diag_def(end_sinfo, "end"); - diag_def(strides_sinfo, "strides"); + diag_def(begin_ty, "begin"); + diag_def(end_ty, "end"); + diag_def(strides_ty, "strides"); // The output shape will depend on the runtime value in begin/end/strides tensors. // TODO(tvm-team): Currently, it is unable to express partially-static shape. Revisit when // PrimValue lands. - return TensorStructInfo(data_sinfo->dtype, n_axis, data_sinfo->vdevice); + return TensorType(data_ty->dtype, n_axis, data_ty->vdevice); } InferLayoutOutput InferLayoutDynStridedSlice( @@ -561,13 +557,13 @@ InferLayoutOutput InferLayoutDynStridedSlice( const VarLayoutMap& var_layout_map) { TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Layout inference only supports known dimensionality, " << "but expression " << call << " has argument " << call->args[0] << " of unknown dimensionality."; - int ndim = tensor_sinfo->ndim; + int ndim = tensor_ty->ndim; // Since begin/end/strides are dynamic tensors, we cannot transform // them at compile time. Fall back to the initial layout. LayoutDecision initial = LayoutDecision(InitialLayout(ndim)); @@ -580,7 +576,7 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice") .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") .add_argument("end", "Tensor", "Indices indicating end of the slice.") .add_argument("strides", "Tensor", "The stride values.") - .set_attr("FInferStructInfo", InferStructInfoDynStridedSlice) + .set_attr("FInferType", InferTypeDynStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutDynStridedSlice) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true) diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 53ee3b18eafa..ebfbccf11ed2 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -36,54 +36,51 @@ namespace tvm { namespace relax { namespace inspect { -TensorStructInfo GetTensorArgInfo(const Call& call) { +TensorType GetTensorArgInfo(const Call& call) { TVM_FFI_CHECK_EQ(call->args.size(), 1, TypeError) << "Operator " << call->op << " expects one argument, " << "but received " << call->args.size() << " arguments: " << call->args; const auto& arg = call->args[0]; - auto sinfo = GetStructInfo(arg); + auto ty = GetType(arg); - auto tensor_sinfo = sinfo.as(); - TVM_FFI_CHECK(tensor_sinfo, TypeError) - << "Operator " << call->op << " expects a tensor argument, " - << "but argument " << arg << " has struct info " << sinfo; + auto tensor_ty = ty.as(); + TVM_FFI_CHECK(tensor_ty, TypeError) << "Operator " << call->op << " expects a tensor argument, " + << "but argument " << arg << " has type " << ty; - return tensor_sinfo.value(); + return tensor_ty.value(); } -std::tuple GetTensorArgInfoWithIndex(const Call& call) { +std::tuple GetTensorArgInfoWithIndex(const Call& call) { TVM_FFI_CHECK_EQ(call->args.size(), 2, TypeError) << "Operator " << call->op << " expects two arguments, " << "but received " << call->args.size() << " arguments: " << call->args; const auto& arg = call->args[0]; const auto& axis = call->args[1]; - auto tensor_sinfo = arg->struct_info_.as(); - TVM_FFI_CHECK(tensor_sinfo, TypeError) + auto tensor_ty = arg->ty.as(); + TVM_FFI_CHECK(tensor_ty, TypeError) << "Operator " << call->op << " expects arguments (tensor, axis), " - << "but the first argument " << arg << " in expression " << call << " has struct info " - << arg->struct_info_; + << "but the first argument " << arg << " in expression " << call << " has type " << arg->ty; - auto axis_sinfo = axis->struct_info_.as(); - TVM_FFI_CHECK(axis_sinfo, TypeError) + auto axis_ty = axis->ty.as(); + TVM_FFI_CHECK(axis_ty, TypeError) << "Operator " << call->op << " expects arguments (tensor, axis), " - << "but the second argument " << arg << " in expression " << call << " has struct info " - << axis->struct_info_; + << "but the second argument " << arg << " in expression " << call << " has type " << axis->ty; - auto int_imm_axis = axis_sinfo->value.as(); + auto int_imm_axis = axis_ty->value.as(); if (int_imm_axis) { TVM_FFI_ICHECK_GE(int_imm_axis->value, 0); } - if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) { - TVM_FFI_CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim, ValueError) + if (int_imm_axis && !tensor_ty->IsUnknownNdim()) { + TVM_FFI_CHECK_LT(int_imm_axis->value, tensor_ty->ndim, ValueError) << "Expression " << call << " attempts to access " << arg << ".shape[" << int_imm_axis->value << "]" - << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " elements"; + << ", but " << arg << ".shape only has " << tensor_ty->ndim << " elements"; } - return {ffi::GetRef(tensor_sinfo), ffi::GetRef(axis_sinfo)}; + return {ffi::GetRef(tensor_ty), ffi::GetRef(axis_ty)}; } DataType GetTensorDataType(const Call& call) { return GetTensorArgInfo(call)->dtype; } @@ -100,19 +97,19 @@ tirx::PrimFunc GetDLTensorField(tirx::builtin::TVMStructFieldKind field, DataTyp DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host_func", true}}); - tirx::PrimFunc func(ffi::Array{dlpack_handle}, body, PrimType(field_dtype), {}, attrs); + tirx::PrimFunc func(ffi::Array{dlpack_handle}, body, tvm::PrimType(field_dtype), {}, + attrs); - FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)}, - PrimStructInfo(field_dtype)); - func->struct_info_ = sinfo; + FuncType ty({TensorType(DataType::Void(), kUnknownNDim)}, PrimType(field_dtype)); + func->ty = ty; return func; } Expr NormalizeToKnownPrimValue(const BlockBuilder&, Call call) { - if (auto prim_sinfo = call->struct_info_.as()) { - if (prim_sinfo->value.defined()) { - return PrimValue(prim_sinfo->value.value()); + if (auto prim_ty = call->ty.as()) { + if (prim_ty->value.defined()) { + return PrimValue(prim_ty->value.value()); } } return call; @@ -125,19 +122,19 @@ Expr tensor_dtype_code(Expr expr) { return Call(op, {expr}); } -StructInfo InferStructInfoTensorDtypeCode(const Call& call, const BlockBuilder&) { +Type InferTypeTensorDtypeCode(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::UInt(8); DataType dtype = GetTensorDataType(call); if (dtype.is_void()) { - return PrimStructInfo(dlpack_type); + return PrimType(dlpack_type); } else { - return PrimStructInfo(IntImm(dlpack_type, dtype.code())); + return PrimType(IntImm(dlpack_type, dtype.code())); } } Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { - auto field_dtype = Downcast(call->struct_info_)->dtype; + auto field_dtype = Downcast(call->ty)->dtype; Expr arg = call->args[0]; tirx::PrimFunc getter = @@ -150,7 +147,7 @@ Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_code") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The tensor to be inspected") - .set_attr("FInferStructInfo", InferStructInfoTensorDtypeCode) + .set_attr("FInferType", InferTypeTensorDtypeCode) .set_attr("FLegalize", LegalizeTensorDtypeCode) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -163,19 +160,19 @@ Expr tensor_dtype_bits(Expr expr) { return Call(op, {expr}); } -StructInfo InferStructInfoTensorDtypeBits(const Call& call, const BlockBuilder&) { +Type InferTypeTensorDtypeBits(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::UInt(8); DataType dtype = GetTensorDataType(call); if (dtype.is_void()) { - return PrimStructInfo(dlpack_type); + return PrimType(dlpack_type); } else { - return PrimStructInfo(IntImm(dlpack_type, dtype.bits())); + return PrimType(IntImm(dlpack_type, dtype.bits())); } } Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { - auto field_dtype = Downcast(call->struct_info_)->dtype; + auto field_dtype = Downcast(call->ty)->dtype; Expr arg = call->args[0]; tirx::PrimFunc getter = @@ -188,7 +185,7 @@ Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_bits") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The tensor to be inspected") - .set_attr("FInferStructInfo", InferStructInfoTensorDtypeBits) + .set_attr("FInferType", InferTypeTensorDtypeBits) .set_attr("FLegalize", LegalizeTensorDtypeBits) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -201,19 +198,19 @@ Expr tensor_dtype_lanes(Expr expr) { return Call(op, {expr}); } -StructInfo InferStructInfoTensorDtypeLanes(const Call& call, const BlockBuilder&) { +Type InferTypeTensorDtypeLanes(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::UInt(16); DataType dtype = GetTensorDataType(call); if (dtype.is_void()) { - return PrimStructInfo(dlpack_type); + return PrimType(dlpack_type); } else { - return PrimStructInfo(IntImm(dlpack_type, dtype.lanes())); + return PrimType(IntImm(dlpack_type, dtype.lanes())); } } Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { - auto field_dtype = Downcast(call->struct_info_)->dtype; + auto field_dtype = Downcast(call->ty)->dtype; Expr arg = call->args[0]; tirx::PrimFunc getter = @@ -226,7 +223,7 @@ Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_dtype_lanes") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The tensor to be inspected") - .set_attr("FInferStructInfo", InferStructInfoTensorDtypeLanes) + .set_attr("FInferType", InferTypeTensorDtypeLanes) .set_attr("FLegalize", LegalizeTensorDtypeLanes) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -239,19 +236,19 @@ Expr tensor_ndim(Expr expr) { return Call(op, {expr}); } -StructInfo InferStructInfoTensorNDim(const Call& call, const BlockBuilder&) { +Type InferTypeTensorNDim(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::Int(32); - auto sinfo = GetTensorArgInfo(call); - if (sinfo->IsUnknownNdim()) { - return PrimStructInfo(dlpack_type); + auto ty = GetTensorArgInfo(call); + if (ty->IsUnknownNdim()) { + return PrimType(dlpack_type); } else { - return PrimStructInfo(IntImm(dlpack_type, sinfo->ndim)); + return PrimType(IntImm(dlpack_type, ty->ndim)); } } Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { - auto field_dtype = Downcast(call->struct_info_)->dtype; + auto field_dtype = Downcast(call->ty)->dtype; Expr arg = call->args[0]; tirx::PrimFunc getter = @@ -264,7 +261,7 @@ Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { TVM_REGISTER_OP("relax.inspect.tensor_ndim") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The tensor to be inspected") - .set_attr("FInferStructInfo", InferStructInfoTensorNDim) + .set_attr("FInferType", InferTypeTensorNDim) .set_attr("FLegalize", LegalizeTensorNDim) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -277,23 +274,23 @@ Expr tensor_shape_i(Expr expr) { return Call(op, {expr}); } -StructInfo InferStructInfoTensorShape(const Call& call, const BlockBuilder&) { +Type InferTypeTensorShape(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::Int(64); - auto [tensor_sinfo, axis_sinfo] = GetTensorArgInfoWithIndex(call); + auto [tensor_ty, axis_ty] = GetTensorArgInfoWithIndex(call); - auto tensor_shape = tensor_sinfo->GetShape(); - auto int_imm_axis = axis_sinfo->value.as(); + auto tensor_shape = tensor_ty->GetShape(); + auto int_imm_axis = axis_ty->value.as(); if (int_imm_axis && tensor_shape.defined()) { - return PrimStructInfo(tensor_shape.value()[int_imm_axis->value]); + return PrimType(tensor_shape.value()[int_imm_axis->value]); } else { - return PrimStructInfo(dlpack_type); + return PrimType(dlpack_type); } } Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { - auto field_dtype = Downcast(call->struct_info_)->dtype; + auto field_dtype = Downcast(call->ty)->dtype; tirx::PrimFunc getter = [&]() -> tirx::PrimFunc { tirx::Var dlpack_handle("dlpack_handle", DataType::Handle()); @@ -325,12 +322,11 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host_func", true}}); - tirx::PrimFunc func({dlpack_handle, axis}, body, PrimType(field_dtype), {}, attrs); + tirx::PrimFunc func({dlpack_handle, axis}, body, tvm::PrimType(field_dtype), {}, attrs); - FuncStructInfo sinfo( - {TensorStructInfo(DataType::Void(), kUnknownNDim), PrimStructInfo(axis->dtype)}, - PrimStructInfo(field_dtype)); - func->struct_info_ = sinfo; + FuncType ty({TensorType(DataType::Void(), kUnknownNDim), PrimType(axis->dtype)}, + PrimType(field_dtype)); + func->ty = ty; return func; }(); @@ -342,7 +338,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_shape_i") .set_num_inputs(2) .add_argument("tensor", "Tensor", "The tensor to be inspected") .add_argument("axis", "Prim(int64)", "The axis whose extent should be returned") - .set_attr("FInferStructInfo", InferStructInfoTensorShape) + .set_attr("FInferType", InferTypeTensorShape) .set_attr("FLegalize", LegalizeTensorShape) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) @@ -355,17 +351,17 @@ Expr tensor_stride_i(Expr expr) { return Call(op, {expr}); } -StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) { +Type InferTypeTensorStride(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::Int(64); - auto [tensor_sinfo, axis_sinfo] = GetTensorArgInfoWithIndex(call); + auto [tensor_ty, axis_ty] = GetTensorArgInfoWithIndex(call); - auto opt_tensor_shape = tensor_sinfo->GetShape(); - auto int_imm_axis = axis_sinfo->value.as(); + auto opt_tensor_shape = tensor_ty->GetShape(); + auto int_imm_axis = axis_ty->value.as(); if (int_imm_axis && opt_tensor_shape.defined()) { // As of 2024-03-14, Relax does not have an explicit - // representation for striding in `TensorStructInfo`. The + // representation for striding in `TensorType`. The // `FLegalize` function for most operators is implemented in terms // of `topi`, and is then converted from TE to `tirx::PrimFunc` // using `tvm::tirx::CreatePrimFunc`. The `te::Tensor` is @@ -381,9 +377,9 @@ StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) { for (size_t axis = int_imm_axis->value + 1; axis < tensor_shape.size(); axis++) { stride = stride * tensor_shape[axis]; } - return PrimStructInfo(stride); + return PrimType(stride); } else { - return PrimStructInfo(dlpack_type); + return PrimType(dlpack_type); } } @@ -391,7 +387,7 @@ TVM_REGISTER_OP("relax.inspect.tensor_stride_i") .set_num_inputs(2) .add_argument("tensor", "Tensor", "The tensor to be inspected") .add_argument("axis", "Prim(int64)", "The axis whose extent should be returned") - .set_attr("FInferStructInfo", InferStructInfoTensorStride) + .set_attr("FInferType", InferTypeTensorStride) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) .set_attr("FPurity", true); @@ -403,26 +399,26 @@ Expr tensor_byte_offset(Expr expr) { return Call(op, {expr}); } -StructInfo InferStructInfoTensorByteOffset(const Call& call, const BlockBuilder&) { +Type InferTypeTensorByteOffset(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::UInt(64); - auto tensor_sinfo = GetTensorArgInfo(call); + auto tensor_ty = GetTensorArgInfo(call); - auto opt_tensor_shape = tensor_sinfo->GetShape(); + auto opt_tensor_shape = tensor_ty->GetShape(); if (opt_tensor_shape.defined()) { // Relax implicitly requires that the byte offset is zero for any - // legalizable tensor. See InferStructInfoTensorStride for full + // legalizable tensor. See InferTypeTensorStride for full // explanation. - return PrimStructInfo(IntImm(dlpack_type, 0)); + return PrimType(IntImm(dlpack_type, 0)); } else { - return PrimStructInfo(dlpack_type); + return PrimType(dlpack_type); } } TVM_REGISTER_OP("relax.inspect.tensor_byte_offset") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The tensor to be inspected") - .set_attr("FInferStructInfo", InferStructInfoTensorByteOffset) + .set_attr("FInferType", InferTypeTensorByteOffset) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) .set_attr("FPurity", true); @@ -434,26 +430,26 @@ Expr tensor_elem_offset(Expr expr) { return Call(op, {expr}); } -StructInfo InferStructInfoTensorElemOffset(const Call& call, const BlockBuilder&) { +Type InferTypeTensorElemOffset(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::UInt(64); - auto tensor_sinfo = GetTensorArgInfo(call); + auto tensor_ty = GetTensorArgInfo(call); - auto opt_tensor_shape = tensor_sinfo->GetShape(); + auto opt_tensor_shape = tensor_ty->GetShape(); if (opt_tensor_shape.defined()) { // Relax implicitly requires that the element offset is zero for - // any legalizable tensor. See InferStructInfoTensorStride for + // any legalizable tensor. See InferTypeTensorStride for // full explanation. - return PrimStructInfo(IntImm(dlpack_type, 0)); + return PrimType(IntImm(dlpack_type, 0)); } else { - return PrimStructInfo(dlpack_type); + return PrimType(dlpack_type); } } TVM_REGISTER_OP("relax.inspect.tensor_elem_offset") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The tensor to be inspected") - .set_attr("FInferStructInfo", InferStructInfoTensorElemOffset) + .set_attr("FInferType", InferTypeTensorElemOffset) .set_attr("RequiresArgumentShapes", false) .set_attr("FNormalize", NormalizeToKnownPrimValue) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h index 2aa20a13813f..3f820ab58a83 100644 --- a/src/relax/op/tensor/inspect.h +++ b/src/relax/op/tensor/inspect.h @@ -33,20 +33,20 @@ namespace inspect { /* \brief Return the DLTensor::dtype::type_code field * * \param expr The relax expression to be inspected. Must have - * `TensorStructInfo`. + * `TensorType`. * * \returns The uint8_t value of the type_code, with - * `PrimStructInfo(DataType::UInt(8))` + * `PrimType(DataType::UInt(8))` */ Expr tensor_dtype_code(Expr expr); /* \brief Return the DLTensor::dtype::bits field * * \param expr The relax expression to be inspected. Must have - * `TensorStructInfo`. + * `TensorType`. * * \returns The uint8_t value of the number of bits, with - * `PrimStructInfo(DataType::UInt(8))`. For vectorized types, returns + * `PrimType(DataType::UInt(8))`. For vectorized types, returns * the bit width of the underlying scalar type (e.g. 32 for * "float32x4", not 128). */ @@ -55,33 +55,33 @@ Expr tensor_dtype_bits(Expr expr); /* \brief Return the DLTensor::dtype::lanes field * * \param expr The relax expression to be inspected. Must have - * `TensorStructInfo`. + * `TensorType`. * * \returns The uint16_t value of the number of lanes, with - * `PrimStructInfo(DataType::UInt(16))` + * `PrimType(DataType::UInt(16))` */ Expr tensor_dtype_lanes(Expr expr); /* \brief Return the DLTensor::ndim field * * \param expr The relax expression to be inspected. Must have - * `TensorStructInfo`. + * `TensorType`. * * \returns The int32_t value of the dimensionality, with - * `PrimStructInfo(DataType::Int(32))`. + * `PrimType(DataType::Int(32))`. */ Expr tensor_ndim(Expr expr); /* \brief Return the DLTensor::shape[i] field * * \param expr The relax expression to be inspected. Must have - * `TensorStructInfo`. + * `TensorType`. * * \param axis The axis to inspect. Must be within the range `0 <= * axis < tensor_ndim(expr)`, or else the results are undefined. * * \returns The int64_t extent of the specified tensor axis, with - * `PrimStructInfo(DataType::Int(64))`. + * `PrimType(DataType::Int(64))`. */ Expr tensor_shape_i(Expr expr, Expr axis); @@ -92,22 +92,22 @@ Expr tensor_shape_i(Expr expr, Expr axis); * returned stride is computed from the `DLTensor::shape`. * * \param expr The relax expression to be inspected. Must have - * `TensorStructInfo`. + * `TensorType`. * * \param axis The axis to inspect. Must be within the range `0 <= * axis < tensor_ndim(expr)`, or else the results are undefined. * * \returns The int64_t extent of the specified tensor axis, with - * `PrimStructInfo(DataType::Int(64))`. + * `PrimType(DataType::Int(64))`. */ Expr tensor_stride_i(Expr expr, Expr axis); /* \brief Return the DLTensor::byte_offset field * * \param expr The relax expression to be inspected. Must have - * `TensorStructInfo`. + * `TensorType`. * - * \returns The uint64_t byte offset, with `PrimStructInfo(DataType::UInt(64))`. + * \returns The uint64_t byte offset, with `PrimType(DataType::UInt(64))`. */ Expr tensor_byte_offset(Expr expr); @@ -118,9 +118,9 @@ Expr tensor_byte_offset(Expr expr); * `DLTensor::data_type` fields. * * \param expr The relax expression to be inspected. Must have - * `TensorStructInfo`. + * `TensorType`. * - * \returns The uint64_t element offset, with `PrimStructInfo(DataType::UInt(64))`. + * \returns The uint64_t element offset, with `PrimType(DataType::UInt(64))`. */ Expr tensor_elem_offset(Expr expr); diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index ba6d0aacd817..c6fc5d3778ec 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -55,48 +55,48 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.matmul", matmul); } -StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); +Type InferTypeMatmul(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); Expr lhs = call->args[0]; Expr rhs = call->args[1]; - TensorStructInfo x1_sinfo = input_sinfo[0]; - TensorStructInfo x2_sinfo = input_sinfo[1]; + TensorType x1_ty = input_ty[0]; + TensorType x2_ty = input_ty[1]; VDevice vdev = VDevice(); - if (x1_sinfo->vdevice.defined() && x2_sinfo->vdevice.defined()) { - if (x1_sinfo->vdevice.value() == x2_sinfo->vdevice.value()) { - vdev = x1_sinfo->vdevice.value(); + if (x1_ty->vdevice.defined() && x2_ty->vdevice.defined()) { + if (x1_ty->vdevice.value() == x2_ty->vdevice.value()) { + vdev = x1_ty->vdevice.value(); } - } else if (x1_sinfo->vdevice.defined()) { - vdev = x1_sinfo->vdevice.value(); - } else if (x2_sinfo->vdevice.defined()) { - vdev = x2_sinfo->vdevice.value(); + } else if (x1_ty->vdevice.defined()) { + vdev = x1_ty->vdevice.value(); + } else if (x2_ty->vdevice.defined()) { + vdev = x2_ty->vdevice.value(); } const auto* attrs = call->attrs.as(); DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo) + ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) : attrs->out_dtype; - if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + if (x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { if (vdev.defined()) { - return TensorStructInfo(out_dtype, kUnknownNDim, vdev); + return TensorType(out_dtype, kUnknownNDim, vdev); } - return TensorStructInfo(out_dtype, kUnknownNDim); + return TensorType(out_dtype, kUnknownNDim); } - int x1_ndim = x1_sinfo->ndim; - int x2_ndim = x2_sinfo->ndim; + int x1_ndim = x1_ty->ndim; + int x2_ndim = x2_ty->ndim; if (x1_ndim == 0) { TVM_FFI_VISIT_THROW(ValueError, call) << "Matmul operands must not be scalar. " - << "However, the expression " << call << " has a LHS of " << lhs << " with struct info " - << x1_sinfo << ", which is scalar (zero-dimensional) tensor."; + << "However, the expression " << call << " has a LHS of " << lhs << " with type " << x1_ty + << ", which is scalar (zero-dimensional) tensor."; } if (x2_ndim == 0) { TVM_FFI_VISIT_THROW(ValueError, call) << "Matmul operands must not be scalar. " - << "However, the expression " << call << " has a RHS of " << rhs << " with struct info " - << x2_sinfo << ", which is scalar (zero-dimensional) tensor."; + << "However, the expression " << call << " has a RHS of " << rhs << " with type " << x2_ty + << ", which is scalar (zero-dimensional) tensor."; } int x1_prepended = 0; @@ -111,13 +111,13 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { } int output_ndim = std::max(x1_ndim, x2_ndim) - x1_prepended - x2_appended; - const auto* x1_shape = x1_sinfo->shape.as(); - const auto* x2_shape = x2_sinfo->shape.as(); + const auto* x1_shape = x1_ty->shape.as(); + const auto* x2_shape = x2_ty->shape.as(); if (x1_shape == nullptr || x2_shape == nullptr) { if (vdev.defined()) { - return TensorStructInfo(out_dtype, output_ndim, vdev); + return TensorType(out_dtype, output_ndim, vdev); } - return TensorStructInfo(out_dtype, output_ndim); + return TensorType(out_dtype, output_ndim); } ffi::Array x1_shape_prefix{x1_shape->values.begin(), @@ -128,20 +128,20 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); if (!output_shape_prefix.defined()) { if (vdev.defined()) { - return TensorStructInfo(out_dtype, output_ndim, vdev); + return TensorType(out_dtype, output_ndim, vdev); } - return TensorStructInfo(out_dtype, output_ndim); + return TensorType(out_dtype, output_ndim); } arith::Analyzer analyzer = ctx->GetAnalyzer(); - PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1]; + PrimExpr x1_reduction_length = x1_shape->values[x1_ty->ndim - 1]; PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2]; if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) { TVM_FFI_VISIT_THROW(ValueError, call) << "Matmul requires the reduction length of the operands to be equal. " - << "However, the LHS " << lhs << " has shape " << x1_sinfo->shape << ", while the RHS " - << rhs << " has shape " << x2_sinfo->shape << ". The reduction dimensions of " - << x1_reduction_length << " and " << x2_reduction_length << " are not equal."; + << "However, the LHS " << lhs << " has shape " << x1_ty->shape << ", while the RHS " << rhs + << " has shape " << x2_ty->shape << ". The reduction dimensions of " << x1_reduction_length + << " and " << x2_reduction_length << " are not equal."; } ffi::Array output_shape = output_shape_prefix.value(); @@ -153,9 +153,9 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { } TVM_FFI_ICHECK_EQ(static_cast(output_shape.size()), output_ndim); if (vdev.defined()) { - return TensorStructInfo(ShapeExpr(output_shape), out_dtype, vdev); + return TensorType(ShapeExpr(output_shape), out_dtype, vdev); } - return TensorStructInfo(ShapeExpr(output_shape), out_dtype); + return TensorType(ShapeExpr(output_shape), out_dtype); } Call InferMixedPrecisionMatmul(const Call& call, const DataType& out_dtype) { @@ -166,7 +166,7 @@ TVM_REGISTER_OP("relax.matmul") .set_num_inputs(2) .add_argument("x1", "Tensor", "The first input tensor.") .add_argument("x2", "Tensor", "The second input tensor.") - .set_attr("FInferStructInfo", InferStructInfoMatmul) + .set_attr("FInferType", InferTypeMatmul) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionMatmul) .set_attr("FPurity", true); @@ -186,13 +186,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.einsum", einsum); } -StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { +Type InferTypeEinsum(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "Einsum op should take 1 argument"; } - ffi::Array operands_tensor_sinfo = - GetTensorStructInfoFromTuple(call, ctx, call->args[0]); - if (operands_tensor_sinfo.empty()) { + ffi::Array operands_tensor_ty = GetTensorTypeFromTuple(call, ctx, call->args[0]); + if (operands_tensor_ty.empty()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Einsum op expects at least one tensor in the input Tuple. However, the " "given input Tuple is empty."; @@ -202,14 +201,14 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { bool vdevice_unknown = false; VDevice vdev = VDevice(); - for (TensorStructInfo sinfo : operands_tensor_sinfo) { + for (TensorType ty : operands_tensor_ty) { if (!vdevice_unknown) { - if (sinfo->vdevice.defined()) { + if (ty->vdevice.defined()) { if (!vdev.defined()) { - vdev = sinfo->vdevice.value(); - } else if (sinfo->vdevice.value()->target.defined()) { + vdev = ty->vdevice.value(); + } else if (ty->vdevice.value()->target.defined()) { // mismatch - if (sinfo->vdevice.value() != vdev) { + if (ty->vdevice.value() != vdev) { vdevice_unknown = true; } } @@ -219,44 +218,44 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { ffi::String subscripts = attrs->subscripts; - DataType operand_dtype = operands_tensor_sinfo[0]->dtype; + DataType operand_dtype = operands_tensor_ty[0]->dtype; std::vector> input_shapes; - input_shapes.reserve(operands_tensor_sinfo.size()); + input_shapes.reserve(operands_tensor_ty.size()); - for (TensorStructInfo tensor_sinfo : operands_tensor_sinfo) { + for (TensorType tensor_ty : operands_tensor_ty) { // Check the input tuple consists of tensors with same dtype - if (tensor_sinfo->dtype != operand_dtype) { + if (tensor_ty->dtype != operand_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "Einsum expects all input tensors to have the same dtype. However, the " "input contains tensors with dtype " - << operand_dtype << " and " << tensor_sinfo->dtype; + << operand_dtype << " and " << tensor_ty->dtype; } // Get input shapes - const auto* shape_expr = tensor_sinfo->shape.as(); + const auto* shape_expr = tensor_ty->shape.as(); if (shape_expr != nullptr) { input_shapes.push_back(shape_expr->values); } else { if (!vdevice_unknown) { - return TensorStructInfo(operand_dtype, tensor_sinfo->ndim, vdev); + return TensorType(operand_dtype, tensor_ty->ndim, vdev); } - return TensorStructInfo(operand_dtype, tensor_sinfo->ndim); + return TensorType(operand_dtype, tensor_ty->ndim); } } // Calculate output shape using InferEinsumShape in topi ffi::Array oshape = topi::InferEinsumShape(subscripts, input_shapes); if (!vdevice_unknown) { - return TensorStructInfo(ShapeExpr(oshape), operand_dtype, vdev); + return TensorType(ShapeExpr(oshape), operand_dtype, vdev); } - return TensorStructInfo(ShapeExpr(oshape), operand_dtype); + return TensorType(ShapeExpr(oshape), operand_dtype); } TVM_REGISTER_OP("relax.einsum") .set_attrs_type() .set_num_inputs(1) .add_argument("operands", "Tensor", "The input tensors.") - .set_attr("FInferStructInfo", InferStructInfoEinsum) + .set_attr("FInferType", InferTypeEinsum) .set_attr("FPurity", true); /* relax.outer */ @@ -271,31 +270,31 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.outer", outer); } -StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { - auto input_sinfo = GetInputTensorStructInfo(call, ctx); - auto x1_sinfo = input_sinfo[0]; - auto x2_sinfo = input_sinfo[1]; +Type InferTypeOuter(const Call& call, const BlockBuilder& ctx) { + auto input_ty = GetInputTensorType(call, ctx); + auto x1_ty = input_ty[0]; + auto x2_ty = input_ty[1]; // Ensure both inputs are 1D tensors - if (x1_sinfo->ndim != 1 || x2_sinfo->ndim != 1) { + if (x1_ty->ndim != 1 || x2_ty->ndim != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "torch.outer requires both inputs to be 1D tensors."; } // Determine output shape - auto x1_shape = x1_sinfo->shape.as(); - auto x2_shape = x2_sinfo->shape.as(); + auto x1_shape = x1_ty->shape.as(); + auto x2_shape = x2_ty->shape.as(); if (!x1_shape || !x2_shape) { - return TensorStructInfo(x1_sinfo->dtype, 2); + return TensorType(x1_ty->dtype, 2); } ffi::Array output_shape = {x1_shape->values[0], x2_shape->values[0]}; - return TensorStructInfo(ShapeExpr(output_shape), x1_sinfo->dtype); + return TensorType(ShapeExpr(output_shape), x1_ty->dtype); } TVM_REGISTER_OP("relax.outer") .set_num_inputs(2) .add_argument("x1", "Tensor", "The first input tensor.") .add_argument("x2", "Tensor", "The second input tensor.") - .set_attr("FInferStructInfo", InferStructInfoOuter) + .set_attr("FInferType", InferTypeOuter) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 85936fae3fb2..1303ddbf2f33 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -72,43 +72,43 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.broadcast_to", broadcast_to); } -StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) { +Type InferTypeBroadcastTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "broadcast_to should take 2 arguments."; } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* tgt_shape_sinfo = GetStructInfoAs(call->args[1]); - if (data_sinfo == nullptr) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* tgt_shape_ty = GetTypeAs(call->args[1]); + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "broadcast_to requires the input data to be Tensor. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (tgt_shape_sinfo == nullptr) { + if (tgt_shape_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "broadcast_to requires the input new shape to be Shape. However, the given one is " - << call->args[1]->struct_info_->GetTypeKey(); + << call->args[1]->ty->GetTypeKey(); } - if (!data_sinfo->IsUnknownNdim() && !tgt_shape_sinfo->IsUnknownNdim() && - tgt_shape_sinfo->ndim < data_sinfo->ndim) { + if (!data_ty->IsUnknownNdim() && !tgt_shape_ty->IsUnknownNdim() && + tgt_shape_ty->ndim < data_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "broadcast_to expects the input shape to have the number of ndim at least " "as the input tensor's. However, the given tensor has ndim " - << data_sinfo->ndim << " while the target shape has ndim " << tgt_shape_sinfo->ndim; + << data_ty->ndim << " while the target shape has ndim " << tgt_shape_ty->ndim; } // Trust the input target shape when there is no possibility to do any compile-time check. - if (!data_sinfo->shape.defined()) { - return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype, data_sinfo->vdevice); + if (!data_ty->shape.defined()) { + return TensorType(/*shape=*/call->args[1], data_ty->dtype, data_ty->vdevice); } - ShapeStructInfo shape_sinfo = Downcast(data_sinfo->shape.value()->struct_info_); - if (!shape_sinfo->values.defined() || !tgt_shape_sinfo->values.defined()) { - return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype, data_sinfo->vdevice); + ShapeType shape_ty = Downcast(data_ty->shape.value()->ty); + if (!shape_ty->values.defined() || !tgt_shape_ty->values.defined()) { + return TensorType(/*shape=*/call->args[1], data_ty->dtype, data_ty->vdevice); } arith::Analyzer analyzer = ctx->GetAnalyzer(); - ffi::Array old_shape_value = shape_sinfo->values.value(); - ffi::Array tgt_shape_value = tgt_shape_sinfo->values.value(); + ffi::Array old_shape_value = shape_ty->values.value(); + ffi::Array tgt_shape_value = tgt_shape_ty->values.value(); int old_ndim = old_shape_value.size(); int tgt_ndim = tgt_shape_value.size(); for (int i = 0; i < old_ndim; ++i) { @@ -127,14 +127,14 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) // Todo(relax-team): revisit here for better check on if the tensor length // is consistent with the length in the given shape. } - return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(/*shape=*/call->args[1], data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.broadcast_to") .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("shape", "Shape", "The target shape.") - .set_attr("FInferStructInfo", InferStructInfoBroadcastTo) + .set_attr("FInferType", InferTypeBroadcastTo) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -205,13 +205,12 @@ ffi::Optional> CheckConcatOutputShape( return output_shape; } -StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { +Type InferTypeConcat(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "Concat op should have 1 argument"; } - ffi::Array tensor_sinfo = - GetTensorStructInfoFromTuple(call, ctx, call->args[0]); - if (tensor_sinfo.empty()) { + ffi::Array tensor_ty = GetTensorTypeFromTuple(call, ctx, call->args[0]); + if (tensor_ty.empty()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Concat op expects at least one tensor in the input Tuple. However, the " "given input Tuple is empty."; @@ -225,41 +224,41 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { bool is_void_dtype = false; bool vdevice_unknown = false; std::vector> shape_values; - shape_values.reserve(tensor_sinfo.size()); + shape_values.reserve(tensor_ty.size()); - for (TensorStructInfo sinfo : tensor_sinfo) { + for (TensorType ty : tensor_ty) { // Update the output dtype. - if (sinfo->dtype.is_void()) { + if (ty->dtype.is_void()) { is_void_dtype = true; } else if (output_dtype.is_void()) { - output_dtype = sinfo->dtype; - } else if (sinfo->dtype != output_dtype) { + output_dtype = ty->dtype; + } else if (ty->dtype != output_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "Concat expects all input tensors to have the same dtype. However, the " "input contains tensors with dtype " - << output_dtype << " and " << sinfo->dtype; + << output_dtype << " and " << ty->dtype; } // Update the output ndim. // Todo(relax-team): revisit here for better check on if the input tensor has // ndim 1 when the input axis is undefined. if (output_ndim == kUnknownNDim) { - output_ndim = sinfo->ndim; - } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) { + output_ndim = ty->ndim; + } else if (ty->ndim != kUnknownNDim && ty->ndim != output_ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "Concat expects all input tensors to have same ndim. However, the " "input contains tensors with ndim " - << output_ndim << " and " << sinfo->ndim; + << output_ndim << " and " << ty->ndim; } // Update the virtual device. if (!vdevice_unknown) { - if (sinfo->vdevice.defined()) { + if (ty->vdevice.defined()) { if (!vdev.defined()) { - vdev = sinfo->vdevice.value(); - } else if (sinfo->vdevice.value()->target.defined()) { + vdev = ty->vdevice.value(); + } else if (ty->vdevice.value()->target.defined()) { // mismatch - if (sinfo->vdevice.value() != vdev) { + if (ty->vdevice.value() != vdev) { vdevice_unknown = true; } } @@ -267,20 +266,20 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { } // Update the shape values for best effort check. - const auto* shape_expr = sinfo->shape.as(); + const auto* shape_expr = ty->shape.as(); if (shape_expr != nullptr) { shape_values.push_back(shape_expr->values); continue; } shape_unknown = true; - if (!sinfo->shape.defined()) { + if (!ty->shape.defined()) { continue; } // Keep the shape value for equality check. - ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); - if (shape_sinfo->values.defined()) { - shape_values.push_back(shape_sinfo->values.value()); + ShapeType shape_ty = Downcast(ty->shape.value()->ty); + if (shape_ty->values.defined()) { + shape_values.push_back(shape_ty->values.value()); } } @@ -292,21 +291,20 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { } if (output_ndim == kUnknownNDim) { - return tensor_sinfo.size() == 1 ? tensor_sinfo[0] - : TensorStructInfo(output_dtype, output_ndim, vdev); + return tensor_ty.size() == 1 ? tensor_ty[0] : TensorType(output_dtype, output_ndim, vdev); } int axis = attrs->axis.has_value() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()) : 0; // If there is only one input tensor, no action is needed. - if (tensor_sinfo.size() == 1) { - return tensor_sinfo[0]; + if (tensor_ty.size() == 1) { + return tensor_ty[0]; } if (shape_values.empty()) { if (!vdevice_unknown) { - return TensorStructInfo(output_dtype, output_ndim, vdev); + return TensorType(output_dtype, output_ndim, vdev); } - return TensorStructInfo(output_dtype, output_ndim); + return TensorType(output_dtype, output_ndim); } // As long as the there is known shape value, we will do the best effort check to ensure safety. @@ -315,14 +313,14 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { if (shape_unknown || !output_shape.defined()) { if (!vdevice_unknown) { - return TensorStructInfo(output_dtype, output_ndim, vdev); + return TensorType(output_dtype, output_ndim, vdev); } - return TensorStructInfo(output_dtype, output_ndim); + return TensorType(output_dtype, output_ndim); } else { if (!vdevice_unknown) { - return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev); + return TensorType(ShapeExpr(output_shape.value()), output_dtype, vdev); } - return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + return TensorType(ShapeExpr(output_shape.value()), output_dtype); } } @@ -349,20 +347,19 @@ InferLayoutOutput InferLayoutConcat( TVM_FFI_ICHECK(n_layout.IsLeaf()); LayoutDecision in_layout = n_layout.LeafValue(); if (in_layout->layout.ndim() != in_layout->layout.ndim_primal()) { - const auto* tuple_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tuple_sinfo != nullptr) + const auto* tuple_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tuple_ty != nullptr) << " expects the input to be a Tuple of Tensors. However, the given input is " - << call->args[0]->struct_info_->GetTypeKey(); - for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { - StructInfo field_sinfo = tuple_sinfo->fields[i]; - const auto* field_tensor_sinfo = field_sinfo.as(); - TVM_FFI_ICHECK(field_tensor_sinfo != nullptr) + << call->args[0]->ty->GetTypeKey(); + for (size_t i = 0; i < tuple_ty->fields.size(); ++i) { + Type field_ty = tuple_ty->fields[i]; + const auto* field_tensor_ty = field_ty.as(); + TVM_FFI_ICHECK(field_tensor_ty != nullptr) << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " - << call->args[0]->struct_info_; - auto t_sinfo = ffi::GetRef(field_tensor_sinfo); - ffi::Optional t_shape = - ffi::GetRef(t_sinfo->shape.as()); + << call->args[0]->ty; + auto t_ty = ffi::GetRef(field_tensor_ty); + ffi::Optional t_shape = ffi::GetRef(t_ty->shape.as()); LayoutDecision curr_layout = nlayout_array[i].LeafValue(); if (!CanProveLayoutTransform(curr_layout->layout, in_layout->layout, t_shape.value()->values)) { @@ -396,7 +393,7 @@ TVM_REGISTER_OP("relax.concat") .set_attrs_type() .set_num_inputs(1) .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") - .set_attr("FInferStructInfo", InferStructInfoConcat) + .set_attr("FInferType", InferTypeConcat) .set_attr("FRelaxInferLayout", InferLayoutConcat) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -416,24 +413,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.expand_dims", expand_dims); } -StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeExpandDims(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); if (attrs->axis.empty()) { - return data_sinfo; + return data_ty; } - if (data_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } int n_new_dim = attrs->axis.size(); - int output_ndim = data_sinfo->ndim + n_new_dim; + int output_ndim = data_ty->ndim + n_new_dim; std::vector axes = NormalizeAxes(call, ctx, output_ndim, attrs->axis); - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { - return TensorStructInfo(data_sinfo->dtype, output_ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, output_ndim, data_ty->vdevice); } std::vector output_shape; @@ -447,12 +444,12 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) if (output_shape[i].defined()) { continue; } - TVM_FFI_ICHECK_LT(i_data_shape, data_sinfo->ndim); + TVM_FFI_ICHECK_LT(i_data_shape, data_ty->ndim); output_shape[i] = data_shape->values[i_data_shape]; ++i_data_shape; } - TVM_FFI_ICHECK_EQ(i_data_shape, data_sinfo->ndim); - return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); + TVM_FFI_ICHECK_EQ(i_data_shape, data_ty->ndim); + return TensorType(ShapeExpr(output_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutExpandDims( @@ -461,12 +458,12 @@ InferLayoutOutput InferLayoutExpandDims( TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); - int ndim = tensor_sinfo->ndim; + int ndim = tensor_ty->ndim; // Can't handle sub indexed layouts. if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { existing_layout = LayoutDecision(InitialLayout(ndim)); @@ -500,7 +497,7 @@ TVM_REGISTER_OP("relax.expand_dims") .set_num_inputs(1) .set_attrs_type() .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoExpandDims) + .set_attr("FInferType", InferTypeExpandDims) .set_attr("FRelaxInferLayout", InferLayoutExpandDims) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -525,29 +522,28 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.flatten", flatten); } -StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (data_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice); - } else if (data_sinfo->ndim == 0) { - return TensorStructInfo(ShapeExpr({1}), data_sinfo->dtype, data_sinfo->vdevice); - } else if (data_sinfo->ndim == 1) { - return data_sinfo; +Type InferTypeFlatten(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); + if (data_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, /*ndim=*/1, data_ty->vdevice); + } else if (data_ty->ndim == 0) { + return TensorType(ShapeExpr({1}), data_ty->dtype, data_ty->vdevice); + } else if (data_ty->ndim == 1) { + return data_ty; } - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice); + return TensorType(data_ty->dtype, /*ndim=*/1, data_ty->vdevice); } PrimExpr shape_prod = ComputeShapeProduct(data_shape->values); - return TensorStructInfo(ShapeExpr({std::move(shape_prod)}), data_sinfo->dtype, - data_sinfo->vdevice); + return TensorType(ShapeExpr({std::move(shape_prod)}), data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.flatten") .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoFlatten) + .set_attr("FInferType", InferTypeFlatten) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -563,27 +559,26 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.index_tensor", index_tensor); } -StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { +Type InferTypeIndexTensor(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Index.Tensor op should have 2 arguments"; } - TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - ffi::Array indices_sinfo = - GetTensorStructInfoFromTuple(call, ctx, call->args[1]); + TensorType data_ty = GetInputTensorType(call, 0, ctx); + ffi::Array indices_ty = GetTensorTypeFromTuple(call, ctx, call->args[1]); - if (indices_sinfo.empty()) { + if (indices_ty.empty()) { TVM_FFI_VISIT_THROW(ValueError, call) << "index_tensor expects a non‑empty tuple of index tensors"; } - DataType output_dtype = data_sinfo->dtype; - int n_indices = static_cast(indices_sinfo.size()); - ffi::Optional vdev = data_sinfo->vdevice; + DataType output_dtype = data_ty->dtype; + int n_indices = static_cast(indices_ty.size()); + ffi::Optional vdev = data_ty->vdevice; // Indices must be integers for (int i = 0; i < n_indices; ++i) { - const auto& s = indices_sinfo[i]; + const auto& s = indices_ty[i]; if (!s->IsUnknownDtype() && !s->dtype.is_int()) { TVM_FFI_VISIT_THROW(TypeError, call) << "index_tensor requires every index tensor to have an integer dtype; " @@ -592,10 +587,10 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } // Count of indices must be less than or equal to data.ndim - if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) { + if (!data_ty->IsUnknownNdim() && n_indices > data_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "index_tensor received " << n_indices << " index tensors, but data has only " - << data_sinfo->ndim << " dimensions"; + << data_ty->ndim << " dimensions"; } arith::Analyzer analyzer = ctx->GetAnalyzer(); @@ -603,7 +598,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) std::vector> index_shapes; int max_index_ndim = 0; - for (const auto& s : indices_sinfo) { + for (const auto& s : indices_ty) { const auto* shp = s->shape.as(); if (!shp) { all_index_have_shape_value = false; @@ -668,8 +663,8 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) // Count of dimensions in output int out_ndim = kUnknownNDim; - if (!data_sinfo->IsUnknownNdim()) { - int tail_ndim = data_sinfo->ndim - n_indices; + if (!data_ty->IsUnknownNdim()) { + int tail_ndim = data_ty->ndim - n_indices; if (broadcast_shape.defined()) { out_ndim = static_cast(broadcast_shape.value().size()) + tail_ndim; } else if (!shape_unknown) { @@ -679,25 +674,25 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) // Derive output shape if (broadcast_shape.defined()) { - const auto* data_shape_expr = data_sinfo->shape.as(); + const auto* data_shape_expr = data_ty->shape.as(); if (data_shape_expr) { ffi::Array result_shape = broadcast_shape.value(); - for (int i = n_indices; i < data_sinfo->ndim; ++i) { + for (int i = n_indices; i < data_ty->ndim; ++i) { result_shape.push_back(data_shape_expr->values[i]); } - return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev); + return TensorType(ShapeExpr(result_shape), output_dtype, vdev); } } // Unknown output shape - return TensorStructInfo(output_dtype, out_ndim, vdev); + return TensorType(output_dtype, out_ndim, vdev); } TVM_REGISTER_OP("relax.index_tensor") .set_num_inputs(2) .add_argument("data", "Tensor", "The input data.") .add_argument("indices", "List of Tensors", "The indices used to index.") - .set_attr("FInferStructInfo", InferStructInfoIndexTensor) + .set_attr("FInferType", InferTypeIndexTensor) .set_attr("FPurity", true); /* relax.layout_transform */ @@ -720,8 +715,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.layout_transform", layout_transform); } -StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeLayoutTransform(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); tirx::IndexMap index_map = attrs->index_map; ffi::Optional optional_pad_value = attrs->pad_value; @@ -729,48 +724,45 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& // Check pad_value has same dtype as input. if (optional_pad_value.defined()) { PrimExpr padded_value = optional_pad_value.value()->value; - if (padded_value->dtype != data_sinfo->dtype) { + if (padded_value->dtype != data_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "layout_transform pad_value dtype (" << padded_value->dtype << ") and input dtype (" - << data_sinfo->dtype << ") must be the same"; + << data_ty->dtype << ") must be the same"; } } - if (data_sinfo->IsUnknownNdim()) { + if (data_ty->IsUnknownNdim()) { // Todo(relax-team): revisit here for better check on if the input tensor has desired ndim. - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(), - data_sinfo->vdevice); + return TensorType(data_ty->dtype, /*ndim=*/index_map->final_indices.size(), data_ty->vdevice); } // If rank is known, check that it is compatible with the index_map, i.e., #dims match. - if (index_map->initial_indices.size() != static_cast(data_sinfo->ndim)) { + if (index_map->initial_indices.size() != static_cast(data_ty->ndim)) { TVM_FFI_VISIT_THROW(ValueError, call) << "number of dimensions in input must match the number of source dimensions " "in index map, but got " - << data_sinfo->ndim << " != " << index_map->initial_indices.size(); + << data_ty->ndim << " != " << index_map->initial_indices.size(); } - if (!data_sinfo->shape.defined()) { - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(), - data_sinfo->vdevice); + if (!data_ty->shape.defined()) { + return TensorType(data_ty->dtype, /*ndim=*/index_map->final_indices.size(), data_ty->vdevice); } - ShapeStructInfo shape_sinfo = Downcast(data_sinfo->shape.value()->struct_info_); - if (!shape_sinfo->values.defined()) { - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(), - data_sinfo->vdevice); + ShapeType shape_ty = Downcast(data_ty->shape.value()->ty); + if (!shape_ty->values.defined()) { + return TensorType(data_ty->dtype, /*ndim=*/index_map->final_indices.size(), data_ty->vdevice); } arith::Analyzer analyzer; - ffi::Array output_shape = index_map->MapShape(shape_sinfo->values.value(), analyzer); - return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); + ffi::Array output_shape = index_map->MapShape(shape_ty->values.value(), analyzer); + return TensorType(ShapeExpr(output_shape), data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.layout_transform") .set_num_inputs(1) .set_attrs_type() .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoLayoutTransform) + .set_attr("FInferType", InferTypeLayoutTransform) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -798,49 +790,49 @@ bool IsIdentityPermutation(const std::vector& permutation) { return true; } -StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypePermuteDims(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); // Todo(relax-team): revisit here for better check on if the input tensor has // ndim same as the number of input axes. - if (!attrs->axes.defined() && data_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (!attrs->axes.defined() && data_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } if (attrs->axes.defined()) { int n_axis = attrs->axes.value().size(); - if (!data_sinfo->IsUnknownNdim() && n_axis != data_sinfo->ndim) { + if (!data_ty->IsUnknownNdim() && n_axis != data_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "PermuteDims expects the number of input axes to equal the ndim of the " "input tensor. However, the tensor ndim is " - << data_sinfo->ndim << " while the given number of axes is " << n_axis; + << data_ty->ndim << " while the given number of axes is " << n_axis; } } std::vector axes; if (attrs->axes.defined()) { - axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes.value()); + axes = NormalizeAxes(call, ctx, data_ty->ndim, attrs->axes.value()); } else { // Construct the reverse permutation via std::iota - axes.resize(data_sinfo->ndim); + axes.resize(data_ty->ndim); std::iota(axes.rbegin(), axes.rend(), 0); } if (IsIdentityPermutation(axes)) { - return data_sinfo; + return data_ty; } - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice); } std::vector new_shape; - new_shape.reserve(data_sinfo->ndim); - for (int i = 0; i < data_sinfo->ndim; ++i) { + new_shape.reserve(data_ty->ndim); + for (int i = 0; i < data_ty->ndim; ++i) { new_shape.push_back(data_shape->values[axes[i]]); } - return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(new_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutPermuteDims( @@ -850,10 +842,10 @@ InferLayoutOutput InferLayoutPermuteDims( const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - int ndim = tensor_sinfo->ndim; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_ty->ndim; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); @@ -890,7 +882,7 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoPermuteDims) + .set_attr("FInferType", InferTypePermuteDims) .set_attr("FRelaxInferLayout", InferLayoutPermuteDims) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -951,23 +943,23 @@ Expr ConvertNewShapeToExpr(const Expr& data, } // Otherwise, we require the input tensor to have known shape value for inference. - const auto* data_sinfo = GetStructInfoAs(data); - TVM_FFI_ICHECK(data_sinfo != nullptr) + const auto* data_ty = GetTypeAs(data); + TVM_FFI_ICHECK(data_ty != nullptr) << "Reshape expects the input data to be a Tensor. However, the given input is " - << data->struct_info_->GetTypeKey(); - TVM_FFI_ICHECK(data_sinfo->shape.defined()) + << data->ty->GetTypeKey(); + TVM_FFI_ICHECK(data_ty->shape.defined()) << "Reshape expects the input tensor to have known shape when there is some dimension length " "to infer. However, the given input has no shape."; - const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); - TVM_FFI_ICHECK(shape_sinfo != nullptr && shape_sinfo->values.defined()) + const auto* shape_ty = GetTypeAs(data_ty->shape.value()); + TVM_FFI_ICHECK(shape_ty != nullptr && shape_ty->values.defined()) << "Reshape expects the input tensor to have known shape when there is some dimension length " "to infer. However, the given input shape is " - << data_sinfo->shape << " whose shape value is unknown."; + << data_ty->shape << " whose shape value is unknown."; // Set any 0 valued dimensions to match the corresponding input shape. if (!zero_dims.empty()) { for (int i : zero_dims) { - array_ref.Set(i, shape_sinfo->values.value()[i]); + array_ref.Set(i, shape_ty->values.value()[i]); } } @@ -987,7 +979,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, // Assign appropriate value to -1 dimension. if (dim_to_infer != -1) { arith::Analyzer analyzer; - PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); + PrimExpr old_shape_prod = ComputeShapeProduct(shape_ty->values.value()); array_ref.Set(dim_to_infer, analyzer->Simplify(floordiv(old_shape_prod, new_shape_prod))); } return ShapeExpr(array_ref); @@ -1004,55 +996,54 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.reshape", reshape); } -StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { +Type InferTypeReshape(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Reshape op should take 2 arguments"; } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); - if (data_sinfo == nullptr) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* new_shape_ty = GetTypeAs(call->args[1]); + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Reshape requires the input data to be Tensor. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (new_shape_sinfo == nullptr) { + if (new_shape_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "Reshape requires the input new shape to be Shape. However, the given one is " - << call->args[1]->struct_info_->GetTypeKey(); + << call->args[1]->ty->GetTypeKey(); } ffi::Optional> old_shape_values; - if (data_sinfo->shape.defined()) { - const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); - TVM_FFI_ICHECK_NOTNULL(old_shape_sinfo); - old_shape_values = old_shape_sinfo->values; + if (data_ty->shape.defined()) { + const auto* old_shape_ty = GetTypeAs(data_ty->shape.value()); + TVM_FFI_ICHECK_NOTNULL(old_shape_ty); + old_shape_values = old_shape_ty->values; } - if (new_shape_sinfo->values.defined() && old_shape_values.defined()) { - PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_sinfo->values.value()); + if (new_shape_ty->values.defined() && old_shape_values.defined()) { + PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_ty->values.value()); PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value()); if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) { TVM_FFI_VISIT_THROW(ValueError, call) << "Reshape expects the new shape to be convertible from the old shape. " "However, the old shape is " - << data_sinfo->shape << ", with product " << old_shape_prod << ", while the new shape is " + << data_ty->shape << ", with product " << old_shape_prod << ", while the new shape is " << call->args[1] << ", with product " << new_shape_prod; } } Expr target_shape = call->args[1]; // If shape values are defined, use them - if (target_shape->IsInstance() && new_shape_sinfo->values.defined()) { - return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype, - data_sinfo->vdevice); + if (target_shape->IsInstance() && new_shape_ty->values.defined()) { + return TensorType(ShapeExpr(new_shape_ty->values.value()), data_ty->dtype, data_ty->vdevice); } - return TensorStructInfo(target_shape, data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(target_shape, data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.reshape") .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("shape", "Shape", "The input new shape.") - .set_attr("FInferStructInfo", InferStructInfoReshape) + .set_attr("FInferType", InferTypeReshape) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -1094,31 +1085,29 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.split", split); } -StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeSplit(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - const auto* data_shape = data_sinfo->shape.as(); - int axis = - data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + const auto* data_shape = data_ty->shape.as(); + int axis = data_ty->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_ty->ndim, attrs->axis); if (auto opt_indices = attrs->indices_or_sections.as>()) { auto p_indices = opt_indices.value(); - // When there is not index, return the input tensor's struct info. + // When there is not index, return the input tensor's type. if (p_indices.size() == 0) { - return data_sinfo; + return data_ty; } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { - return TupleStructInfo(ffi::Array( - p_indices.size() + 1, - TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); + return TupleType(ffi::Array( + p_indices.size() + 1, TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice))); } TVM_FFI_ICHECK_NE(axis, -1); IntImm zero(DataType::Int(64), /*value=*/0); - std::vector output_sinfo; + std::vector output_ty; for (size_t i = 0; i < p_indices.size() + 1; i++) { PrimExpr left; if (i == 0) { @@ -1143,39 +1132,37 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { ffi::Array shape = data_shape->values; shape.Set(axis, split_dim); - output_sinfo.push_back( - TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); + output_ty.push_back(TensorType(ShapeExpr(shape), data_ty->dtype, data_ty->vdevice)); } - return TupleStructInfo(output_sinfo); + return TupleType(output_ty); } else if (const auto* p_n_section = attrs->indices_or_sections.as()) { TVM_FFI_ICHECK_GT(p_n_section->value, 0); int n_section = p_n_section->value; - // When the number of section is one, return the input tensor's struct info. + // When the number of section is one, return the input tensor's type. if (n_section == 1) { - return data_sinfo; + return data_ty; } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { - return TupleStructInfo(ffi::Array( - n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); + return TupleType( + ffi::Array(n_section, TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice))); } TVM_FFI_ICHECK_NE(axis, -1); PrimExpr split_len = ceildiv(data_shape->values[axis], n_section); split_len = ctx->GetAnalyzer()->Simplify(split_len); - // Construct struct info for tensors except the last one. + // Construct type for tensors except the last one. ffi::Array shape = data_shape->values; shape.Set(axis, split_len); - std::vector output_sinfo( - n_section - 1, TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); + std::vector output_ty(n_section - 1, + TensorType(ShapeExpr(shape), data_ty->dtype, data_ty->vdevice)); - // Construct struct info for the last tensor. + // Construct type for the last tensor. PrimExpr last_split_len = data_shape->values[axis] - split_len * (n_section - 1); last_split_len = ctx->GetAnalyzer()->Simplify(last_split_len); shape.Set(axis, last_split_len); - output_sinfo.push_back( - TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); - return TupleStructInfo(output_sinfo); + output_ty.push_back(TensorType(ShapeExpr(shape), data_ty->dtype, data_ty->vdevice)); + return TupleType(output_ty); } TVM_FFI_ICHECK(false) << "Cannot reach here."; throw; @@ -1188,13 +1175,13 @@ InferLayoutOutput InferLayoutSplit( const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support known ndim"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); - StructInfo out_sinfo = InferStructInfoSplit(call, BlockBuilder::Create(IRModule())); - const auto* out_tuple = out_sinfo.as(); + Type out_ty = InferTypeSplit(call, BlockBuilder::Create(IRModule())); + const auto* out_tuple = out_ty.as(); /* * Fallback if the outputs can't be represented in input sub indexed layout @@ -1202,18 +1189,16 @@ InferLayoutOutput InferLayoutSplit( */ if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { for (const auto& si : out_tuple->fields) { - TVM_FFI_ICHECK(si->IsInstance()) - << "Fields of TupleStructInfo must be TensorStructInfo" - "output structinfo, but got " - << si; - auto sinfo = Downcast(si); - ffi::Optional shape_expr = - ffi::GetRef(sinfo->shape.as()); + TVM_FFI_ICHECK(si->IsInstance()) << "Fields of TupleType must be TensorType" + "output structinfo, but got " + << si; + auto ty = Downcast(si); + ffi::Optional shape_expr = ffi::GetRef(ty->shape.as()); TVM_FFI_ICHECK(shape_expr.defined()); auto shape_arr = shape_expr.value(); - if (!CanProveLayoutTransform(InitialLayout(tensor_sinfo->ndim), existing_layout->layout, + if (!CanProveLayoutTransform(InitialLayout(tensor_ty->ndim), existing_layout->layout, shape_arr->values)) { - existing_layout = InitialLayout(tensor_sinfo->ndim); + existing_layout = InitialLayout(tensor_ty->ndim); break; } } @@ -1230,7 +1215,7 @@ TVM_REGISTER_OP("relax.split") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoSplit) + .set_attr("FInferType", InferTypeSplit) .set_attr("FRelaxInferLayout", InferLayoutSplit) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -1250,31 +1235,30 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.squeeze", squeeze); } -StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeSqueeze(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); if (attrs->axis.defined() && attrs->axis.value().empty()) { - return data_sinfo; + return data_ty; } - if (data_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } ffi::Optional> shape_value; - if (data_sinfo->shape.defined()) { - shape_value = Downcast(data_sinfo->shape.value()->struct_info_)->values; + if (data_ty->shape.defined()) { + shape_value = Downcast(data_ty->shape.value()->ty)->values; } std::vector axis_removal_mask; - axis_removal_mask.resize(data_sinfo->ndim, /*value=*/false); + axis_removal_mask.resize(data_ty->ndim, /*value=*/false); if (attrs->axis.defined()) { - std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + std::vector axes = NormalizeAxes(call, ctx, data_ty->ndim, attrs->axis.value()); if (!shape_value.defined()) { - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size(), - data_sinfo->vdevice); + return TensorType(data_ty->dtype, data_ty->ndim - axes.size(), data_ty->vdevice); } for (int i = 0; i < static_cast(axes.size()); ++i) { // Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1. @@ -1291,13 +1275,13 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { // (https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html). // Consider discourage usage later. if (!shape_value.defined()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } - for (int i = 0; i < data_sinfo->ndim; ++i) { + for (int i = 0; i < data_ty->ndim; ++i) { // Whenever a dimension length is symbolic, fall back to unknown ndim. const auto* int_len = shape_value.value()[i].as(); if (int_len == nullptr) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } if (int_len->value == 1) { axis_removal_mask[i] = true; @@ -1306,23 +1290,23 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { } std::vector output_shape; - output_shape.reserve(data_sinfo->ndim - axis_removal_mask.size()); - for (int i = 0; i < data_sinfo->ndim; ++i) { + output_shape.reserve(data_ty->ndim - axis_removal_mask.size()); + for (int i = 0; i < data_ty->ndim; ++i) { if (!axis_removal_mask[i]) { output_shape.push_back(shape_value.value()[i]); } } - if (data_sinfo->shape.value()->IsInstance()) { - if (static_cast(output_shape.size()) == data_sinfo->ndim) { - return data_sinfo; + if (data_ty->shape.value()->IsInstance()) { + if (static_cast(output_shape.size()) == data_ty->ndim) { + return data_ty; } else if (attrs->axis.defined()) { - return TensorStructInfo(data_sinfo->dtype, output_shape.size(), data_sinfo->vdevice); + return TensorType(data_ty->dtype, output_shape.size(), data_ty->vdevice); } else { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } } else { - return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(output_shape), data_ty->dtype, data_ty->vdevice); } } @@ -1333,12 +1317,12 @@ InferLayoutOutput InferLayoutSqueeze( const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Only support static shape for now"; - int ndim = tensor_sinfo->ndim; - const auto* shape = tensor_sinfo->shape.as(); + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_ty->shape.defined()) << "Only support static shape for now"; + int ndim = tensor_ty->ndim; + const auto* shape = tensor_ty->shape.as(); TVM_FFI_ICHECK(shape != nullptr) << "Only support static shape for now"; ffi::Array axis; @@ -1390,7 +1374,7 @@ TVM_REGISTER_OP("relax.squeeze") .set_num_inputs(1) .set_attrs_type() .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoSqueeze) + .set_attr("FInferType", InferTypeSqueeze) .set_attr("FRelaxInferLayout", InferLayoutSqueeze) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -1485,14 +1469,13 @@ ffi::Optional> CheckStackOutputShape( return output_shape; } -StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { +Type InferTypeStack(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "Stack op should have 1 argument"; } - ffi::Array tensor_sinfo = - GetTensorStructInfoFromTuple(call, ctx, call->args[0]); - if (tensor_sinfo.empty()) { + ffi::Array tensor_ty = GetTensorTypeFromTuple(call, ctx, call->args[0]); + if (tensor_ty.empty()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Stack op expects at least one tensor in the input Tuple. " << "However, the given input Tuple is empty."; @@ -1502,57 +1485,57 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(attrs != nullptr) << "Stack must have StackAttrs"; // Default axis is 0 if not specified - int output_ndim = tensor_sinfo[0]->ndim + 1; // Stack adds one dimension + int output_ndim = tensor_ty[0]->ndim + 1; // Stack adds one dimension DataType output_dtype = DataType::Void(); ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; bool vdevice_unknown = false; std::vector> shape_values; - shape_values.reserve(tensor_sinfo.size()); + shape_values.reserve(tensor_ty.size()); - for (TensorStructInfo sinfo : tensor_sinfo) { + for (TensorType ty : tensor_ty) { // Check dtype consistency - if (sinfo->dtype.is_void()) { + if (ty->dtype.is_void()) { is_void_dtype = true; } else if (output_dtype.is_void()) { - output_dtype = sinfo->dtype; - } else if (sinfo->dtype != output_dtype) { + output_dtype = ty->dtype; + } else if (ty->dtype != output_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "Stack expects all input tensors to have the same dtype. " - << "Found " << output_dtype << " and " << sinfo->dtype; + << "Found " << output_dtype << " and " << ty->dtype; } // Check ndim consistency - if (sinfo->ndim != kUnknownNDim && sinfo->ndim != tensor_sinfo[0]->ndim) { + if (ty->ndim != kUnknownNDim && ty->ndim != tensor_ty[0]->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "Stack expects all input tensors to have same ndim. " - << "Found " << tensor_sinfo[0]->ndim << " and " << sinfo->ndim; + << "Found " << tensor_ty[0]->ndim << " and " << ty->ndim; } // Check virtual device consistency if (!vdevice_unknown) { - if (sinfo->vdevice.defined()) { + if (ty->vdevice.defined()) { if (!vdev.defined()) { - vdev = sinfo->vdevice.value(); - } else if (sinfo->vdevice.value() != vdev) { + vdev = ty->vdevice.value(); + } else if (ty->vdevice.value() != vdev) { vdevice_unknown = true; } } } // Collect shape information - const auto* shape_expr = sinfo->shape.as(); + const auto* shape_expr = ty->shape.as(); if (shape_expr != nullptr) { shape_values.push_back(shape_expr->values); continue; } shape_unknown = true; - if (!sinfo->shape.defined()) continue; - ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); - if (shape_sinfo->values.defined()) { - shape_values.push_back(shape_sinfo->values.value()); + if (!ty->shape.defined()) continue; + ShapeType shape_ty = Downcast(ty->shape.value()->ty); + if (shape_ty->values.defined()) { + shape_values.push_back(shape_ty->values.value()); } } @@ -1565,12 +1548,12 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { : 0; // Single tensor case - if (tensor_sinfo.size() == 1) { + if (tensor_ty.size() == 1) { if (shape_values.empty()) { if (!vdevice_unknown) { - return TensorStructInfo(output_dtype, output_ndim, vdev); + return TensorType(output_dtype, output_ndim, vdev); } - return TensorStructInfo(output_dtype, output_ndim); + return TensorType(output_dtype, output_ndim); } ffi::Array output_shape; for (int i = 0; i < axis; ++i) { @@ -1581,31 +1564,31 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { output_shape.push_back(shape_values[0][i]); } if (!vdevice_unknown) { - return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev); + return TensorType(ShapeExpr(output_shape), output_dtype, vdev); } - return TensorStructInfo(ShapeExpr(output_shape), output_dtype); + return TensorType(ShapeExpr(output_shape), output_dtype); } // Multiple tensors case if (shape_values.empty()) { if (!vdevice_unknown) { - return TensorStructInfo(output_dtype, output_ndim, vdev); + return TensorType(output_dtype, output_ndim, vdev); } - return TensorStructInfo(output_dtype, output_ndim); + return TensorType(output_dtype, output_ndim); } ffi::Optional> output_shape = CheckStackOutputShape(call, ctx, shape_values, axis); if (shape_unknown || !output_shape.defined()) { if (!vdevice_unknown) { - return TensorStructInfo(output_dtype, output_ndim, vdev); + return TensorType(output_dtype, output_ndim, vdev); } - return TensorStructInfo(output_dtype, output_ndim); + return TensorType(output_dtype, output_ndim); } else { if (!vdevice_unknown) { - return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev); + return TensorType(ShapeExpr(output_shape.value()), output_dtype, vdev); } - return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + return TensorType(ShapeExpr(output_shape.value()), output_dtype); } } @@ -1643,7 +1626,7 @@ TVM_REGISTER_OP("relax.stack") .set_attrs_type() .set_num_inputs(1) .add_argument("tensors", "Tuple of Tensors", "The input list of tensors to stack") - .set_attr("FInferStructInfo", InferStructInfoStack) + .set_attr("FInferType", InferTypeStack) .set_attr("FRelaxInferLayout", InferLayoutStack) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -1659,33 +1642,31 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.collapse_sum_like", collapse_sum_like); } -StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo collapse_target_sinfo = input_sinfo[1]; +Type InferTypeCollapseSumLike(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType data_ty = input_ty[0]; + TensorType collapse_target_ty = input_ty[1]; - DataType output_dtype = data_sinfo->dtype; + DataType output_dtype = data_ty->dtype; ffi::Optional> data_shape_value; - if (data_sinfo->shape.defined()) { - data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + if (data_ty->shape.defined()) { + data_shape_value = GetTypeAs(data_ty->shape.value())->values; } ffi::Optional> collapse_target_shape_value; - if (collapse_target_sinfo->shape.defined()) { + if (collapse_target_ty->shape.defined()) { collapse_target_shape_value = - GetStructInfoAs(collapse_target_sinfo->shape.value())->values; + GetTypeAs(collapse_target_ty->shape.value())->values; } if (data_shape_value.defined() && collapse_target_shape_value.defined()) { CheckCollapseShape(call, ctx, data_shape_value.value(), collapse_target_shape_value.value()); } - if (collapse_target_sinfo->shape.defined()) { - return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype, - collapse_target_sinfo->vdevice); + if (collapse_target_ty->shape.defined()) { + return TensorType(collapse_target_ty->shape.value(), output_dtype, collapse_target_ty->vdevice); } else { - return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim, - collapse_target_sinfo->vdevice); + return TensorType(output_dtype, collapse_target_ty->ndim, collapse_target_ty->vdevice); } } @@ -1694,7 +1675,7 @@ TVM_REGISTER_OP("relax.collapse_sum_like") .add_argument("data", "Tensor", "The input tensor.") .add_argument("collapse_target", "Tensor", "The tensor whose shape is the shape to collapse to.") - .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike) + .set_attr("FInferType", InferTypeCollapseSumLike) .set_attr("FPurity", true); /* relax.collapse_sum_to */ @@ -1708,43 +1689,43 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.collapse_sum_to", collapse_sum_to); } -StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) { +Type InferTypeCollapseSumTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "CollapseSumTo should have 2 arguments"; } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* shape_sinfo = GetStructInfoAs(call->args[1]); + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* shape_ty = GetTypeAs(call->args[1]); - if (data_sinfo == nullptr) { + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "CollapseSumTo requires the input data to be a Tensor. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (shape_sinfo == nullptr) { + if (shape_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "CollapseSumTo requires the input shape to be a Shape. However, the given one is " - << call->args[1]->struct_info_->GetTypeKey(); + << call->args[1]->ty->GetTypeKey(); } - DataType output_dtype = data_sinfo->dtype; + DataType output_dtype = data_ty->dtype; ffi::Optional> data_shape_value; - if (data_sinfo->shape.defined()) { - data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + if (data_ty->shape.defined()) { + data_shape_value = GetTypeAs(data_ty->shape.value())->values; } - if (data_shape_value.defined() && shape_sinfo->values.defined()) { - CheckCollapseShape(call, ctx, data_shape_value.value(), shape_sinfo->values.value()); + if (data_shape_value.defined() && shape_ty->values.defined()) { + CheckCollapseShape(call, ctx, data_shape_value.value(), shape_ty->values.value()); } - return TensorStructInfo(/*shape=*/call->args[1], output_dtype, data_sinfo->vdevice); + return TensorType(/*shape=*/call->args[1], output_dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.collapse_sum_to") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape", "Shape", "The shape to collapse to.") - .set_attr("FInferStructInfo", InferStructInfoCollapseSumTo) + .set_attr("FInferType", InferTypeCollapseSumTo) .set_attr("FPurity", true); /* relax.repeat */ @@ -1763,19 +1744,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.repeat", repeat); } -StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { +Type InferTypeRepeat(const Call& call, const BlockBuilder& ctx) { arith::Analyzer analyzer = ctx->GetAnalyzer(); - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); - if (attrs->axis.has_value() && !data_sinfo->IsUnknownNdim()) { + if (attrs->axis.has_value() && !data_ty->IsUnknownNdim()) { int axis = attrs->axis.value(); - int ndim = data_sinfo->ndim; + int ndim = data_ty->ndim; if (axis < -ndim || axis >= ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "Repeat requires the input axis belongs range " - "[-data.struct_info.ndim, data.struct_info.ndim - 1]. However, the input axis is " + "[-data.ty.ndim, data.ty.ndim - 1]. However, the input axis is " << axis << ", while ndim is " << ndim; } } @@ -1784,26 +1765,26 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { if (attrs->axis.has_value()) { if (analyzer->CanProveEqual(attrs->repeats, 1)) { // the shape does not changes - return data_sinfo; + return data_ty; } else { - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice); } } else { - return TensorStructInfo(data_sinfo->dtype, 1, data_sinfo->vdevice); + return TensorType(data_ty->dtype, 1, data_ty->vdevice); } } if (!attrs->axis.has_value()) { PrimExpr new_shape = analyzer->Simplify(ComputeShapeProduct(data_shape->values) * attrs->repeats); - return TensorStructInfo(ShapeExpr(ffi::Array({new_shape})), data_sinfo->dtype, - data_sinfo->vdevice); + return TensorType(ShapeExpr(ffi::Array({new_shape})), data_ty->dtype, + data_ty->vdevice); } - int axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()); + int axis = NormalizeAxis(call, ctx, data_ty->ndim, attrs->axis.value()); auto shape_array = data_shape->values; shape_array.Set(axis, analyzer->Simplify(shape_array[axis] * attrs->repeats)); - return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(shape_array), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutRepeat( @@ -1813,12 +1794,12 @@ InferLayoutOutput InferLayoutRepeat( const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); - int ndim = tensor_sinfo->ndim; + int ndim = tensor_ty->ndim; // Can't handle sub indexed layouts. if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { @@ -1868,7 +1849,7 @@ TVM_REGISTER_OP("relax.repeat") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoRepeat) + .set_attr("FInferType", InferTypeRepeat) .set_attr("FRelaxInferLayout", InferLayoutRepeat) .set_attr("FPurity", true); @@ -1887,28 +1868,28 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.tile", tile); } -StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { +Type InferTypeTile(const Call& call, const BlockBuilder& ctx) { arith::Analyzer analyzer = ctx->GetAnalyzer(); - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); int l = attrs->repeats.size(); - int ndim = data_sinfo->ndim; + int ndim = data_ty->ndim; if (data_shape == nullptr) { - if (data_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } if (l > ndim) { - return TensorStructInfo(data_sinfo->dtype, l, data_sinfo->vdevice); + return TensorType(data_ty->dtype, l, data_ty->vdevice); } else { for (int64_t i : attrs->repeats) { if (i != 1) { - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice); } } // if control reaches here, the shape should not be changed - return data_sinfo; + return data_ty; } } @@ -1927,7 +1908,7 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { } } - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutTile( @@ -1937,12 +1918,12 @@ InferLayoutOutput InferLayoutTile( const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); - int ndim = tensor_sinfo->ndim; + int ndim = tensor_ty->ndim; int l = attrs->repeats.size(); int out_ndim = std::max(l, ndim); @@ -2012,7 +1993,7 @@ TVM_REGISTER_OP("relax.tile") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoTile) + .set_attr("FInferType", InferTypeTile) .set_attr("FRelaxInferLayout", InferLayoutTile) .set_attr("FPurity", true); @@ -2030,22 +2011,22 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.flip", flip); } -StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { +Type InferTypeFlip(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "Flip op should take 1 argument"; } - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); int axis = static_cast(attrs->axis); - if (!data_sinfo->IsUnknownNdim()) { - int ndim = data_sinfo->ndim; + if (!data_ty->IsUnknownNdim()) { + int ndim = data_ty->ndim; if (axis < -ndim || axis >= ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "Flip requires the input axis belongs range " "[-ndim, ndim - 1]. However, the input axis is " << axis << ", while ndim is " << ndim; } } - return data_sinfo; + return data_ty; } InferLayoutOutput InferLayoutFlip( @@ -2055,12 +2036,12 @@ InferLayoutOutput InferLayoutFlip( const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); - int ndim = tensor_sinfo->ndim; + int ndim = tensor_ty->ndim; if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { existing_layout = LayoutDecision(InitialLayout(ndim)); @@ -2084,7 +2065,7 @@ TVM_REGISTER_OP("relax.flip") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoFlip) + .set_attr("FInferType", InferTypeFlip) .set_attr("FRelaxInferLayout", InferLayoutFlip) .set_attr("FPurity", true); @@ -2102,49 +2083,49 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.gather_elements", gather_elements); } -StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& ctx) { - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* indices_sinfo = GetStructInfoAs(call->args[1]); +Type InferTypeGatherElements(const Call& call, const BlockBuilder& ctx) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* indices_ty = GetTypeAs(call->args[1]); const auto* attrs = call->attrs.as(); - if (data_sinfo == nullptr) { + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherElements requires the input data to be a Tensor. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (indices_sinfo == nullptr) { + if (indices_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherElements requires the input indices to be a Tensor. However, the given one is " - << call->args[1]->struct_info_->GetTypeKey(); + << call->args[1]->ty->GetTypeKey(); } - if (!indices_sinfo->IsUnknownDtype() && !indices_sinfo->dtype.is_int()) { + if (!indices_ty->IsUnknownDtype() && !indices_ty->dtype.is_int()) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherElements requires the input indices to have int64 dtype. However, the " - << "given indices dtype is " << indices_sinfo->dtype; + << "given indices dtype is " << indices_ty->dtype; } - if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim() || indices_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } int axis = static_cast(attrs->axis); - if (axis < -data_sinfo->ndim || axis >= data_sinfo->ndim) { + if (axis < -data_ty->ndim || axis >= data_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) - << "GatherElements requires axis to be within the input dimension range [" - << -data_sinfo->ndim << ", " << data_sinfo->ndim - 1 << "]. However, the " + << "GatherElements requires axis to be within the input dimension range [" << -data_ty->ndim + << ", " << data_ty->ndim - 1 << "]. However, the " << "given axis is " << axis; } - if (data_sinfo->ndim != indices_sinfo->ndim) { + if (data_ty->ndim != indices_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "GatherElements requires data and indices to have the same rank. However, " - << "data rank is " << data_sinfo->ndim << " while indices rank is " << indices_sinfo->ndim; + << "data rank is " << data_ty->ndim << " while indices rank is " << indices_ty->ndim; } - if (indices_sinfo->shape.defined()) { - return TensorStructInfo(indices_sinfo->shape.value(), data_sinfo->dtype, data_sinfo->vdevice); + if (indices_ty->shape.defined()) { + return TensorType(indices_ty->shape.value(), data_ty->dtype, data_ty->vdevice); } - return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, indices_ty->ndim, data_ty->vdevice); } InferLayoutOutput InferLayoutGatherElements( @@ -2168,10 +2149,10 @@ InferLayoutOutput InferLayoutGatherElements( } if (layout->layout.ndim() != layout->layout.ndim_primal()) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - int ndim = tensor_sinfo->ndim; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_ty->ndim; layout = LayoutDecision(InitialLayout(ndim)); } @@ -2185,7 +2166,7 @@ TVM_REGISTER_OP("relax.gather_elements") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor.") - .set_attr("FInferStructInfo", InferStructInfoGatherElements) + .set_attr("FInferType", InferTypeGatherElements) .set_attr("FRelaxInferLayout", InferLayoutGatherElements) .set_attr("FPurity", true); @@ -2203,56 +2184,56 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.gather_nd", gather_nd); } -StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* indices_sinfo = GetStructInfoAs(call->args[1]); +Type InferTypeGatherND(const Call& call, const BlockBuilder& ctx) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* indices_ty = GetTypeAs(call->args[1]); const auto* attrs = call->attrs.as(); - if (data_sinfo == nullptr) { + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherND requires the input data to be a Tensor. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey(); + << call->args[0]->ty->GetTypeKey(); } - if (indices_sinfo == nullptr) { + if (indices_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherND requires the input indices to be a Tensor. However, the given one is " - << call->args[1]->struct_info_->GetTypeKey(); + << call->args[1]->ty->GetTypeKey(); } TVM_FFI_ICHECK_GE(attrs->batch_dims, 0); int batch_dims = static_cast(attrs->batch_dims); - int input_dims = data_sinfo->ndim; - if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != DataType::Int(64)) { + int input_dims = data_ty->ndim; + if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != DataType::Int(64)) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherND requires the input indices to have int64 dtype. However, the " - << "given indices dtype is " << indices_sinfo->dtype; + << "given indices dtype is " << indices_ty->dtype; } - if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim() || indices_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } - if (batch_dims < 0 || batch_dims > data_sinfo->ndim) { + if (batch_dims < 0 || batch_dims > data_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "GatherND batch_dims must be in range [0, data.ndim]. However, got batch_dims=" << batch_dims << ", data.ndim=" << input_dims; } - if (batch_dims > indices_sinfo->ndim - 1) { + if (batch_dims > indices_ty->ndim - 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "GatherND batch_dims cannot exceed indices.ndim-1. However, got batch_dims=" - << batch_dims << ", indices.ndim=" << indices_sinfo->ndim; + << batch_dims << ", indices.ndim=" << indices_ty->ndim; } // Check if indices shape is known - const auto* indices_shape = indices_sinfo->shape.as(); - const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_ty->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (!indices_shape || !indices_shape->values.back()->IsInstance()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } int l = indices_shape->values.back().as()->value; - int output_ndim = indices_sinfo->ndim + input_dims - l - 1 - batch_dims; + int output_ndim = indices_ty->ndim + input_dims - l - 1 - batch_dims; if (!data_shape) { - return TensorStructInfo(data_sinfo->dtype, output_ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, output_ndim, data_ty->vdevice); } // In this condition, all input shapes are known @@ -2264,14 +2245,14 @@ StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { << "indices: " << ShapeExpr(indices_shape->values) << ", data: " << ShapeExpr(data_shape->values) << ", with batch_dims=" << batch_dims; } - for (int i = 0; i < indices_sinfo->ndim - 1; ++i) { + for (int i = 0; i < indices_ty->ndim - 1; ++i) { out_shape.push_back(indices_shape->values[i]); } for (int i = batch_dims + l; i < input_dims; ++i) { out_shape.push_back(data_shape->values[i]); } TVM_FFI_ICHECK_EQ(out_shape.size(), output_ndim); - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.gather_nd") @@ -2279,7 +2260,7 @@ TVM_REGISTER_OP("relax.gather_nd") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor.") - .set_attr("FInferStructInfo", InferStructInfoGatherND) + .set_attr("FInferType", InferTypeGatherND) .set_attr("FPurity", true); /* relax.index_put */ @@ -2296,84 +2277,84 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.index_put", index_put); } -StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* values_sinfo = GetStructInfoAs(call->args[2]); +Type InferTypeIndexPut(const Call& call, const BlockBuilder& ctx) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* values_ty = GetTypeAs(call->args[2]); - auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name, ffi::String type_key) { - if (sinfo == nullptr) { + auto diag_def = [&](const TensorTypeNode* ty, ffi::String name, ffi::String type_key) { + if (ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "IndexPut requires the input " << name << " to be a Tensor. However, the given one is " << type_key; } }; - diag_def(data_sinfo, "data", call->args[0]->struct_info_->GetTypeKey()); - diag_def(values_sinfo, "values", call->args[2]->struct_info_->GetTypeKey()); + diag_def(data_ty, "data", call->args[0]->ty->GetTypeKey()); + diag_def(values_ty, "values", call->args[2]->ty->GetTypeKey()); // Handle indices: either a single tensor or a tuple of tensors - ffi::Array indices_tensors; + ffi::Array indices_tensors; - if (const auto* tuple_sinfo = GetStructInfoAs(call->args[1])) { + if (const auto* tuple_ty = GetTypeAs(call->args[1])) { // Indices is a tuple of tensors - for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { - const auto* tensor_sinfo = tuple_sinfo->fields[i].as(); - if (tensor_sinfo == nullptr) { + for (size_t i = 0; i < tuple_ty->fields.size(); ++i) { + const auto* tensor_ty = tuple_ty->fields[i].as(); + if (tensor_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "IndexPut requires each index in the indices tuple to be a Tensor. " - << "However, element " << i << " is " << tuple_sinfo->fields[i]->GetTypeKey(); + << "However, element " << i << " is " << tuple_ty->fields[i]->GetTypeKey(); } - indices_tensors.push_back(ffi::GetRef(tensor_sinfo)); + indices_tensors.push_back(ffi::GetRef(tensor_ty)); } - } else if (const auto* tensor_sinfo = GetStructInfoAs(call->args[1])) { + } else if (const auto* tensor_ty = GetTypeAs(call->args[1])) { // Indices is a single tensor - indices_tensors.push_back(ffi::GetRef(tensor_sinfo)); + indices_tensors.push_back(ffi::GetRef(tensor_ty)); } else { TVM_FFI_VISIT_THROW(TypeError, call) << "IndexPut requires indices to be a Tensor or a tuple of Tensors. " - << "However, the given one is " << call->args[1]->struct_info_->GetTypeKey(); + << "However, the given one is " << call->args[1]->ty->GetTypeKey(); } - if (data_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } // Validate each index tensor // Index tensors can be multi-dimensional for broadcasting int max_index_ndim = -1; for (size_t i = 0; i < indices_tensors.size(); ++i) { - const auto& tensor_sinfo = indices_tensors[i]; - if (!tensor_sinfo->IsUnknownNdim()) { - if (tensor_sinfo->ndim < 1) { + const auto& tensor_ty = indices_tensors[i]; + if (!tensor_ty->IsUnknownNdim()) { + if (tensor_ty->ndim < 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "IndexPut requires each index tensor to have at least 1 dimension. " - << "However, index tensor " << i << " has ndim=" << tensor_sinfo->ndim; + << "However, index tensor " << i << " has ndim=" << tensor_ty->ndim; } - if (max_index_ndim < tensor_sinfo->ndim) { - max_index_ndim = tensor_sinfo->ndim; + if (max_index_ndim < tensor_ty->ndim) { + max_index_ndim = tensor_ty->ndim; } } - if (tensor_sinfo->IsUnknownDtype()) { + if (tensor_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of index tensor " << i << " has not been specified. Assume it has an integer type."; - } else if (!(tensor_sinfo->dtype.is_int() || tensor_sinfo->dtype.is_uint())) { + } else if (!(tensor_ty->dtype.is_int() || tensor_ty->dtype.is_uint())) { TVM_FFI_VISIT_THROW(TypeError, call) << "IndexPut requires each index tensor to have integer dtype. " - << "However, index tensor " << i << " has dtype=" << tensor_sinfo->dtype; + << "However, index tensor " << i << " has dtype=" << tensor_ty->dtype; } } // Validate that index tensor shapes are broadcastable if (max_index_ndim > 1) { for (size_t i = 0; i < indices_tensors.size(); ++i) { - const auto& tensor_sinfo = indices_tensors[i]; - if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim > 1) { + const auto& tensor_ty = indices_tensors[i]; + if (!tensor_ty->IsUnknownNdim() && tensor_ty->ndim > 1) { // Check that multi-dimensional indices are broadcastable - const auto* shape = tensor_sinfo->shape.as(); + const auto* shape = tensor_ty->shape.as(); if (shape) { // Verify trailing dimensions can broadcast // For now, we accept any multi-dimensional index and rely on runtime validation - LOG(INFO) << "IndexPut: index tensor " << i << " has ndim=" << tensor_sinfo->ndim + LOG(INFO) << "IndexPut: index tensor " << i << " has ndim=" << tensor_ty->ndim << " for broadcasting"; } } @@ -2381,44 +2362,43 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { } // Check that the number of index tensors matches data dimensions - if (!data_sinfo->IsUnknownNdim() && - indices_tensors.size() != static_cast(data_sinfo->ndim)) { + if (!data_ty->IsUnknownNdim() && indices_tensors.size() != static_cast(data_ty->ndim)) { TVM_FFI_VISIT_THROW(ValueError, call) << "IndexPut requires the number of index tensors (" << indices_tensors.size() - << ") to match the data tensor dimensions (" << data_sinfo->ndim << ")"; + << ") to match the data tensor dimensions (" << data_ty->ndim << ")"; } // Check data and values dtype compatibility - if (data_sinfo->IsUnknownDtype() || values_sinfo->IsUnknownDtype()) { - auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { - if (sinfo->IsUnknownDtype()) { + if (data_ty->IsUnknownDtype() || values_ty->IsUnknownDtype()) { + auto diag_dtype = [&](const TensorTypeNode* ty, ffi::String name) { + if (ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of " << name << " has not been specified. Assume it has an integer type."; } }; - diag_dtype(data_sinfo, "data"); - diag_dtype(values_sinfo, "values"); - } else if (data_sinfo->dtype != values_sinfo->dtype) { + diag_dtype(data_ty, "data"); + diag_dtype(values_ty, "values"); + } else if (data_ty->dtype != values_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "IndexPut requires the input data to have the same type as values. " - << "However, the given types are data: " << data_sinfo->dtype - << ", values: " << values_sinfo->dtype; + << "However, the given types are data: " << data_ty->dtype + << ", values: " << values_ty->dtype; } // Check values shape compatibility - const auto* values_shape = values_sinfo->shape.as(); + const auto* values_shape = values_ty->shape.as(); if (values_shape) { - if (values_sinfo->ndim != 1) { + if (values_ty->ndim != 1) { LOG(WARNING) << "IndexPut typically expects values to be 1D, but got ndim=" - << values_sinfo->ndim; + << values_ty->ndim; } } - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape) { - return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(data_shape->values), data_ty->dtype, data_ty->vdevice); } - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice); } TVM_REGISTER_OP("relax.index_put") @@ -2427,7 +2407,7 @@ TVM_REGISTER_OP("relax.index_put") .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor(s).") .add_argument("values", "Tensor", "The values to put.") - .set_attr("FInferStructInfo", InferStructInfoIndexPut) + .set_attr("FInferType", InferTypeIndexPut) .set_attr("FPurity", true); /* relax.meshgrid */ @@ -2444,13 +2424,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.meshgrid", meshgrid); } -StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { +Type InferTypeMeshgrid(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "meshgrid op expects 1 Tuple input argument."; } - ffi::Array input_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + ffi::Array input_ty = GetTensorTypeFromTuple(call, ctx, call->args[0]); - int n_inputs = input_sinfo.size(); + int n_inputs = input_ty.size(); if (n_inputs == 0) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -2464,25 +2444,25 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { bool vdevice_unknown = false; for (int i = 0; i < n_inputs; ++i) { - const TensorStructInfo& sinfo = input_sinfo[i]; + const TensorType& ty = input_ty[i]; - if (sinfo->ndim != 1) { + if (ty->ndim != 1) { TVM_FFI_VISIT_THROW(ValueError, call) - << "meshgrid expects each input tensor to be 1D. Got ndim = " << sinfo->ndim - << " at index " << i; + << "meshgrid expects each input tensor to be 1D. Got ndim = " << ty->ndim << " at index " + << i; } - if (sinfo->dtype.is_void()) { + if (ty->dtype.is_void()) { continue; } else if (common_dtype.is_void()) { - common_dtype = sinfo->dtype; - } else if (sinfo->dtype != common_dtype) { + common_dtype = ty->dtype; + } else if (ty->dtype != common_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) - << "meshgrid expects all input tensors to have the same dtype. Found " << sinfo->dtype + << "meshgrid expects all input tensors to have the same dtype. Found " << ty->dtype << " and " << common_dtype; } - const auto* shape_expr = sinfo->shape.as(); + const auto* shape_expr = ty->shape.as(); if (shape_expr && shape_expr->values.size() == 1) { lengths.push_back(shape_expr->values[0]); } else { @@ -2490,10 +2470,10 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { } if (!vdevice_unknown) { - if (sinfo->vdevice.defined()) { + if (ty->vdevice.defined()) { if (!vdev.defined()) { - vdev = sinfo->vdevice.value(); - } else if (sinfo->vdevice.value() != vdev) { + vdev = ty->vdevice.value(); + } else if (ty->vdevice.value() != vdev) { vdevice_unknown = true; } } @@ -2507,31 +2487,31 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { } } - ffi::Array out_fields; + ffi::Array out_fields; for (int i = 0; i < n_inputs; ++i) { if (!out_shape.empty()) { if (!vdevice_unknown) { - out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape), common_dtype, vdev)); + out_fields.push_back(TensorType(ShapeExpr(out_shape), common_dtype, vdev)); } else { - out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape), common_dtype)); + out_fields.push_back(TensorType(ShapeExpr(out_shape), common_dtype)); } } else { if (!vdevice_unknown) { - out_fields.push_back(TensorStructInfo(common_dtype, n_inputs, vdev)); + out_fields.push_back(TensorType(common_dtype, n_inputs, vdev)); } else { - out_fields.push_back(TensorStructInfo(common_dtype, n_inputs)); + out_fields.push_back(TensorType(common_dtype, n_inputs)); } } } - return TupleStructInfo(out_fields); + return TupleType(out_fields); } TVM_REGISTER_OP("relax.meshgrid") .set_attrs_type() .set_num_inputs(1) .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") - .set_attr("FInferStructInfo", InferStructInfoMeshgrid) + .set_attr("FInferType", InferTypeMeshgrid) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); @@ -2550,77 +2530,77 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.scatter_elements", scatter_elements); } -StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& ctx) { +Type InferTypeScatterElements(const Call& call, const BlockBuilder& ctx) { arith::Analyzer analyzer = ctx->GetAnalyzer(); - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* indices_sinfo = GetStructInfoAs(call->args[1]); - const auto* updates_sinfo = GetStructInfoAs(call->args[2]); + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* indices_ty = GetTypeAs(call->args[1]); + const auto* updates_ty = GetTypeAs(call->args[2]); - auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name, ffi::String type_key) { - if (sinfo == nullptr) { + auto diag_def = [&](const TensorTypeNode* ty, ffi::String name, ffi::String type_key) { + if (ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "ScatterElements requires the input " << name << " to be a Tensor. However, the given one is " << type_key; } }; - diag_def(data_sinfo, "data", call->args[0]->struct_info_->GetTypeKey()); - diag_def(indices_sinfo, "indices", call->args[1]->struct_info_->GetTypeKey()); - diag_def(updates_sinfo, "updates", call->args[2]->struct_info_->GetTypeKey()); + diag_def(data_ty, "data", call->args[0]->ty->GetTypeKey()); + diag_def(indices_ty, "indices", call->args[1]->ty->GetTypeKey()); + diag_def(updates_ty, "updates", call->args[2]->ty->GetTypeKey()); - if (data_sinfo->IsUnknownNdim()) { + if (data_ty->IsUnknownNdim()) { // When `data` has unknown rank, assume rest of arguments are correct and proceed. // If the assumption turns out to be wrong, runtime error will be triggered. - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } - if (!indices_sinfo->IsUnknownNdim() && !updates_sinfo->IsUnknownNdim()) { - if (data_sinfo->ndim != indices_sinfo->ndim) { + if (!indices_ty->IsUnknownNdim() && !updates_ty->IsUnknownNdim()) { + if (data_ty->ndim != indices_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "ScatterElements op requires the data tensor to have the same rank with " "indices tensor. However, the given dimensions are " - << "indices: " << indices_sinfo->ndim << ", data: " << data_sinfo->ndim; + << "indices: " << indices_ty->ndim << ", data: " << data_ty->ndim; } - if (indices_sinfo->ndim != updates_sinfo->ndim) { + if (indices_ty->ndim != updates_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "ScatterElements op requires the indices tensor to have the same rank with " "updates tensor. However, the given dimensions are " - << "indices: " << indices_sinfo->ndim << ", updates: " << updates_sinfo->ndim; + << "indices: " << indices_ty->ndim << ", updates: " << updates_ty->ndim; } } - if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { - auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { - if (sinfo->IsUnknownDtype()) { + if (data_ty->IsUnknownDtype() || updates_ty->IsUnknownDtype()) { + auto diag_dtype = [&](const TensorTypeNode* ty, ffi::String name) { + if (ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of " << name << " has not been specified. Assume it has an integer type."; } }; - diag_dtype(data_sinfo, "data"); - diag_dtype(data_sinfo, "updates"); + diag_dtype(data_ty, "data"); + diag_dtype(data_ty, "updates"); } else { - if (data_sinfo->dtype != updates_sinfo->dtype) { + if (data_ty->dtype != updates_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "ScatterElements op requires the input data to have same type with " "updates. However, the given types are " - << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype; + << "data: " << data_ty->dtype << ", updates: " << updates_ty->dtype; } } - if (indices_sinfo->IsUnknownDtype()) { + if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { TVM_FFI_VISIT_THROW(TypeError, call) << "ScatterElements op requires the input indices to have integer dtype. However, the " "given indices dtype is " - << indices_sinfo->dtype; + << indices_ty->dtype; } - const auto* indices_shape = indices_sinfo->shape.as(); - const auto* updates_shape = updates_sinfo->shape.as(); + const auto* indices_shape = indices_ty->shape.as(); + const auto* updates_shape = updates_ty->shape.as(); if (indices_shape && updates_shape) { - for (int i = 0; i < indices_sinfo->ndim; i++) { + for (int i = 0; i < indices_ty->ndim; i++) { if (analyzer->CanProve(indices_shape->values[i] != updates_shape->values[i])) { TVM_FFI_VISIT_THROW(ValueError, call) << "ScatterElements op requires the indices tensor to have the same shape with " @@ -2630,11 +2610,11 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& } } } - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape) { - return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(data_shape->values), data_ty->dtype, data_ty->vdevice); } - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice); } InferLayoutOutput InferLayoutScatterElements( @@ -2654,10 +2634,10 @@ InferLayoutOutput InferLayoutScatterElements( } if (layout->layout.ndim() != layout->layout.ndim_primal()) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - int ndim = tensor_sinfo->ndim; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_ty->ndim; layout = LayoutDecision(InitialLayout(ndim)); } @@ -2672,7 +2652,7 @@ TVM_REGISTER_OP("relax.scatter_elements") .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor.") .add_argument("updates", "Tensor", "The input tensor of updates.") - .set_attr("FInferStructInfo", InferStructInfoScatterElements) + .set_attr("FInferType", InferTypeScatterElements) .set_attr("FRelaxInferLayout", InferLayoutScatterElements) .set_attr("FPurity", true); @@ -2690,67 +2670,67 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.scatter_nd", scatter_nd); } -StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { +Type InferTypeScatterND(const Call& call, const BlockBuilder& ctx) { // `call->args` contains: [data, indices, updates] arith::Analyzer analyzer = ctx->GetAnalyzer(); TVM_FFI_ICHECK_EQ(call->args.size(), 3); - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* indices_sinfo = GetStructInfoAs(call->args[1]); - const auto* updates_sinfo = GetStructInfoAs(call->args[2]); + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* indices_ty = GetTypeAs(call->args[1]); + const auto* updates_ty = GetTypeAs(call->args[2]); - if (data_sinfo == nullptr) { + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "ScatterND op requires the input data to be a tensor. However, the given type is " << call->args[0]->GetTypeKey(); } - if (indices_sinfo == nullptr) { + if (indices_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "ScatterND op requires the input indices to be a tensor. However, the given type is " << call->args[1]->GetTypeKey(); } - if (updates_sinfo == nullptr) { + if (updates_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "ScatterND op requires the input updates to be a tensor. However, the given type is " << call->args[2]->GetTypeKey(); } - if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { + if (data_ty->IsUnknownDtype() || updates_ty->IsUnknownDtype()) { TVM_FFI_VISIT_THROW(ValueError, call) << "ScatterND op requires the input data and updates to have known dtype. " "However, the given types are " - << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype; + << "data: " << data_ty->dtype << ", updates: " << updates_ty->dtype; } - if (data_sinfo->dtype != updates_sinfo->dtype) { + if (data_ty->dtype != updates_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "ScatterND op requires the input data to have same type with updates. " "However, the given types are " - << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype; + << "data: " << data_ty->dtype << ", updates: " << updates_ty->dtype; } - if (indices_sinfo->IsUnknownDtype()) { + if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { TVM_FFI_VISIT_THROW(TypeError, call) << "ScatterND op requires the input indices to have integer dtype. However, " "the given indices dtype is " - << indices_sinfo->dtype; + << indices_ty->dtype; } - const auto* data_shape = data_sinfo->shape.as(); - const auto* indices_shape = indices_sinfo->shape.as(); - const auto* updates_shape = updates_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); + const auto* indices_shape = indices_ty->shape.as(); + const auto* updates_shape = updates_ty->shape.as(); if (data_shape && indices_shape && updates_shape) { - const IntImmNode* k_dim = indices_shape->values[indices_sinfo->ndim - 1].as(); + const IntImmNode* k_dim = indices_shape->values[indices_ty->ndim - 1].as(); if (!k_dim) { TVM_FFI_VISIT_THROW(ValueError, call) << "ScatterND needs a static shape for the last axis of indices, got " << indices_shape->values; } - const size_t data_ndim = data_sinfo->ndim; - const size_t indices_ndim = indices_sinfo->ndim; - const size_t updates_ndim = updates_sinfo->ndim; + const size_t data_ndim = data_ty->ndim; + const size_t indices_ndim = indices_ty->ndim; + const size_t updates_ndim = updates_ty->ndim; if (data_ndim + indices_ndim - k_dim->value - 1 != updates_ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "ScatterND op requires the updates tensor to have the rank of " @@ -2796,9 +2776,9 @@ StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { } } if (data_shape) { - return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(data_shape->values), data_ty->dtype, data_ty->vdevice); } - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice); } InferLayoutOutput InferLayoutScatterND( @@ -2810,12 +2790,12 @@ InferLayoutOutput InferLayoutScatterND( LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]); LayoutDecision updates_layout = GetLayoutDecision(var_layout_map, call->args[2]); - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* updates_sinfo = GetStructInfoAs(call->args[2]); - TVM_FFI_ICHECK(data_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(updates_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!data_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - TVM_FFI_ICHECK(!updates_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* updates_ty = GetTypeAs(call->args[2]); + TVM_FFI_ICHECK(data_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(updates_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!data_ty->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(!updates_ty->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision layout = data_layout; LayoutDecision out_updates_layout = updates_layout; @@ -2825,15 +2805,15 @@ InferLayoutOutput InferLayoutScatterND( if (has_sub_indexed_layout) { // Fall back to initial layouts for both data and updates - layout = LayoutDecision(InitialLayout(data_sinfo->ndim)); - out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim)); - } else if (data_sinfo->ndim == updates_sinfo->ndim) { + layout = LayoutDecision(InitialLayout(data_ty->ndim)); + out_updates_layout = LayoutDecision(InitialLayout(updates_ty->ndim)); + } else if (data_ty->ndim == updates_ty->ndim) { // When data and updates have the same rank, apply the same layout to both out_updates_layout = layout; } else { // Different ranks - fall back to initial layouts for both - layout = LayoutDecision(InitialLayout(data_sinfo->ndim)); - out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim)); + layout = LayoutDecision(InitialLayout(data_ty->ndim)); + out_updates_layout = LayoutDecision(InitialLayout(updates_ty->ndim)); } return InferLayoutOutput({layout, indices_layout, out_updates_layout}, {layout}, @@ -2846,7 +2826,7 @@ TVM_REGISTER_OP("relax.scatter_nd") .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor.") .add_argument("updates", "Tensor", "The input tensor of updates.") - .set_attr("FInferStructInfo", InferStructInfoScatterND) + .set_attr("FInferType", InferTypeScatterND) .set_attr("FRelaxInferLayout", InferLayoutScatterND) .set_attr("FPurity", true); @@ -2864,29 +2844,28 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.slice_scatter", slice_scatter); } -StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx) { +Type InferTypeSliceScatter(const Call& call, const BlockBuilder& ctx) { arith::Analyzer analyzer = ctx->GetAnalyzer(); - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* src_sinfo = GetStructInfoAs(call->args[1]); + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* src_ty = GetTypeAs(call->args[1]); auto* attrs = call->attrs.as(); - auto diag_tensor_check = [&](const TensorStructInfoNode* sinfo, const Expr& arg_expr, - ffi::String name) { - if (sinfo == nullptr) { + auto diag_tensor_check = [&](const TensorTypeNode* ty, const Expr& arg_expr, ffi::String name) { + if (ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "SliceScatter requires the input " << name - << " to be a Tensor. However, the given one is " << arg_expr->struct_info_->GetTypeKey(); + << " to be a Tensor. However, the given one is " << arg_expr->ty->GetTypeKey(); } }; - diag_tensor_check(data_sinfo, call->args[0], "data"); - diag_tensor_check(src_sinfo, call->args[1], "src"); + diag_tensor_check(data_ty, call->args[0], "data"); + diag_tensor_check(src_ty, call->args[1], "src"); - if (data_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + if (data_ty->IsUnknownNdim()) { + return TensorType(data_ty->dtype, kUnknownNDim, data_ty->vdevice); } - int ndim = data_sinfo->ndim; + int ndim = data_ty->ndim; int raw_axis = attrs->axis; if (raw_axis < -ndim || raw_axis >= ndim) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -2895,31 +2874,31 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx << ", while ndim is " << ndim; } - if (!data_sinfo->IsUnknownNdim() && !src_sinfo->IsUnknownNdim()) { - if (data_sinfo->ndim != src_sinfo->ndim) { + if (!data_ty->IsUnknownNdim() && !src_ty->IsUnknownNdim()) { + if (data_ty->ndim != src_ty->ndim) { TVM_FFI_VISIT_THROW(ValueError, call) << "SliceScatter op requires the data tensor to have the same rank as the " "src tensor. However, the given dimensions are " - << "src: " << src_sinfo->ndim << ", data: " << data_sinfo->ndim; + << "src: " << src_ty->ndim << ", data: " << data_ty->ndim; } } - if (data_sinfo->IsUnknownDtype() || src_sinfo->IsUnknownDtype()) { - auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, ffi::String name) { - if (sinfo->IsUnknownDtype()) { + if (data_ty->IsUnknownDtype() || src_ty->IsUnknownDtype()) { + auto diag_dtype_warn = [&](const TensorTypeNode* ty, ffi::String name) { + if (ty->IsUnknownDtype()) { LOG(WARNING) << "SliceScatter: Data type of " << name << " has not been specified for call node " << call << ". Assuming it is compatible."; } }; - diag_dtype_warn(data_sinfo, "data"); - diag_dtype_warn(src_sinfo, "src"); + diag_dtype_warn(data_ty, "data"); + diag_dtype_warn(src_ty, "src"); } else { - if (data_sinfo->dtype != src_sinfo->dtype) { + if (data_ty->dtype != src_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "SliceScatter op requires the input data to have the same type as " "src. However, the given types are " - << "data: " << data_sinfo->dtype << ", src: " << src_sinfo->dtype; + << "data: " << data_ty->dtype << ", src: " << src_ty->dtype; } } @@ -2955,14 +2934,14 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx int axis = NormalizeAxis(call, ctx, ndim, attrs->axis); - const auto* data_shape_node = data_sinfo->shape.as(); - const auto* src_shape_node = src_sinfo->shape.as(); + const auto* data_shape_node = data_ty->shape.as(); + const auto* src_shape_node = src_ty->shape.as(); - if (data_shape_node && src_shape_node && !src_sinfo->IsUnknownNdim()) { + if (data_shape_node && src_shape_node && !src_ty->IsUnknownNdim()) { TVM_FFI_ICHECK_EQ(data_shape_node->values.size(), static_cast(ndim)) - << "Internal error: data_shape_node rank mismatch with data_sinfo->ndim for call " << call; - TVM_FFI_ICHECK_EQ(src_shape_node->values.size(), static_cast(src_sinfo->ndim)) - << "Internal error: src_shape_node rank mismatch with src_sinfo->ndim for call " << call; + << "Internal error: data_shape_node rank mismatch with data_ty->ndim for call " << call; + TVM_FFI_ICHECK_EQ(src_shape_node->values.size(), static_cast(src_ty->ndim)) + << "Internal error: src_shape_node rank mismatch with src_ty->ndim for call " << call; PrimExpr num_elem = tvm::floordiv((stop_val - start_val + step_val - PrimExpr(1)), step_val); @@ -2973,8 +2952,8 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx << "SliceScatter op requires the data tensor to have the same shape as the " "src tensor except at the scatter axis (" << axis << "). Mismatch at dimension " << i << ". " - << "data shape: " << data_sinfo->GetShape().value() - << ", src shape: " << src_sinfo->GetShape().value(); + << "data shape: " << data_ty->GetShape().value() + << ", src shape: " << src_ty->GetShape().value(); } } } @@ -2988,10 +2967,10 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx } } - if (data_sinfo->shape.defined()) { - return TensorStructInfo(data_sinfo->shape.value(), data_sinfo->dtype, data_sinfo->vdevice); + if (data_ty->shape.defined()) { + return TensorType(data_ty->shape.value(), data_ty->dtype, data_ty->vdevice); } - return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice); } TVM_REGISTER_OP("relax.slice_scatter") @@ -3002,7 +2981,7 @@ TVM_REGISTER_OP("relax.slice_scatter") .add_argument("start", "PrimValue", "The starting index of the slice (inclusive).") .add_argument("end", "PrimValue", "The ending index of the slice (exclusive).") .add_argument("step", "PrimValue", "The step of the slice.") - .set_attr("FInferStructInfo", InferStructInfoSliceScatter) + .set_attr("FInferType", InferTypeSliceScatter) .set_attr("FPurity", true); /* relax.one_hot */ @@ -3030,8 +3009,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.one_hot", one_hot); } -StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx); +Type InferTypeOneHot(const Call& call, const BlockBuilder& ctx) { + TensorType indices_ty = GetInputTensorType(call, 0, ctx); const auto* attrs = call->attrs.as(); PrimValue on_value = Downcast(call->args[1]); PrimValue off_value = Downcast(call->args[2]); @@ -3042,22 +3021,22 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { DataType dtype = on_value->value->dtype; // Check if indices has an integer dtype - if (indices_sinfo->IsUnknownDtype()) { + if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { TVM_FFI_VISIT_THROW(TypeError, call) << "one_hot op requires the input indices to have integer dtype. However, the " "given indices dtype is " - << indices_sinfo->dtype; + << indices_ty->dtype; } // Check if indices has unknown dimension - if (indices_sinfo->IsUnknownNdim()) { - return TensorStructInfo(dtype, kUnknownNDim, indices_sinfo->vdevice); + if (indices_ty->IsUnknownNdim()) { + return TensorType(dtype, kUnknownNDim, indices_ty->vdevice); } // Get the shape of indices - const auto* indices_shape = indices_sinfo->shape.as(); + const auto* indices_shape = indices_ty->shape.as(); if (indices_shape == nullptr) { - return TensorStructInfo(dtype, indices_sinfo->ndim + 1, indices_sinfo->vdevice); + return TensorType(dtype, indices_ty->ndim + 1, indices_ty->vdevice); } ffi::Array output_shape = indices_shape->values; @@ -3070,7 +3049,7 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { << "but got " << axis; output_shape.insert(output_shape.begin() + axis, attrs->depth); - return TensorStructInfo(ShapeExpr(output_shape), dtype, indices_sinfo->vdevice); + return TensorType(ShapeExpr(output_shape), dtype, indices_ty->vdevice); } TVM_REGISTER_OP("relax.one_hot") @@ -3079,7 +3058,7 @@ TVM_REGISTER_OP("relax.one_hot") .add_argument("indices", "Tensor", "The indices tensor.") .add_argument("on_value", "PrimValue", "The value to fill at specified indices.") .add_argument("off_value", "PrimValue", "The value to fill at other indices.") - .set_attr("FInferStructInfo", InferStructInfoOneHot) + .set_attr("FInferType", InferTypeOneHot) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 3ab3e7513139..974d70e7300a 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -52,7 +52,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.quantize", quantize); } -StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { +Type InferTypeQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != DataType::UInt(8) && attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16) && @@ -61,45 +61,44 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { << "Unsupported output datatype attribute for operation: '" << attrs->out_dtype; } - TensorStructInfo input_sinfo = GetInputTensorStructInfo(call, ctx)[0]; - TensorStructInfo scale_sinfo = GetInputTensorStructInfo(call, ctx)[1]; - TensorStructInfo zp_sinfo = GetInputTensorStructInfo(call, ctx)[2]; + TensorType input_ty = GetInputTensorType(call, ctx)[0]; + TensorType scale_ty = GetInputTensorType(call, ctx)[1]; + TensorType zp_ty = GetInputTensorType(call, ctx)[2]; // Check input datatype: - if (input_sinfo->dtype != DataType::Float(16) && input_sinfo->dtype != DataType::Float(32)) { + if (input_ty->dtype != DataType::Float(16) && input_ty->dtype != DataType::Float(32)) { TVM_FFI_VISIT_THROW(TypeError, call) - << "Unsupported input datatype for operation: " << input_sinfo->dtype; + << "Unsupported input datatype for operation: " << input_ty->dtype; } // Check datatype of scale param: - if (scale_sinfo->dtype != DataType::Float(32) && scale_sinfo->dtype != DataType::Float(16)) { + if (scale_ty->dtype != DataType::Float(32) && scale_ty->dtype != DataType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) - << "scale param datatype should be one of [float16, float32], but got " - << scale_sinfo->dtype; + << "scale param datatype should be one of [float16, float32], but got " << scale_ty->dtype; } // Check datatype of zero_point param: - if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::UInt(8) && - zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != DataType::UInt(16) && - zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != DataType::UInt(32) && - zp_sinfo->dtype != DataType::Float(16)) { + if (zp_ty->dtype != DataType::Int(8) && zp_ty->dtype != DataType::UInt(8) && + zp_ty->dtype != DataType::Int(16) && zp_ty->dtype != DataType::UInt(16) && + zp_ty->dtype != DataType::Int(32) && zp_ty->dtype != DataType::UInt(32) && + zp_ty->dtype != DataType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "zero_point param datatype should be one of " << "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], " - << "but got " << zp_sinfo->dtype; + << "but got " << zp_ty->dtype; } // Check that "axis" attribute is not out of range: - int axis = (attrs->axis < 0) ? (input_sinfo->ndim + attrs->axis) : attrs->axis; - if (axis < 0 || axis > input_sinfo->ndim - 1) { + int axis = (attrs->axis < 0) ? (input_ty->ndim + attrs->axis) : attrs->axis; + if (axis < 0 || axis > input_ty->ndim - 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "relax.quantize: axis param is out of range (" << attrs->axis << ")"; } - auto check_param_size = [&](const TensorStructInfo& param_sinfo, - const TensorStructInfo& data_sinfo, ffi::String param_name) { - const PrimExpr& param_dim = param_sinfo->GetShape().value()[0]; - const PrimExpr& input_dim = data_sinfo->GetShape().value()[axis]; + auto check_param_size = [&](const TensorType& param_ty, const TensorType& data_ty, + ffi::String param_name) { + const PrimExpr& param_dim = param_ty->GetShape().value()[0]; + const PrimExpr& input_dim = data_ty->GetShape().value()[axis]; if (!ctx->GetAnalyzer()->CanProveEqual(param_dim, input_dim)) { TVM_FFI_VISIT_THROW(ValueError, call) << "Size mismatch: " << call->op << ": the input shape at dim " << attrs->axis << " is '" @@ -107,10 +106,10 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { } }; - auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& param_sinfo) { - if (IsScalarTensor(param_sinfo)) return true; - if (param_sinfo->shape.defined() && param_sinfo->shape->IsInstance()) { - const auto& values = param_sinfo->shape.as()->values; + auto is_scalar_or_singleton_vector = [&](const TensorType& param_ty) { + if (IsScalarTensor(param_ty)) return true; + if (param_ty->shape.defined() && param_ty->shape->IsInstance()) { + const auto& values = param_ty->shape.as()->values; if (!values.empty()) { return std::all_of(values.begin(), values.end(), [&](const PrimExpr& dim) { return ctx->GetAnalyzer()->CanProveEqual(dim, 1); @@ -121,14 +120,12 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { }; // Check size matching of scale/zp params with input shape at dim = attrs->axis. - if (!is_scalar_or_singleton_vector(scale_sinfo)) - check_param_size(scale_sinfo, input_sinfo, "scale"); - if (!is_scalar_or_singleton_vector(zp_sinfo)) - check_param_size(zp_sinfo, input_sinfo, "zero_point"); - - auto output_sinfo = ffi::make_object(*input_sinfo.get()); - output_sinfo->dtype = attrs->out_dtype; - return TensorStructInfo(output_sinfo); + if (!is_scalar_or_singleton_vector(scale_ty)) check_param_size(scale_ty, input_ty, "scale"); + if (!is_scalar_or_singleton_vector(zp_ty)) check_param_size(zp_ty, input_ty, "zero_point"); + + auto output_ty = ffi::make_object(*input_ty.get()); + output_ty->dtype = attrs->out_dtype; + return TensorType(output_ty); } TVM_REGISTER_OP("relax.quantize") @@ -137,7 +134,7 @@ TVM_REGISTER_OP("relax.quantize") .add_argument("data", "Tensor", "The input tensor.") .add_argument("scale", "Tensor", "The quantization scale of the output tensor.") .add_argument("zero_point", "Tensor", "The quantization zero_point of the output tensor.") - .set_attr("FInferStructInfo", InferStructInfoQuantize) + .set_attr("FInferType", InferTypeQuantize) .set_attr("FPurity", true); /* relax.dequantize */ @@ -155,56 +152,55 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.dequantize", dequantize); } -StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) { +Type InferTypeDequantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); if (attrs->out_dtype != DataType::Float(16) && attrs->out_dtype != DataType::Float(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported output datatype attribute for operation: " << attrs->out_dtype; } - TensorStructInfo input_sinfo = GetInputTensorStructInfo(call, ctx)[0]; - TensorStructInfo scale_sinfo = GetInputTensorStructInfo(call, ctx)[1]; - TensorStructInfo zp_sinfo = GetInputTensorStructInfo(call, ctx)[2]; + TensorType input_ty = GetInputTensorType(call, ctx)[0]; + TensorType scale_ty = GetInputTensorType(call, ctx)[1]; + TensorType zp_ty = GetInputTensorType(call, ctx)[2]; // Check input datatype: - if (input_sinfo->dtype != DataType::Int(8) && input_sinfo->dtype != DataType::UInt(8) && - input_sinfo->dtype != DataType::Int(16) && input_sinfo->dtype != DataType::UInt(16) && - input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype != DataType::Float8E4M3FN() && - input_sinfo->dtype != DataType::Float8E5M2() && input_sinfo->dtype != DataType::Float(16) && - input_sinfo->dtype != DataType::Float(32)) { + if (input_ty->dtype != DataType::Int(8) && input_ty->dtype != DataType::UInt(8) && + input_ty->dtype != DataType::Int(16) && input_ty->dtype != DataType::UInt(16) && + input_ty->dtype != DataType::Int(32) && input_ty->dtype != DataType::Float8E4M3FN() && + input_ty->dtype != DataType::Float8E5M2() && input_ty->dtype != DataType::Float(16) && + input_ty->dtype != DataType::Float(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported input datatype for operation: " << attrs->out_dtype; } // Check datatype of scale param: - if (scale_sinfo->dtype != DataType::Float(32) && scale_sinfo->dtype != DataType::Float(16)) { + if (scale_ty->dtype != DataType::Float(32) && scale_ty->dtype != DataType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) - << "scale param datatype should be one of [float16, float32], but got " - << scale_sinfo->dtype; + << "scale param datatype should be one of [float16, float32], but got " << scale_ty->dtype; } // Check datatype of zero_point param: - if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::UInt(8) && - zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != DataType::UInt(16) && - zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != DataType::UInt(32) && - zp_sinfo->dtype != DataType::Float(16)) { + if (zp_ty->dtype != DataType::Int(8) && zp_ty->dtype != DataType::UInt(8) && + zp_ty->dtype != DataType::Int(16) && zp_ty->dtype != DataType::UInt(16) && + zp_ty->dtype != DataType::Int(32) && zp_ty->dtype != DataType::UInt(32) && + zp_ty->dtype != DataType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "zero_point param datatype should be one of " << "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], " - << "but got " << zp_sinfo->dtype; + << "but got " << zp_ty->dtype; } // Check that "axis" attribute is not out of range: - int axis = (attrs->axis < 0) ? (input_sinfo->ndim + attrs->axis) : attrs->axis; - if (axis < 0 || axis > input_sinfo->ndim - 1) { + int axis = (attrs->axis < 0) ? (input_ty->ndim + attrs->axis) : attrs->axis; + if (axis < 0 || axis > input_ty->ndim - 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "relax.dequantize: axis param is out of range (" << attrs->axis << ")"; } - auto check_param_size = [&](const TensorStructInfo& param_sinfo, - const TensorStructInfo& data_sinfo, ffi::String param_name) { - const PrimExpr& param_dim = param_sinfo->GetShape().value()[0]; - const PrimExpr& input_dim = data_sinfo->GetShape().value()[axis]; + auto check_param_size = [&](const TensorType& param_ty, const TensorType& data_ty, + ffi::String param_name) { + const PrimExpr& param_dim = param_ty->GetShape().value()[0]; + const PrimExpr& input_dim = data_ty->GetShape().value()[axis]; if (!ctx->GetAnalyzer()->CanProveEqual(param_dim, input_dim)) { TVM_FFI_VISIT_THROW(ValueError, call) << "Size mismatch: " << call->op << ": the input shape at dim " << attrs->axis << " is '" @@ -212,10 +208,10 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) } }; - auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& param_sinfo) { - if (IsScalarTensor(param_sinfo)) return true; - if (param_sinfo->shape.defined() && param_sinfo->shape->IsInstance()) { - const auto& values = param_sinfo->shape.as()->values; + auto is_scalar_or_singleton_vector = [&](const TensorType& param_ty) { + if (IsScalarTensor(param_ty)) return true; + if (param_ty->shape.defined() && param_ty->shape->IsInstance()) { + const auto& values = param_ty->shape.as()->values; if (!values.empty()) { return std::all_of(values.begin(), values.end(), [&](const PrimExpr& dim) { return ctx->GetAnalyzer()->CanProveEqual(dim, 1); @@ -226,14 +222,12 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) }; // Check size matching of scale/zp params with input shape at dim = attrs->axis. - if (!is_scalar_or_singleton_vector(scale_sinfo)) - check_param_size(scale_sinfo, input_sinfo, "scale"); - if (!is_scalar_or_singleton_vector(zp_sinfo)) - check_param_size(zp_sinfo, input_sinfo, "zero_point"); - - auto output_sinfo = ffi::make_object(*input_sinfo.get()); - output_sinfo->dtype = attrs->out_dtype; - return TensorStructInfo(output_sinfo); + if (!is_scalar_or_singleton_vector(scale_ty)) check_param_size(scale_ty, input_ty, "scale"); + if (!is_scalar_or_singleton_vector(zp_ty)) check_param_size(zp_ty, input_ty, "zero_point"); + + auto output_ty = ffi::make_object(*input_ty.get()); + output_ty->dtype = attrs->out_dtype; + return TensorType(output_ty); } TVM_REGISTER_OP("relax.dequantize") @@ -242,7 +236,7 @@ TVM_REGISTER_OP("relax.dequantize") .add_argument("data", "Tensor", "The input tensor.") .add_argument("scale", "Tensor", "The quantization scale of the input tensor.") .add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") - .set_attr("FInferStructInfo", InferStructInfoDequantize) + .set_attr("FInferType", InferTypeDequantize) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 7996965024e1..27f9241e2c29 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -52,64 +52,64 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.multinomial_from_uniform", multinomial_from_uniform); } -StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) { +Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); - TensorStructInfo prob_sinfo = GetInputTensorStructInfo(call, 0, ctx); - TensorStructInfo uniform_sample_sinfo = GetInputTensorStructInfo(call, 1, ctx); - TensorStructInfo sample_indices_sinfo = GetInputTensorStructInfo(call, 2, ctx); + TensorType prob_ty = GetInputTensorType(call, 0, ctx); + TensorType uniform_sample_ty = GetInputTensorType(call, 1, ctx); + TensorType sample_indices_ty = GetInputTensorType(call, 2, ctx); const auto* attrs = call->attrs.as(); - if (!prob_sinfo->dtype.is_float()) { + if (!prob_ty->dtype.is_float()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial_from_uniform op requires the input prob to have float dtype. " "However, the given prob dtype is " - << prob_sinfo->dtype; + << prob_ty->dtype; } - if (!uniform_sample_sinfo->dtype.is_float()) { + if (!uniform_sample_ty->dtype.is_float()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial_from_uniform op requires the input uniform_sample to have float " "dtype. However, the given uniform_sample dtype is " - << uniform_sample_sinfo->dtype; + << uniform_sample_ty->dtype; } - if (!sample_indices_sinfo->dtype.is_int()) { + if (!sample_indices_ty->dtype.is_int()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial from uniform op requires the input sample_indices to have int " "dtype. However, the given sample_indices dtype is " - << sample_indices_sinfo->dtype; + << sample_indices_ty->dtype; } - if (prob_sinfo->IsUnknownNdim() || uniform_sample_sinfo->IsUnknownNdim() || - sample_indices_sinfo->IsUnknownNdim()) { - return TensorStructInfo(attrs->dtype, kUnknownNDim, prob_sinfo->vdevice); + if (prob_ty->IsUnknownNdim() || uniform_sample_ty->IsUnknownNdim() || + sample_indices_ty->IsUnknownNdim()) { + return TensorType(attrs->dtype, kUnknownNDim, prob_ty->vdevice); } - if (prob_sinfo->ndim != 2) { + if (prob_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Multinomial_from_uniform op requires the input prob to be a 2D tensor. " "However, the given prob tensor has ndim " - << prob_sinfo->ndim; + << prob_ty->ndim; } - if (uniform_sample_sinfo->ndim != 2) { + if (uniform_sample_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Multinomial_from_uniform op requires the input uniform_sample to be a 2D " "tensor. However, the given uniform_sample tensor has ndim " - << uniform_sample_sinfo->ndim; + << uniform_sample_ty->ndim; } - if (sample_indices_sinfo->ndim != 2) { + if (sample_indices_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "Multinomial_from_uniform op requires the input sample_indices to be a 2D " "tensor. However, the given sample_indices tensor has ndim " - << sample_indices_sinfo->ndim; + << sample_indices_ty->ndim; } // Expected to be `(batch, vocab_size)` - const auto* prob_shape = prob_sinfo->shape.as(); + const auto* prob_shape = prob_ty->shape.as(); // Expected to be `(n, 1)` - const auto* uniform_sample_shape = uniform_sample_sinfo->shape.as(); + const auto* uniform_sample_shape = uniform_sample_ty->shape.as(); // Expected to be `(n, 1)` - const auto* sample_indices_shape = sample_indices_sinfo->shape.as(); + const auto* sample_indices_shape = sample_indices_ty->shape.as(); // The output shape is expected to be `(n, 1)` if (prob_shape == nullptr || uniform_sample_shape == nullptr || sample_indices_shape == nullptr) { - return TensorStructInfo(attrs->dtype, 2, prob_sinfo->vdevice); + return TensorType(attrs->dtype, 2, prob_ty->vdevice); } PrimExpr batch = prob_shape->values[0]; @@ -129,10 +129,10 @@ StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBu << "Multinomial_from_uniform op requires the input uniform_sample and " "sample_indices to be 2D tensors with the second dimension being 1. " "However, the given uniform_sample tensor has shape " - << uniform_sample_sinfo->shape << " and the given sample_indices tensor has shape " - << sample_indices_sinfo->shape; + << uniform_sample_ty->shape << " and the given sample_indices tensor has shape " + << sample_indices_ty->shape; } - return TensorStructInfo(ShapeExpr({n, 1}), attrs->dtype, prob_sinfo->vdevice); + return TensorType(ShapeExpr({n, 1}), attrs->dtype, prob_ty->vdevice); } TVM_REGISTER_OP("relax.multinomial_from_uniform") @@ -141,7 +141,7 @@ TVM_REGISTER_OP("relax.multinomial_from_uniform") .add_argument("prob", "Tensor", "The probability tensor.") .add_argument("uniform_sample", "Tensor", "The uniform sample tensor.") .add_argument("sample_indices", "Tensor", "The sample indices tensor.") - .set_attr("FInferStructInfo", InferStructInfoMultinomialFromUniform) + .set_attr("FInferType", InferTypeMultinomialFromUniform) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 4c1bd5cbb74e..d80f484ebcf5 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -53,10 +53,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.bucketize", bucketize); } -StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo input_tensor_info = input_sinfo[0]; - TensorStructInfo boundaries_info = input_sinfo[1]; +Type InferTypeBucketize(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType input_tensor_info = input_ty[0]; + TensorType boundaries_info = input_ty[1]; if (!boundaries_info->IsUnknownNdim() && boundaries_info->ndim != 1) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -72,9 +72,9 @@ StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { const auto* data_shape = input_tensor_info->shape.as(); if (data_shape) { - return TensorStructInfo(ShapeExpr(data_shape->values), out_dtype, input_tensor_info->vdevice); + return TensorType(ShapeExpr(data_shape->values), out_dtype, input_tensor_info->vdevice); } - return TensorStructInfo(out_dtype, input_tensor_info->ndim, input_tensor_info->vdevice); + return TensorType(out_dtype, input_tensor_info->ndim, input_tensor_info->vdevice); } TVM_REGISTER_OP("relax.bucketize") @@ -84,7 +84,7 @@ TVM_REGISTER_OP("relax.bucketize") .add_argument("boundaries", "Tensor", "1-D tensor, must contain a strictly increasing sequence, or the return value is " "undefined.") - .set_attr("FInferStructInfo", InferStructInfoBucketize) + .set_attr("FInferType", InferTypeBucketize) .set_attr("FPurity", true); /* relax.where */ @@ -98,20 +98,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.where", where); } -StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo cond_sinfo = input_sinfo[0]; - TensorStructInfo x1_sinfo = input_sinfo[1]; - TensorStructInfo x2_sinfo = input_sinfo[2]; +Type InferTypeWhere(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType cond_ty = input_ty[0]; + TensorType x1_ty = input_ty[1]; + TensorType x2_ty = input_ty[2]; VDevice vdev = VDevice(); for (int i = 0; i < 3; ++i) { - if (input_sinfo[i]->vdevice.defined()) { + if (input_ty[i]->vdevice.defined()) { if (!vdev.defined()) { - vdev = input_sinfo[i]->vdevice.value(); - } else if (input_sinfo[i]->vdevice.value()->target.defined()) { + vdev = input_ty[i]->vdevice.value(); + } else if (input_ty[i]->vdevice.value()->target.defined()) { // mismatch - if (input_sinfo[i]->vdevice.value() != vdev) { + if (input_ty[i]->vdevice.value() != vdev) { vdev = VDevice(); break; } @@ -119,62 +119,62 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { } } - if (!cond_sinfo->dtype.is_bool()) { + if (!cond_ty->dtype.is_bool()) { TVM_FFI_VISIT_THROW(TypeError, call) << "Where requires the input condition tensor to have boolean dtype. However, " "the given condition dtype is " - << cond_sinfo->dtype; + << cond_ty->dtype; } - DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo); + DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty); int output_ndim; - if (cond_sinfo->IsUnknownNdim() || x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + if (cond_ty->IsUnknownNdim() || x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { output_ndim = kUnknownNDim; } else { - output_ndim = std::max(cond_sinfo->ndim, std::max(x1_sinfo->ndim, x2_sinfo->ndim)); + output_ndim = std::max(cond_ty->ndim, std::max(x1_ty->ndim, x2_ty->ndim)); } - const auto* cond_shape = cond_sinfo->shape.as(); - const auto* x1_shape = x1_sinfo->shape.as(); - const auto* x2_shape = x2_sinfo->shape.as(); + const auto* cond_shape = cond_ty->shape.as(); + const auto* x1_shape = x1_ty->shape.as(); + const auto* x2_shape = x2_ty->shape.as(); if (cond_shape && x1_shape && x2_shape) { // Step 1. Compute the broadcasted shape of x1's and x2's ffi::Optional> broadcasted_shape = InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); if (!broadcasted_shape.defined()) { if (vdev.defined()) { - return TensorStructInfo(output_dtype, output_ndim, vdev); + return TensorType(output_dtype, output_ndim, vdev); } - return TensorStructInfo(output_dtype, output_ndim); + return TensorType(output_dtype, output_ndim); } // Step 2. Compute the broadcasted shape of cond's and the previous broadcasted shape. broadcasted_shape = InferBinaryBroadcastShape(call, ctx, cond_shape->values, broadcasted_shape.value()); if (!broadcasted_shape.defined()) { if (vdev.defined()) { - return TensorStructInfo(output_dtype, output_ndim, vdev); + return TensorType(output_dtype, output_ndim, vdev); } - return TensorStructInfo(output_dtype, output_ndim); + return TensorType(output_dtype, output_ndim); } TVM_FFI_ICHECK_EQ(static_cast(broadcasted_shape.value().size()), output_ndim); if (vdev.defined()) { - return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype, vdev); + return TensorType(ShapeExpr(broadcasted_shape.value()), output_dtype, vdev); } - return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype); - } else if (cond_sinfo->shape.defined() && // - x1_sinfo->shape.defined() && // - x2_sinfo->shape.defined() && // - cond_sinfo->shape.same_as(x1_sinfo->shape) && // - cond_sinfo->shape.same_as(x2_sinfo->shape)) { + return TensorType(ShapeExpr(broadcasted_shape.value()), output_dtype); + } else if (cond_ty->shape.defined() && // + x1_ty->shape.defined() && // + x2_ty->shape.defined() && // + cond_ty->shape.same_as(x1_ty->shape) && // + cond_ty->shape.same_as(x2_ty->shape)) { if (vdev.defined()) { - return TensorStructInfo(cond_sinfo->shape.value(), output_dtype, vdev); + return TensorType(cond_ty->shape.value(), output_dtype, vdev); } - return TensorStructInfo(cond_sinfo->shape.value(), output_dtype); + return TensorType(cond_ty->shape.value(), output_dtype); } else { if (vdev.defined()) { - return TensorStructInfo(output_dtype, output_ndim, vdev); + return TensorType(output_dtype, output_ndim, vdev); } - return TensorStructInfo(output_dtype, output_ndim); + return TensorType(output_dtype, output_ndim); } } @@ -183,29 +183,29 @@ TVM_REGISTER_OP("relax.where") .add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, yield `x2`.") .add_argument("x1", "Tensor", "The first input tensor.") .add_argument("x2", "Tensor", "The second input tensor.") - .set_attr("FInferStructInfo", InferStructInfoWhere) + .set_attr("FInferType", InferTypeWhere) .set_attr("FPurity", true); /* relax.argmax & relax.argmin */ -StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeArgmaxArgmin(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); int axis = -1; - if (!data_sinfo->IsUnknownNdim() && attrs->axis.has_value()) { - axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()); + if (!data_ty->IsUnknownNdim() && attrs->axis.has_value()) { + axis = NormalizeAxis(call, ctx, data_ty->ndim, attrs->axis.value()); } int out_ndim; if (attrs->keepdims) { - out_ndim = data_sinfo->ndim; + out_ndim = data_ty->ndim; } else if (!attrs->axis.has_value()) { out_ndim = 0; - } else if (data_sinfo->IsUnknownNdim()) { + } else if (data_ty->IsUnknownNdim()) { out_ndim = kUnknownNDim; } else { - out_ndim = data_sinfo->ndim - 1; + out_ndim = data_ty->ndim - 1; TVM_FFI_ICHECK_GE(out_ndim, 0); } @@ -217,26 +217,25 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx // - axes is not None, keepdims is false -> the returned shape does not contain the input axes. // - axes is not None, keepdims is true -> the returned shape has value 1 at the positions of the // input axes - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { if (!attrs->axis.has_value() && attrs->keepdims && out_ndim != kUnknownNDim) { - return TensorStructInfo( - ShapeExpr(ffi::Array(out_ndim, IntImm(out_dtype, /*value=*/1))), out_dtype, - data_sinfo->vdevice); + return TensorType(ShapeExpr(ffi::Array(out_ndim, IntImm(out_dtype, /*value=*/1))), + out_dtype, data_ty->vdevice); } else { - return out_ndim == 0 ? TensorStructInfo(ShapeExpr(ffi::Array()), out_dtype, - data_sinfo->vdevice) - : TensorStructInfo(out_dtype, out_ndim, data_sinfo->vdevice); + return out_ndim == 0 + ? TensorType(ShapeExpr(ffi::Array()), out_dtype, data_ty->vdevice) + : TensorType(out_dtype, out_ndim, data_ty->vdevice); } } - if (data_sinfo->ndim > 0) { + if (data_ty->ndim > 0) { out_dtype = data_shape->values[0]->dtype; } ffi::Array out_shape; out_shape.reserve(out_ndim); - for (int i = 0; i < data_sinfo->ndim; ++i) { + for (int i = 0; i < data_ty->ndim; ++i) { if (attrs->axis.has_value() && i != axis) { out_shape.push_back(data_shape->values[i]); } else if (attrs->keepdims) { @@ -244,7 +243,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx } } TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); - return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice); } #define RELAX_REGISTER_ARGMAX_ARGMIN_OP(OpName) \ @@ -261,7 +260,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ - .set_attr("FInferStructInfo", InferStructInfoArgmaxArgmin) \ + .set_attr("FInferType", InferTypeArgmaxArgmin) \ .set_attr("FPurity", true); RELAX_REGISTER_ARGMAX_ARGMIN_OP(argmax); diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index edf2a385b429..5c6bfb807fa4 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -54,18 +54,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.unique", unique); } -StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = Downcast(call->args[0]->struct_info_); +Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = Downcast(call->args[0]->ty); PrimValue axis, return_index, return_inverse, return_counts; if (call->args.size() == 6) { if (auto* prim_value_node = call->args[5].as()) { axis = ffi::GetRef(prim_value_node); } } - if (!data_sinfo->IsUnknownNdim() && axis.defined()) { + if (!data_ty->IsUnknownNdim() && axis.defined()) { // Normalize the axis for sanity check purpose. if (const auto* axis_int = axis->value.as()) { - NormalizeAxis(call, ctx, data_sinfo->ndim, axis_int->value); + NormalizeAxis(call, ctx, data_ty->ndim, axis_int->value); } } TVM_FFI_ICHECK(call->args[2]->IsInstance()); @@ -88,61 +88,60 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { f_convert_to_int64(return_inverse->value) + f_convert_to_int64(return_counts->value); - std::vector output_sinfo; - output_sinfo.reserve(1 + n_int_return); + std::vector output_ty; + output_ty.reserve(1 + n_int_return); // unique values - if (data_sinfo->ndim == 0) { - output_sinfo.push_back(TensorStructInfo(ShapeExpr({IntImm::Int64(/*value=*/1)}), - data_sinfo->dtype, data_sinfo->vdevice)); + if (data_ty->ndim == 0) { + output_ty.push_back( + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), data_ty->dtype, data_ty->vdevice)); } else if (axis.defined()) { - output_sinfo.push_back( - TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice)); + output_ty.push_back(TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice)); } else { - output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice)); + output_ty.push_back(TensorType(data_ty->dtype, /*ndim=*/1, data_ty->vdevice)); } // index, inverse_indices, and counts // index: always 1D if (f_convert_to_int64(return_index->value)) { - TensorStructInfo index_sinfo{nullptr}; - if (data_sinfo->ndim == 0) { - index_sinfo = TensorStructInfo(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), - data_sinfo->vdevice); + TensorType index_ty{nullptr}; + if (data_ty->ndim == 0) { + index_ty = + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice); } else { - index_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, data_sinfo->vdevice); + index_ty = TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice); } - output_sinfo.push_back(index_sinfo); + output_ty.push_back(index_ty); } // inverse_indices: always 1D per ONNX spec if (f_convert_to_int64(return_inverse->value)) { - TensorStructInfo inverse_sinfo{nullptr}; - if (data_sinfo->ndim == 0) { - inverse_sinfo = TensorStructInfo(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), - data_sinfo->vdevice); + TensorType inverse_ty{nullptr}; + if (data_ty->ndim == 0) { + inverse_ty = + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice); } else { - inverse_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, data_sinfo->vdevice); + inverse_ty = TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice); } - output_sinfo.push_back(inverse_sinfo); + output_ty.push_back(inverse_ty); } // counts: always 1D if (f_convert_to_int64(return_counts->value)) { - TensorStructInfo counts_sinfo{nullptr}; - if (data_sinfo->ndim == 0) { - counts_sinfo = TensorStructInfo(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), - data_sinfo->vdevice); + TensorType counts_ty{nullptr}; + if (data_ty->ndim == 0) { + counts_ty = + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice); } else { - counts_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, data_sinfo->vdevice); + counts_ty = TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice); } - output_sinfo.push_back(counts_sinfo); + output_ty.push_back(counts_ty); } - if (output_sinfo.size() == 1) { - return output_sinfo[0]; + if (output_ty.size() == 1) { + return output_ty[0]; } else { - return TupleStructInfo(output_sinfo); + return TupleType(output_ty); } } @@ -165,7 +164,7 @@ TVM_REGISTER_OP("relax.unique") "The dimension to apply unique. If it is std::nullopt, the unique values of the " "flattened input " "are returned.") - .set_attr("FInferStructInfo", InferStructInfoUnique) + .set_attr("FInferType", InferTypeUnique) .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", true); @@ -180,15 +179,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.nonzero", nonzero); } -StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - return TensorStructInfo(DataType::Int(64), 2, data_sinfo->vdevice); +Type InferTypeNonzero(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetInputTensorType(call, 0, ctx); + return TensorType(DataType::Int(64), 2, data_ty->vdevice); } TVM_REGISTER_OP("relax.nonzero") .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor") - .set_attr("FInferStructInfo", InferStructInfoNonzero) + .set_attr("FInferType", InferTypeNonzero) .set_attr("FCallPacked", "relax.run.nonzero") .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 7b8a310c65d9..2d014cded4ec 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -53,15 +53,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.sort", sort); } -StructInfo InferStructInfoSort(const Call& call, const BlockBuilder& ctx) { - return GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeSort(const Call& call, const BlockBuilder& ctx) { + return GetUnaryInputTensorType(call, ctx); } TVM_REGISTER_OP("relax.sort") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoSort) + .set_attr("FInferType", InferTypeSort) .set_attr("FPurity", true); /* relax.argsort */ @@ -81,21 +81,21 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.argsort", argsort); } -StructInfo InferStructInfoArgsort(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeArgsort(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - DataType out_type = attrs->dtype.is_void() ? data_sinfo->dtype : attrs->dtype; - if (data_sinfo->shape.defined()) { - return TensorStructInfo(data_sinfo->shape.value(), out_type, data_sinfo->vdevice); + DataType out_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; + if (data_ty->shape.defined()) { + return TensorType(data_ty->shape.value(), out_type, data_ty->vdevice); } - return TensorStructInfo(out_type, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(out_type, data_ty->ndim, data_ty->vdevice); } TVM_REGISTER_OP("relax.argsort") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoArgsort) + .set_attr("FInferType", InferTypeArgsort) .set_attr("FPurity", true); /* relax.topk */ @@ -117,12 +117,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.topk", topk); } -StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - const auto* data_shape = data_sinfo->shape.as(); +Type InferTypeTopK(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); + const auto* data_shape = data_ty->shape.as(); const auto* attrs = call->attrs.as(); - DataType indices_type = attrs->dtype.is_void() ? data_sinfo->dtype : attrs->dtype; - int ndim = data_sinfo->ndim; + DataType indices_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; + int ndim = data_ty->ndim; int k = attrs->k; ffi::String ret_type = attrs->ret_type; int axis = attrs->axis; @@ -130,30 +130,27 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { axis += ndim; } - std::vector output_sinfos; - output_sinfos.reserve(2); + std::vector output_tys; + output_tys.reserve(2); if (data_shape == nullptr) { - output_sinfos.push_back( - TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice)); - output_sinfos.push_back(TensorStructInfo(indices_type, data_sinfo->ndim, data_sinfo->vdevice)); + output_tys.push_back(TensorType(data_ty->dtype, data_ty->ndim, data_ty->vdevice)); + output_tys.push_back(TensorType(indices_type, data_ty->ndim, data_ty->vdevice)); } else { ffi::Array out_shape = data_shape->values; const auto* int_dim = out_shape[axis].as(); if (k > 0 && (int_dim == nullptr || k < int_dim->value)) { out_shape.Set(axis, k); } - output_sinfos.push_back( - TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice)); - output_sinfos.push_back( - TensorStructInfo(ShapeExpr(out_shape), indices_type, data_sinfo->vdevice)); + output_tys.push_back(TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice)); + output_tys.push_back(TensorType(ShapeExpr(out_shape), indices_type, data_ty->vdevice)); } if (ret_type == "both") { - return TupleStructInfo(output_sinfos); + return TupleType(output_tys); } else if (ret_type == "values") { - return output_sinfos[0]; + return output_tys[0]; } else if (ret_type == "indices") { - return output_sinfos[1]; + return output_tys[1]; } TVM_FFI_THROW(InternalError) << "Unsupported ret type: " << ret_type; TVM_FFI_UNREACHABLE(); @@ -163,7 +160,7 @@ TVM_REGISTER_OP("relax.topk") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoTopK) + .set_attr("FInferType", InferTypeTopK) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 1da75b71309b..9fe68afe2901 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -37,24 +37,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { ScanopAttrs::RegisterReflection(); } -StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeStatistical(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); std::vector axes; - if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { - axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + if (!data_ty->IsUnknownNdim() && attrs->axis.defined()) { + axes = NormalizeAxes(call, ctx, data_ty->ndim, attrs->axis.value()); } int out_ndim; if (attrs->keepdims) { - out_ndim = data_sinfo->ndim; + out_ndim = data_ty->ndim; } else if (!attrs->axis.defined()) { out_ndim = 0; - } else if (data_sinfo->IsUnknownNdim()) { + } else if (data_ty->IsUnknownNdim()) { out_ndim = kUnknownNDim; } else { - out_ndim = data_sinfo->ndim - axes.size(); + out_ndim = data_ty->ndim - axes.size(); TVM_FFI_ICHECK_GE(out_ndim, 0); } @@ -65,21 +65,21 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) // - axes is not None, keepdims is false -> the returned shape does not contain the input axes. // - axes is not None, keepdims is true -> the returned shape has value 1 at the positions of the // input axes - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { - return TensorStructInfo(ShapeExpr(ffi::Array(out_ndim, IntImm::Int64(/*value=*/1))), - data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(ffi::Array(out_ndim, IntImm::Int64(/*value=*/1))), + data_ty->dtype, data_ty->vdevice); } else { - return out_ndim == 0 ? TensorStructInfo(ShapeExpr(ffi::Array()), data_sinfo->dtype, - data_sinfo->vdevice) - : TensorStructInfo(data_sinfo->dtype, out_ndim, data_sinfo->vdevice); + return out_ndim == 0 + ? TensorType(ShapeExpr(ffi::Array()), data_ty->dtype, data_ty->vdevice) + : TensorType(data_ty->dtype, out_ndim, data_ty->vdevice); } } ffi::Array out_shape; out_shape.reserve(out_ndim); - for (int i = 0; i < data_sinfo->ndim; ++i) { + for (int i = 0; i < data_ty->ndim; ++i) { if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { out_shape.push_back(data_shape->values[i]); } else if (attrs->keepdims) { @@ -87,7 +87,7 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) } } TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } InferLayoutOutput InferLayoutStatistical( @@ -97,10 +97,10 @@ InferLayoutOutput InferLayoutStatistical( const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; - int ndim = tensor_sinfo->ndim; + const auto* tensor_ty = GetTypeAs(call->args[0]); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_ty->IsUnknownNdim()) << "Only support known ndim"; + int ndim = tensor_ty->ndim; ffi::Array axis; if (attrs->axis.defined()) { @@ -151,52 +151,51 @@ InferLayoutOutput InferLayoutStatistical( Attrs(new_attrs)); } -StructInfo InferStructInfoScan(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeScan(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - DataType out_type = attrs->dtype.is_void() ? data_sinfo->dtype : attrs->dtype; + DataType out_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; if (!attrs->axis.has_value()) { // flattened - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { - return TensorStructInfo(out_type, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(out_type, data_ty->ndim, data_ty->vdevice); } else { PrimExpr flattened_d = 1; for (const auto v : data_shape->values) { flattened_d *= v; } - return TensorStructInfo(ShapeExpr(ffi::Array({flattened_d})), out_type, - data_sinfo->vdevice); + return TensorType(ShapeExpr(ffi::Array({flattened_d})), out_type, data_ty->vdevice); } } - if (data_sinfo->shape.defined()) { - return TensorStructInfo(data_sinfo->shape.value(), out_type, data_sinfo->vdevice); + if (data_ty->shape.defined()) { + return TensorType(data_ty->shape.value(), out_type, data_ty->vdevice); } else { - return TensorStructInfo(out_type, data_sinfo->ndim, data_sinfo->vdevice); + return TensorType(out_type, data_ty->ndim, data_ty->vdevice); } } -StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +Type InferTypeStatisticalExtension(const Call& call, const BlockBuilder& ctx) { + TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); std::vector axes; - if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { - axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + if (!data_ty->IsUnknownNdim() && attrs->axis.defined()) { + axes = NormalizeAxes(call, ctx, data_ty->ndim, attrs->axis.value()); } int out_ndim; if (attrs->keepdims) { - out_ndim = data_sinfo->ndim; + out_ndim = data_ty->ndim; } else if (!attrs->axis.defined()) { out_ndim = 0; - } else if (data_sinfo->IsUnknownNdim()) { + } else if (data_ty->IsUnknownNdim()) { out_ndim = kUnknownNDim; } else { - out_ndim = data_sinfo->ndim - axes.size(); + out_ndim = data_ty->ndim - axes.size(); TVM_FFI_ICHECK_GE(out_ndim, 0); } @@ -207,23 +206,22 @@ StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuil // - len(axes) == 1, keepdims is false -> the returned shape does not contain the input axis. // - len(axes) == 1, keepdims is true -> the returned shape has value 1 at the positions of the // input axis - const auto* data_shape = data_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { - return TensorStructInfo(ShapeExpr(ffi::Array(out_ndim, IntImm::Int64(/*value=*/1))), - data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(ffi::Array(out_ndim, IntImm::Int64(/*value=*/1))), + data_ty->dtype, data_ty->vdevice); } if (out_ndim == 0) { - return TensorStructInfo(ShapeExpr(ffi::Array()), data_sinfo->dtype, - data_sinfo->vdevice); + return TensorType(ShapeExpr(ffi::Array()), data_ty->dtype, data_ty->vdevice); } - return TupleStructInfo({TensorStructInfo(data_sinfo->dtype, out_ndim, data_sinfo->vdevice), - TensorStructInfo(DataType::Int(64), out_ndim, data_sinfo->vdevice)}); + return TupleType({TensorType(data_ty->dtype, out_ndim, data_ty->vdevice), + TensorType(DataType::Int(64), out_ndim, data_ty->vdevice)}); } ffi::Array out_shape; out_shape.reserve(out_ndim); - for (int i = 0; i < data_sinfo->ndim; ++i) { + for (int i = 0; i < data_ty->ndim; ++i) { if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { out_shape.push_back(data_shape->values[i]); } else if (attrs->keepdims) { @@ -233,11 +231,10 @@ StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuil TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); if (!attrs->axis.defined() || axes.size() > 1) - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); else - return TupleStructInfo( - {TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice), - TensorStructInfo(ShapeExpr(out_shape), DataType::Int(64), data_sinfo->vdevice)}); + return TupleType({TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice), + TensorType(ShapeExpr(out_shape), DataType::Int(64), data_ty->vdevice)}); } /* relax.cumprod */ @@ -261,7 +258,7 @@ TVM_REGISTER_OP("relax.cumprod") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoScan) + .set_attr("FInferType", InferTypeScan) .set_attr("FPurity", true); /* relax.cumsum */ @@ -284,7 +281,7 @@ TVM_REGISTER_OP("relax.cumsum") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoScan) + .set_attr("FInferType", InferTypeScan) .set_attr("FPurity", true); /* relax.median */ @@ -304,7 +301,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_REGISTER_OP("relax.median") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoStatisticalExtension) + .set_attr("FInferType", InferTypeStatisticalExtension) .set_attr("FPurity", true); RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max); diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index ee4138f133b1..2d80790926ed 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -56,7 +56,7 @@ namespace relax { TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ - .set_attr("FInferStructInfo", InferStructInfoStatistical) \ + .set_attr("FInferType", InferTypeStatistical) \ .set_attr("FRelaxInferLayout", InferLayoutStatistical) \ .set_attr("FPurity", true) diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index f6eeb25f8ff5..6daacfe16578 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -30,11 +30,11 @@ namespace tvm { namespace relax { -StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo t1 = input_sinfo[0]; - TensorStructInfo t2 = input_sinfo[1]; - TensorStructInfo t3 = input_sinfo[2]; +Type InferTypeEwiseFMA(const Call& call, const BlockBuilder& ctx) { + ffi::Array input_ty = GetInputTensorType(call, ctx); + TensorType t1 = input_ty[0]; + TensorType t2 = input_ty[1]; + TensorType t3 = input_ty[2]; int ndim = kUnknownNDim; if (!t1->IsUnknownNdim()) { @@ -69,12 +69,12 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { VDevice vdev = VDevice(); for (int i = 0; i < 3; ++i) { - if (input_sinfo[i]->vdevice.defined()) { + if (input_ty[i]->vdevice.defined()) { if (!vdev.defined()) { - vdev = input_sinfo[i]->vdevice.value(); - } else if (input_sinfo[i]->vdevice.value()->target.defined()) { + vdev = input_ty[i]->vdevice.value(); + } else if (input_ty[i]->vdevice.value()->target.defined()) { // mismatch - if (input_sinfo[i]->vdevice.value() != vdev) { + if (input_ty[i]->vdevice.value() != vdev) { vdev = VDevice(); break; } @@ -100,19 +100,19 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { } } if (vdev.defined()) { - return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev); + return TensorType(ShapeExpr(output_shape), output_dtype, vdev); } - return TensorStructInfo(ShapeExpr(output_shape), output_dtype); + return TensorType(ShapeExpr(output_shape), output_dtype); } else if (t1->shape.defined() && t1->shape.same_as(t2->shape) && t1->shape.same_as(t3->shape)) { if (vdev.defined()) { - return TensorStructInfo(t1->shape.value(), output_dtype, vdev); + return TensorType(t1->shape.value(), output_dtype, vdev); } - return TensorStructInfo(t1->shape.value(), output_dtype); + return TensorType(t1->shape.value(), output_dtype); } if (vdev.defined()) { - return TensorStructInfo(output_dtype, ndim, vdev); + return TensorType(output_dtype, ndim, vdev); } - return TensorStructInfo(output_dtype, ndim); + return TensorType(output_dtype, ndim); } InferLayoutOutput InferLayoutEwiseFMA( @@ -135,7 +135,7 @@ TVM_REGISTER_OP("relax.ewise_fma") .add_argument("x1", "Tensor", "The left hand operand of the multiplication") .add_argument("x2", "Tensor", "The right hand operand of the multiplication") .add_argument("x3", "Tensor", "The operand of the addition") - .set_attr("FInferStructInfo", InferStructInfoEwiseFMA) + .set_attr("FInferType", InferTypeEwiseFMA) .set_attr("FRelaxInferLayout", InferLayoutEwiseFMA) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 16a0bc305f17..598ec78aacda 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -31,9 +31,9 @@ namespace tvm { namespace relax { -StructInfo InferStructInfoUnaryCheck(const Call& call, const BlockBuilder& ctx) { - return InferStructInfoUnary( - call, ctx, [](const TensorStructInfo& input_sinfo) { return DataType::Bool(); }); +Type InferTypeUnaryCheck(const Call& call, const BlockBuilder& ctx) { + return InferTypeUnary(call, ctx, + [](const TensorType& input_ty) { return DataType::Bool(); }); } /***************** Arithmetic operators *****************/ @@ -73,7 +73,7 @@ TVM_REGISTER_OP("relax.clip") .add_argument("x", "Tensor", "The input tensor.") .add_argument("min", "PrimValue", "The lower-bound of the range to be clipped to") .add_argument("max", "PrimValue", "The upper-bound of the range to be clipped to") - .set_attr("FInferStructInfo", ReturnStructInfoFromArg<0>) + .set_attr("FInferType", ReturnTypeFromArg<0>) .set_attr("FPurity", true); Expr clip(Expr x, Expr min, Expr max) { diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h index 1847ba3c365a..4ae1388ba1cc 100644 --- a/src/relax/op/tensor/unary.h +++ b/src/relax/op/tensor/unary.h @@ -42,12 +42,12 @@ namespace relax { RELAX_REGISTER_UNARY_OP(#OpName) #define RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(OpName, RequireFloatDtype) \ - RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ - "FInferStructInfo", InferStructInfoUnaryArith) + RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferType", InferTypeUnaryArith) -#define RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(OpName) \ - RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ - "FInferStructInfo", InferStructInfoUnaryCheck) // require_float_dtype=false for check op +#define RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferType", InferTypeUnaryCheck) // require_float_dtype=false for check op /***************** Arithmetic operators *****************/ diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index e87c2d439caf..bc4da7382351 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -26,7 +26,7 @@ #include #include -#include +#include #include @@ -56,7 +56,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } /*! - * \brief Infer struct info for relax.vision.multibox_transform_loc. + * \brief Infer type for relax.vision.multibox_transform_loc. * * \note Shape cross-checks that need the anchor count N (e.g. loc_pred.shape[1] == 4*N, * anchor.shape[1] == N with N = cls_pred.shape[2]) run only when cls_pred has a known @@ -64,7 +64,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { * skips those N-based relations; other checks (ndim, dtype, loc dim divisible by 4, etc.) * still apply when their inputs are known. */ -StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuilder& ctx) { +Type InferTypeMultiboxTransformLoc(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "multibox_transform_loc: expected 3 inputs (cls_pred, loc_pred, anchor), " @@ -72,44 +72,43 @@ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuil << call->args.size(); } - ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - const auto cls_sinfo = input_sinfo[0]; - const auto loc_sinfo = input_sinfo[1]; - const auto anchor_sinfo = input_sinfo[2]; + ffi::Array input_ty = GetInputTensorType(call, ctx); + const auto cls_ty = input_ty[0]; + const auto loc_ty = input_ty[1]; + const auto anchor_ty = input_ty[2]; - if (!cls_sinfo->IsUnknownNdim() && cls_sinfo->ndim != 3) { + if (!cls_ty->IsUnknownNdim() && cls_ty->ndim != 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "multibox_transform_loc: cls_pred must be 3-D [B, num_classes, N], got " "ndim " - << cls_sinfo->ndim; + << cls_ty->ndim; } - if (!loc_sinfo->IsUnknownNdim() && loc_sinfo->ndim != 2) { + if (!loc_ty->IsUnknownNdim() && loc_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) - << "multibox_transform_loc: loc_pred must be 2-D [B, 4*N], got ndim " << loc_sinfo->ndim; + << "multibox_transform_loc: loc_pred must be 2-D [B, 4*N], got ndim " << loc_ty->ndim; } - if (!anchor_sinfo->IsUnknownNdim() && anchor_sinfo->ndim != 3) { + if (!anchor_ty->IsUnknownNdim() && anchor_ty->ndim != 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "multibox_transform_loc: anchor must be 3-D [1, N, 4] ltrb, got ndim " - << anchor_sinfo->ndim; + << anchor_ty->ndim; } - if (!cls_sinfo->IsUnknownDtype() && !loc_sinfo->IsUnknownDtype() && - cls_sinfo->dtype != loc_sinfo->dtype) { + if (!cls_ty->IsUnknownDtype() && !loc_ty->IsUnknownDtype() && cls_ty->dtype != loc_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) - << "multibox_transform_loc: cls_pred and loc_pred dtype must match, got " - << cls_sinfo->dtype << " vs " << loc_sinfo->dtype; + << "multibox_transform_loc: cls_pred and loc_pred dtype must match, got " << cls_ty->dtype + << " vs " << loc_ty->dtype; } - if (!cls_sinfo->IsUnknownDtype() && !anchor_sinfo->IsUnknownDtype() && - cls_sinfo->dtype != anchor_sinfo->dtype) { + if (!cls_ty->IsUnknownDtype() && !anchor_ty->IsUnknownDtype() && + cls_ty->dtype != anchor_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) - << "multibox_transform_loc: cls_pred and anchor dtype must match, got " << cls_sinfo->dtype - << " vs " << anchor_sinfo->dtype; + << "multibox_transform_loc: cls_pred and anchor dtype must match, got " << cls_ty->dtype + << " vs " << anchor_ty->dtype; } - auto vdev = cls_sinfo->vdevice; - const auto* cls_shape = cls_sinfo->shape.as(); - const auto* loc_shape = loc_sinfo->shape.as(); - const auto* anchor_shape = anchor_sinfo->shape.as(); + auto vdev = cls_ty->vdevice; + const auto* cls_shape = cls_ty->shape.as(); + const auto* loc_shape = loc_ty->shape.as(); + const auto* anchor_shape = anchor_ty->shape.as(); if (loc_shape != nullptr) { const auto* loc_dim1 = loc_shape->values[1].as(); @@ -145,9 +144,9 @@ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuil } if (cls_shape == nullptr) { - ffi::Array fields = {TensorStructInfo(cls_sinfo->dtype, 3, vdev), - TensorStructInfo(cls_sinfo->dtype, 3, vdev)}; - return TupleStructInfo(fields); + ffi::Array fields = {TensorType(cls_ty->dtype, 3, vdev), + TensorType(cls_ty->dtype, 3, vdev)}; + return TupleType(fields); } const auto& batch = cls_shape->values[0]; @@ -179,10 +178,9 @@ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuil ffi::Array boxes_shape = {batch, num_anchors, IntImm::Int32(4)}; ffi::Array scores_shape = {batch, num_classes, num_anchors}; - ffi::Array fields = { - TensorStructInfo(ShapeExpr(boxes_shape), cls_sinfo->dtype, vdev), - TensorStructInfo(ShapeExpr(scores_shape), cls_sinfo->dtype, vdev)}; - return TupleStructInfo(fields); + ffi::Array fields = {TensorType(ShapeExpr(boxes_shape), cls_ty->dtype, vdev), + TensorType(ShapeExpr(scores_shape), cls_ty->dtype, vdev)}; + return TupleType(fields); } TVM_REGISTER_OP("relax.vision.multibox_transform_loc") @@ -196,7 +194,7 @@ TVM_REGISTER_OP("relax.vision.multibox_transform_loc") .add_argument("loc_pred", "Tensor", "[B,4*N] box encodings (x,y,w,h); TFLite yxhw order remapped to xywh.") .add_argument("anchor", "Tensor", "[1,N,4] priors as ltrb (left,top,right,bottom).") - .set_attr("FInferStructInfo", InferStructInfoMultiboxTransformLoc) + .set_attr("FInferType", InferTypeMultiboxTransformLoc) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index 88139c62fc5b..bde579f0ed5a 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include @@ -62,31 +62,31 @@ TVM_FFI_STATIC_INIT_BLOCK() { all_class_non_max_suppression); } -StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder& ctx) { - tvm::ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); - const auto boxes_sinfo = input_sinfo[0]; - const auto scores_sinfo = input_sinfo[1]; - TVM_FFI_ICHECK(!boxes_sinfo->IsUnknownNdim()) << "Only support known ndim"; - TVM_FFI_ICHECK(!scores_sinfo->IsUnknownNdim()) << "Only support known ndim"; - TVM_FFI_ICHECK_EQ(boxes_sinfo->ndim, 3) << "AllClassNMS input boxes should be 3-D."; - TVM_FFI_ICHECK_EQ(scores_sinfo->ndim, 3) << "AllClassNMS input scores count should be 3-D."; +Type InferTypeAllClassNMS(const Call& call, const BlockBuilder& ctx) { + tvm::ffi::Array input_ty = GetInputTensorType(call, ctx); + const auto boxes_ty = input_ty[0]; + const auto scores_ty = input_ty[1]; + TVM_FFI_ICHECK(!boxes_ty->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(!scores_ty->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK_EQ(boxes_ty->ndim, 3) << "AllClassNMS input boxes should be 3-D."; + TVM_FFI_ICHECK_EQ(scores_ty->ndim, 3) << "AllClassNMS input scores count should be 3-D."; - const auto batch = boxes_sinfo->shape.as()->values[0]; - const auto num_classes = scores_sinfo->shape.as()->values[1]; - const auto num_boxes = boxes_sinfo->shape.as()->values[1]; + const auto batch = boxes_ty->shape.as()->values[0]; + const auto num_classes = scores_ty->shape.as()->values[1]; + const auto num_boxes = boxes_ty->shape.as()->values[1]; - auto vdev = input_sinfo[0]->vdevice; + auto vdev = input_ty[0]->vdevice; const auto* attrs = call->attrs.as(); if (attrs->output_format == "onnx") { - auto vdev = input_sinfo[0]->vdevice; + auto vdev = input_ty[0]->vdevice; auto num_total_boxes = batch * num_classes * num_boxes; tvm::ffi::Array oshape_values = {num_total_boxes, 3}; ShapeExpr oshape(oshape_values); tvm::ffi::Array counts_values = {1}; ShapeExpr counts_shape(counts_values); - tvm::ffi::Array fields = {TensorStructInfo(oshape, DataType::Int(64), vdev), - TensorStructInfo(counts_shape, DataType::Int(64), vdev)}; - return TupleStructInfo(fields); + tvm::ffi::Array fields = {TensorType(oshape, DataType::Int(64), vdev), + TensorType(counts_shape, DataType::Int(64), vdev)}; + return TupleType(fields); } auto num_total_boxes_per_batch = num_classes * num_boxes; @@ -96,10 +96,10 @@ StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder& ctx) ShapeExpr scores_shape(scores_values); tvm::ffi::Array counts_values = {batch}; ShapeExpr counts_shape(counts_values); - tvm::ffi::Array fields = {TensorStructInfo(indices_shape, DataType::Int(64), vdev), - TensorStructInfo(scores_shape, DataType::Float(32), vdev), - TensorStructInfo(counts_shape, DataType::Int(64), vdev)}; - return TupleStructInfo(fields); + tvm::ffi::Array fields = {TensorType(indices_shape, DataType::Int(64), vdev), + TensorType(scores_shape, DataType::Float(32), vdev), + TensorType(counts_shape, DataType::Int(64), vdev)}; + return TupleType(fields); } TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression") @@ -113,7 +113,7 @@ TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression") .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the overlap test.") .add_argument("score_threshold", "Tensor", "The score threshold to filter out low score boxes early.") - .set_attr("FInferStructInfo", InferStructInfoAllClassNMS) + .set_attr("FInferType", InferTypeAllClassNMS) .set_attr("FPurity", true); /* relax.vision.get_valid_counts */ @@ -133,30 +133,30 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.vision.get_valid_counts", get_valid_counts); } -StructInfo InferStructInfoGetValidCounts(const Call& call, const BlockBuilder& ctx) { +Type InferTypeGetValidCounts(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { TVM_FFI_VISIT_THROW(ValueError, call) << "get_valid_counts expects 1 argument, got " << call->args.size(); } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - if (data_sinfo == nullptr) { + const auto* data_ty = GetTypeAs(call->args[0]); + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "get_valid_counts expects input data to be a Tensor."; } - if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) { + if (data_ty->ndim != -1 && data_ty->ndim != 3) { TVM_FFI_VISIT_THROW(ValueError, call) - << "get_valid_counts expects 3-D input, got ndim " << data_sinfo->ndim; + << "get_valid_counts expects 3-D input, got ndim " << data_ty->ndim; } const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid get_valid_counts attrs"; - auto vdev = data_sinfo->vdevice; - const auto* data_shape = data_sinfo->shape.as(); + auto vdev = data_ty->vdevice; + const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { - tvm::ffi::Array fields = {TensorStructInfo(DataType::Int(32), /*ndim=*/1, vdev), - TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev), - TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)}; - return TupleStructInfo(fields); + tvm::ffi::Array fields = {TensorType(DataType::Int(32), /*ndim=*/1, vdev), + TensorType(data_ty->dtype, /*ndim=*/3, vdev), + TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + return TupleType(fields); } auto batch = data_shape->values[0]; @@ -176,11 +176,11 @@ StructInfo InferStructInfoGetValidCounts(const Call& call, const BlockBuilder& c } } - tvm::ffi::Array fields = { - TensorStructInfo(ShapeExpr({batch}), DataType::Int(32), vdev), - TensorStructInfo(ShapeExpr({batch, num_anchors, elem_length}), data_sinfo->dtype, vdev), - TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev)}; - return TupleStructInfo(fields); + tvm::ffi::Array fields = { + TensorType(ShapeExpr({batch}), DataType::Int(32), vdev), + TensorType(ShapeExpr({batch, num_anchors, elem_length}), data_ty->dtype, vdev), + TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev)}; + return TupleType(fields); } TVM_REGISTER_OP("relax.vision.get_valid_counts") @@ -188,7 +188,7 @@ TVM_REGISTER_OP("relax.vision.get_valid_counts") .set_num_inputs(1) .add_argument("data", "Tensor", "Input data, 3-D tensor [batch_size, num_anchors, elem_length].") - .set_attr("FInferStructInfo", InferStructInfoGetValidCounts) + .set_attr("FInferType", InferTypeGetValidCounts) .set_attr("FPurity", true); /* relax.vision.non_max_suppression */ @@ -219,52 +219,51 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.vision.non_max_suppression", non_max_suppression); } -StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { +Type InferTypeNMS(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { TVM_FFI_VISIT_THROW(ValueError, call) << "non_max_suppression expects 3 arguments, got " << call->args.size(); } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* valid_count_sinfo = GetStructInfoAs(call->args[1]); - const auto* indices_sinfo = GetStructInfoAs(call->args[2]); - if (data_sinfo == nullptr) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* valid_count_ty = GetTypeAs(call->args[1]); + const auto* indices_ty = GetTypeAs(call->args[2]); + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "non_max_suppression expects input data to be a Tensor."; } - if (valid_count_sinfo == nullptr) { + if (valid_count_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "non_max_suppression expects valid_count to be a Tensor."; } - if (indices_sinfo == nullptr) { + if (indices_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "non_max_suppression expects indices to be a Tensor."; } - if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) { + if (data_ty->ndim != -1 && data_ty->ndim != 3) { TVM_FFI_VISIT_THROW(ValueError, call) - << "non_max_suppression expects 3-D input, got ndim " << data_sinfo->ndim; + << "non_max_suppression expects 3-D input, got ndim " << data_ty->ndim; } - if (valid_count_sinfo->ndim != -1 && valid_count_sinfo->ndim != 1) { + if (valid_count_ty->ndim != -1 && valid_count_ty->ndim != 1) { TVM_FFI_VISIT_THROW(ValueError, call) - << "non_max_suppression expects valid_count to be 1-D, got ndim " - << valid_count_sinfo->ndim; + << "non_max_suppression expects valid_count to be 1-D, got ndim " << valid_count_ty->ndim; } - if (indices_sinfo->ndim != -1 && indices_sinfo->ndim != 2) { + if (indices_ty->ndim != -1 && indices_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) - << "non_max_suppression expects indices to be 2-D, got ndim " << indices_sinfo->ndim; + << "non_max_suppression expects indices to be 2-D, got ndim " << indices_ty->ndim; } - if (!valid_count_sinfo->IsUnknownDtype() && valid_count_sinfo->dtype != DataType::Int(32)) { + if (!valid_count_ty->IsUnknownDtype() && valid_count_ty->dtype != DataType::Int(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "non_max_suppression expects valid_count to have dtype int32, got " - << valid_count_sinfo->dtype; + << valid_count_ty->dtype; } - if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != DataType::Int(32)) { + if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != DataType::Int(32)) { TVM_FFI_VISIT_THROW(TypeError, call) - << "non_max_suppression expects indices to have dtype int32, got " << indices_sinfo->dtype; + << "non_max_suppression expects indices to have dtype int32, got " << indices_ty->dtype; } - const auto* data_shape = data_sinfo->shape.as(); - const auto* valid_count_shape = valid_count_sinfo->shape.as(); - const auto* indices_shape = indices_sinfo->shape.as(); + const auto* data_shape = data_ty->shape.as(); + const auto* valid_count_shape = valid_count_ty->shape.as(); + const auto* indices_shape = indices_ty->shape.as(); if (data_shape != nullptr) { arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr batch = data_shape->values[0]; @@ -274,7 +273,7 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { TVM_FFI_VISIT_THROW(ValueError, call) << "non_max_suppression expects valid_count to have shape [batch_size]. " "However, the given data tensor has batch size `" - << batch << "` and the given valid_count tensor has shape " << valid_count_sinfo->shape; + << batch << "` and the given valid_count tensor has shape " << valid_count_ty->shape; } if (indices_shape != nullptr) { if (!analyzer->CanProveEqual(indices_shape->values[0], batch) || @@ -282,15 +281,14 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { TVM_FFI_VISIT_THROW(ValueError, call) << "non_max_suppression expects indices to have shape [batch_size, num_anchors]. " "However, the given data tensor has shape " - << data_sinfo->shape << " and the given indices tensor has shape " - << indices_sinfo->shape; + << data_ty->shape << " and the given indices tensor has shape " << indices_ty->shape; } } } const auto* attrs = call->attrs.as(); TVM_FFI_ICHECK(attrs != nullptr) << "Invalid non_max_suppression attrs"; - auto vdev = data_sinfo->vdevice; + auto vdev = data_ty->vdevice; if (data_shape != nullptr) { const auto* elem_length_imm = data_shape->values[2].as(); if (elem_length_imm != nullptr) { @@ -320,40 +318,39 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { // box_indices[batch, num_anchors], // valid_box_count[batch, 1]) if (data_shape == nullptr) { - tvm::ffi::Array fields = { - TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev), - TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev), - TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)}; - return TupleStructInfo(fields); + tvm::ffi::Array fields = {TensorType(data_ty->dtype, /*ndim=*/3, vdev), + TensorType(DataType::Int(32), /*ndim=*/2, vdev), + TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + return TupleType(fields); } auto batch = data_shape->values[0]; auto num_anchors = data_shape->values[1]; - tvm::ffi::Array fields = { - TensorStructInfo(ffi::GetRef(data_shape), data_sinfo->dtype, vdev), - TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), - TensorStructInfo(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; - return TupleStructInfo(fields); + tvm::ffi::Array fields = { + TensorType(ffi::GetRef(data_shape), data_ty->dtype, vdev), + TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), + TensorType(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; + return TupleType(fields); } // Hard NMS returns (box_indices[batch, num_anchors], valid_box_count[batch, 1]) if (data_shape == nullptr) { - tvm::ffi::Array fields = {TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev), - TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)}; - return TupleStructInfo(fields); + tvm::ffi::Array fields = {TensorType(DataType::Int(32), /*ndim=*/2, vdev), + TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + return TupleType(fields); } auto batch = data_shape->values[0]; auto num_anchors = data_shape->values[1]; - tvm::ffi::Array fields = { - TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), - TensorStructInfo(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; - return TupleStructInfo(fields); + tvm::ffi::Array fields = { + TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), + TensorType(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; + return TupleType(fields); } // Returns modified data tensor with the same shape as input. - if (const auto* data_shape = data_sinfo->shape.as()) { - return TensorStructInfo(ffi::GetRef(data_shape), data_sinfo->dtype, vdev); + if (const auto* data_shape = data_ty->shape.as()) { + return TensorType(ffi::GetRef(data_shape), data_ty->dtype, vdev); } - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev); + return TensorType(data_ty->dtype, /*ndim=*/3, vdev); } TVM_REGISTER_OP("relax.vision.non_max_suppression") @@ -363,7 +360,7 @@ TVM_REGISTER_OP("relax.vision.non_max_suppression") "Input data, 3-D tensor [batch_size, num_anchors, elem_length].") .add_argument("valid_count", "Tensor", "1-D tensor for valid number of boxes.") .add_argument("indices", "Tensor", "2-D tensor with shape [batch_size, num_anchors].") - .set_attr("FInferStructInfo", InferStructInfoNMS) + .set_attr("FInferType", InferTypeNMS) .set_attr("FPurity", true); } // namespace relax diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc index 131c634bc46a..b959073cee67 100644 --- a/src/relax/op/vision/roi_align.cc +++ b/src/relax/op/vision/roi_align.cc @@ -60,34 +60,34 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.vision.roi_align", roi_align); } -StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { +Type InferTypeROIAlign(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "ROIAlign expects two arguments, while the given number of arguments is " << call->args.size(); } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* rois_sinfo = GetStructInfoAs(call->args[1]); - if (data_sinfo == nullptr) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* rois_ty = GetTypeAs(call->args[1]); + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "ROIAlign expects the input data to be a Tensor, while the given data is " << call->args[0]->GetTypeKey(); } - if (rois_sinfo == nullptr) { + if (rois_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "ROIAlign expects the rois to be a Tensor, while the given rois is " << call->args[1]->GetTypeKey(); } - if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim != 4) { + if (!data_ty->IsUnknownNdim() && data_ty->ndim != 4) { TVM_FFI_VISIT_THROW(ValueError, call) << "ROIAlign expects the input data to be 4-D, while the given data has ndim " - << data_sinfo->ndim; + << data_ty->ndim; } - if (!rois_sinfo->IsUnknownNdim() && rois_sinfo->ndim != 2) { + if (!rois_ty->IsUnknownNdim() && rois_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "ROIAlign expects the rois tensor to be 2-D, while the given rois has ndim " - << rois_sinfo->ndim; + << rois_ty->ndim; } const auto* attrs = call->attrs.as(); @@ -101,7 +101,7 @@ StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { << "ROIAlign only supports avg and max mode, but got " << attrs->mode; } - const auto* rois_shape = rois_sinfo->shape.as(); + const auto* rois_shape = rois_ty->shape.as(); if (rois_shape != nullptr) { const auto* last_dim = rois_shape->values[1].as(); if (last_dim != nullptr && last_dim->value != 5) { @@ -112,11 +112,11 @@ StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { } } - if (data_sinfo->shape.as() == nullptr || rois_shape == nullptr) { - return TensorStructInfo(data_sinfo->dtype, 4, data_sinfo->vdevice); + if (data_ty->shape.as() == nullptr || rois_shape == nullptr) { + return TensorType(data_ty->dtype, 4, data_ty->vdevice); } - ffi::Array data_shape = data_sinfo->shape.as()->values; + ffi::Array data_shape = data_ty->shape.as()->values; ffi::Array out_shape; if (attrs->layout == "NCHW") { out_shape = {rois_shape->values[0], data_shape[1], IntImm::Int32(attrs->pooled_size[0]), @@ -125,7 +125,7 @@ StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { out_shape = {rois_shape->values[0], IntImm::Int32(attrs->pooled_size[0]), IntImm::Int32(attrs->pooled_size[1]), data_shape[3]}; } - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.vision.roi_align") @@ -134,7 +134,7 @@ TVM_REGISTER_OP("relax.vision.roi_align") .add_argument("data", "Tensor", "The input tensor.") .add_argument("rois", "Tensor", "The input rois with shape (num_roi, 5) in [batch_idx, x1, y1, x2, y2] format.") - .set_attr("FInferStructInfo", InferStructInfoROIAlign) + .set_attr("FInferType", InferTypeROIAlign) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc index ae8fc5d57bbb..f0554155e020 100644 --- a/src/relax/op/vision/roi_pool.cc +++ b/src/relax/op/vision/roi_pool.cc @@ -57,34 +57,34 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.vision.roi_pool", roi_pool); } -StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { +Type InferTypeROIPool(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "ROIPool expects two arguments, while the given number of arguments is " << call->args.size(); } - const auto* data_sinfo = GetStructInfoAs(call->args[0]); - const auto* rois_sinfo = GetStructInfoAs(call->args[1]); - if (data_sinfo == nullptr) { + const auto* data_ty = GetTypeAs(call->args[0]); + const auto* rois_ty = GetTypeAs(call->args[1]); + if (data_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "ROIPool expects the input data to be a Tensor, while the given data is " << call->args[0]->GetTypeKey(); } - if (rois_sinfo == nullptr) { + if (rois_ty == nullptr) { TVM_FFI_VISIT_THROW(TypeError, call) << "ROIPool expects the rois to be a Tensor, while the given rois is " << call->args[1]->GetTypeKey(); } - if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim != 4) { + if (!data_ty->IsUnknownNdim() && data_ty->ndim != 4) { TVM_FFI_VISIT_THROW(ValueError, call) << "ROIPool expects the input data to be 4-D, while the given data has ndim " - << data_sinfo->ndim; + << data_ty->ndim; } - if (!rois_sinfo->IsUnknownNdim() && rois_sinfo->ndim != 2) { + if (!rois_ty->IsUnknownNdim() && rois_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) << "ROIPool expects the rois tensor to be 2-D, while the given rois has ndim " - << rois_sinfo->ndim; + << rois_ty->ndim; } const auto* attrs = call->attrs.as(); @@ -94,7 +94,7 @@ StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { << "ROIPool only supports NCHW layout, but got " << attrs->layout; } - const auto* rois_shape = rois_sinfo->shape.as(); + const auto* rois_shape = rois_ty->shape.as(); if (rois_shape != nullptr) { const auto* last_dim = rois_shape->values[1].as(); if (last_dim != nullptr && last_dim->value != 5) { @@ -105,15 +105,15 @@ StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { } } - if (data_sinfo->shape.as() == nullptr || rois_shape == nullptr) { - return TensorStructInfo(data_sinfo->dtype, 4, data_sinfo->vdevice); + if (data_ty->shape.as() == nullptr || rois_shape == nullptr) { + return TensorType(data_ty->dtype, 4, data_ty->vdevice); } - ffi::Array data_shape = data_sinfo->shape.as()->values; + ffi::Array data_shape = data_ty->shape.as()->values; ffi::Array out_shape = {rois_shape->values[0], data_shape[1], IntImm::Int32(attrs->pooled_size[0]), IntImm::Int32(attrs->pooled_size[1])}; - return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); } TVM_REGISTER_OP("relax.vision.roi_pool") @@ -122,7 +122,7 @@ TVM_REGISTER_OP("relax.vision.roi_pool") .add_argument("data", "Tensor", "The input tensor.") .add_argument("rois", "Tensor", "The input rois with shape (num_roi, 5) in [batch_idx, x1, y1, x2, y2] format.") - .set_attr("FInferStructInfo", InferStructInfoROIPool) + .set_attr("FInferType", InferTypeROIPool) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", true); diff --git a/src/relax/script/builder/distributed.cc b/src/relax/script/builder/distributed.cc index 0561e582dd3c..c5ffc3a4eb6d 100644 --- a/src/relax/script/builder/distributed.cc +++ b/src/relax/script/builder/distributed.cc @@ -19,40 +19,39 @@ #include #include #include -#include +#include #include -#include +#include #include #include "./utils.h" namespace tvm { namespace relax { -Expr MakeCallTIRDist(Expr func, Tuple args, - ffi::Array out_sinfo_list, +Expr MakeCallTIRDist(Expr func, Tuple args, ffi::Array out_ty_list, ffi::Optional packed_ints) { - for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) { - const auto* shape = sinfo->tensor_sinfo->shape.as(); + for (const distributed::DTensorType& ty : out_ty_list) { + const auto* shape = ty->tensor_ty->shape.as(); TVM_FFI_ICHECK(shape != nullptr) - << "out_sinfo of call_tir should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + << "out_ty of call_tir should have defined ShapeExpr as shape. " + "However, one given type information is " + << ty; } - StructInfo out_sinfo{nullptr}; - if (out_sinfo_list.size() == 1) { - out_sinfo = out_sinfo_list[0]; + Type out_ty{nullptr}; + if (out_ty_list.size() == 1) { + out_ty = out_ty_list[0]; } else { - out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + out_ty = TupleType({out_ty_list.begin(), out_ty_list.end()}); } static const Op& op = Op::Get("relax.call_tir"); Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, {}, {out_sinfo}); + call = Call(op, {func, args}, {}, {out_ty}); } else { - call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); + call = Call(op, {func, args, packed_ints.value()}, {}, {out_ty}); } return call; } diff --git a/src/relax/script/builder/frame.cc b/src/relax/script/builder/frame.cc index f710611076dc..8a84db30a3b5 100644 --- a/src/relax/script/builder/frame.cc +++ b/src/relax/script/builder/frame.cc @@ -81,7 +81,7 @@ void FunctionFrameNode::ExitWithScope() { this->block_builder->EndScope(); tvm::relax::Function func(/*params=*/params, /*body=*/body, - /*ret_struct_info=*/ret_struct_info, + /*ret_ty=*/ret_ty, /*is_pure=*/is_pure.value_or(true), /*attrs=*/DictAttrs(attrs)); // Step 2: Update IRModule. @@ -170,7 +170,7 @@ void BindingBlockFrameNode::ExitWithScope() { std::unordered_map var_remap; for (const auto& output_var : output_vars) { - tvm::relax::Var new_output_var(output_var->name_hint(), GetStructInfo(output_var)); + tvm::relax::Var new_output_var(output_var->name_hint(), GetType(output_var)); new_output_vars.push_back(new_output_var); var_remap[output_var->vid] = new_output_var; } diff --git a/src/relax/script/builder/ir.cc b/src/relax/script/builder/ir.cc index 48bba2e592f1..df2aa1e9ea60 100644 --- a/src/relax/script/builder/ir.cc +++ b/src/relax/script/builder/ir.cc @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include "./utils.h" @@ -69,9 +69,9 @@ FunctionFrame Function(bool is_pure, bool is_private) { return FunctionFrame(n); } -tvm::relax::Var Arg(const ffi::String& name, const tvm::relax::StructInfo& struct_info) { +tvm::relax::Var Arg(const ffi::String& name, const tvm::Type& ty) { FunctionFrame frame = FindFunctionFrame("R.Arg"); - tvm::relax::Var var(name, struct_info); + tvm::relax::Var var(name, ty); frame->params.push_back(var); frame->block_builder->AddDefinitionToScope(var); @@ -106,13 +106,13 @@ void FuncAttrs(ffi::Map attrs) { } } -void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { - FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info"); - if (frame->ret_struct_info.defined()) { - TVM_FFI_THROW(ValueError) << "Duplicate function return struct info, previous one is:\n " - << frame->ret_struct_info.value(); +void FuncRetType(const tvm::Type& ret_ty) { + FunctionFrame frame = FindFunctionFrame("R.func_ret_type"); + if (frame->ret_ty.defined()) { + TVM_FFI_THROW(ValueError) << "Duplicate function return type, previous one is:\n " + << frame->ret_ty.value(); } - frame->ret_struct_info = ret_sinfo; + frame->ret_ty = ret_ty; } void FuncRetValue(const tvm::relax::Expr& value) { @@ -153,7 +153,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("script.ir_builder.relax.Arg", Arg) .def("script.ir_builder.relax.FuncName", FuncName) .def("script.ir_builder.relax.FuncAttrs", FuncAttrs) - .def("script.ir_builder.relax.FuncRetStructInfo", FuncRetStructInfo) + .def("script.ir_builder.relax.FuncRetType", FuncRetType) .def("script.ir_builder.relax.FuncRetValue", FuncRetValue); } @@ -209,20 +209,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { /////////////////////////////// Bindings /////////////////////////////// -tvm::relax::Var Emit(const tvm::relax::Expr& expr, - const ffi::Optional& annotate_struct_info) { - using tvm::relax::GetStructInfo; +tvm::relax::Var Emit(const tvm::relax::Expr& expr, const ffi::Optional& annotate_ty) { + using tvm::relax::GetType; BindingBlockFrame block_frame = CheckBindingBlockFrameExistAndUnended(); const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); - if (annotate_struct_info.defined()) { - const auto& sinfo = annotate_struct_info.value(); - if (!expr->struct_info_.defined()) { - UpdateStructInfo(expr, sinfo); + if (annotate_ty.defined()) { + const auto& ty = annotate_ty.value(); + if (!expr->ty.defined()) { + tvm::relax::UpdateType(expr, ty); } else { - TVM_FFI_ICHECK(StructInfoBaseCheck(sinfo, GetStructInfo(expr)) != + TVM_FFI_ICHECK(tvm::relax::TypeBaseCheck(ty, GetType(expr)) != tvm::relax::BaseCheckResult::kFailL0) - << "Invalid annotation. Got rhs value struct info: " << GetStructInfo(expr) - << ", given struct info: " << sinfo; + << "Invalid annotation. Got rhs value type: " << GetType(expr) << ", given type: " << ty; } } tvm::relax::Var var = block_builder->Emit(expr); @@ -230,12 +228,11 @@ tvm::relax::Var Emit(const tvm::relax::Expr& expr, return var; } -tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, - const tvm::relax::StructInfo& struct_info) { +tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, const tvm::Type& ty) { BindingBlockFrame block_frame = CheckBindingBlockFrameExistAndUnended(); const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); - tvm::relax::Var var = block_builder->EmitMatchCast(value, struct_info); + tvm::relax::Var var = block_builder->EmitMatchCast(value, ty); block_frame->emitted_vars.push_back(var); return var; } diff --git a/src/relax/script/builder/utils.h b/src/relax/script/builder/utils.h index 14e762064f42..3090969e235f 100644 --- a/src/relax/script/builder/utils.h +++ b/src/relax/script/builder/utils.h @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include @@ -108,7 +108,7 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, ffi::S last_block->bindings.end() - 1); tvm::relax::Var new_var = tvm::relax::Var(last_binding->var->name_hint() + output_var_suffix, - GetStructInfo(last_binding->var)); + GetType(last_binding->var)); tvm::relax::Expr body; const auto* var_binding = last_binding.as(); @@ -120,7 +120,7 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, ffi::S body = new_var; } else if (const auto* match_cast = last_binding.as()) { last_block_bindings.push_back( - tvm::relax::MatchCast(new_var, match_cast->value, match_cast->struct_info)); + tvm::relax::MatchCast(new_var, match_cast->value, match_cast->ty)); body = new_var; } else { TVM_FFI_CHECK(false, TypeError) << "Unsupported binding type: " << last_binding->GetTypeKey(); diff --git a/src/relax/script/printer/binding.cc b/src/relax/script/printer/binding.cc index d756a82a0e18..ee947ad30240 100644 --- a/src/relax/script/printer/binding.cc +++ b/src/relax/script/printer/binding.cc @@ -45,15 +45,15 @@ IfDoc PrintIfExpr(const relax::If& n, const AccessPath& n_p, TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](relax::MatchCast n, AccessPath n_p, IRDocsifier d) -> Doc { - using relax::StructInfo; - using relax::MatchStructInfo; + using tvm::Type; + using relax::MatchType; ffi::Optional ann = std::nullopt; - if (d->cfg->GetExtraConfig("relax.show_all_struct_info", true)) { - ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + if (d->cfg->GetExtraConfig("relax.show_all_ty", true)) { + ann = TypeAsAnn(n->var, n_p->Attr("var"), d, n->value); } ExprDoc rhs = Relax(d, "match_cast") ->Call({d->AsDoc(n->value, n_p->Attr("value")), - d->AsDoc(n->struct_info, n_p->Attr("struct_info_"))}); + d->AsDoc(n->ty, n_p->Attr("ty"))}); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); return AssignDoc(lhs, rhs, ann); }); @@ -62,7 +62,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::VarBinding n, AccessPath n_p, IRDocsifier d) -> Doc { if (const auto if_ = n->value.as()) { - ffi::Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ffi::Optional ann = TypeAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); return PrintIfExpr(ffi::GetRef(if_), n_p->Attr("value"), d, lhs, ann); } else if (n->value->IsInstance() && @@ -72,13 +72,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Doc ret = d->AsDoc(n->value, n_p->Attr("value")); d->cfg->binding_names.pop_back(); return ret; - } else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) && - relax::HasVoidStructInfo(n->var)) { + } else if (d->cfg->syntax_sugar && relax::HasVoidType(n->value) && + relax::HasVoidType(n->var)) { ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); return ExprStmtDoc(rhs); } else { ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); - ffi::Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ffi::Optional ann = TypeAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); return AssignDoc(lhs, rhs, ann); } diff --git a/src/relax/script/printer/call.cc b/src/relax/script/printer/call.cc index 10f5a228f495..e2225af0df55 100644 --- a/src/relax/script/printer/call.cc +++ b/src/relax/script/printer/call.cc @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include "./utils.h" @@ -82,7 +82,7 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessP return std::nullopt; } TVM_FFI_ICHECK(n->args.size() == 2 || n->args.size() == 3); - TVM_FFI_ICHECK(n->sinfo_args.size() == 1); + TVM_FFI_ICHECK(n->ty_args.size() == 1); ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -90,26 +90,26 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessP args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); // Step 2. Print n->args[1], the input arguments args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayItem(1))); - // Step 3. Print n->sinfo_args, the output struct info - relax::StructInfo o_sinfo = n->sinfo_args[0]; - AccessPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayItem(0); + // Step 3. Print n->ty_args, the output type + tvm::Type out_ty = n->ty_args[0]; + AccessPath out_ty_p = n_p->Attr("ty_args")->ArrayItem(0); bool is_dtensor = false; - kwargs_keys.push_back("out_sinfo"); - if (const auto* o = o_sinfo.as()) { + kwargs_keys.push_back("out_ty"); + if (const auto* o = out_ty.as()) { ffi::Array fields; - AccessPath fields_p = o_sinfo_p->Attr("fields"); + AccessPath fields_p = out_ty_p->Attr("fields"); for (int i = 0, l = o->fields.size(); i < l; ++i) { - if (o->fields[i].as()) { + if (o->fields[i].as()) { is_dtensor = true; } fields.push_back(d->AsDoc(o->fields[i], fields_p->ArrayItem(i))); } kwargs_values.push_back(ListDoc(fields)); } else { - if (o_sinfo.as()) { + if (out_ty.as()) { is_dtensor = true; } - kwargs_values.push_back(d->AsDoc(o_sinfo, o_sinfo_p)); + kwargs_values.push_back(d->AsDoc(out_ty, out_ty_p)); } // for call_tir_inplace, we also need to include the inplace args @@ -321,14 +321,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } // Step 4. Print type_args - if (n->sinfo_args.size() > 0) { - AccessPath sinfo_args_p = n_p->Attr("sinfo_args"); - ffi::Array sinfo_args; - for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { - sinfo_args.push_back(d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayItem(i))); + if (n->ty_args.size() > 0) { + AccessPath ty_args_p = n_p->Attr("ty_args"); + ffi::Array ty_args; + for (int i = 0, l = n->ty_args.size(); i < l; ++i) { + ty_args.push_back(d->AsDoc(n->ty_args[i], ty_args_p->ArrayItem(i))); } - kwargs_keys.push_back("sinfo_args"); - kwargs_values.push_back(TupleDoc(sinfo_args)); + kwargs_keys.push_back("ty_args"); + kwargs_values.push_back(TupleDoc(ty_args)); } return prefix->Call(args, kwargs_keys, kwargs_values); }); diff --git a/src/relax/script/printer/struct_info.cc b/src/relax/script/printer/dependent_type.cc similarity index 71% rename from src/relax/script/printer/struct_info.cc rename to src/relax/script/printer/dependent_type.cc index 1019cfa7e9bb..ee3aa6663cc5 100644 --- a/src/relax/script/printer/struct_info.cc +++ b/src/relax/script/printer/dependent_type.cc @@ -26,8 +26,8 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](relax::ObjectStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](relax::ObjectType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Object"); }); @@ -62,25 +62,24 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifie } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](relax::PrimStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { - ffi::Array args; - ffi::Array kwargs_keys; - ffi::Array kwargs_values; - - if (n->value.defined()) { - kwargs_keys.push_back("value"); - kwargs_values.push_back(PrintShapeVar(n->value.value(), n_p->Attr("value"), d)); - } else { - args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))); - } + .set_dispatch("", [](relax::PrimType n, AccessPath n_p, IRDocsifier d) -> Doc { + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; + + if (n->value.defined()) { + kwargs_keys.push_back("value"); + kwargs_values.push_back(PrintShapeVar(n->value.value(), n_p->Attr("value"), d)); + } else { + args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))); + } - return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values); - }); + return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](relax::ShapeStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + .set_dispatch( + "", [](relax::ShapeType n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->values.defined()) { ffi::Array shape = n->values.value(); AccessPath shape_p = n_p->Attr("values"); @@ -95,8 +94,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](relax::TensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](relax::TensorType n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -138,22 +137,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](relax::TupleStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { - if (n->fields.empty()) { - return Relax(d, "Tuple"); - } - ffi::Array fields_doc; - AccessPath fields_p = n_p->Attr("fields"); - for (int i = 0, l = n->fields.size(); i < l; ++i) { - fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); - } - return Relax(d, "Tuple")->Call(fields_doc); - }); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](relax::FuncStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](relax::FuncType n, AccessPath n_p, IRDocsifier d) -> Doc { auto ret_doc = d->AsDoc(n->ret, n_p->Attr("ret")); auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity")); @@ -161,7 +146,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Array keys; ffi::Array values; - if (!n->ret->IsInstance()) { + if (!n->ret->IsInstance()) { keys.push_back("ret"); values.push_back(ret_doc); } @@ -178,7 +163,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // TODO(@junrushao): track symbolic shape relation ffi::Array params_doc; - ffi::Array params = n->params.value(); + ffi::Array params = n->params.value(); AccessPath params_p = n_p->Attr("params"); for (int i = 0, n_params = params.size(); i < n_params; ++i) { params_doc.push_back(d->AsDoc(params[i], params_p->ArrayItem(i))); @@ -186,12 +171,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return Relax(d, "Callable")->Call({TupleDoc(params_doc), ret_doc, purity_doc}); }); -TVM_REGISTER_SCRIPT_AS_REPR(relax::ObjectStructInfoNode, ReprPrintRelax); -TVM_REGISTER_SCRIPT_AS_REPR(relax::PrimStructInfoNode, ReprPrintRelax); -TVM_REGISTER_SCRIPT_AS_REPR(relax::ShapeStructInfoNode, ReprPrintRelax); -TVM_REGISTER_SCRIPT_AS_REPR(relax::TensorStructInfoNode, ReprPrintRelax); -TVM_REGISTER_SCRIPT_AS_REPR(relax::TupleStructInfoNode, ReprPrintRelax); -TVM_REGISTER_SCRIPT_AS_REPR(relax::FuncStructInfoNode, ReprPrintRelax); +TVM_REGISTER_SCRIPT_AS_REPR(relax::ObjectTypeNode, ReprPrintRelax); +TVM_REGISTER_SCRIPT_AS_REPR(relax::PrimTypeNode, ReprPrintRelax); +TVM_REGISTER_SCRIPT_AS_REPR(relax::ShapeTypeNode, ReprPrintRelax); +TVM_REGISTER_SCRIPT_AS_REPR(relax::TensorTypeNode, ReprPrintRelax); +TVM_REGISTER_SCRIPT_AS_REPR(relax::FuncTypeNode, ReprPrintRelax); } // namespace printer } // namespace script diff --git a/src/relax/script/printer/distributed.cc b/src/relax/script/printer/distributed.cc index 0a67b55af89f..f05ec8fe714a 100644 --- a/src/relax/script/printer/distributed.cc +++ b/src/relax/script/printer/distributed.cc @@ -18,7 +18,7 @@ */ #include #include -#include +#include #include "../../../script/printer/ir/utils.h" #include "./utils.h" @@ -36,15 +36,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](relax::distributed::DTensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { + .set_dispatch( + "", [](relax::distributed::DTensorType n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; bool require_kwargs = false; - if (n->tensor_sinfo->shape.defined()) { + if (n->tensor_ty->shape.defined()) { // Need to dig into ShapeExpr to preserve the `R.shape` prefix - if (const auto* shape = n->tensor_sinfo->shape.value().as()) { + if (const auto* shape = n->tensor_ty->shape.value().as()) { auto shape_expr = ffi::GetRef(shape); AccessPath shape_p = n_p->Attr("shape")->Attr("values"); ffi::Array shape_docs; @@ -54,18 +54,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } args.push_back(TupleDoc(shape_docs)); } else { - args.push_back(d->AsDoc(n->tensor_sinfo->shape.value(), n_p->Attr("shape"))); + args.push_back(d->AsDoc(n->tensor_ty->shape.value(), n_p->Attr("shape"))); } } else { require_kwargs = true; } - if (!n->tensor_sinfo->IsUnknownDtype()) { + if (!n->tensor_ty->IsUnknownDtype()) { if (!require_kwargs) { - args.push_back(LiteralDoc::DataType(n->tensor_sinfo->dtype, n_p->Attr("dtype"))); + args.push_back(LiteralDoc::DataType(n->tensor_ty->dtype, n_p->Attr("dtype"))); } else { kwargs_keys.push_back("dtype"); kwargs_values.push_back( - LiteralDoc::DataType(n->tensor_sinfo->dtype, n_p->Attr("dtype"))); + LiteralDoc::DataType(n->tensor_ty->dtype, n_p->Attr("dtype"))); } } else { require_kwargs = true; @@ -82,9 +82,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_keys.push_back("placement"); kwargs_values.push_back(d->AsDoc(n->placement, n_p->Attr("placement"))); } - if (!n->tensor_sinfo->shape.defined() && !n->tensor_sinfo->IsUnknownNdim()) { + if (!n->tensor_ty->shape.defined() && !n->tensor_ty->IsUnknownNdim()) { kwargs_keys.push_back("ndim"); - kwargs_values.push_back(LiteralDoc::Int(n->tensor_sinfo->ndim, n_p->Attr("ndim"))); + kwargs_values.push_back(LiteralDoc::Int(n->tensor_ty->ndim, n_p->Attr("ndim"))); } return Relax(d, "DTensor")->Call(args, kwargs_keys, kwargs_values); }); @@ -128,7 +128,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_REGISTER_SCRIPT_AS_REPR(relax::distributed::DeviceMeshNode, ReprPrintRelax); TVM_REGISTER_SCRIPT_AS_REPR(relax::distributed::PlacementNode, ReprPrintRelax); -TVM_REGISTER_SCRIPT_AS_REPR(relax::distributed::DTensorStructInfoNode, ReprPrintRelax); +TVM_REGISTER_SCRIPT_AS_REPR(relax::distributed::DTensorTypeNode, ReprPrintRelax); } // namespace printer } // namespace script } // namespace tvm diff --git a/src/relax/script/printer/expr.cc b/src/relax/script/printer/expr.cc index c8a813b8d5ab..dfce2b40b1f9 100644 --- a/src/relax/script/printer/expr.cc +++ b/src/relax/script/printer/expr.cc @@ -17,7 +17,7 @@ * under the License. */ -#include +#include #include #include @@ -137,8 +137,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Constant n, AccessPath n_p, IRDocsifier d) -> Doc { if (ffi::Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { - if (n->struct_info_.as()) { - ExprDoc ann = d->AsDoc(n->struct_info_, n_p->Attr("struct_info_")); + if (n->ty.as()) { + ExprDoc ann = d->AsDoc(n->ty, n_p->Attr("ty")); return Relax(d, "dist.const")->Call({s.value(), ann}); } return Relax(d, "const") @@ -152,7 +152,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Doc PrintRelaxVar(relax::Var n, AccessPath p, IRDocsifier d) { if (!d->IsVarDefined(n)) { - ExprDoc ann = d->AsDoc(n->struct_info_, p->Attr("struct_info_")); + ExprDoc ann = d->AsDoc(n->ty, p->Attr("ty")); Frame f = d->frames.back(); ExprDoc var = DefineVar(n, f, d); f->stmts.push_back(AssignDoc(var, std::nullopt, ann)); diff --git a/src/relax/script/printer/function.cc b/src/relax/script/printer/function.cc index 4c0d84f9f6af..e600e5755213 100644 --- a/src/relax/script/printer/function.cc +++ b/src/relax/script/printer/function.cc @@ -22,10 +22,10 @@ namespace tvm { namespace script { namespace printer { -static bool HasDefaultExternFuncStructInfo(const relax::ExternFunc& n) { - const auto* sinfo = n->struct_info_.as(); - if (sinfo == nullptr || sinfo->params.defined() || sinfo->purity || - !sinfo->ret->IsInstance()) { +static bool HasDefaultExternFuncType(const relax::ExternFunc& n) { + const auto* ty = n->ty.as(); + if (ty == nullptr || ty->params.defined() || ty->purity || + !ty->ret->IsInstance()) { return false; } return true; @@ -66,9 +66,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->func_vars = &func_vars; // Step 1. Print the return type ffi::Optional ret_type = std::nullopt; - if (const auto& func_sinfo = relax::MatchStructInfo(n)) { - ret_type = d->AsDoc(func_sinfo.value()->ret, // - n_p->Attr("struct_info_")->Attr("ret")); + if (const auto& func_ty = relax::MatchType(n)) { + ret_type = d->AsDoc(func_ty.value()->ret, // + n_p->Attr("ty")->Attr("ret")); } // Step 2. Print params ffi::Array params; @@ -78,7 +78,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) params.push_back(AssignDoc( /*lhs=*/DefineVar(n->params[i], *f, d), /*rhs=*/std::nullopt, - StructInfoAsAnn(n->params[i], params_p->ArrayItem(i), d, std::nullopt))); + TypeAsAnn(n->params[i], params_p->ArrayItem(i), d, std::nullopt))); } } // Step 3. Clean up func variables @@ -138,8 +138,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](relax::ExternFunc n, AccessPath n_p, IRDocsifier d) -> Doc { ffi::Array args; args.push_back(LiteralDoc::Str(n->global_symbol, n_p->Attr("global_symbol"))); - if (!HasDefaultExternFuncStructInfo(n)) { - args.push_back(d->AsDoc(n->struct_info_, n_p->Attr("struct_info_"))); + if (!HasDefaultExternFuncType(n)) { + args.push_back(d->AsDoc(n->ty, n_p->Attr("ty"))); } return Relax(d, "ExternFunc")->Call(args); }); diff --git a/src/relax/script/printer/type.cc b/src/relax/script/printer/type.cc index f5cbfcb16615..b2fd7fabe8c7 100644 --- a/src/relax/script/printer/type.cc +++ b/src/relax/script/printer/type.cc @@ -24,28 +24,6 @@ namespace tvm { namespace script { namespace printer { -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](relax::ShapeType n, AccessPath n_p, IRDocsifier d) -> Doc { - return Relax(d, "Shape") - ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); - }); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](relax::ObjectType n, AccessPath n_p, IRDocsifier d) -> Doc { - return Relax(d, "Object"); - }); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](relax::TensorType n, AccessPath n_p, IRDocsifier d) -> Doc { - return Relax(d, "Tensor") - ->Call({}, {"ndim", "dtype"}, - {LiteralDoc::Int(n->ndim, n_p->Attr("ndim")), - LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))}); - }); - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::PackedFuncType n, AccessPath n_p, IRDocsifier d) -> Doc { @@ -80,9 +58,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) d->AsDoc(n->ret_type, n_p->Attr("ret_type"))}); }); -TVM_REGISTER_SCRIPT_AS_REPR(relax::ShapeTypeNode, ReprPrintRelax); -TVM_REGISTER_SCRIPT_AS_REPR(relax::ObjectTypeNode, ReprPrintRelax); -TVM_REGISTER_SCRIPT_AS_REPR(relax::TensorTypeNode, ReprPrintRelax); TVM_REGISTER_SCRIPT_AS_REPR(relax::PackedFuncTypeNode, ReprPrintRelax); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; diff --git a/src/relax/script/printer/utils.h b/src/relax/script/printer/utils.h index 607728cb5b69..ee40f4f3fe9b 100644 --- a/src/relax/script/printer/utils.h +++ b/src/relax/script/printer/utils.h @@ -22,7 +22,6 @@ #include #include #include -#include #include #include #include @@ -79,58 +78,56 @@ inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsif return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); } -inline ffi::Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p, - const IRDocsifier& d, - const ffi::Optional& rhs) { - if (!v->struct_info_.defined()) { +inline ffi::Optional TypeAsAnn(const relax::Var& v, const AccessPath& v_p, + const IRDocsifier& d, + const ffi::Optional& rhs) { + if (!v->ty.defined()) { return std::nullopt; } - bool attempt_to_hide_struct_info = - !d->cfg->GetExtraConfig("relax.show_all_struct_info", true); + bool attempt_to_hide_ty = !d->cfg->GetExtraConfig("relax.show_all_ty", true); if (const auto* call = rhs.as()) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) { - attempt_to_hide_struct_info = true; + attempt_to_hide_ty = true; } } - if (attempt_to_hide_struct_info) { - ffi::Optional inferred_sinfo = std::nullopt; + if (attempt_to_hide_ty) { + ffi::Optional inferred_ty = std::nullopt; if (auto opt = rhs.as()) { auto call = opt.value(); if (auto opt = call->op.as()) { auto op = opt.value(); - static auto op_map_infer_struct_info = - Op::GetAttrMap("FInferStructInfo"); + static auto op_map_infer_ty = Op::GetAttrMap("FInferType"); auto temp_builder = relax::BlockBuilder::Create(std::nullopt); - inferred_sinfo = op_map_infer_struct_info[op](call, temp_builder); - } else if (auto opt = call->op.as()) { + inferred_ty = op_map_infer_ty[op](call, temp_builder); + } else if (auto opt = call->op.as()) { auto temp_builder = relax::BlockBuilder::Create(std::nullopt); - inferred_sinfo = - DeriveCallRetStructInfo(opt.value(), call, temp_builder, temp_builder->GetAnalyzer()); + inferred_ty = + DeriveCallRetType(opt.value(), call, temp_builder, temp_builder->GetAnalyzer()); } } else if (const auto* tuple = rhs.as()) { - inferred_sinfo = relax::TupleStructInfo(tuple->fields.Map(relax::GetStructInfo)); + inferred_ty = relax::TupleType(tuple->fields.Map(relax::GetType)); } else if (const auto* get_item = rhs.as()) { - if (auto ptr = get_item->tuple->struct_info_.as(); + if (auto ptr = get_item->tuple->ty.as(); ptr && get_item->index < static_cast(ptr->fields.size())) { - inferred_sinfo = ptr->fields[get_item->index]; + inferred_ty = ptr->fields[get_item->index]; } } else if (const auto* trivial_binding = rhs.as()) { - inferred_sinfo = trivial_binding->struct_info_.as(); + inferred_ty = trivial_binding->ty.as(); } - if (inferred_sinfo && ffi::StructuralEqual()(inferred_sinfo, v->struct_info_)) { + if (inferred_ty && ffi::StructuralEqual()(inferred_ty, v->ty)) { return std::nullopt; } } - return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); + return d->AsDoc(v->ty, v_p->Attr("ty")); } ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 294abf97d720..3f5992bbea8e 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -147,7 +147,7 @@ class AppendLossMutator : private ExprMutator { /*! * \brief Check the number of elements in loss_func_params is no less than num_backbone_outputs, - * and the elements in backbone_return_arr_ and loss_func_params have matched struct_info. Also + * and the elements in backbone_return_arr_ and loss_func_params have matched ty. Also * sets up var_remap_ from loss parameter Vars to backbone returned Vars. */ void CheckAndRemapLossParams(const ffi::Array& loss_func_params) { @@ -158,14 +158,14 @@ class AppendLossMutator : private ExprMutator { for (int i = 0; i < num_backbone_outputs_; ++i) { Var loss_param = loss_func_params[i]; Var backbone_ret = backbone_return_arr_[i]; - auto loss_param_sinfo = GetStructInfo(loss_param); - auto backbone_ret_sinfo = GetStructInfo(backbone_ret); + auto loss_param_ty = GetType(loss_param); + auto backbone_ret_ty = GetType(backbone_ret); - TVM_FFI_ICHECK(checker(backbone_ret_sinfo, loss_param_sinfo)) - << "The struct info of the " << i - << "-th return value of backbone function is: " << backbone_ret_sinfo - << " while the corresponding struct info of parameter of loss function is " - << loss_param_sinfo << ", which is different."; + TVM_FFI_ICHECK(checker(backbone_ret_ty, loss_param_ty)) + << "The type of the " << i + << "-th return value of backbone function is: " << backbone_ret_ty + << " while the corresponding type of parameter of loss function is " << loss_param_ty + << ", which is different."; this->var_remap_[loss_param->vid] = backbone_ret; } @@ -190,7 +190,7 @@ class AppendLossMutator : private ExprMutator { for (int i = 0; i < num_backbone_outputs_; ++i) { auto var = backbone_return_arr_[i]; if (other_outputs_var.count(var) == 0) { - auto new_var = DataflowVar(var->vid, GetStructInfo(var), var->span); + auto new_var = DataflowVar(var->vid, GetType(var), var->span); this->var_remap_[var->vid] = new_var; backbone_return_arr_.Set(i, new_var); } diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 5d13dc6773d4..8e5211c3743e 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -101,7 +101,7 @@ std::tuple)>> if (upper_bounds || lower_bounds) { ffi::Map name_lookup; - for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) { + for (const auto& tir_var : TIRVarsInType(GetType(func))) { name_lookup.Set(tir_var->name_hint, tir_var); symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var); } @@ -144,9 +144,9 @@ std::tuple)>> } auto get_shape = [](Expr expr) -> ffi::Optional> { - auto sinfo = expr->struct_info_.as(); - if (sinfo) { - return sinfo->GetShape(); + auto ty = expr->ty.as(); + if (ty) { + return ty->GetShape(); } else { return std::nullopt; } diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 13dd506a3fb5..14442156e434 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -61,9 +61,8 @@ class ExternFunctionRewriter : ExprMutator { // Append the workspace parameter to this function. ffi::Array new_params = func_node->params; - auto sinfo = - TensorStructInfo(ShapeExpr({IntImm::Int32(max_workspace_size_)}), DataType::UInt(8)); - Var workspace_param(name_sup_->FreshName("workspace"), sinfo); + auto ty = TensorType(ShapeExpr({IntImm::Int32(max_workspace_size_)}), DataType::UInt(8)); + Var workspace_param(name_sup_->FreshName("workspace"), ty); if (func_node->GetAttr(attr::kCodegen)) { workspace_var_param_ = workspace_param; @@ -72,8 +71,8 @@ class ExternFunctionRewriter : ExprMutator { new_params.push_back(workspace_param); auto new_attrs = func_node->attrs; new_attrs.CopyOnWrite()->dict.erase(attr::kWorkspaceSize); - return Function(new_params, VisitExpr(func_node->body), func_node->ret_struct_info, - func_node->is_pure, new_attrs); + return Function(new_params, VisitExpr(func_node->body), func_node->ret_ty, func_node->is_pure, + new_attrs); } return ExprMutator::VisitExpr_(func_node); } @@ -89,7 +88,7 @@ class ExternFunctionRewriter : ExprMutator { auto new_args = call_node->args; TVM_FFI_ICHECK(workspace_var_param_.defined()); new_args.push_back(workspace_var_param_); - return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + return Call(new_op, new_args, call_node->attrs, call_node->ty_args, call_node->span); } } return ExprMutator::VisitExpr_(call_node); @@ -139,8 +138,8 @@ class WorkspaceProvider : ExprMutator { continue; } auto func = Downcast(mod_->Lookup(gvar)); - auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, - func->is_pure, func->attrs); + auto new_func = + Function(func->params, VisitExpr(func->body), func->ret_ty, func->is_pure, func->attrs); builder_->UpdateFunction(gvar, new_func); } return builder_->GetContextIRModule(); @@ -175,7 +174,7 @@ class WorkspaceProvider : ExprMutator { auto new_args = call_node->args; TVM_FFI_ICHECK(workspace_var_main_.defined()); new_args.push_back(workspace_var_main_); - return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + return Call(new_op, new_args, call_node->attrs, call_node->ty_args, call_node->span); } } @@ -189,7 +188,7 @@ class WorkspaceProvider : ExprMutator { size_t max_workspace_size_ = 0; /*! \brief A map from old global variables representing a function with workspace requirement to * the new ones that are transformed to take an additional workspace parameter. This is only - * needed since the struct info of the global variables changes between transformation. */ + * needed since the type of the global variables changes between transformation. */ std::unordered_map gvar_map_; std::unordered_set new_gvars_; }; diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index fa9db81f3aca..4bc0b457607b 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -48,15 +48,15 @@ static ffi::Array ConstructRangeFromShape(const ffi::Array& sha return shape.Map([](const PrimExpr& dim) { return Range(IntImm(dim.dtype(), 0), dim); }); } -static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { - auto shape = tensor_sinfo->GetShape(); +static ffi::Array GetShapeFromTensorType(const TensorType& tensor_ty) { + auto shape = tensor_ty->GetShape(); TVM_FFI_ICHECK(shape.defined()); return shape.value(); } static ffi::Array GetShapeFromTensor(const Expr& expr) { - const auto& tensor_sinfo = Downcast(expr->struct_info_); - return GetShapeFromTensorStructInfo(tensor_sinfo); + const auto& tensor_ty = Downcast(expr->ty); + return GetShapeFromTensorType(tensor_ty); } static IndexMap DeepCopyIndexMap(const IndexMap& index_map) { @@ -153,31 +153,30 @@ class AlterOpImplMutator : public ExprMutator { Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators, input_axis_separators); - TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1) - << "call_tir sinfo_args.size() is expected to be 1"; - StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms); + TVM_FFI_ICHECK_EQ(call->ty_args.size(), 1) << "call_tir ty_args.size() is expected to be 1"; + Type updated_ret_ty = UpdateOutputType(call->ty_args[0], buffer_transforms); auto updated_call = builder_->Normalize( - Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_sinfo})); + Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_ty})); // Now transform each of the outputs to previous layout. - return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators, + return TransformOutputs(updated_call, buffer_transforms, call->ty_args[0], axis_separators, input_axis_separators); } - ffi::Array GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) { - if (const auto* tensor_sinfo = output_sinfo.as()) - return {ffi::GetRef(tensor_sinfo)}; - const auto* tuple_sinfo = output_sinfo.as(); - TVM_FFI_ICHECK(tuple_sinfo); - - ffi::Array arr_tensor_sinfo; - arr_tensor_sinfo.reserve(tuple_sinfo->fields.size()); - for (const auto& sinfo : tuple_sinfo->fields) { - const auto* tensor_sinfo = sinfo.as(); - TVM_FFI_ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not supported yet"; - arr_tensor_sinfo.push_back(ffi::GetRef(tensor_sinfo)); + ffi::Array GetTensorTypePerOutput(const Type& output_ty) { + if (const auto* tensor_ty = output_ty.as()) + return {ffi::GetRef(tensor_ty)}; + const auto* tuple_ty = output_ty.as(); + TVM_FFI_ICHECK(tuple_ty); + + ffi::Array tensor_tys; + tensor_tys.reserve(tuple_ty->fields.size()); + for (const auto& ty : tuple_ty->fields) { + const auto* tensor_ty = ty.as(); + TVM_FFI_ICHECK(tensor_ty) << "Nested tuples in output of call_tir is not supported yet"; + tensor_tys.push_back(ffi::GetRef(tensor_ty)); } - return arr_tensor_sinfo; + return tensor_tys; } bool IsScalarConstant(const Expr& expr) { @@ -246,13 +245,13 @@ class AlterOpImplMutator : public ExprMutator { } Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map, - const TensorStructInfo& old_tensor_sinfo, + const TensorType& old_tensor_ty, const ffi::Array& axis_separator, const ffi::Array& input_axis_separator) { if (IsScalarConstant(expr) || index_map.get() == nullptr) { return expr; } - ffi::Array old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); + ffi::Array old_shape = GetShapeFromTensorType(old_tensor_ty); ffi::Array initial_ranges = ConstructRangeFromShape(old_shape); arith::Analyzer analyzer; auto [inverse_index_map, padding_predicate] = @@ -263,10 +262,10 @@ class AlterOpImplMutator : public ExprMutator { } else { auto padded_expr = builder_->Normalize( TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator)); - const auto& tensor_sinfo = Downcast(padded_expr->struct_info_); + const auto& tensor_ty = Downcast(padded_expr->ty); - GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_sinfo->dtype); - return Call(call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {}, {old_tensor_sinfo}); + GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_ty->dtype); + return Call(call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {}, {old_tensor_ty}); } } @@ -318,59 +317,56 @@ class AlterOpImplMutator : public ExprMutator { return Tuple(updated_inputs); } - /*! \brief Updates output struct info */ - StructInfo UpdateStructInfo(const StructInfo& out_sinfo, - const ffi::Array& buffer_transforms) { - if (buffer_transforms.empty()) return out_sinfo; + /*! \brief Updates the call_tir output type after applying buffer transforms. */ + Type UpdateOutputType(const Type& out_ty, const ffi::Array& buffer_transforms) { + if (buffer_transforms.empty()) return out_ty; - if (out_sinfo->IsInstance()) - return UpdateStructInfo(Downcast(out_sinfo), + if (out_ty->IsInstance()) + return UpdateOutputType(Downcast(out_ty), buffer_transforms[buffer_transforms.size() - 1]); - TVM_FFI_ICHECK(out_sinfo->IsInstance()) - << "Expect output struct info of call_tir to be either TupleStructInfo or " - "TensorStructInfo, but got " - << out_sinfo; + TVM_FFI_ICHECK(out_ty->IsInstance()) + << "Expect output type of call_tir to be either TupleType or " + "TensorType, but got " + << out_ty; - const auto& tuple_sinfo = Downcast(out_sinfo); - ffi::Array sinfo_fields; - size_t first_output_index = buffer_transforms.size() - tuple_sinfo->fields.size(); + const auto& tuple_ty = Downcast(out_ty); + ffi::Array ty_fields; + size_t first_output_index = buffer_transforms.size() - tuple_ty->fields.size(); size_t i = 0; - for (const auto& si : tuple_sinfo->fields) { - TVM_FFI_ICHECK(si->IsInstance()) - << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + for (const auto& si : tuple_ty->fields) { + TVM_FFI_ICHECK(si->IsInstance()) + << "Fields of TupleType must be TensorType for call_tir " "output structinfo, but got " << si; - sinfo_fields.push_back(UpdateStructInfo(Downcast(si), - buffer_transforms[first_output_index + i++])); + ty_fields.push_back( + UpdateOutputType(Downcast(si), buffer_transforms[first_output_index + i++])); } - return TupleStructInfo(sinfo_fields); + return TupleType(ty_fields); } - /*! \brief Returns the TensorStructInfo after applying the \p transform on its shape */ - StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const IndexMap& transform) { - if (transform.get() == nullptr) return tensor_sinfo; - auto shape = GetShapeFromTensorStructInfo(tensor_sinfo); + /*! \brief Returns the TensorType after applying the \p transform on its shape */ + Type UpdateOutputType(const TensorType& tensor_ty, const IndexMap& transform) { + if (transform.get() == nullptr) return tensor_ty; + auto shape = GetShapeFromTensorType(tensor_ty); arith::Analyzer analyzer; auto new_shape = transform->MapShape(shape, analyzer); - if (tensor_sinfo->vdevice.defined()) { - return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype, - tensor_sinfo->vdevice.value()); + if (tensor_ty->vdevice.defined()) { + return TensorType(ShapeExpr(new_shape), tensor_ty->dtype, tensor_ty->vdevice.value()); } - return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype); + return TensorType(ShapeExpr(new_shape), tensor_ty->dtype); } Expr TransformOutputs( - const Expr& expr, const ffi::Array& buffer_transforms, - const StructInfo& old_struct_info, + const Expr& expr, const ffi::Array& buffer_transforms, const Type& old_ty, const ffi::Optional>>& axis_separators, const ffi::Optional>>& input_axis_separators) { if (buffer_transforms.empty()) return expr; - ffi::Array old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info); + ffi::Array old_output_ty = GetTensorTypePerOutput(old_ty); ffi::Array axis_sep, input_axis_sep; - size_t num_outputs = old_output_sinfo.size(); + size_t num_outputs = old_output_ty.size(); if (num_outputs == 0) return expr; size_t first_output_index = buffer_transforms.size() - num_outputs; @@ -385,8 +381,7 @@ class AlterOpImplMutator : public ExprMutator { ffi::Array> input_axis_separators_value = input_axis_separators.value(); input_axis_sep = input_axis_separators_value[first_output_index]; } - return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep, - input_axis_sep); + return TransformLayoutInverse(expr, output_map, old_output_ty[0], axis_sep, input_axis_sep); } // In case of more than one output, we would have to get each item of the output tuple, @@ -403,8 +398,8 @@ class AlterOpImplMutator : public ExprMutator { input_axis_sep = input_axis_separators_value[i + first_output_index]; } auto output = builder_->Normalize(TupleGetItem(expr, static_cast(i))); - transformed_outputs.push_back(TransformLayoutInverse(output, output_map, old_output_sinfo[i], - axis_sep, input_axis_sep)); + transformed_outputs.push_back( + TransformLayoutInverse(output, output_map, old_output_ty[i], axis_sep, input_axis_sep)); } return Tuple(transformed_outputs); } diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 0e8cd722c12d..e1a4659422df 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -25,8 +25,8 @@ #include #include #include -#include #include +#include #include #include @@ -163,8 +163,8 @@ Pass AttachGlobalSymbol() { updates->Add(gvar, new_func); if (new_name.value() != gvar->name_hint) { GlobalVar new_gvar(new_name.value()); - if (auto sinfo = gvar->struct_info_.as()) { - UpdateStructInfo(new_gvar, sinfo.value()); + if (auto ty = gvar->ty.as()) { + UpdateType(new_gvar, ty.value()); } gvar_updates.Set(gvar, new_gvar); diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index c7b4cc5e9ba0..ca57295bf6e2 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -35,40 +35,39 @@ namespace relax { void MatchSymbolicVar(const Expr& arg, const Expr& constant, ffi::Map* symbolic_var_map, arith::AnalyzerObj* analyzer_) { - auto opt_arg_sinfo = MatchStructInfo(arg); - TVM_FFI_ICHECK(opt_arg_sinfo) - << "The struct info of the bound parameter is expected to be TensorStructInfo, but got: " - << GetStructInfo(arg); - auto opt_const_sinfo = MatchStructInfo(constant); + auto opt_arg_ty = MatchType(arg); + TVM_FFI_ICHECK(opt_arg_ty) + << "The type of the bound parameter is expected to be TensorType, but got: " << GetType(arg); + auto opt_const_ty = MatchType(constant); // As the constant is generated by internal codes, we use TVM_FFI_ICHECK here. - TVM_FFI_ICHECK(opt_const_sinfo) - << "The struct info of the bound weight is expected to be TensorStructInfo, but got: " - << GetStructInfo(constant); + TVM_FFI_ICHECK(opt_const_ty) + << "The type of the bound weight is expected to be TensorType, but got: " + << GetType(constant); - TensorStructInfo arg_sinfo = opt_arg_sinfo.value(); - TensorStructInfo const_sinfo = opt_const_sinfo.value(); - TVM_FFI_ICHECK(!const_sinfo->IsUnknownDtype()); - TVM_FFI_ICHECK(!const_sinfo->IsUnknownNdim()); - TVM_FFI_ICHECK(const_sinfo->shape.defined()); + TensorType arg_ty = opt_arg_ty.value(); + TensorType const_ty = opt_const_ty.value(); + TVM_FFI_ICHECK(!const_ty->IsUnknownDtype()); + TVM_FFI_ICHECK(!const_ty->IsUnknownNdim()); + TVM_FFI_ICHECK(const_ty->shape.defined()); // dtype mismatch - if (!arg_sinfo->IsUnknownDtype() && arg_sinfo->dtype != const_sinfo->dtype) { + if (!arg_ty->IsUnknownDtype() && arg_ty->dtype != const_ty->dtype) { TVM_FFI_THROW(InternalError) << "The dtype of the bound parameter is expected to be " - << arg_sinfo->dtype << ", but got: " << const_sinfo->dtype; + << arg_ty->dtype << ", but got: " << const_ty->dtype; } // ndim mismatch - if (!arg_sinfo->IsUnknownNdim() && arg_sinfo->ndim != const_sinfo->ndim) { + if (!arg_ty->IsUnknownNdim() && arg_ty->ndim != const_ty->ndim) { TVM_FFI_THROW(InternalError) << "The ndim of the bound parameter is expected to be " - << arg_sinfo->ndim << ", but got: " << const_sinfo->ndim; + << arg_ty->ndim << ", but got: " << const_ty->ndim; } - if (!arg_sinfo->shape.defined()) return; - const auto* arg_shape = arg_sinfo->shape.value().as(); - const auto* const_shape = const_sinfo->shape.value().as(); + if (!arg_ty->shape.defined()) return; + const auto* arg_shape = arg_ty->shape.value().as(); + const auto* const_shape = const_ty->shape.value().as(); TVM_FFI_ICHECK(arg_shape && const_shape) << "The shape of the bound parameter and weight is expected to be ShapeExprNode for now"; - for (int i = 0; i < arg_sinfo->ndim; ++i) { + for (int i = 0; i < arg_ty->ndim; ++i) { const PrimExpr& const_dim = const_shape->values[i]; TVM_FFI_ICHECK(tirx::is_const_int(const_dim)); if (const auto* shape_var = arg_shape->values[i].as()) { diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 92823a690d9d..f828915004ee 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -19,7 +19,6 @@ #include #include #include -#include #include #include diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index 0ff22aef5e38..1d14d8f60c92 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -57,12 +57,12 @@ class ModelParamBundler : public ExprMutator { params.push_back(func->params[i]); } - ffi::Array param_tuple; + ffi::Array param_tuple; for (size_t i = num_input; i < func->params.size(); i++) { - param_tuple.push_back(GetStructInfo(func->params[i])); + param_tuple.push_back(GetType(func->params[i])); } - Var var_param_tuple(param_tuple_name_.value_or("model_params"), TupleStructInfo(param_tuple)); + Var var_param_tuple(param_tuple_name_.value_or("model_params"), TupleType(param_tuple)); params.push_back(var_param_tuple); for (size_t i = num_input; i < func->params.size(); i++) { diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index beb396462a13..0bb6c04e10e2 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include #include @@ -76,27 +75,27 @@ class CallTIRMutator : public ExprMutator { bool is_inplace = (call->op == call_tir_inplace_op); const auto* inplace_attrs = call->attrs.as(); ffi::Array outs; - if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { + if (const auto& tensor_ty = MatchType(expr)) { // single output case - const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); - TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) - << "the TensorStructInfo shape of call_tir has not populated"; + const TensorType& output_ty = tensor_ty.value(); + TVM_FFI_ICHECK(output_ty->shape.defined()) + << "the TensorType shape of call_tir has not populated"; int dev_index = 0; ffi::String scope = "global"; - if (tensor_sinfo->vdevice.defined()) { - dev_index = GetDeviceIndex(mod_, tensor_sinfo->vdevice.value()); - scope = tensor_sinfo->vdevice.value()->memory_scope; + if (output_ty->vdevice.defined()) { + dev_index = GetDeviceIndex(mod_, output_ty->vdevice.value()); + scope = output_ty->vdevice.value()->memory_scope; } else { dev_index = GetDeviceIndexByScope(mod_, scope); } if (!is_inplace) { - outs.push_back(builder_->Emit(Call(alloc_tensor_op, - {Downcast(tensor_sinfo->shape.value()), - DataTypeImm(tensor_sinfo->dtype), - PrimValue::Int64(dev_index), StringImm(scope)}, - Attrs(), {tensor_sinfo}), - "alloc")); + outs.push_back(builder_->Emit( + Call(alloc_tensor_op, + {Downcast(output_ty->shape.value()), DataTypeImm(output_ty->dtype), + PrimValue::Int64(dev_index), StringImm(scope)}, + Attrs(), {output_ty}), + "alloc")); } else { // if there is only one output, it must be an in-place argument, but check anyway TVM_FFI_ICHECK(inplace_attrs->inplace_indices[0] != -1) @@ -104,19 +103,19 @@ class CallTIRMutator : public ExprMutator { " be -1."; outs.push_back(Downcast(call->args[1])->fields[inplace_attrs->inplace_indices[0]]); } - } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { + } else if (const auto& tuple_ty = MatchType(expr)) { // multiple output case - const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); - for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { - const auto& field = tuple_sinfo->fields[i]; - - TVM_FFI_ICHECK(field->IsInstance()) - << "call_tir expects Tuple of TensorStructInfo, but got " << field - << " as an element of TupleStructInfo"; - const auto& field_tensor = Downcast(field); + const TupleType& output_ty = tuple_ty.value(); + for (size_t i = 0; i < output_ty->fields.size(); ++i) { + const auto& field = output_ty->fields[i]; + + TVM_FFI_ICHECK(field->IsInstance()) + << "call_tir expects Tuple of TensorType, but got " << field + << " as an element of TupleType"; + const auto& field_tensor = Downcast(field); TVM_FFI_ICHECK(field_tensor->shape.defined()) - << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor - << " as an element of TupleStructInfo"; + << "call_tir expects all TensorType has shape, but got " << field_tensor + << " as an element of TupleType"; int dev_index = 0; ffi::String scope = "global"; @@ -138,9 +137,9 @@ class CallTIRMutator : public ExprMutator { } } } else { - TVM_FFI_THROW(TypeError) << "The struct info of call_tir expects to be TensorStructInfo or " - "TupleStructInfo, but got" - << expr->struct_info_; + TVM_FFI_THROW(TypeError) << "The type of call_tir expects to be TensorType or " + "TupleType, but got" + << expr->ty; } ffi::Array args; diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 6fb7e195b722..174576819cdc 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -29,8 +29,8 @@ #include #include #include -#include #include +#include #include namespace tvm { @@ -87,27 +87,27 @@ class SymbolicVarCanonicalizer : public ExprMutator { // correctly return `R.Tensor(ndim=2)`, removing all shape // information. // - // Since we know the StructInfo prior to replacing TIR variables, - // this pass can provide a better StructInfo than the generic + // Since we know the Type prior to replacing TIR variables, + // this pass can provide a better Type than the generic // handling in ExprMutator, by restoring the symbolic variables // within each branch. - auto new_sinfo = VisitExprDepStructInfoField(Downcast(op->struct_info_)); + auto new_ty = VisitExprDepTypeField(Downcast(op->ty)); ffi::StructuralEqual struct_equal; - if (!struct_equal(new_sinfo, GetStructInfo(true_b))) { - auto output_var = Var("then_branch_with_dyn", new_sinfo); + if (!struct_equal(new_ty, GetType(true_b))) { + auto output_var = Var("then_branch_with_dyn", new_ty); true_b = SeqExpr({BindingBlock({ - MatchCast(output_var, true_b, new_sinfo), + MatchCast(output_var, true_b, new_ty), })}, output_var); } - if (!struct_equal(new_sinfo, GetStructInfo(false_b))) { - auto output_var = Var("else_branch_with_dyn", new_sinfo); + if (!struct_equal(new_ty, GetType(false_b))) { + auto output_var = Var("else_branch_with_dyn", new_ty); false_b = SeqExpr({BindingBlock({ - MatchCast(output_var, false_b, new_sinfo), + MatchCast(output_var, false_b, new_ty), })}, output_var); } @@ -172,7 +172,7 @@ class CanonicalizePlanner : public ExprVisitor { // of trivial bindings, then we can replace it with a DataflowVar. for (auto var : visitor.defined_inside_dataflow_) { if (!var.as() && !visitor.used_outside_home_dataflow_.count(var)) { - DataflowVar new_var(var->name_hint(), GetStructInfo(var)); + DataflowVar new_var(var->name_hint(), GetType(var)); plan.replace_binding.Set(var->vid, new_var); plan.replace_usage.Set(var->vid, new_var); @@ -315,8 +315,7 @@ class CanonicalizePlanner : public ExprVisitor { return std::nullopt; } - auto earlier_tuple_size = - Downcast(GetStructInfo(first_element->tuple))->fields.size(); + auto earlier_tuple_size = Downcast(GetType(first_element->tuple))->fields.size(); if (earlier_tuple_size != expr_tuple->fields.size()) { return std::nullopt; } @@ -349,12 +348,11 @@ class CanonicalizePlanner : public ExprVisitor { } void VisitBinding(const Binding& binding) override { - bool has_same_struct_info = [&]() { + bool has_same_ty = [&]() { if (binding.as()) { return true; } else if (auto match_cast = binding.as()) { - return ffi::StructuralEqual()(GetStructInfo(binding->var), - GetStructInfo(match_cast->value)); + return ffi::StructuralEqual()(GetType(binding->var), GetType(match_cast->value)); } else { TVM_FFI_THROW(InternalError) << "Invalid binding type: " << binding->GetTypeKey(); } @@ -366,7 +364,7 @@ class CanonicalizePlanner : public ExprVisitor { value = unwrapped.value(); } - if (auto parent = value.as(); parent && has_same_struct_info) { + if (auto parent = value.as(); parent && has_same_ty) { trivial_bindings_.Set(binding->var, parent.value()); } @@ -534,7 +532,7 @@ class BindingCanonicalizer : public ExprMutator { if (auto* match_binding = binding.as()) { auto new_binding = MatchCast(binding->var, candidates.at(Downcast(match_binding->value)), - match_binding->struct_info); + match_binding->ty); new_bindings.push_back(new_binding); } else if (auto* var_binding = binding.as()) { auto new_binding = diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 8e2591c0dea6..ed31bce2b564 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #include #include @@ -53,9 +53,7 @@ std::unordered_map> GroupShapes( return indices_map; } -inline TensorStructInfo GetTensorSInfo(Expr e) { - return Downcast(GetStructInfo(e)); -} +inline TensorType GetTensorType(Expr e) { return Downcast(GetType(e)); } struct BranchInfo { int num_branches; @@ -136,7 +134,7 @@ ffi::TypedFunction(ffi::Map, ffi::Map matchings, ffi::Map bindings) { std::vector> rhs_shapes; for (const auto& rhs_pat : patterns.rhs) { - auto rhs_shape_opt = GetTensorSInfo(matchings[rhs_pat])->GetShape(); + auto rhs_shape_opt = GetTensorType(matchings[rhs_pat])->GetShape(); if (!rhs_shape_opt) { return ffi::Map{}; } @@ -163,7 +161,7 @@ ffi::TypedFunction(ffi::Map, ffi::MapGetShape().value()[rhs_dim - 1]; + PrimExpr split_size = GetTensorType(rhs)->GetShape().value()[rhs_dim - 1]; DFPattern pattern_to_replace = patterns_to_replace[index]; splits.push_back(SplitInfo{rhs, bias, split_size, pattern_to_replace}); } @@ -204,11 +202,11 @@ ffi::TypedFunction(ffi::Map, ffi::Mapdtype; + auto out_dtype = GetTensorType(matchings[patterns.matmul[indices[0]]])->dtype; auto matmul_combined = matmul(lhs, concat_rhs, out_dtype); if (branch_info.bias_dim) { - auto bias_dim = GetTensorSInfo(bias[0])->ndim; + auto bias_dim = GetTensorType(bias[0])->ndim; auto concat_bias = concat(Tuple(bias), bias_dim - 1); matmul_combined = add(matmul_combined, concat_bias); } @@ -237,7 +235,7 @@ ffi::TypedFunction(ffi::Map, ffi::Mapndim; + int lhs_dim = GetTensorType(lhs)->ndim; int split_axis = std::max(lhs_dim, rhs_dim) - 1; auto chunks = split(matmul_combined, sections, split_axis); @@ -294,7 +292,7 @@ std::vector GetBranchInfo(Function f) { std::optional activation = std::nullopt; if (match.value().count(bias_pat)) { - bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim; + bias_dim = GetTensorType(match.value()[bias_pat])->ndim; } for (size_t i = 0; i < activations.size(); ++i) { diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index 7ee6606e6e9d..70c73a09f447 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -47,7 +47,7 @@ class PrimValueComputeInjector : public ExprMutator { auto param_vars = tirx::UndefinedVars(node->value); tirx::Stmt body = tirx::Evaluate(tirx::Call(ret_dtype, tirx::builtin::ret(), {node->value})); - tirx::PrimFunc func(param_vars, body, PrimType(ret_dtype), {}, + tirx::PrimFunc func(param_vars, body, tvm::PrimType(ret_dtype), {}, DictAttrs({{tirx::attr::kIsHostFunc, true}, {tvm::attr::kSTir, true}})); func = s_tir::RenewDefs(func); diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 2f47727301cc..d1b9d859e955 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -118,7 +118,7 @@ class LayoutConvertMutator : public ExprMutator { TVM_FFI_ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) << "Cannot convert when exactly one of the layouts is unknown"; - const auto* tensor = GetStructInfoAs(expr); + const auto* tensor = GetTypeAs(expr); TVM_FFI_ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr; if (from.LeafValue()->layout.ndim() == to.LeafValue()->layout.ndim()) { @@ -228,7 +228,7 @@ class LayoutConvertMutator : public ExprMutator { ffi::Optional res = GetInferLayoutInfo(call_node, desired_layouts_, layout_cb_, var_layout_map_); ffi::ObjectPtr new_call = ffi::make_object(*call_node); - new_call->struct_info_ = std::nullopt; + new_call->ty = Type(); if (!res.defined() || (!IsNestedTensor(binding->var) && !binding->var->IsInstance())) { // Default policy: use the initial layout. @@ -307,35 +307,34 @@ class LayoutConvertMutator : public ExprMutator { } NLayout from_layout = InitialNLayout(binding->value); NLayout input_layout = GetNLayout(var_layout_map_, binding->value); - auto fvisitleaf = [&](const StructInfo& sinfo, std::array layouts) -> StructInfo { + auto fvisitleaf = [&](const Type& ty, std::array layouts) -> Type { NLayout from = layouts[0], to = layouts[1]; - if (NLayoutEqual()(from, to)) return sinfo; + if (NLayoutEqual()(from, to)) return ty; // If not both from and to are unknown, then none of them can be unknown. TVM_FFI_ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) << "Cannot convert when exactly one of the layouts is unknown"; - const TensorStructInfoNode* tsinfo = sinfo.as(); - TVM_FFI_ICHECK(tsinfo != nullptr) << "We can not set layout for non-tensor struct"; - if (!tsinfo->shape.defined()) return sinfo; - const ShapeExprNode* shape = tsinfo->shape.value().as(); - if (shape == nullptr) return sinfo; + const TensorTypeNode* tensor_ty = ty.as(); + TVM_FFI_ICHECK(tensor_ty != nullptr) << "We can not set layout for non-tensor struct"; + if (!tensor_ty->shape.defined()) return ty; + const ShapeExprNode* shape = tensor_ty->shape.value().as(); + if (shape == nullptr) return ty; TVM_FFI_ICHECK_EQ(shape->values.size(), to.LeafValue()->layout.ndim()); std::vector new_shape; for (size_t i = 0; i < shape->values.size(); ++i) { new_shape.push_back( shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]); } - VDevice vdev = tsinfo->vdevice.value_or(VDevice()); - return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype, vdev, tsinfo->span); + VDevice vdev = tensor_ty->vdevice.value_or(VDevice()); + return TensorType(ShapeExpr(new_shape), tensor_ty->dtype, vdev, tensor_ty->span); }; - StructInfo new_struct_info = TransformTupleLeaf( - binding->struct_info, std::array({from_layout, input_layout}), fvisitleaf); + Type new_ty = TransformTupleLeaf( + binding->ty, std::array({from_layout, input_layout}), fvisitleaf); // re-emit old binding if nothing changes - if (new_struct_info.same_as(binding->struct_info)) { + if (new_ty.same_as(binding->ty)) { builder_->EmitNormalized(ffi::GetRef(binding)); } else { - Var new_var = - builder_->EmitMatchCast(RewriteExpr(binding->value, input_layout), new_struct_info); + Var new_var = builder_->EmitMatchCast(RewriteExpr(binding->value, input_layout), new_ty); var_layout_map_[binding->var] = input_layout; this->var_remap_[binding->var->vid] = new_var; } diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 271eda0d499c..062029a54294 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -171,7 +171,7 @@ class AliasAnalyzer { for (auto input : inputs) { int curr_idx = get_fresh_idx(); alias_map_[input] = {curr_idx}; - if (auto* tup_info = GetStructInfoAs(input)) { + if (auto* tup_info = GetTypeAs(input)) { InsertFreshTuple(curr_idx, tup_info); } } @@ -193,12 +193,12 @@ class AliasAnalyzer { } // Fresh tuple = each element is assumed to be a unique allocation - void InsertFreshTuple(int tup_idx, const TupleStructInfoNode* tup_info) { + void InsertFreshTuple(int tup_idx, const TupleTypeNode* tup_info) { std::vector> tuple_set; for (int i = 0; i < static_cast(tup_info->fields.size()); i++) { int curr_field = get_fresh_idx(); tuple_set.push_back({curr_field}); - if (auto* nested_tup_info = tup_info->fields[i].as()) { + if (auto* nested_tup_info = tup_info->fields[i].as()) { InsertFreshTuple(curr_field, nested_tup_info); } } @@ -251,7 +251,7 @@ class AliasAnalyzer { std::unordered_set ret; int res_idx = get_fresh_idx(); // the result may be a tuple - if (auto* tup_info_node = GetStructInfoAs(bound_var)) { + if (auto* tup_info_node = GetTypeAs(bound_var)) { InsertFreshTuple(res_idx, tup_info_node); } AddCapturedIndices(&ret, res_idx); @@ -270,7 +270,7 @@ class AliasAnalyzer { } // given the expression value, return the set of memory locations corresponding to it - // (the var the expression is being bound to is needed for struct info) + // (the var the expression is being bound to is needed for type) std::unordered_set GetAliasSet(const Expr& value, const Var& bound_var) { std::unordered_set ret; @@ -328,10 +328,10 @@ class AliasAnalyzer { return HandleMysteryCall(call_node, bound_var, true); } else if (op_node->name == "relax.call_tir") { // call_tir: can potentially return a tuple - if (auto* tuple_struct_info = call_node->sinfo_args[0].as()) { + if (auto* tuple_ty = call_node->ty_args[0].as()) { int tup_idx = get_fresh_idx(); ret.insert(tup_idx); - InsertFreshTuple(tup_idx, tuple_struct_info); + InsertFreshTuple(tup_idx, tuple_ty); } else { ret.insert(get_fresh_idx()); } @@ -344,7 +344,7 @@ class AliasAnalyzer { // If the returned value is a tuple, we'll assume it's a fresh tuple // (there may be exceptions to this too) - if (auto* tup_info = GetStructInfoAs(bound_var)) { + if (auto* tup_info = GetTypeAs(bound_var)) { int tup_idx = get_fresh_idx(); ret.insert(tup_idx); InsertFreshTuple(tup_idx, tup_info); @@ -375,13 +375,13 @@ PrimExpr NumElements(const ShapeExpr& shape) { return ret; } -// Given the struct info of the result, return any struct info nested in it +// Given the type of the result, return any type nested in it // that is eleigible to be used for in-place computations (tensors are eligible // only if all their dimensions are integer constants, tuples are eligible if // all members are eligible though we can consider only individual members separately) -std::unordered_set GatherCandidateSinfo( - const StructInfo& result_sinfo) { - if (auto* tensor_info = result_sinfo.as()) { +std::unordered_set GatherCandidateType( + const Type& result_ty) { + if (auto* tensor_info = result_ty.as()) { // don't consider void dtype (don't know the size at compile time) if (tensor_info->dtype.is_void()) { return {}; @@ -389,20 +389,20 @@ std::unordered_set GatherCa // don't consider cases where we don't know the shape at compile time // (we will use the analyzer to do best-effort analysis where there are vars) if (tensor_info->shape.as()) { - return {ffi::GetRef(tensor_info)}; + return {ffi::GetRef(tensor_info)}; } else { return {}; } - } else if (auto* tuple_info = result_sinfo.as()) { + } else if (auto* tuple_info = result_ty.as()) { // we can see if the whole tuple matches or go for any of the components - std::unordered_set ret; + std::unordered_set ret; for (auto field : tuple_info->fields) { - auto field_candidates = GatherCandidateSinfo(field); + auto field_candidates = GatherCandidateType(field); ret.insert(field_candidates.begin(), field_candidates.end()); } // at least one field should be eligible to be done in-place if (!ret.empty()) { - ret.insert(ffi::GetRef(tuple_info)); + ret.insert(ffi::GetRef(tuple_info)); } return ret; } else { @@ -411,16 +411,16 @@ std::unordered_set GatherCa } } -// Given the two struct info, return a pair of bools where the first element is true if -// the two struct info have the same number of elements and dtype and the second element is true +// Given the two type, return a pair of bools where the first element is true if +// the two type have the same number of elements and dtype and the second element is true // if the shapes match _exactly_. Performs this check recursively and ensures the -// stated condition is true for all tensor members of the struct info (return false +// stated condition is true for all tensor members of the type (return false // if a single pair of corresponding tensors does not meet the condition). -std::pair SizeMatches(const StructInfo& target_info, const StructInfo& arg_info, +std::pair SizeMatches(const Type& target_info, const Type& arg_info, const BlockBuilder& ctx) { - if (target_info.as() && arg_info.as()) { - auto target_tensor = Downcast(target_info); - auto arg_tensor = Downcast(arg_info); + if (target_info.as() && arg_info.as()) { + auto target_tensor = Downcast(target_info); + auto arg_tensor = Downcast(arg_info); if (target_tensor->shape.defined() && target_tensor->shape.as() && arg_tensor->shape.defined() && arg_tensor->shape.as()) { if (target_tensor->dtype != arg_tensor->dtype) { @@ -446,9 +446,9 @@ std::pair SizeMatches(const StructInfo& target_info, const StructInf } else { return {false, false}; } - } else if (target_info.as() && arg_info.as()) { - auto target_tup = Downcast(target_info); - auto arg_tup = Downcast(arg_info); + } else if (target_info.as() && arg_info.as()) { + auto target_tup = Downcast(target_info); + auto arg_tup = Downcast(arg_info); if (target_tup->fields.size() != arg_tup->fields.size()) { return {false, false}; } @@ -456,10 +456,9 @@ std::pair SizeMatches(const StructInfo& target_info, const StructInf for (size_t i = 0; i < target_tup->fields.size(); i++) { // if members aren't either tuples or tensors, simply skip them, // since they don't matter for in-place computations - if (!(target_tup->fields[i].as() || - target_tup->fields[i].as()) && - !(arg_tup->fields[i].as() || - arg_tup->fields[i].as())) { + if (!(target_tup->fields[i].as() || + target_tup->fields[i].as()) && + !(arg_tup->fields[i].as() || arg_tup->fields[i].as())) { continue; } auto [field_size_match, field_exact_match] = @@ -690,17 +689,17 @@ FindInplaceOpportunities(const DataflowBlock& block, const ffi::Array& inpu std::unordered_set candidates; std::unordered_set exact_match_candidates; - auto target_sinfo = GatherCandidateSinfo(GetStructInfo(defined_var)); + auto target_ty = GatherCandidateType(GetType(defined_var)); // can't be done in-place, ignore - if (target_sinfo.empty()) { + if (target_ty.empty()) { continue; } // Check that at least one argument matches size with the result for (size_t j = 0; j < call_node->args.size(); j++) { auto arg = call_node->args[j]; - for (auto target : target_sinfo) { - auto [matches_size, matches_exactly] = SizeMatches(target, GetStructInfo(arg), ctx); + for (auto target : target_ty) { + auto [matches_size, matches_exactly] = SizeMatches(target, GetType(arg), ctx); if (matches_size) { candidates.insert(static_cast(j)); if (matches_exactly) { @@ -921,8 +920,7 @@ class ModuleInplaceTransformer : public ExprMutator { return; } Expr new_value = ReplaceBoundCall(binding_ref); - builder_->EmitNormalized( - MatchCast(binding->var, new_value, binding->struct_info, binding->span)); + builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->ty, binding->span)); } // Given the call and indices of arguments that could be done in-place, diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 0175b4a6aa1a..2d9030defe01 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #include #include @@ -34,10 +34,10 @@ namespace tvm { namespace relax { -TensorStructInfo MatchTensorStructInfo(Expr data) { - auto _sinfo = MatchStructInfo(data); - TVM_FFI_ICHECK(_sinfo.defined()) << "Expect data to be a tensor, but get " << GetStructInfo(data); - return _sinfo.value(); +TensorType MatchTensorType(Expr data) { + auto _ty = MatchType(data); + TVM_FFI_ICHECK(_ty.defined()) << "Expect data to be a tensor, but get " << GetType(data); + return _ty.value(); } Expr ExpandToMatchInput(Expr data, int ndim, ffi::Array axes) { @@ -58,23 +58,23 @@ Tuple DecomposeBatchNorm(const Call& call) { TVM_FFI_ICHECK_NOTNULL(attrs); Expr data = call->args[0]; - TensorStructInfo sinfo = MatchTensorStructInfo(data); + TensorType ty = MatchTensorType(data); Expr gamma = call->args[1]; Expr beta = call->args[2]; - Expr moving_mean = ExpandToMatchInput(call->args[3], sinfo->ndim, {attrs->axis}); - Expr moving_var = ExpandToMatchInput(call->args[4], sinfo->ndim, {attrs->axis}); + Expr moving_mean = ExpandToMatchInput(call->args[3], ty->ndim, {attrs->axis}); + Expr moving_var = ExpandToMatchInput(call->args[4], ty->ndim, {attrs->axis}); // output = (x - mean) / sqrt(var + epsilon) * gamma + beta - Expr epsilon = MakeConstantScalar(attrs->epsilon, sinfo->dtype); + Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype); Expr sqrt_var = sqrt(add(moving_var, epsilon)); Expr out = divide(subtract(data, moving_mean), sqrt_var); if (attrs->scale) { - out = multiply(out, ExpandToMatchInput(gamma, sinfo->ndim, {attrs->axis})); + out = multiply(out, ExpandToMatchInput(gamma, ty->ndim, {attrs->axis})); } if (attrs->center) { - out = add(out, ExpandToMatchInput(beta, sinfo->ndim, {attrs->axis})); + out = add(out, ExpandToMatchInput(beta, ty->ndim, {attrs->axis})); } return Tuple({out, call->args[3], call->args[4]}); @@ -91,10 +91,10 @@ Expr MutateBatchNormForTraining(Call call) { Expr moving_mean = call->args[3]; Expr moving_var = call->args[4]; - TensorStructInfo sinfo = MatchTensorStructInfo(data); + TensorType ty = MatchTensorType(data); ffi::Array reduce_axes; - for (int i = 0; i < sinfo->ndim; ++i) { + for (int i = 0; i < ty->ndim; ++i) { if (i != attrs->axis) { reduce_axes.push_back(i); } @@ -103,8 +103,8 @@ Expr MutateBatchNormForTraining(Call call) { Expr data_mean = mean(data, reduce_axes, false); Expr data_var = variance(data, reduce_axes, false); - Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype); - Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype); + Expr momentum = MakeConstantScalar(attrs->momentum, ty->dtype); + Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, ty->dtype); Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)); Expr new_moving_var = add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)); @@ -120,7 +120,7 @@ Expr DecomposeLayerNorm(const Call& call) { TVM_FFI_ICHECK_NOTNULL(attrs); Expr data = call->args[0]; - TensorStructInfo sinfo = MatchTensorStructInfo(data); + TensorType ty = MatchTensorType(data); Expr gamma = call->args[1]; Expr beta = call->args[2]; @@ -128,7 +128,7 @@ Expr DecomposeLayerNorm(const Call& call) { Expr data_var = variance(data, attrs->axes, true); // output = (x - mean) / sqrt(var + epsilon) * gamma + beta - Expr epsilon = MakeConstantScalar(attrs->epsilon, sinfo->dtype); + Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype); Expr sqrt_var = sqrt(add(data_var, epsilon)); Expr out = divide(subtract(data, data_mean), sqrt_var); @@ -143,27 +143,27 @@ Expr DecomposeLayerNorm(const Call& call) { } Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { - TVM_FFI_ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->ty.defined()); Expr expr = call_node->args[0]; - const ShapeStructInfoNode* sinfo = GetStructInfoAs(call_node); - TVM_FFI_ICHECK(sinfo); + const ShapeTypeNode* ty = GetTypeAs(call_node); + TVM_FFI_ICHECK(ty); // call builtin function that converts tensor to shape tuple // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape" static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); Var call = builder->Emit(Call(call_pure_packed_op, {ExternFunc("vm.builtin.tensor_to_shape"), expr}, {}, - {ffi::GetRef(sinfo)})); + {ffi::GetRef(ty)})); // Operators like reshape take the output of `TensorToShape` as their output shape. // Because TOPI expects to have such output shape in symbolic shape at least (i.e., // ffi::Array), we define symbolic variables and returns them as a ShapeExpr. ffi::Array shape_var; - for (int i = 0; i < sinfo->ndim; i++) { + for (int i = 0; i < ty->ndim; i++) { shape_var.push_back(tirx::Var("x", DataType::Int(64))); } // bind symbolic variables to the shape tuple - relax::Var var("y", ShapeStructInfo(shape_var)); - builder->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var))); + relax::Var var("y", ShapeType(shape_var)); + builder->EmitNormalized(MatchCast(var, call, ShapeType(shape_var))); return ShapeExpr(shape_var); } @@ -264,7 +264,7 @@ Pass ApplyDecomposeToFunction(Pass pass, ffi::String func_name) { // Replace non-target functions with stubs to keep references intact. keep_original_version.insert(gvar->name_hint); func = relax::ExternFunc("dummy_" + std::string(gvar->name_hint)); - func->struct_info_ = gvar->struct_info_; + func->ty = gvar->ty; } subset->Add(gvar, func); } diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 4b54dd44224f..3364175ace56 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -42,22 +42,22 @@ namespace { /* \brief Lookup key for subexpression replacements * * The lookup key must contain the expression being bound, along with - * the struct info used for a match cast, if applicable. Using + * the type used for a match cast, if applicable. Using * `MatchCast` with StructuralEqual and StructuralHash would be almost * correct, but acts as a point of definition for symbolic variables - * within the output struct info. As a result, it would erroneously + * within the output type. As a result, it would erroneously * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and * `R.match_cast(A, R.Tensor([p,q]))`, even though they define * different symbolic variables. */ struct ReplacementKey { tvm::relax::Expr bound_value; - tvm::ffi::Optional match_cast = std::nullopt; + tvm::ffi::Optional match_cast = std::nullopt; explicit ReplacementKey(const tvm::relax::Binding& binding) : bound_value(GetBoundValue(binding)) { if (const auto* ptr = binding.as()) { - match_cast = ptr->struct_info; + match_cast = ptr->ty; } } @@ -116,7 +116,7 @@ class CommonSubexprEliminator : public ExprMutator { if (binding.as()) { return VarBinding(binding->var, bound_value); } else if (auto match_cast = binding.as()) { - return MatchCast(binding->var, bound_value, match_cast->struct_info); + return MatchCast(binding->var, bound_value, match_cast->ty); } else { TVM_FFI_THROW(InternalError) << "Binding must be either VarBinding or MatchCast, " << "but was " << binding->GetTypeKey(); @@ -170,8 +170,7 @@ class CommonSubexprEliminator : public ExprMutator { Expr true_branch = VisitWithInnerScope(op->true_branch); Expr false_branch = VisitWithInnerScope(op->false_branch); if (op->cond.same_as(cond) && op->true_branch.same_as(true_branch) && - op->false_branch.same_as(false_branch) && - VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + op->false_branch.same_as(false_branch) && VisitAndCheckTypeFieldUnchanged(op->ty)) { return ffi::GetRef(op); } else { return If(cond, true_branch, false_branch, op->span); diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc index 2db4a6dba3dd..4eec2e794707 100644 --- a/src/relax/transform/expand_tuple_arguments.cc +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -36,9 +36,9 @@ ffi::Optional ExpandParams(Function func) { bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return std::nullopt; - bool has_tuple_param = std::any_of( - func->params.begin(), func->params.end(), - [](const Var& param) -> bool { return param->struct_info_.as(); }); + bool has_tuple_param = + std::any_of(func->params.begin(), func->params.end(), + [](const Var& param) -> bool { return param->ty.as(); }); if (!has_tuple_param) return std::nullopt; @@ -46,13 +46,13 @@ ffi::Optional ExpandParams(Function func) { ffi::Array bindings; std::function expand_param = [&](const Var& param) { - if (auto sinfo = param->struct_info_.as()) { + if (auto ty = param->ty.as()) { ffi::Array internal_tuple; - for (size_t i = 0; i < sinfo->fields.size(); i++) { + for (size_t i = 0; i < ty->fields.size(); i++) { auto name = static_cast(std::stringstream() << param->name_hint() << "_" << i) .str(); - Var new_param(name, sinfo->fields[i]); + Var new_param(name, ty->fields[i]); internal_tuple.push_back(new_param); expand_param(new_param); } @@ -66,14 +66,13 @@ ffi::Optional ExpandParams(Function func) { expand_param(param); } - FuncStructInfo new_sinfo(params.Map([](const auto& var) { return GetStructInfo(var); }), - func->ret_struct_info, - Downcast(func->struct_info_)->purity); + FuncType new_ty(params.Map([](const auto& var) { return GetType(var); }), func->ret_ty, + Downcast(func->ty)->purity); auto write_ptr = func.CopyOnWrite(); write_ptr->params = params; write_ptr->body = SeqExpr({BindingBlock(bindings)}, func->body); - write_ptr->struct_info_ = new_sinfo; + write_ptr->ty = new_ty; return func; } @@ -92,8 +91,8 @@ class TupleExpander : public ExprMutator { ffi::Array new_args; std::function expand_arg = [&](const Expr& arg) { - if (auto sinfo = arg->struct_info_.as()) { - for (size_t i = 0; i < sinfo->fields.size(); i++) { + if (auto ty = arg->ty.as()) { + for (size_t i = 0; i < ty->fields.size(); i++) { expand_arg(TupleGetItem(arg, i)); } } else { @@ -133,7 +132,7 @@ Pass ExpandTupleArguments() { if (auto opt = ExpandParams(func.value())) { auto new_func = opt.value(); GlobalVar new_gvar(gvar->name_hint); - new_gvar->struct_info_ = new_func->struct_info_; + new_gvar->ty = new_func->ty; gvar_replacements[gvar] = new_gvar; new_callees[new_gvar] = new_func; } diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 75b1b09bd48b..bdf56d7f3416 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -45,21 +45,21 @@ class ConstantFolder : public ExprMutator { explicit ConstantFolder(IRModule ctx_module) : ExprMutator(ctx_module) {} /*! - * \brief Pattern match the shape inside the given struct info to a + * \brief Pattern match the shape inside the given type to a * constant shape and get runtime shape tuple from it. - * \param struct_info The given struct info whose shape inside is to be casted. + * \param ty The given type whose shape inside is to be casted. * \return The runtime shape tuple, or nullopt if it is not a constant shape. - * \note Only TensorStructInfo is supported. Returns std::nullopt - * if the input struct info is not TensorStructInfo. + * \note Only TensorType is supported. Returns std::nullopt + * if the input type is not TensorType. */ - static ffi::Optional MatchConstShape(const StructInfo& struct_info) { - const auto* tensor_sinfo = struct_info.as(); - if (tensor_sinfo == nullptr) { + static ffi::Optional MatchConstShape(const Type& ty) { + const auto* tensor_ty = ty.as(); + if (tensor_ty == nullptr) { return std::nullopt; } - const auto* shape = tensor_sinfo->shape.as(); - TVM_FFI_ICHECK(shape != nullptr) << "struct info given by call_tir should have ShapeExpr shape"; + const auto* shape = tensor_ty->shape.as(); + TVM_FFI_ICHECK(shape != nullptr) << "type given by call_tir should have ShapeExpr shape"; std::vector shape_values; for (const auto v : shape->values) { @@ -139,7 +139,7 @@ class ConstantFolder : public ExprMutator { * of the program. */ static bool ExprContainsTensor(const Expr& expr) { - if (GetStructInfo(expr).as()) { + if (GetType(expr).as()) { return true; } if (const auto* tuple = expr.as()) { @@ -161,10 +161,10 @@ class ConstantFolder : public ExprMutator { const auto* call = expr.as(); if (!call) return true; - const auto* tensor_sinfo = call->struct_info_.as(); - if (!tensor_sinfo) return true; + const auto* tensor_ty = call->ty.as(); + if (!tensor_ty) return true; - auto opt_shape = tensor_sinfo->GetShape(); + auto opt_shape = tensor_ty->GetShape(); if (!opt_shape) return true; int64_t num_elements = 1; @@ -229,21 +229,21 @@ class ConstantFolder : public ExprMutator { // Returns std::nullopt on failure. ffi::Optional ConstEvaluateCallTIRTuple(tirx::PrimFunc tir_func, ffi::Array arr_args, - const TupleStructInfoNode* tuple_sinfo) { + const TupleTypeNode* tuple_ty) { ffi::Optional func = GetCachedBuild(tir_func); if (!func) return std::nullopt; DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0}; - size_t num_outputs = tuple_sinfo->fields.size(); + size_t num_outputs = tuple_ty->fields.size(); // Match shapes and dtypes for all output fields. std::vector ret_tensors; for (size_t i = 0; i < num_outputs; ++i) { - ffi::Optional shape = MatchConstShape(tuple_sinfo->fields[i]); + ffi::Optional shape = MatchConstShape(tuple_ty->fields[i]); if (!shape) return std::nullopt; - auto tensor_sinfo = Downcast(tuple_sinfo->fields[i]); - if (tensor_sinfo->IsUnknownDtype()) return std::nullopt; - ret_tensors.push_back(runtime::Tensor::Empty(shape.value(), tensor_sinfo->dtype, cpu_dev)); + auto tensor_ty = Downcast(tuple_ty->fields[i]); + if (tensor_ty->IsUnknownDtype()) return std::nullopt; + ret_tensors.push_back(runtime::Tensor::Empty(shape.value(), tensor_ty->dtype, cpu_dev)); } // Pack input args + all output tensors. @@ -275,20 +275,20 @@ class ConstantFolder : public ExprMutator { TVM_FFI_ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; ffi::Optional> arr_args = MatchConstArrayArgs(call->args[1].as()->fields); - TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; + TVM_FFI_ICHECK_EQ(call->ty_args.size(), 1) << "call_tir should have exactly one ty arg"; if (!func || !arr_args) return {}; - // Handle tuple output: sinfo_args[0] is a TupleStructInfo. - if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { - return ConstEvaluateCallTIRTuple(func.value(), arr_args.value(), tuple_sinfo); + // Handle tuple output: ty_args[0] is a TupleType. + if (const auto* tuple_ty = call->ty_args[0].as()) { + return ConstEvaluateCallTIRTuple(func.value(), arr_args.value(), tuple_ty); } // Handle single tensor output. - ffi::Optional shape = MatchConstShape(call->sinfo_args[0]); + ffi::Optional shape = MatchConstShape(call->ty_args[0]); if (shape) { - TensorStructInfo ret_sinfo = Downcast(call->struct_info_); - return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_sinfo->dtype) + TensorType ret_ty = Downcast(call->ty); + return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_ty->dtype) .value_or({}); } return {}; @@ -341,7 +341,7 @@ class ConstantFolder : public ExprMutator { new_args.push_back(arg); } post_call = - Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args, post_call->span); + Call(post_call->op, new_args, post_call->attrs, post_call->ty_args, post_call->span); // If we are in a dataflow block, we can fold ops. if (builder_->CurrentBlockIsDataFlow()) { diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index b86a2110c3a6..840e9be0000a 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -33,8 +33,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -414,7 +414,7 @@ class FunctionCreator : public ExprMutator { const Tuple& args = Downcast(call->args[1]); for (const Expr& arg : args->fields) { CheckDefAndUpdateParam(arg); - TVM_FFI_ICHECK(GetStructInfoAs(arg) == nullptr); + TVM_FFI_ICHECK(GetTypeAs(arg) == nullptr); } // TODO(tvm-team): handle shape expr } else { @@ -433,12 +433,12 @@ class FunctionCreator : public ExprMutator { if (auto tuple = arg.as()) { for (const Expr& tup_arg : tuple->fields) { CheckDefAndUpdateParam(tup_arg); - TVM_FFI_ICHECK(GetStructInfoAs(tup_arg) == nullptr); + TVM_FFI_ICHECK(GetTypeAs(tup_arg) == nullptr); } } else { CheckDefAndUpdateParam(arg); } - if (GetStructInfoAs(arg) != nullptr) { + if (GetTypeAs(arg) != nullptr) { // The argument is fully referenced. Thus we remove it from the mapping. partially_used_tuple_params_.erase(arg.get()); } @@ -497,14 +497,14 @@ class FunctionCreator : public ExprMutator { int param_idx = tuple_param_idx_[tuple_arg]; Var param = params_[param_idx]; ffi::String param_name = params_[param_idx]->name_hint(); - TupleStructInfo param_sinfo = Downcast(tuple_arg->struct_info_); + TupleType param_ty = Downcast(tuple_arg->ty); ffi::Array item_args; ffi::Array item_params; item_args.reserve(item_indices.size()); item_params.reserve(item_indices.size()); for (int item_idx : item_indices) { - Var item_param(param_name + "_" + std::to_string(item_idx), param_sinfo->fields[item_idx]); + Var item_param(param_name + "_" + std::to_string(item_idx), param_ty->fields[item_idx]); item_args.push_back(TupleGetItem(ffi::GetRef(tuple_arg), item_idx)); item_params.push_back(item_param); tuple_get_item_remap[tuple_arg][item_idx] = item_param; @@ -559,20 +559,20 @@ class FunctionCreator : public ExprMutator { body = builder_->Normalize(body); body = builder_->Normalize(SeqExpr({new_block}, body)); group_attrs.Set(tvm::relax::attr::kPrimitive, true); - Function function = Function(/*params=*/params_, // - /*body=*/body, // - /*ret_struct_info=*/std::nullopt, // - /*is_pure=*/true, // + Function function = Function(/*params=*/params_, // + /*body=*/body, // + /*ret_ty=*/std::nullopt, // + /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); ffi::Array free_vars = FreeSymbolicVars(function).Map([](const tirx::Var& var) -> PrimExpr { return var; }); if (!free_vars.empty()) { - params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); + params_.push_back(Var("tir_vars", ShapeType(free_vars))); arguments_.push_back(ShapeExpr(free_vars)); - function = Function(/*params=*/params_, // - /*body=*/body, // - /*ret_struct_info=*/std::nullopt, // - /*is_pure=*/true, // + function = Function(/*params=*/params_, // + /*body=*/body, // + /*ret_ty=*/std::nullopt, // + /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); } function_ = SymbolicVarRenewMutator::Renew(function); @@ -618,16 +618,16 @@ class FunctionCreator : public ExprMutator { ffi::String name = var != nullptr ? var->name_hint() : ffi::String("param_" + std::to_string(n_param_for_const_++)); - StructInfo param_sinfo = GetStructInfo(expr); + Type param_ty = GetType(expr); if (!IsInlinableConstants(expr)) { - Var param(std::move(name), GetStructInfo(expr)); + Var param(std::move(name), GetType(expr)); arguments_.push_back(expr); params_.push_back(param); } // Mark the tuple parameter is partially referenced in the beginning. // We will remove it from the mapping once we find it is fully referenced. - if (param_sinfo->IsInstance()) { + if (param_ty->IsInstance()) { partially_used_tuple_params_[expr.get()] = {}; tuple_param_idx_[expr.get()] = static_cast(arguments_.size()) - 1; } @@ -759,17 +759,17 @@ class OperatorFusor : public ExprMutator { } bool IsTupleOutput(Function f) { - auto sinfo = GetStructInfo(f).as(); - TVM_FFI_ICHECK(sinfo); - return sinfo->ret->IsInstance(); + auto ty = GetType(f).as(); + TVM_FFI_ICHECK(ty); + return ty->ret->IsInstance(); } bool IsNestedTupleOutput(Function f) { if (!IsTupleOutput(f)) return false; - auto tup = GetStructInfo(f).as()->ret.as(); + auto tup = GetType(f).as()->ret.as(); for (const auto& field : tup->fields) { - if (field->IsInstance()) return true; + if (field->IsInstance()) return true; } return false; } @@ -829,8 +829,7 @@ class OperatorFusor : public ExprMutator { // needs to be remapped to the output of TupleGetItem after the corresponding tuple is // emitted. if (IsTupleOutput(func) && tuple_get_indices_.count(binding->var.get())) { - if (!GetStructInfo(binding->var)->IsInstance() || - IsNestedTupleOutput(func)) { + if (!GetType(binding->var)->IsInstance() || IsNestedTupleOutput(func)) { // When binding->var itself is a tuple, we do not need to remap this variable to the // output of TupleGetItem unless the output is a nested tuple. pending_tuple_get[group].push_back(binding->var); @@ -1276,8 +1275,8 @@ class CompositeFunctionAnnotator : public ExprMutator { auto new_body = VisitWithNewScope(func->body, func->params); if (!new_body.same_as(func->body)) { - auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure, - func->attrs, func->span); + auto new_func = Function(func->params, new_body, func->ret_ty, func->is_pure, func->attrs, + func->span); builder_->UpdateFunction(gv, new_func); } } @@ -1322,7 +1321,7 @@ class CompositeFunctionAnnotator : public ExprMutator { ffi::Array params; for (auto v : func_node->params) { - Var new_v(v->name_hint(), GetStructInfo(v)); + Var new_v(v->name_hint(), GetType(v)); param_vars.push_back(new_v); params.push_back(new_v); } @@ -1330,8 +1329,8 @@ class CompositeFunctionAnnotator : public ExprMutator { // We cannot delegate to `ExprMutator::VisitExpr_(const FunctionNode*)` at this point, as it // would recursively visit the Call node. However, we are still required to generate // well-formed Relax IR. As a result, we need to build the SeqExpr ourselves. - Var local_func_var("local_func", GetStructInfo(f_inner)); - Var output_var("output", f_inner->ret_struct_info); + Var local_func_var("local_func", GetType(f_inner)); + Var output_var("output", f_inner->ret_ty); SeqExpr new_body({BindingBlock({ VarBinding(local_func_var, f_inner), VarBinding(output_var, Call(local_func_var, params)), @@ -1339,7 +1338,7 @@ class CompositeFunctionAnnotator : public ExprMutator { output_var); // pure if the inner func is pure (no need to force purity if it's forced for the inner func) - return Function(param_vars, new_body, func_node->ret_struct_info, f_inner->is_pure); + return Function(param_vars, new_body, func_node->ret_ty, f_inner->is_pure); } private: diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 2b9320dcfd29..308c9da4c353 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -21,8 +21,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -592,7 +592,7 @@ class FusedTIRConstructor : public ExprVisitor { // printed, it's more readable when done explicitly. Since // Buffer is used more than param it gets the name with better // readability. - tirx::Var param = tirx::Var("p_" + buffer->name, PrimType(DataType::Handle())); + tirx::Var param = tirx::Var("p_" + buffer->name, tvm::PrimType(DataType::Handle())); func_info_.params.push_back(param); func_info_.buffer_map.Set(param, buffer); } @@ -637,7 +637,7 @@ class FusedTIRConstructor : public ExprVisitor { } tirx::Var param = - tirx::Var("p_output" + std::to_string(out_idx), PrimType(DataType::Handle())); + tirx::Var("p_output" + std::to_string(out_idx), tvm::PrimType(DataType::Handle())); out_idx++; func_info_.buffer_map.Set(param, buffers[i]); func_info_.params.push_back(param); @@ -733,12 +733,11 @@ class FusedTIRConstructor : public ExprVisitor { if (it != func_info_.expr2buffers.end()) { int begin_buf_idx = 0; int end_buf_idx = 0; - const TupleStructInfo& tuple_sinfo = - Downcast(tuple_get_item->tuple->struct_info_); + const TupleType& tuple_ty = Downcast(tuple_get_item->tuple->ty); for (int i = 0; i < tuple_get_item->index; ++i) { - begin_buf_idx += GetTotalTensorSize(tuple_sinfo->fields[i]); + begin_buf_idx += GetTotalTensorSize(tuple_ty->fields[i]); } - end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_sinfo->fields[tuple_get_item->index]); + end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_ty->fields[tuple_get_item->index]); func_info_.expr2buffers.Set( ffi::GetRef(tuple_get_item), {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); @@ -771,32 +770,30 @@ class FusedTIRConstructor : public ExprVisitor { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); TVM_FFI_ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_)); - TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1); + TVM_FFI_ICHECK_EQ(call->ty_args.size(), 1); auto get_tensor_shape = - [](const TensorStructInfoNode* sinfo) { - const auto* shape_expr = sinfo->shape.as(); + [](const TensorTypeNode* ty) { + const auto* shape_expr = ty->shape.as(); TVM_FFI_ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; return shape_expr->values; }; - if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { + if (const auto* tuple_ty = call->ty_args[0].as()) { ffi::Array> shapes; - for (const StructInfo& field : tuple_sinfo->fields) { - const auto* tensor_sinfo = field.as(); - TVM_FFI_ICHECK(tensor_sinfo) - << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " - "TensorStructInfo, but got " - << call->sinfo_args[0]; - shapes.push_back(get_tensor_shape(tensor_sinfo)); + for (const Type& field : tuple_ty->fields) { + const auto* tensor_ty = field.as(); + TVM_FFI_ICHECK(tensor_ty) << "CallTIR ty_args are expected to be TensorType or Tuple of " + "TensorType, but got " + << call->ty_args[0]; + shapes.push_back(get_tensor_shape(tensor_ty)); } return shapes; - } else if (const auto* tensor_sinfo = call->sinfo_args[0].as()) { - return {get_tensor_shape(tensor_sinfo)}; + } else if (const auto* tensor_ty = call->ty_args[0].as()) { + return {get_tensor_shape(tensor_ty)}; } else { - TVM_FFI_ICHECK(tensor_sinfo) - << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " - "TensorStructInfo, but got " - << call->sinfo_args[0]; + TVM_FFI_ICHECK(tensor_ty) << "CallTIR ty_args are expected to be TensorType or Tuple of " + "TensorType, but got " + << call->ty_args[0]; throw; } } @@ -933,7 +930,7 @@ class FusedTIRConstructor : public ExprVisitor { } return unique_name; }; - // Update buffer with new symbolic shape according to the sinfo + // Update buffer with new symbolic shape according to the ty auto n = ffi::make_object(*buffer.get()); n->shape = output_shapes[i]; n->name = unify_name_hints(); @@ -951,22 +948,22 @@ class FusedTIRConstructor : public ExprVisitor { /*! * \brief Collect TIR func params and buffers with specified relax type and shape - * \param struct_info The struct info + * \param ty The type * \param name_hint The name hint for params and buffers * \param out The vector into which to collect the params/buffers */ static void CollectPrimFuncParams(const Var& relax_param, std::vector>* out, const ffi::Optional& tir_buffer_param) { - auto struct_info = GetStructInfo(relax_param); + auto ty = GetType(relax_param); - TVM_FFI_CHECK(!struct_info.as(), InternalError) + TVM_FFI_CHECK(!ty.as(), InternalError) << "All tuple parameters should be expanded before this point in FuseTIR. " - << "However, parameter " << relax_param << " has struct info " << struct_info; + << "However, parameter " << relax_param << " has type " << ty; auto name_hint = relax_param->name_hint(); - if (const auto* tensor = struct_info.as()) { + if (const auto* tensor = ty.as()) { // Case 1. The relax param is a Tensor, we directly create a tirx var and buffer const auto* shape_expr = tensor->shape.as(); TVM_FFI_ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape."; @@ -981,12 +978,12 @@ class FusedTIRConstructor : public ExprVisitor { } out->push_back(std::move(buffer)); - } else if (const auto* prim_value = struct_info.as()) { + } else if (const auto* prim_value = ty.as()) { // Case 2. The relax param is a scalar, we directly create a tirx var TVM_FFI_ICHECK(prim_value->value->IsInstance()); out->push_back(Downcast(prim_value->value)); - } else if (const auto* shape_expr = struct_info.as()) { + } else if (const auto* shape_expr = ty.as()) { // Case 3. The relax param is a tuple of scalars, each represented as a tirx var for (const auto& var : shape_expr->values.value()) { TVM_FFI_ICHECK(var->IsInstance()); @@ -995,7 +992,7 @@ class FusedTIRConstructor : public ExprVisitor { } else { TVM_FFI_THROW(TypeError) << "The param type of PrimFunc is expected to be " << "Tensor, PrimValue, or ShapeExpr, " - << "but got " << struct_info->GetTypeKey(); + << "but got " << ty->GetTypeKey(); } } @@ -1029,17 +1026,17 @@ class FusedTIRConstructor : public ExprVisitor { } /*! \brief Get DynTensor numbers from recursive Tuples. */ - static size_t GetTotalTensorSize(const StructInfo& sinfo) { - if (sinfo.as()) { + static size_t GetTotalTensorSize(const Type& ty) { + if (ty.as()) { return 1; - } else if (const auto* tuple_sinfo = sinfo.as()) { + } else if (const auto* tuple_ty = ty.as()) { size_t num = 0; - for (const StructInfo& sinfo : tuple_sinfo->fields) { - num += GetTotalTensorSize(sinfo); + for (const Type& ty : tuple_ty->fields) { + num += GetTotalTensorSize(ty); } return num; } else { - TVM_FFI_THROW(InternalError) << "TensorType and TupleType are expect, but got: " << sinfo; + TVM_FFI_THROW(InternalError) << "TensorType and TupleType are expect, but got: " << ty; return 0; } } @@ -1153,7 +1150,7 @@ class TIRFuseMutator : public ExprMutator { const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, old_gvar); GlobalVar new_gvar(old_gvar->name_hint); - UpdateStructInfo(new_gvar, GetStructInfo(prim_func)); + UpdateType(new_gvar, GetType(prim_func)); mod->Remove(old_gvar); updates->Add(new_gvar, prim_func); @@ -1194,12 +1191,12 @@ class TIRFuseMutator : public ExprMutator { using ExprMutator::VisitExpr_; // Get shape from call tirx - static Expr GetCallTIRShape(StructInfo sinfo) { - if (auto* tuple = sinfo.as()) { - ffi::Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); + static Expr GetCallTIRShape(Type ty) { + if (auto* tuple = ty.as()) { + ffi::Array fields = tuple->fields.Map([&](Type x) { return GetCallTIRShape(x); }); return Tuple(fields); } else { - auto* tensor = sinfo.as(); + auto* tensor = ty.as(); TVM_FFI_ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; auto* shape_expr = tensor->shape.as(); TVM_FFI_ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; @@ -1240,26 +1237,24 @@ class TIRFuseMutator : public ExprMutator { ffi::Array tir_vars; for (size_t i = 0; i < call->args.size(); ++i) { auto arg = call->args[i]; - auto sinfo = GetStructInfo(arg); + auto ty = GetType(arg); - TVM_FFI_CHECK(!relax_func->params[i]->struct_info_->IsInstance() && - !sinfo.as(), - InternalError) + TVM_FFI_CHECK( + !relax_func->params[i]->ty->IsInstance() && !ty.as(), + InternalError) << "All tuple parameters should be expanded before this point in FuseTIR. " - << "However, argument " << arg << " with struct info " << arg->struct_info_ - << " is passed as argument " << i << " to Primitive Relax function " << old_gvar - << ", which expects parameter " << relax_func->params[i] << " to have struct info " - << relax_func->params[i]->struct_info_; - - if (const auto* shape = sinfo.as()) { - TVM_FFI_ICHECK(shape->values.defined()) - << "FuseTIR requires all shape input has struct_info value."; + << "However, argument " << arg << " with type " << arg->ty << " is passed as argument " + << i << " to Primitive Relax function " << old_gvar << ", which expects parameter " + << relax_func->params[i] << " to have type " << relax_func->params[i]->ty; + + if (const auto* shape = ty.as()) { + TVM_FFI_ICHECK(shape->values.defined()) << "FuseTIR requires all shape input has ty value."; for (const PrimExpr& prim_value : shape->values.value()) { TVM_FFI_ICHECK(prim_value->IsInstance()) << "All shape inputs are expected to be single tirx var."; tir_vars.push_back(prim_value); } - } else if (const auto* prim_value = sinfo.as()) { + } else if (const auto* prim_value = ty.as()) { TVM_FFI_ICHECK(prim_value->value.defined()) << "FuseTIR requires all R.Prim arguments to have a known value."; PrimExpr expr = prim_value->value.value(); @@ -1286,7 +1281,7 @@ class TIRFuseMutator : public ExprMutator { inplace_attrs->inplace_indices = replacement.inplace_indices; call_attrs = Attrs(inplace_attrs); } - return Call(call_op, call_args, call_attrs, {GetStructInfo(call)}); + return Call(call_op, call_args, call_attrs, {GetType(call)}); } private: diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 7e134cb1d396..25ea737ae726 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -76,7 +76,7 @@ class CallTIRWithGradEliminator : private ExprMutator { if (call_node->op != Op::Get("relax.call_tir_with_grad")) { return ExprMutator::VisitExpr_(call_node); } - return Call(Op::Get("relax.call_tir"), call_node->args, {}, call_node->sinfo_args, + return Call(Op::Get("relax.call_tir"), call_node->args, {}, call_node->ty_args, call_node->span); } }; @@ -264,7 +264,7 @@ class CheckpointGenerator : private ExprMutator { Expr new_arg = this->VisitExpr(arg); call_args.push_back(new_arg); } - return Call(new_op, call_args, call_node->attrs, call_node->sinfo_args); + return Call(new_op, call_args, call_node->attrs, call_node->ty_args); } BlockBuilder builder_; @@ -303,8 +303,8 @@ class BackwardBindingGenerator : private ExprVisitor { BackwardBindingGenerator generator(builder, cp_collector, checkpoint_generator); // Initialize the adjoint of target_var as ones op. We have already checked the target. - auto* target_sinfo = GetStructInfoAs(target_var); - generator.UpdateAdjoint(target_var, ones(target_sinfo->shape.value(), target_sinfo->dtype)); + auto* target_ty = GetTypeAs(target_var); + generator.UpdateAdjoint(target_var, ones(target_ty->shape.value(), target_ty->dtype)); // Do reverse-mode ad, so visit bindings backwards for (auto it = forward_block->bindings.rbegin(); it != forward_block->bindings.rend(); ++it) { @@ -374,13 +374,13 @@ class BackwardBindingGenerator : private ExprVisitor { grad_func(checkpoint_var, Downcast(checkpoint_call), adjoint_var, builder_) .cast(); Tuple args = Downcast(call->args[1]); - auto* tuple_sinfo = GetStructInfoAs(partials); - if (!tuple_sinfo) { + auto* tuple_ty = GetTypeAs(partials); + if (!tuple_ty) { // result_var is a tensor TVM_FFI_ICHECK(args->fields.size() == 1); UpdateAdjoint(args->fields[0], partials); } else { - TVM_FFI_ICHECK(args->fields.size() == tuple_sinfo->fields.size()); + TVM_FFI_ICHECK(args->fields.size() == tuple_ty->fields.size()); for (int i = 0; i < static_cast(args->fields.size()); ++i) { UpdateAdjoint(args->fields[i], TupleGetItem(partials, i)); } @@ -419,12 +419,12 @@ class BackwardBindingGenerator : private ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* tuple_get_item) final { TVM_FFI_ICHECK(tuple_get_item->tuple->IsInstance()) << "The tuple field of a TupleGetItem is not bound to a Var"; - auto* tuple_sinfo = GetStructInfoAs(tuple_get_item->tuple); - TVM_FFI_ICHECK(tuple_sinfo) << "The tuple field of a TupleGetItem must has a TupleStructInfo"; + auto* tuple_ty = GetTypeAs(tuple_get_item->tuple); + TVM_FFI_ICHECK(tuple_ty) << "The tuple field of a TupleGetItem must has a TupleType"; const Var& tuple_var = Downcast(tuple_get_item->tuple); if (adjoint_var_map_.count(tuple_var) == 0) { - auto nested_zeros = Downcast(NestedZeros(ffi::GetRef(tuple_sinfo))); + auto nested_zeros = Downcast(NestedZeros(ffi::GetRef(tuple_ty))); auto tuple_fields = nested_zeros->fields; tuple_fields.Set(tuple_get_item->index, adjoint_var_map_[binding->var]); EmitAdjoint(tuple_var, Tuple(tuple_fields), false); @@ -496,7 +496,7 @@ class BackwardBindingGenerator : private ExprVisitor { // zeros auto it = adjoint_var_map_.find(var); if (it == adjoint_var_map_.end()) { - UpdateAdjoint(var, NestedZeros(GetStructInfo(var))); + UpdateAdjoint(var, NestedZeros(GetType(var))); } Var adjoint_output_var = EmitAdjoint(var, adjoint_var_map_[var], true); out_adjoints.push_back(adjoint_output_var); @@ -532,37 +532,37 @@ class BackwardBindingGenerator : private ExprVisitor { } static AdjointMsg ExprToAdjointMsg(Expr expr) { - return MapToNestedMsgBySInfo(expr, [](Expr leaf) { - TVM_FFI_ICHECK(GetStructInfoAs(leaf)) - << "The leaf of adjoint: " << leaf << " should have StructInfo and be a Tensor."; + return MapToNestedMsgByType(expr, [](Expr leaf) { + TVM_FFI_ICHECK(GetTypeAs(leaf)) + << "The leaf of adjoint: " << leaf << " should have Type and be a Tensor."; return AdjointMsg(leaf); }); } - // Create a zeros Expr with specified struct info - // When sinfo is TupleStructInfo, we would create a (nested) Tuple containing zeros - static Expr NestedZeros(const StructInfo& sinfo) { - AdjointMsg msg = MapToNestedMsg(sinfo, [](StructInfo sinfo) { - auto* tensor_sinfo = sinfo.as(); - TVM_FFI_ICHECK(tensor_sinfo) << "The leaf of adjoint should be a Tensor."; - TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Missing shape when building zeros tuple."; - const Expr& init = zeros(tensor_sinfo->shape.value(), tensor_sinfo->dtype); + // Create a zeros Expr with specified type + // When ty is TupleType, we would create a (nested) Tuple containing zeros + static Expr NestedZeros(const Type& ty) { + AdjointMsg msg = MapToNestedMsg(ty, [](Type ty) { + auto* tensor_ty = ty.as(); + TVM_FFI_ICHECK(tensor_ty) << "The leaf of adjoint should be a Tensor."; + TVM_FFI_ICHECK(tensor_ty->shape.defined()) << "Missing shape when building zeros tuple."; + const Expr& init = zeros(tensor_ty->shape.value(), tensor_ty->dtype); return init; }); return AdjointMsgToExpr(msg); } - // Return lhs + rhs. Requires lhs and rhs has the same StructInfo. + // Return lhs + rhs. Requires lhs and rhs has the same Type. // Use NestedMsg to handle cases when lhs and rhs are tuples. static Expr TupleAwareAdd(const Expr& lhs, const Expr& rhs) { AdjointMsg res = CombineNestedMsg( ExprToAdjointMsg(lhs), ExprToAdjointMsg(rhs), [](Expr l_leaf, Expr r_leaf) { - auto* sinfo = GetStructInfoAs(l_leaf); - TVM_FFI_ICHECK(sinfo) << "The leaf of adjoint should have StructInfo and be a Tensor."; - TVM_FFI_ICHECK(GetStructInfoAs(r_leaf)) - << "The leaf of adjoint should have StructInfo and be a Tensor."; + auto* ty = GetTypeAs(l_leaf); + TVM_FFI_ICHECK(ty) << "The leaf of adjoint should have Type and be a Tensor."; + TVM_FFI_ICHECK(GetTypeAs(r_leaf)) + << "The leaf of adjoint should have Type and be a Tensor."; Expr res = add(l_leaf, r_leaf); - UpdateStructInfo(res, ffi::GetRef(sinfo)); + UpdateType(res, ffi::GetRef(ty)); return res; }); return AdjointMsgToExpr(res); @@ -575,11 +575,11 @@ class BackwardBindingGenerator : private ExprVisitor { // Step 2)t2_new = t2 + increment (TupleAwareAdd) // Step 3) tuple_new = (t1, t2_new, t3) static Expr AddInTuple(const Expr& tuple, int index, const Expr& increment) { - auto* sinfo = GetStructInfoAs(tuple); - TVM_FFI_ICHECK(sinfo) << "The first argument of AddInTuple should have tuple struct info."; - TVM_FFI_ICHECK(index >= 0 && index < static_cast(sinfo->fields.size())); + auto* ty = GetTypeAs(tuple); + TVM_FFI_ICHECK(ty) << "The first argument of AddInTuple should have tuple type."; + TVM_FFI_ICHECK(index >= 0 && index < static_cast(ty->fields.size())); ffi::Array res; - for (size_t i = 0; i < sinfo->fields.size(); ++i) { + for (size_t i = 0; i < ty->fields.size(); ++i) { Expr field; if (const auto* expr_tuple = tuple.as()) { field = expr_tuple->fields[i]; @@ -705,14 +705,14 @@ class GradientMutator : private ExprMutator { return builder_->EndBlock(); } - static bool IsFloatTensorSInfo(const StructInfo& sinfo) { - auto* tensor_sinfo = sinfo.as(); - return tensor_sinfo && tensor_sinfo->dtype.is_float(); + static bool IsFloatTensorType(const Type& ty) { + auto* tensor_ty = ty.as(); + return tensor_ty && tensor_ty->dtype.is_float(); } // When the return value is a Var, it is the target; // when the return value is a Tuple, the target is the target_index-th field of the return value - // Check that the target should be a Var of scalar tensor struct_info + // Check that the target should be a Var of scalar tensor ty void CheckAndSetTarget(const Expr& e, int target_index) { if (auto* var = e.as()) { TVM_FFI_ICHECK_EQ(target_index, 0) @@ -736,11 +736,11 @@ class GradientMutator : private ExprMutator { "value of the given function is " << e; } - auto target_sinfo = GetStructInfo(target_var_); - TVM_FFI_ICHECK(IsScalarTensor(target_sinfo) && IsFloatTensorSInfo(target_sinfo)) - << "The differentiation target must be a float scalar (0-dim Tensor), but the StructInfo " + auto target_ty = GetType(target_var_); + TVM_FFI_ICHECK(IsScalarTensor(target_ty) && IsFloatTensorType(target_ty)) + << "The differentiation target must be a float scalar (0-dim Tensor), but the Type " "of the given target " - << target_var_ << " is " << GetStructInfo(target_var_); + << target_var_ << " is " << GetType(target_var_); } // Check every Var in require_grads: @@ -761,10 +761,10 @@ class GradientMutator : private ExprMutator { var_set.emplace(var->vid); mapped_vars.push_back((*it).second); - TVM_FFI_ICHECK(IsNestedTensorConditioned(GetStructInfo(var), IsFloatTensorSInfo)) + TVM_FFI_ICHECK(IsNestedTensorConditioned(GetType(var), IsFloatTensorType)) << "Only Tensors of floating point dtype or Tuples of float " - "Tensors can require gradients, but the StructInfo of Var " - << var->name_hint() << " is " << GetStructInfo(var); + "Tensors can require gradients, but the Type of Var " + << var->name_hint() << " is " << GetType(var); } return mapped_vars; } diff --git a/src/relax/transform/gradient_simplifier.cc b/src/relax/transform/gradient_simplifier.cc index ba1302129a1e..0b59b220c959 100644 --- a/src/relax/transform/gradient_simplifier.cc +++ b/src/relax/transform/gradient_simplifier.cc @@ -84,11 +84,11 @@ class GradientSimplifier : private ExprMutator { if (call_node->op != Op::Get("relax.permute_dims")) { return false; } - auto sinfo = MatchStructInfo(call_node->args[0]); - if (!sinfo) { + auto ty = MatchType(call_node->args[0]); + if (!ty) { return false; } - auto ndim = sinfo.value()->ndim; + auto ndim = ty.value()->ndim; if (ndim == kUnknownNDim || ndim == 1) { return false; } @@ -107,9 +107,9 @@ class GradientSimplifier : private ExprMutator { // Return permute_dims(expr). Generate the axes needed. static Expr GetTransposeOf(const Expr& expr) { - auto sinfo = MatchStructInfo(expr); - TVM_FFI_ICHECK(sinfo); - auto ndim = sinfo.value()->ndim; + auto ty = MatchType(expr); + TVM_FFI_ICHECK(ty); + auto ndim = ty.value()->ndim; if (ndim == 1) { return expr; } @@ -177,8 +177,8 @@ class GradientSimplifier : private ExprMutator { // operation should be eliminated // Skip matmuls with 1-dim input because in these cases we cannot simply transpose the input - auto a_dim = MatchStructInfo(prev_call_node->args[0]).value()->ndim; - auto b_dim = MatchStructInfo(prev_call_node->args[1]).value()->ndim; + auto a_dim = MatchType(prev_call_node->args[0]).value()->ndim; + auto b_dim = MatchType(prev_call_node->args[1]).value()->ndim; if (a_dim == 1 || b_dim == 1) { return reemit_and_return(); } diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index 01bd47d96073..41c6cfe5ae42 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -22,19 +22,19 @@ namespace tvm { namespace relax { -NType NTypeFrom(const StructInfo& sinfo, DataType dtype) { - auto fmapleaf = [&](const StructInfo& sinfo) -> NType { - const auto* tensor = sinfo.as(); - TVM_FFI_ICHECK(tensor) << "Expected TensorStructInfo, but got " << sinfo; +NType NTypeFrom(const Type& ty, DataType dtype) { + auto fmapleaf = [&](const Type& ty) -> NType { + const auto* tensor = ty.as(); + TVM_FFI_ICHECK(tensor) << "Expected TensorType, but got " << ty; if (dtype == DataType::Void()) return NType(DLDataTypeToString(tensor->dtype)); else return NType(DLDataTypeToString(dtype)); }; - return MapToNestedMsg(sinfo, fmapleaf); + return MapToNestedMsg(ty, fmapleaf); } -NType NTypeFrom(const Expr& expr, DataType dtype) { return NTypeFrom(GetStructInfo(expr), dtype); } +NType NTypeFrom(const Expr& expr, DataType dtype) { return NTypeFrom(GetType(expr), dtype); } NType NTypeMerge(const NType& a, const NType& b) { auto fcombine = [&](const ffi::String& a_str, const ffi::String& b_str) -> ffi::String { diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h index 0acae0981cbd..faa33edd4a18 100644 --- a/src/relax/transform/infer_amp_utils.h +++ b/src/relax/transform/infer_amp_utils.h @@ -57,8 +57,8 @@ struct NTypeEqual { } }; -// Construct a NType from an StructInfo -NType NTypeFrom(const StructInfo& sinfo, DataType dtype = DataType::Void()); +// Construct a NType from an Type +NType NTypeFrom(const Type& ty, DataType dtype = DataType::Void()); // Construct a NType from an Expr NType NTypeFrom(const Expr& expr, DataType dtype = DataType::Void()); diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc index 16e6b901e295..4c85c545eb14 100644 --- a/src/relax/transform/infer_layout_utils.cc +++ b/src/relax/transform/infer_layout_utils.cc @@ -102,17 +102,17 @@ LayoutDecision InitialLayoutDecision(int ndim) { return SLayout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim); } -NLayout InitialNLayout(const StructInfo& sinfo) { - auto fmapleaf = [&](const StructInfo& sinfo) -> NLayout { - if (const auto* tensor_sinfo = sinfo.as()) { - return NLayout(InitialLayoutDecision(tensor_sinfo->ndim)); +NLayout InitialNLayout(const Type& ty) { + auto fmapleaf = [&](const Type& ty) -> NLayout { + if (const auto* tensor_ty = ty.as()) { + return NLayout(InitialLayoutDecision(tensor_ty->ndim)); } return LayoutDecision::InitUnknownDim(); }; - return MapToNestedMsg(sinfo, fmapleaf); + return MapToNestedMsg(ty, fmapleaf); } -NLayout InitialNLayout(const Expr& expr) { return InitialNLayout(GetStructInfo(expr)); } +NLayout InitialNLayout(const Expr& expr) { return InitialNLayout(GetType(expr)); } LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const Expr& arg) { NLayout nlayout = GetNLayout(var_layout_map, arg); diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index 724464a945c9..03579063df1d 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -175,15 +175,15 @@ SLayout InitialLayout(int ndim); LayoutDecision InitialLayoutDecision(int ndim); /*! - * \brief Initialize a nested layout decision given the struct info. - * \param sinfo The sinfo. + * \brief Initialize a nested layout decision given the type. + * \param ty The ty. * \return The initialized nested layout decision. */ -NLayout InitialNLayout(const StructInfo& sinfo); +NLayout InitialNLayout(const Type& ty); /*! * \brief Initialize a nested layout decision given expression - * \param sinfo The expr + * \param ty The expr * \return The initialized nested layout decision. */ NLayout InitialNLayout(const Expr& expr); diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index 8189f4fac8b2..f7863cb83a3b 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -128,7 +128,7 @@ class FunctionInliner : public ExprMutator { // // This implementation uses Option 4. - Var param_var(func->params[i]->name_hint(), args[i]->struct_info_.as()); + Var param_var(func->params[i]->name_hint(), args[i]->ty.as()); param_bindings.push_back(VarBinding(param_var, args[i])); param_map.Set(func->params[i], param_var); } diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index f96616e5d10f..660c20b34110 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -53,7 +53,7 @@ class UnusedTrivialBindingRemover : public ExprMutator { } void VisitBinding_(const MatchCastNode* binding) override { if (binding->value.as() && - ffi::StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(binding->value))) { + ffi::StructuralEqual()(GetType(binding->var), GetType(binding->value))) { has_trivial_binding.insert(binding->var.get()); } ExprVisitor::VisitBinding_(binding); @@ -118,14 +118,13 @@ class CollectLastUsage : public ExprVisitor { // In the future, this may be handled more easily at the // CodeGenVM level. bool stored_in_vm_register = - !(visitor.constant_tensors_.count(var) || var->struct_info_.as() || - var->struct_info_.as() || - var->struct_info_.as()); + !(visitor.constant_tensors_.count(var) || var->ty.as() || + var->ty.as() || var->ty.as()); if (!is_output && !already_killed) { if (visitor.storage_objects_.count(var)) { output[last_usage_point].storage.push_back(var); - } else if (var->struct_info_.as() && stored_in_vm_register) { + } else if (var->ty.as() && stored_in_vm_register) { output[last_usage_point].tensors.push_back(var); } else if (stored_in_vm_register) { output[last_usage_point].objects.push_back(var); @@ -197,8 +196,8 @@ class CollectLastUsage : public ExprVisitor { std::unordered_map last_usage_of_; // Storage objects, eligible for R.vm.kill_object. This cannot be - // determined solely from the StructInfo, because the - // `R.*.alloc_storage` operators return ObjectStructInfo + // determined solely from the Type, because the + // `R.*.alloc_storage` operators return ObjectType std::unordered_set storage_objects_; // Constants, which do not have a VM register, and may *not* have diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 251d5f2d7237..0e27da5f9053 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -285,7 +285,7 @@ class LambdaLifter : public ExprMutator { ffi::Array typed_captured_vars; ffi::Map rebinding_map; for (auto free_var : captured_vars) { - Var var = Var(free_var->name_hint(), GetStructInfo(free_var), free_var->span); + Var var = Var(free_var->name_hint(), GetType(free_var), free_var->span); typed_captured_vars.push_back(var); rebinding_map.Set(free_var, var); } @@ -298,12 +298,11 @@ class LambdaLifter : public ExprMutator { auto gvar_lifted_func = GlobalVar(lift_func_name); { - auto func_sinfo = Downcast(func_node->struct_info_); + auto func_ty = Downcast(func_node->ty); if (is_closure) { - func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo), func_sinfo->ret, - func_sinfo->purity); + func_ty = FuncType(lifted_func_params.Map(GetType), func_ty->ret, func_ty->purity); } - UpdateStructInfo(gvar_lifted_func, func_sinfo); + UpdateType(gvar_lifted_func, func_ty); } Expr body = func_node->body; @@ -321,16 +320,16 @@ class LambdaLifter : public ExprMutator { } body = this->VisitWithNewScope(body, lifted_func_params); - StructInfo ret_struct_info = GetStructInfo(body); + Type ret_ty = GetType(body); body = Bind(body, rebinding_map); Function lifted_func; if (lifted_func_params.same_as(func_node->params) && body.same_as(func_node->body) && - ret_struct_info.same_as(func_node->ret_struct_info)) { + ret_ty.same_as(func_node->ret_ty)) { lifted_func = ffi::GetRef(func_node); } else { lifted_func = - Function(lifted_func_params, body, ret_struct_info, func_node->is_pure, func_node->attrs); + Function(lifted_func_params, body, ret_ty, func_node->is_pure, func_node->attrs); } TVM_FFI_ICHECK(lifted_func.defined()); @@ -341,7 +340,7 @@ class LambdaLifter : public ExprMutator { // Add the lifted function to the module. lifted_func = CopyWithNewVars(lifted_func); - gvar_lifted_func->struct_info_ = GetStructInfo(lifted_func); + gvar_lifted_func->ty = GetType(lifted_func); builder_->UpdateFunction(gvar_lifted_func, lifted_func); @@ -360,7 +359,7 @@ class LambdaLifter : public ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { auto call = ffi::GetRef(call_node); - auto orig_sinfo = Downcast(call->struct_info_); + auto orig_ty = Downcast(call->ty); if (auto opt_var = call->op.as()) { auto var = opt_var.value(); @@ -374,22 +373,20 @@ class LambdaLifter : public ExprMutator { if (auto op = orig_call->op.as()) { static const auto& purity_map = Op::GetAttrMap("FPurity"); return purity_map.get(op.value(), false); - } else if (const auto* func_sinfo = - orig_call->op->struct_info_.as()) { - return func_sinfo->purity; + } else if (const auto* func_ty = orig_call->op->ty.as()) { + return func_ty->purity; } else { TVM_FFI_THROW(InternalError) << "Could not determine purity of call to " << orig_call->op << ", as it is neither a tvm::Op (type = \"" << orig_call->op->GetTypeKey() << "\"), " - << "nor is is annotated with FuncStructInfo (sinfo = " - << orig_call->op->struct_info_ << ")"; + << "nor is is annotated with FuncType (ty = " << orig_call->op->ty << ")"; } }(); auto prev = call; call = Call(is_pure ? invoke_pure_closure_op_ : invoke_closure_op_, - {var, Tuple(call->args)}, {}, {orig_sinfo}); + {var, Tuple(call->args)}, {}, {orig_ty}); } } @@ -404,7 +401,7 @@ class LambdaLifter : public ExprMutator { } auto prev = call; - call = Call(nested_call->op, new_args, call->attrs, call->sinfo_args); + call = Call(nested_call->op, new_args, call->attrs, call->ty_args); } } diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index fb3b014b03df..5735f0f74227 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -66,18 +66,16 @@ class LazyInputMutator : public ExprMutator { } Var fget_param("fget_param", - FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, - ObjectStructInfo())); + FuncType({PrimType(DataType::Int(64)), ObjectType()}, ObjectType())); ffi::Array new_params(func->params.begin(), func->params.begin() + num_input_params); new_params.push_back(fget_param); - auto array_externally_visible_vars = - DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo))); + auto array_externally_visible_vars = DefinableTIRVarsInType(TupleType(new_params.Map(GetType))); std::unordered_set externally_visible_vars(array_externally_visible_vars.begin(), array_externally_visible_vars.end()); - StructInfo new_ret_struct_info = EraseToWellDefined( - func->ret_struct_info, [&](const tirx::Var& var) -> ffi::Optional { + Type new_ret_ty = + EraseToWellDefined(func->ret_ty, [&](const tirx::Var& var) -> ffi::Optional { if (externally_visible_vars.count(var)) { return var; } else { @@ -87,7 +85,7 @@ class LazyInputMutator : public ExprMutator { auto node = ffi::GetRef(func); node.CopyOnWrite()->params = new_params; - node.CopyOnWrite()->ret_struct_info = new_ret_struct_info; + node.CopyOnWrite()->ret_ty = new_ret_ty; node = WithAttr(node, attr::kNumInput, num_input_params + 1); plan_ = FunctionPlan{std::move(param_lookup), fget_param}; @@ -106,7 +104,7 @@ class LazyInputMutator : public ExprMutator { StringImm(var->name_hint()), }), var->name_hint() + "_untyped"); - return builder_->EmitMatchCast(untyped, GetStructInfo(var), var->name_hint()); + return builder_->EmitMatchCast(untyped, GetType(var), var->name_hint()); } } @@ -147,10 +145,8 @@ class LazyOutputMutator : public ExprMutator { define_lookup(0, func_body->body); } - Var fset_output( - "fset_output", - FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, - TupleStructInfo(ffi::Array{}), /* purity = */ false)); + Var fset_output("fset_output", FuncType({PrimType(DataType::Int(64)), ObjectType()}, + TupleType(ffi::Array{}), /* purity = */ false)); plan_ = FunctionPlan{std::move(output_lookup), fset_output}; std::optional num_input_params = GetNumInputParams(func); @@ -163,7 +159,7 @@ class LazyOutputMutator : public ExprMutator { ffi::Array propagated_params; for (auto param : func->params) { GenerateSetOutputCalls(param, [&](const auto& fset_output_call) { - Var void_output("_void", TupleStructInfo(ffi::Array{})); + Var void_output("_void", TupleType(ffi::Array{})); propagated_params.push_back(VarBinding(void_output, fset_output_call)); }); } @@ -173,7 +169,7 @@ class LazyOutputMutator : public ExprMutator { ffi::Array propagated_params; for (const auto& [output_index, expr] : inline_outputs) { Call fset_output_call(fset_output, {PrimValue(IntImm::Int64(output_index)), expr}); - Var void_output("_void", TupleStructInfo(ffi::Array{})); + Var void_output("_void", TupleType(ffi::Array{})); propagated_params.push_back(VarBinding(void_output, fset_output_call)); } return BindingBlock(propagated_params); diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 1b6ef25750a4..9198a7c9e597 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -28,8 +28,8 @@ #include #include #include -#include #include +#include #include #include @@ -41,21 +41,20 @@ namespace relax { TVM_REGISTER_PASS_CONFIG_OPTION("relax.transform.apply_legalize_ops", bool); /*! - * \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose + * \brief Check if a given Tensor/Shape/TupleType contains shapes whose * values are all known. - * \param sinfo The StructInfo to be checked. - * \return A boolean indicating the given struct info contains shape values that are all known. + * \param ty The Type to be checked. + * \return A boolean indicating the given type contains shape values that are all known. */ -bool KnowAllShapeValues(const StructInfo& sinfo) { - if (const auto* tensor_sinfo = sinfo.as()) { - return tensor_sinfo->shape.defined() && - tensor_sinfo->shape.value()->IsInstance(); - } else if (const auto* shape_sinfo = sinfo.as()) { - return shape_sinfo->values.defined(); - } else if (const auto* tuple_sinfo = sinfo.as()) { - return std::all_of(tuple_sinfo->fields.begin(), tuple_sinfo->fields.end(), - [](StructInfo field_sinfo) { return KnowAllShapeValues(field_sinfo); }); - } else if (sinfo.as()) { +bool KnowAllShapeValues(const Type& ty) { + if (const auto* tensor_ty = ty.as()) { + return tensor_ty->shape.defined() && tensor_ty->shape.value()->IsInstance(); + } else if (const auto* shape_ty = ty.as()) { + return shape_ty->values.defined(); + } else if (const auto* tuple_ty = ty.as()) { + return std::all_of(tuple_ty->fields.begin(), tuple_ty->fields.end(), + [](Type field_ty) { return KnowAllShapeValues(field_ty); }); + } else if (ty.as()) { return true; } else { return false; @@ -124,8 +123,8 @@ class LegalizeMutator : public ExprMutator { bool pure_legalized_op = [&]() -> bool { if (auto legalized_op = call->op.as()) { return purity_map.get(legalized_op.value(), false); - } else if (auto func_sinfo = call->op->struct_info_.as()) { - return func_sinfo->purity; + } else if (auto func_ty = call->op->ty.as()) { + return func_ty->purity; } else { return false; } @@ -145,20 +144,20 @@ class LegalizeMutator : public ExprMutator { for (auto arg : ret->args) { ret_args.push_back(arg); } - return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args); + return Call(call_pure_packed_op, ret_args, ret->attrs, ret->ty_args); } - ffi::Optional GetTarget(const ffi::Array& sinfos) { - for (auto sinfo : sinfos) { - if (const auto* tinfo = sinfo.as()) { + ffi::Optional GetTarget(const ffi::Array& types) { + for (auto ty : types) { + if (const auto* tinfo = ty.as()) { if (tinfo->vdevice.defined()) { auto vdevice = tinfo->vdevice.value(); if (vdevice->target.defined()) { return vdevice->target; } } - } else if (const auto* tup_sinfo = sinfo.as()) { - return GetTarget(tup_sinfo->fields); + } else if (const auto* tup_ty = ty.as()) { + return GetTarget(tup_ty->fields); } } return std::nullopt; @@ -177,7 +176,7 @@ class LegalizeMutator : public ExprMutator { auto call = Downcast(expr); - auto vdevice_target = GetTarget(call->sinfo_args); + auto vdevice_target = GetTarget(call->ty_args); if (!vdevice_target.defined()) { // No vdevice annotation is present, so we don't need to apply // any updates. @@ -214,7 +213,7 @@ class LegalizeMutator : public ExprMutator { // The FLegalize function generated a PrimFunc, but that PrimFunc // doesn't have annotations compatible with the vdevice required - // by the Relax StructInfo. Update the call to instead call a + // by the Relax Type. Update the call to instead call a // `PrimFunc` with the appropriate target annotation. In the // future, this may be treated as a bug in the FLegalize // implementation, rather than expected output from it. @@ -266,7 +265,7 @@ class LegalizeMutator : public ExprMutator { bool arg_shapes_defined = std::all_of(visited_call->args.begin(), visited_call->args.end(), - [](Expr arg) { return KnowAllShapeValues(GetStructInfo(arg)); }); + [](Expr arg) { return KnowAllShapeValues(GetType(arg)); }); if (!arg_shapes_defined) { // This operator cannot be legalized, because legalization // requires the argument shapes to be known. @@ -283,7 +282,7 @@ class LegalizeMutator : public ExprMutator { // This fallback would only be applicable for cases where // both the dtype and the dimensionality are known. While // Relax can express a tensor with unknown dtype and - // dimensionality as `TensorStructInfo(DataType::Void(), + // dimensionality as `TensorType(DataType::Void(), // kUnknownNDim)`, TIR cannot express unknown dtype or // unknown dimensionality. return false; @@ -298,7 +297,7 @@ class LegalizeMutator : public ExprMutator { } return false; }(); - bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call)); + bool ret_shape_defined = KnowAllShapeValues(GetType(visited_call)); if (!is_data_dependent_op && !ret_shape_defined) { // This operator cannot be legalized, because legalization by // default requires the output shape. The exception is @@ -332,7 +331,7 @@ class LegalizeMutator : public ExprMutator { // Third choice, use an explicit ffi::String replacement. This does not require the shape ffi::String packed_func_name = call_packed_map[op]; legalization_func = [packed_func_name](const BlockBuilder& bb, const Call& call) -> Expr { - return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetStructInfo(call)}); + return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetType(call)}); }; } else { // No legalization. diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 7e9b526b9a71..05bd3e08154c 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -106,12 +106,12 @@ struct BaseCollectInfo { } for (const auto& var : outputs) { - Var out_var(var->name_hint() + "_output", GetStructInfo(var)); + Var out_var(var->name_hint() + "_output", GetType(var)); output_var_binding.push_back(VarBinding(out_var, var)); output_exprs.push_back(out_var); } - Var tuple_var("output_tuple", TupleStructInfo(output_exprs.Map(GetStructInfo))); + Var tuple_var("output_tuple", TupleType(output_exprs.Map(GetType))); output_var_binding.push_back(VarBinding(tuple_var, Tuple(output_exprs))); SeqExpr body( @@ -120,7 +120,7 @@ struct BaseCollectInfo { DataflowBlock(output_var_binding), }, tuple_var); - Function func(params, body, GetStructInfo(tuple_var)); + Function func(params, body, GetType(tuple_var)); func = WithAttr(func, attr::kNumInput, 0); func = CopyWithNewVars(func); func = BundleModelParams(func); @@ -141,11 +141,9 @@ struct GlobalCollectInfo : public BaseCollectInfo { // The cross-function between between TIR variables. ffi::Map tir_var_remap; ffi::Array GetPropagatedSymbolicVariables() const { - auto vars_from_original_params = - DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + auto vars_from_original_params = DefinableTIRVarsInType(TupleType(params.Map(GetType))); auto vars_from_transformed_params = [&]() -> std::unordered_set { - auto tir_vars = - DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo))); + auto tir_vars = DefinableTIRVarsInType(TupleType(GetCompileTimeOutputs().Map(GetType))); return {tir_vars.begin(), tir_vars.end()}; }(); @@ -183,18 +181,15 @@ struct LocalCollectInfo : public BaseCollectInfo { } ffi::Array GetPropagatedSymbolicVariables() const { - auto vars_from_any_param = - DefinableTIRVarsInStructInfo(TupleStructInfo(orig_func->params.Map(GetStructInfo))); + auto vars_from_any_param = DefinableTIRVarsInType(TupleType(orig_func->params.Map(GetType))); auto vars_from_runtime_params = [&]() -> std::unordered_set { - auto tir_var_vec = - DefinableTIRVarsInStructInfo(TupleStructInfo(GetRuntimeInputs().Map(GetStructInfo))); + auto tir_var_vec = DefinableTIRVarsInType(TupleType(GetRuntimeInputs().Map(GetType))); return {tir_var_vec.begin(), tir_var_vec.end()}; }(); auto vars_from_transformed_params = [&]() -> std::unordered_set { - auto tir_var_vec = - DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo))); + auto tir_var_vec = DefinableTIRVarsInType(TupleType(GetCompileTimeOutputs().Map(GetType))); return {tir_var_vec.begin(), tir_var_vec.end()}; }(); @@ -254,9 +249,8 @@ struct LocalCollectInfo : public BaseCollectInfo { return global_tir_vars; }(); if (propagated_tir_vars.size()) { - ShapeStructInfo shape_sinfo( - propagated_tir_vars.Map([](tirx::Var var) -> PrimExpr { return var; })); - Var shape_expr("vars_from_compile_time_params", shape_sinfo); + ShapeType shape_ty(propagated_tir_vars.Map([](tirx::Var var) -> PrimExpr { return var; })); + Var shape_expr("vars_from_compile_time_params", shape_ty); params.push_back(shape_expr); } ffi::Array compile_time_outputs = [&]() { @@ -285,7 +279,7 @@ struct LocalCollectInfo : public BaseCollectInfo { return global_outputs; }(); for (const auto& var : compile_time_outputs) { - Var param_var(var->name_hint(), GetStructInfo(var)); + Var param_var(var->name_hint(), GetType(var)); bindings.push_back(VarBinding(var, param_var)); params.push_back(param_var); } @@ -326,7 +320,7 @@ struct LocalCollectInfo : public BaseCollectInfo { Expr body = SuppressCompileTime(to_suppress)(orig_func->body); body = SeqExpr({DataflowBlock(bindings)}, body); - Function func(params, body, orig_func->ret_struct_info, orig_func->is_pure, orig_func->attrs); + Function func(params, body, orig_func->ret_ty, orig_func->is_pure, orig_func->attrs); func = CopyWithNewVars(func); func = Downcast(CanonicalizeBindings(func)); return func; @@ -365,9 +359,9 @@ class BaseLiftableBindingCollector : public ExprVisitor { } } - // Cond 4. Do not lift when its struct info contains symbolic variables that do not appear in + // Cond 4. Do not lift when its type contains symbolic variables that do not appear in // params. - for (const auto& var : TIRVarsInStructInfo(GetStructInfo(binding->var))) { + for (const auto& var : TIRVarsInType(GetType(binding->var))) { if (!liftable_vars_.count(var)) { return false; } @@ -445,7 +439,7 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector { for (size_t i = num_runtime_params; i < func->params.size(); i++) { liftable_vars_.insert(func->params[i]); info_.requires_compile_time_param.insert(func->params[i]); - for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(func->params[i]))) { + for (const auto& tir_var : DefinableTIRVarsInType(GetType(func->params[i]))) { liftable_vars_.insert(tir_var); } } @@ -530,8 +524,7 @@ class ParamRemapper : private ExprFunctor { int index_0 = j + num_inputs_0; mapper.VisitExpr(functions[i]->params[index_i], functions[0]->params[index_0]); ffi::StructuralEqual eq; - eq(functions[i]->params[index_i]->struct_info_, - functions[0]->params[index_0]->struct_info_); + eq(functions[i]->params[index_i]->ty, functions[0]->params[index_0]->ty); } } } @@ -546,11 +539,11 @@ class ParamRemapper : private ExprFunctor { } else { var_remap_.Set(ffi::GetRef(lhs_var), rhs_var); } - TVM_FFI_ICHECK(tvm::ffi::StructuralEqual::Equal(lhs_var->struct_info_, rhs_var->struct_info_, + TVM_FFI_ICHECK(tvm::ffi::StructuralEqual::Equal(lhs_var->ty, rhs_var->ty, /*map_free_vars=*/true)) - << "The struct info of the parameters should be the same for all target functions"; - auto lhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(ffi::GetRef(lhs_var))); - auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr)); + << "The type of the parameters should be the same for all target functions"; + auto lhs_tir_vars = DefinableTIRVarsInType(GetType(ffi::GetRef(lhs_var))); + auto rhs_tir_vars = DefinableTIRVarsInType(GetType(rhs_expr)); TVM_FFI_ICHECK_EQ(lhs_tir_vars.size(), rhs_tir_vars.size()); for (size_t i = 0; i < lhs_tir_vars.size(); i++) { if (auto it = tir_var_remap_.find(lhs_tir_vars[i]); it != tir_var_remap_.end()) { @@ -684,7 +677,7 @@ class ConsumeBundledParams : public ExprMutator { builder_->Emit( Call(call_pure_packed, {builtin_tuple_reset_item, tuple_get_item->tuple, PrimValue(tuple_get_item->index)}, - tvm::Attrs(), {TupleStructInfo(ffi::Array{})})); + tvm::Attrs(), {TupleType(ffi::Array{})})); } else { ExprMutator::VisitBinding_(binding, tuple_get_item); } @@ -696,7 +689,7 @@ class ConsumeBundledParams : public ExprMutator { auto num_input = opt_num_input.value(); TVM_FFI_ICHECK_EQ(func->params.size(), num_input + 1); params_ = func->params.back(); - TVM_FFI_ICHECK(params_->struct_info_.as()); + TVM_FFI_ICHECK(params_->ty.as()); return ExprMutator::VisitExpr_(func); } @@ -815,7 +808,7 @@ Pass PartitionTransformParams(ffi::Variant> shared transform = ComposeFunctions(old_transform, transform); } GlobalVar new_gvar(name); - UpdateStructInfo(new_gvar, GetStructInfo(transform)); + UpdateType(new_gvar, GetType(transform)); write_ptr->Add(new_gvar, transform); } } diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 793dbd3f3f43..1b6078f0ead3 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -57,8 +57,8 @@ class Mutator : public ExprMutator { return ptr->values; } - auto sinfo = GetStructInfo(shape_arg); - if (auto ptr = sinfo.as()) { + auto ty = GetType(shape_arg); + if (auto ptr = ty.as()) { if (ptr->values) { return ptr->values.value(); } @@ -67,7 +67,7 @@ class Mutator : public ExprMutator { TVM_FFI_THROW(InternalError) << "Shape argument for " << alloc_tensor_op << " should be a ShapeExpr, " << "or a variable that holds a ShapeExpr. " - << "However, received argument " << shape_arg << " with struct info " << sinfo; + << "However, received argument " << shape_arg << " with type " << ty; TVM_FFI_UNREACHABLE(); }(); diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index ae07faa22359..70d4276937b7 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -57,8 +57,8 @@ #include #include #include -#include #include +#include #include #include "../../support/arena.h" @@ -311,8 +311,8 @@ class CompositeInliner : public ExprMutator { Function Run(Function func) { inlined_functions_ = ffi::Map(); auto new_body = VisitExpr(ToNonDataflow(func->body)); - auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure, - func->attrs, func->span); + auto new_func = + Function(func->params, new_body, func->ret_ty, func->is_pure, func->attrs, func->span); return new_func; } diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index ac3f0611db48..0b2b7b13bf1e 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -20,7 +20,7 @@ /*! * \file tvm/relax/transform/normalize.cc * \brief Pass for transforming Relax IR to normal form, i.e., the expressions are normalized(no - * nesting and hence the AST is in ANF), and all struct_info_ of expressions are + * nesting and hence the AST is in ANF), and all ty of expressions are * available. */ @@ -28,8 +28,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace relax { @@ -49,7 +49,7 @@ class NormalizeMutator : public ExprMutatorBase { if (body.same_as(op->body)) { return ffi::GetRef(op); } else { - return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); + return Function(op->params, body, op->ret_ty, op->is_pure, op->attrs); } } @@ -147,8 +147,8 @@ class NormalizeMutator : public ExprMutatorBase { void VisitBinding_(const VarBindingNode* binding) { Expr new_value = this->VisitExpr(binding->value); - if (!binding->var->struct_info_.defined()) { - UpdateStructInfo(binding->var, GetStructInfo(new_value)); + if (!binding->var->ty.defined()) { + UpdateType(binding->var, GetType(new_value)); } if (new_value.same_as(binding->value)) { @@ -165,7 +165,7 @@ class NormalizeMutator : public ExprMutatorBase { builder_->EmitNormalized(ffi::GetRef(binding)); } else { builder_->EmitNormalized( - MatchCast(binding->var, builder_->NormalizeArgument(new_value), binding->struct_info)); + MatchCast(binding->var, builder_->NormalizeArgument(new_value), binding->ty)); } } diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 2c78dbca257c..ca3a229c6b11 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -103,12 +103,11 @@ class DeviceHintCollector : ExprVisitor { void VisitExpr_(const FunctionNode* func) override { ExprVisitor::VisitExpr_(func); - std::function check_ret_sinfo = [this, &check_ret_sinfo]( - Expr expr, StructInfo sinfo) { + std::function check_ret_ty = [this, &check_ret_ty](Expr expr, Type ty) { // If the function is annotated as returning a tensor on a // specific device, then that annotation may be propagated into // the returned variable. - if (auto tensor_info = sinfo.as(); + if (auto tensor_info = ty.as(); tensor_info && tensor_info->vdevice.defined()) { if (auto opt_var = expr.as()) { auto var = opt_var.value(); @@ -122,7 +121,7 @@ class DeviceHintCollector : ExprVisitor { // where some elements of the tuple are tensors that exist on a // specific device, then those annotations may be propagated // into the corresponding tensor annotations. - if (auto tuple_info = sinfo.as()) { + if (auto tuple_info = ty.as()) { // The returned tuple is not necessarily an in-line tuple. In // order to find the variables that are bound to the // individual tuple elements, we may need to unwrap the @@ -145,18 +144,17 @@ class DeviceHintCollector : ExprVisitor { << "but is annotated as returning a tuple with " << tuple_info->fields.size() << " elements"; for (size_t i = 0; i < tuple->fields.size(); i++) { - check_ret_sinfo(tuple->fields[i], tuple_info->fields[i]); + check_ret_ty(tuple->fields[i], tuple_info->fields[i]); } } } }; - check_ret_sinfo(func->body->body, func->ret_struct_info); + check_ret_ty(func->body->body, func->ret_ty); } void VisitVarDef(const Var& var) override { - if (auto tinfo = var->struct_info_.as(); - tinfo && tinfo->vdevice.defined()) { + if (auto tinfo = var->ty.as(); tinfo && tinfo->vdevice.defined()) { known_vdevice_.Set(var, tinfo->vdevice.value()); } ExprVisitor::VisitVarDef(var); @@ -201,7 +199,7 @@ class DeviceHintCollector : ExprVisitor { // A map from Var to the VDevice they are known to occur on. This // only contains variables whose location is explicitly known // (e.g. output of `R.hint_on_device`, variables with explicit - // `VDevice` in their struct info), and does not include variables + // `VDevice` in their type), and does not include variables // whose location is (e.g. input of `R.hint_on_device`). ffi::Map known_vdevice_; @@ -324,10 +322,10 @@ ffi::Map InferVDevice(IRModule mod) { } // Update the module to include the inferred VDevice annotations. -class VDeviceStructInfoUpdater : ExprMutator { +class VDeviceTypeUpdater : ExprMutator { public: static IRModule Apply(IRModule mod, ffi::Map vdevice_map) { - VDeviceStructInfoUpdater mutator(VDeviceLookup(mod), vdevice_map); + VDeviceTypeUpdater mutator(VDeviceLookup(mod), vdevice_map); IRModule updates; @@ -348,26 +346,26 @@ class VDeviceStructInfoUpdater : ExprMutator { } private: - VDeviceStructInfoUpdater(VDeviceLookup vdevice_lookup, ffi::Map vdevice_map) + VDeviceTypeUpdater(VDeviceLookup vdevice_lookup, ffi::Map vdevice_map) : vdevice_lookup_(vdevice_lookup), vdevice_map_(vdevice_map) {} Var VisitVarDef(const Var& old_var) override { auto var = ExprMutator::VisitVarDef(old_var); - if (auto tinfo = var->struct_info_.as()) { + if (auto tinfo = var->ty.as()) { if (auto opt = vdevice_map_.Get(old_var)) { auto vdevice = opt.value(); - TensorStructInfo new_sinfo = [&]() { + TensorType new_ty = [&]() { if (tinfo->shape.defined()) { - return TensorStructInfo(tinfo->shape.value(), tinfo->dtype, vdevice, tinfo->span); + return TensorType(tinfo->shape.value(), tinfo->dtype, vdevice, tinfo->span); } else { - return TensorStructInfo(tinfo->dtype, tinfo->ndim, vdevice, tinfo->span); + return TensorType(tinfo->dtype, tinfo->ndim, vdevice, tinfo->span); } }(); if (var->IsInstance()) { - var = DataflowVar(var->vid, new_sinfo, var->span); + var = DataflowVar(var->vid, new_ty, var->span); } else { - var = Var(var->vid, new_sinfo, var->span); + var = Var(var->vid, new_ty, var->span); } } } @@ -386,7 +384,7 @@ class VDeviceStructInfoUpdater : ExprMutator { TVM_FFI_ICHECK_EQ(call->args.size(), 1); auto arg = call->args[0]; - auto input_vdevice = Downcast(arg->struct_info_)->vdevice; + auto input_vdevice = Downcast(arg->ty)->vdevice; auto output_vdevice = vdevice_lookup_(call->attrs); if (input_vdevice.defined() && input_vdevice.value() == output_vdevice) { @@ -410,7 +408,7 @@ namespace transform { Pass RealizeVDevice() { auto pass_func = [=](IRModule mod, PassContext pc) { auto known_vdevices = InferVDevice(mod); - return VDeviceStructInfoUpdater::Apply(mod, known_vdevices); + return VDeviceTypeUpdater::Apply(mod, known_vdevices); }; return CreateModulePass(/*pass_function=*/pass_func, /*opt_level=*/0, diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc index 11bdbade6a1e..4fd7f73b74d3 100644 --- a/src/relax/transform/remove_purity_checking.cc +++ b/src/relax/transform/remove_purity_checking.cc @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #include namespace tvm { @@ -42,8 +42,7 @@ class PurityRemover : public ExprMutator { } auto new_body = VisitExpr(ret->body); if (!new_body.same_as(ret->body)) { - return Function(ret->params, new_body, ret->ret_struct_info, ret->is_pure, ret->attrs, - ret->span); + return Function(ret->params, new_body, ret->ret_ty, ret->is_pure, ret->attrs, ret->span); } return ret; } @@ -51,17 +50,17 @@ class PurityRemover : public ExprMutator { Expr VisitExpr_(const CallNode* call) override { if (call->op == call_pure_packed_op_) { auto ret = Call(call->args[0], ffi::Array(call->args.begin() + 1, call->args.end()), - call->attrs, call->sinfo_args); + call->attrs, call->ty_args); return VisitExpr(ret); } if (call->op == call_inplace_packed_op_) { // call_inplace_packed has its own attrs so we don't pass those down auto ret = Call(call->args[0], ffi::Array(call->args.begin() + 1, call->args.end()), - tvm::Attrs(), call->sinfo_args); + tvm::Attrs(), call->ty_args); return VisitExpr(ret); } if (call->op == invoke_pure_closure_op_) { - auto ret = Call(invoke_closure_op_, call->args, call->attrs, call->sinfo_args); + auto ret = Call(invoke_closure_op_, call->args, call->attrs, call->ty_args); return VisitExpr(ret); } return ExprMutator::VisitExpr_(call); diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 26d669d2d6b2..9962309f9fd4 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -50,7 +50,7 @@ class PartialTupleUsageCollector : ExprVisitor { if (!is_exposed) { if (auto relax_func = base_func.as()) { - if (auto out_tuple = relax_func->ret_struct_info.as()) { + if (auto out_tuple = relax_func->ret_ty.as()) { num_outputs[gvar] = out_tuple->fields.size(); } } @@ -121,7 +121,7 @@ class PartialTupleUsageCollector : ExprVisitor { } std::vector* GetCalleeUsageMask(Expr expr) { - if (!expr->struct_info_.as()) { + if (!expr->ty.as()) { return nullptr; } @@ -158,17 +158,17 @@ class PartialTupleUsageCollector : ExprVisitor { }; Function UpdateCallee(Function func, const std::vector& usage_mask) { - auto old_func_sinfo = func->struct_info_.as(); + auto old_func_ty = func->ty.as(); - auto old_ret_sinfo = func->ret_struct_info.as(); - TVM_FFI_ICHECK(old_ret_sinfo) << "All functions returning non-tuple outputs " - << "should have been pruned already by PartialTupleUsageCollector"; + auto old_ret_ty = func->ret_ty.as(); + TVM_FFI_ICHECK(old_ret_ty) << "All functions returning non-tuple outputs " + << "should have been pruned already by PartialTupleUsageCollector"; ffi::Array outputs; // This helper variable will be removed by the post-proc of // CanonicalizeBindings and DeadCodeElimination. - Var previous_outputs("previous_outputs", func->ret_struct_info); + Var previous_outputs("previous_outputs", func->ret_ty); for (size_t i = 0; i < usage_mask.size(); i++) { if (usage_mask[i]) { @@ -177,19 +177,17 @@ Function UpdateCallee(Function func, const std::vector& usage_mask) { } Expr new_output = outputs.size() == 1 ? outputs[0] : Tuple(outputs); - StructInfo new_return_sinfo = - outputs.size() == 1 ? GetStructInfo(outputs[0]) : TupleStructInfo(outputs.Map(GetStructInfo)); + Type new_return_ty = outputs.size() == 1 ? GetType(outputs[0]) : TupleType(outputs.Map(GetType)); VarBinding binding(previous_outputs, func->body); BindingBlock binding_block({binding}); SeqExpr new_body({binding_block}, new_output); - auto old_sinfo = Downcast(func->struct_info_); - FuncStructInfo new_sinfo(old_func_sinfo->params.value(), new_return_sinfo, - old_func_sinfo->purity); + auto old_ty = Downcast(func->ty); + FuncType new_ty(old_func_ty->params.value(), new_return_ty, old_func_ty->purity); auto write_ptr = func.CopyOnWrite(); - write_ptr->struct_info_ = new_sinfo; + write_ptr->ty = new_ty; write_ptr->body = new_body; return func; @@ -242,7 +240,7 @@ Pass RemoveUnusedOutputs() { auto new_func = UpdateCallee(func.value(), usage_mask); GlobalVar new_gvar(gvar->name_hint); - new_gvar->struct_info_ = new_func->struct_info_; + new_gvar->ty = new_func->ty; new_callees->Add(new_gvar, new_func); callsite_updaters[gvar] = [old_gvar = gvar, new_gvar, usage_mask](Call call) -> Expr { @@ -250,14 +248,14 @@ Pass RemoveUnusedOutputs() { << "Updater should be applied to " << old_gvar << ", but was applied to " << call->op; - auto old_call_sinfo = call->struct_info_.as(); - TVM_FFI_CHECK(old_call_sinfo, InternalError) + auto old_call_ty = call->ty.as(); + TVM_FFI_CHECK(old_call_ty, InternalError) << "Updater should be applied to Call producing an output tuple, " - << "but " << call << " has struct info " << call->struct_info_; - TVM_FFI_ICHECK_EQ(usage_mask.size(), old_call_sinfo->fields.size()) + << "but " << call << " has type " << call->ty; + TVM_FFI_ICHECK_EQ(usage_mask.size(), old_call_ty->fields.size()) << "Function " << call->op << " produces " << usage_mask.size() << " outputs, " << "but " << call << " was used in a context expecting " - << old_call_sinfo->fields.size() << " outputs."; + << old_call_ty->fields.size() << " outputs."; Call new_call(new_gvar, call->args); diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index daf4c6e2fd9c..598478c9c2d7 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -86,9 +86,8 @@ std::optional AnalyzeCallee(Function func) { // to reduce computational steps in the parent, but we need to // provide the symbolic variables the other steps. auto defined_tir_params = [&]() -> PSet { - auto param_sinfo = - TupleStructInfo(params.Map([](const auto& var) { return GetStructInfo(var); })); - auto arr = DefinableTIRVarsInStructInfo(param_sinfo); + auto param_ty = TupleType(params.Map([](const auto& var) { return GetType(var); })); + auto arr = DefinableTIRVarsInType(param_ty); return {arr.begin(), arr.end()}; }(); @@ -101,13 +100,12 @@ std::optional AnalyzeCallee(Function func) { } for (const auto& tir_var : free_tir_vars) { - Var relax_var("param_" + tir_var->name_hint, PrimStructInfo(tir_var)); + Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var)); params.push_back(relax_var); } - FuncStructInfo new_sinfo(params.Map([](const auto& var) { return GetStructInfo(var); }), - func->ret_struct_info, - Downcast(func->struct_info_)->purity); + FuncType new_ty(params.Map([](const auto& var) { return GetType(var); }), func->ret_ty, + Downcast(func->ty)->purity); auto arg_updater = [parameter_mask, old_relax_params = func->params, free_tir_vars](ffi::Array old_args) -> ffi::Array { @@ -140,7 +138,7 @@ std::optional AnalyzeCallee(Function func) { auto write_ptr = func.CopyOnWrite(); write_ptr->params = params; - write_ptr->struct_info_ = new_sinfo; + write_ptr->ty = new_ty; return CalleeAnalysis{func, arg_updater}; } @@ -196,7 +194,7 @@ Pass RemoveUnusedParameters() { if (auto callee_res = AnalyzeCallee(func.value())) { auto new_func = callee_res->func; GlobalVar new_gvar(gvar->name_hint); - new_gvar->struct_info_ = new_func->struct_info_; + new_gvar->ty = new_func->ty; new_callees->Add(new_gvar, new_func); callsite_updaters[gvar] = [old_gvar = gvar, new_gvar, @@ -221,7 +219,7 @@ Pass RemoveUnusedParameters() { // Remove any private subroutines that have unused parameters, // then add the updated versions. The new private functions // have the same name, but require a new GlobalVar to hold the - // updated StructInfo. As a result, calling `Update()` without + // updated Type. As a result, calling `Update()` without // first calling `Remove()` introduce a duplicate name and // produce an error. for (const auto& it : callsite_updaters) { diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index 88c64521b047..8d9a1cac9b06 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -98,12 +98,12 @@ std::tuple)>> } else { auto call = Downcast(expr); ffi::Array permutation; - auto arg_sinfo = call->args[0]->struct_info_.as(); - TVM_FFI_ICHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, " - << "but argument " << call->args[0] << " has struct info " - << call->args[0]->struct_info_; - TVM_FFI_ICHECK_GE(arg_sinfo->ndim, 0); - size_t ndim = arg_sinfo->ndim; + auto arg_ty = call->args[0]->ty.as(); + TVM_FFI_ICHECK(arg_ty) << "Expected permute_dims to have a single tensor argument, " + << "but argument " << call->args[0] << " has type " + << call->args[0]->ty; + TVM_FFI_ICHECK_GE(arg_ty->ndim, 0); + size_t ndim = arg_ty->ndim; for (size_t i = 0; i < ndim; i++) { permutation.push_back(static_cast(ndim - i - 1)); } diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index 96c41bea8ef0..bd36c5cb89c5 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -64,26 +64,26 @@ std::tuple)>> << "Attributes for relax.take operator should be TakeAttrs, " << "but were instead " << take_call->attrs << " with type " << take_call->GetTypeKey(); - const auto* lhs_sinfo = lhs->struct_info_.as(); - if (!lhs_sinfo) return expr; + const auto* lhs_ty = lhs->ty.as(); + if (!lhs_ty) return expr; - const auto* weights_sinfo = weights->struct_info_.as(); - if (!weights_sinfo) return expr; + const auto* weights_ty = weights->ty.as(); + if (!weights_ty) return expr; - const auto* indices_sinfo = indices->struct_info_.as(); - if (!indices_sinfo) return expr; + const auto* indices_ty = indices->ty.as(); + if (!indices_ty) return expr; - const auto* matmul_sinfo = expr->struct_info_.as(); - if (!matmul_sinfo) return expr; + const auto* matmul_ty = expr->ty.as(); + if (!matmul_ty) return expr; if (!attrs->axis.has_value()) return expr; auto axis = attrs->axis.value(); - if (lhs_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim() || - matmul_sinfo->IsUnknownNdim() || weights_sinfo->IsUnknownNdim()) + if (lhs_ty->IsUnknownNdim() || indices_ty->IsUnknownNdim() || matmul_ty->IsUnknownNdim() || + weights_ty->IsUnknownNdim()) return expr; - if (indices_sinfo->ndim == 1 && axis + 1 == weights_sinfo->ndim) { + if (indices_ty->ndim == 1 && axis + 1 == weights_ty->ndim) { // Simpler case. The activations may have batch dimensions, but // the weights do not. @@ -94,18 +94,17 @@ std::tuple)>> // out_table.shape = [*batch, table_size] auto out_table = matmul(lhs, weights, DataType::Void()); // new_output.shape = [*batch, outfeatures] - auto new_output = take(out_table, indices, matmul_sinfo->ndim - 1); + auto new_output = take(out_table, indices, matmul_ty->ndim - 1); return new_output; - } else if (lhs_sinfo->ndim == 3 && weights_sinfo->ndim == 3 && indices_sinfo->ndim == 1 && - axis == 0 && weights_sinfo->GetShape().defined() && - lhs_sinfo->GetShape().defined()) { + } else if (lhs_ty->ndim == 3 && weights_ty->ndim == 3 && indices_ty->ndim == 1 && axis == 0 && + weights_ty->GetShape().defined() && lhs_ty->GetShape().defined()) { // More complicated case, used for batched LoRA. The conditions // on the argument dimensions can probably be relaxed, but would // probably need to remove the use of the einsum operator. - auto lhs_shape = lhs_sinfo->GetShape().value(); - auto weight_shape = weights_sinfo->GetShape().value(); + auto lhs_shape = lhs_ty->GetShape().value(); + auto weight_shape = weights_ty->GetShape().value(); // lhs.shape = [batch1, batch2, infeatures] // weights.shape = [table_size, infeatures, outfeatures] diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index a7ec2a587615..af41746f2fac 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -133,13 +133,11 @@ class FuncBuilder : public ExprMutator { tir_var_remap_.Set(ffi::GetRef(var), new_var); tir_vars.push_back(new_var); } - shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars)); + shape_expr = Var("shape_expr", ShapeType(tir_vars)); } // Set up the parameters for (const auto* input : inputs_) { - auto new_var = Var(input->name_hint(), - VisitExprDepStructInfoField( - Downcast>(input->struct_info_).value())); + auto new_var = Var(input->name_hint(), VisitExprDepTypeField(Downcast(input->ty))); var_remap_[input->vid] = new_var; params.push_back(new_var); } @@ -161,7 +159,7 @@ class FuncBuilder : public ExprMutator { auto body = builder_->Normalize(SeqExpr({block}, output)); ffi::Map attrs; attrs.Set(relax::attr::kForcePure, true); - auto func = Function(params, body, Downcast(output->struct_info_.value()), + auto func = Function(params, body, Downcast(output->ty), /*is_pure=*/true, /*attrs=*/DictAttrs(attrs)); return func; } @@ -241,18 +239,18 @@ class CUDAGraphRewritePlanner : public ExprVisitor { if (pair.second->IsInstance()) { // If a function has the num_input attribute, the last func->params.size() - num_inputs // inputs are assumed to be fixed and thus they can be captured into a cuda graph. - // The symbolic variables in the struct info of the fixed inputs (weights) are also allowed + // The symbolic variables in the type of the fixed inputs (weights) are also allowed // to be captured. // If the hints for capturing symbolic variables via // 'relax.rewrite_cuda_graph.capture_symbolic_vars' annotation, the actual variables with - // these names are extracted from the struct info for the capturing. + // these names are extracted from the type for the capturing. const auto& func = Downcast(pair.second); int64_t num_inputs = func->attrs.GetAttr(attr::kNumInput).value_or(func->params.size()); auto capture_symbolic_var_name_hints = ExtractSymbolicVarHints(func); for (int i = 0; i < static_cast(func->params.size()); ++i) { - ffi::Array symbolic_vars = DefinableTIRVarsInStructInfo( - Downcast(func->params[i]->struct_info_.value())); + ffi::Array symbolic_vars = + DefinableTIRVarsInType(Downcast(func->params[i]->ty)); if (i < num_inputs) { for (const auto& symbolic_var : symbolic_vars) { if (capture_symbolic_var_name_hints.count(symbolic_var->name_hint)) { @@ -513,9 +511,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor { if (vars_collector != nullptr) { vars_collector->push_back(var); } - // recursively check the struct info to collect the symbolic TIR vars - return static_vars_.count(var) && IsStatic(Downcast(var->struct_info_.value()), - vars_collector, tir_vars_collector); + // recursively check the type to collect the symbolic TIR vars + return static_vars_.count(var) && + IsStatic(Downcast(var->ty), vars_collector, tir_vars_collector); } if (const auto* shape = expr.as()) { @@ -542,19 +540,19 @@ class CUDAGraphRewritePlanner : public ExprVisitor { return result; } - bool IsStatic(const StructInfo& sinfo, std::vector* vars_collector = nullptr, + bool IsStatic(const Type& ty, std::vector* vars_collector = nullptr, std::vector* tir_vars_collector = nullptr) { - if (const auto* tensor_sinfo = sinfo.as()) { - if (auto shape = tensor_sinfo->GetShape()) { + if (const auto* tensor_ty = ty.as()) { + if (auto shape = tensor_ty->GetShape()) { return IsStatic(shape.value(), vars_collector, tir_vars_collector); } - } else if (const auto* shape_sinfo = sinfo.as()) { - if (shape_sinfo->values) { - return IsStatic(shape_sinfo->values.value(), vars_collector, tir_vars_collector); + } else if (const auto* shape_ty = ty.as()) { + if (shape_ty->values) { + return IsStatic(shape_ty->values.value(), vars_collector, tir_vars_collector); } - } else if (const auto* tuple_sinfo = sinfo.as()) { - return IsStatic(tuple_sinfo->fields, vars_collector, tir_vars_collector); - } else if (sinfo.as() || sinfo.as()) { + } else if (const auto* tuple_ty = ty.as()) { + return IsStatic(tuple_ty->fields, vars_collector, tir_vars_collector); + } else if (ty.as() || ty.as()) { return true; } return false; @@ -784,15 +782,15 @@ class CUDAGraphRewriter : public ExprMutator { TVM_FFI_ICHECK(!plan->propogated_tir_vars.defined()); TVM_FFI_ICHECK(plan->inputs.empty()); auto gv_alloc = gv_global_alloc_.value(); - auto ret_struct_info = Downcast(gv_alloc->struct_info_.value())->ret; + auto ret_ty = Downcast(gv_alloc->ty)->ret; launch_subgraph = Call(call_builtin_with_ctx_op, {builtin_get_cached_alloc, Tuple({gv_alloc, PrimValue(IntImm::Int64(0))})}, Attrs(), - {ret_struct_info}); + {ret_ty}); } else { auto gv_func = builder_->AddFunction( plan->func, current_func_.value()->name_hint + "_cuda_graph_capture"); - StructInfo call_sinfo = plan->func->ret_struct_info; + Type call_ty = plan->func->ret_ty; // Arguments of the lifted function ffi::Array args; for (const auto& arg : plan->inputs) { @@ -801,18 +799,17 @@ class CUDAGraphRewriter : public ExprMutator { if (plan->propogated_tir_vars.defined()) { ShapeExpr propogated_tir_vars = plan->propogated_tir_vars.value(); args.push_back(propogated_tir_vars); - // The ret_struct_info of the lifted function can contain symbolic variables. We need to + // The ret_ty of the lifted function can contain symbolic variables. We need to // bind the symbolic parameters to the actual values. const auto& shape_expr = plan->func->params.back(); - auto symbolic_params = - Downcast(shape_expr->struct_info_.value())->values.value(); + auto symbolic_params = Downcast(shape_expr->ty)->values.value(); ffi::Map tir_var_remap; TVM_FFI_ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size()); for (int i = 0; i < static_cast(symbolic_params.size()); ++i) { tir_var_remap.Set(Downcast(symbolic_params[i]), propogated_tir_vars->values[i]); } - call_sinfo = Bind(call_sinfo, tir_var_remap); + call_ty = Bind(call_ty, tir_var_remap); } // Arguments of builtin_run_or_capture ffi::Array tuple_arg_fields{gv_func, Tuple(args), @@ -823,9 +820,8 @@ class CUDAGraphRewriter : public ExprMutator { // passing it twice simplifies the handling during the capture phase. tuple_arg_fields.push_back(plan->propogated_tir_vars.value()); } - launch_subgraph = - Call(call_builtin_with_ctx_op, {builtin_run_or_capture, Tuple(tuple_arg_fields)}, Attrs(), - {call_sinfo}); + launch_subgraph = Call(call_builtin_with_ctx_op, + {builtin_run_or_capture, Tuple(tuple_arg_fields)}, Attrs(), {call_ty}); } Expr ret_value = builder_->Emit(launch_subgraph); for (const auto& [var, tuple_index] : plan->outputs) { diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 46ccdd82dfa4..9dcbfa94839b 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -104,8 +104,8 @@ class DataflowReshapeRewriter : public ExprMutator { return ffi::GetRef(call); } - TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); - return reshape(arg, res_sinfo->shape.value()); + TensorType res_ty = Downcast(call->ty); + return reshape(arg, res_ty->shape.value()); } bool IsCallingTIRReshape(const CallNode* call, Expr inp) { @@ -120,15 +120,15 @@ class DataflowReshapeRewriter : public ExprMutator { // as the number of elements in the result. There are operators that could have a reshape // pattern that don't meet this requirement (e.g. strided_slice), and they should not be // converted to reshape. - TVM_FFI_ICHECK(inp->struct_info_.defined() && call->struct_info_.defined()); - TensorStructInfo inp_sinfo = Downcast(inp->struct_info_.value()); - TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); + TVM_FFI_ICHECK(inp->ty.defined() && call->ty.defined()); + TensorType inp_ty = Downcast(inp->ty); + TensorType res_ty = Downcast(call->ty); - if (inp_sinfo->IsUnknownDtype() || inp_sinfo->dtype != res_sinfo->dtype) { + if (inp_ty->IsUnknownDtype() || inp_ty->dtype != res_ty->dtype) { return false; } - TVM_FFI_ICHECK(inp_sinfo->shape.defined() && res_sinfo->shape.defined()); - if (inp_sinfo->IsUnknownNdim() || res_sinfo->IsUnknownNdim()) { + TVM_FFI_ICHECK(inp_ty->shape.defined() && res_ty->shape.defined()); + if (inp_ty->IsUnknownNdim() || res_ty->IsUnknownNdim()) { return false; } auto product = [](ffi::Array args) -> PrimExpr { @@ -142,8 +142,8 @@ class DataflowReshapeRewriter : public ExprMutator { for (int i = 1, e = args.size(); i < e; ++i) p *= args[i]; return p; }; - auto inp_count = product(inp_sinfo->GetShape().value()); - auto res_count = product(res_sinfo->GetShape().value()); + auto inp_count = product(inp_ty->GetShape().value()); + auto res_count = product(res_ty->GetShape().value()); if (!arith::Analyzer()->CanProveEqual(inp_count, res_count)) { return false; } diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index efd90d6696d7..d5d42f004ef1 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -110,19 +110,18 @@ class CodeGenRunner : ExprMutator { if (auto const* gvar_node = call_node->op.as()) { const GlobalVar gvar = ffi::GetRef(gvar_node); - auto create_call_dps_packed = [call_node, this](Expr extern_func, - StructInfo ret_struct_info) { + auto create_call_dps_packed = [call_node, this](Expr extern_func, Type ret_ty) { ffi::Array new_args({extern_func}); new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); }))); static const Op& call_op = Op::Get("relax.call_dps_packed"); - return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info}); + return Call(call_op, new_args, tvm::Attrs(), {ret_ty}); }; - auto ret_sinfo = GetStructInfo(call); + auto ret_ty = GetType(call); if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) { - return create_call_dps_packed(it->second, ret_sinfo); + return create_call_dps_packed(it->second, ret_ty); } else if (auto opt_func = builder_->GetContextIRModule()->Lookup(gvar).as()) { // TODO(@sunggg): Is there any better way to get this func? Function func = opt_func.value(); @@ -137,7 +136,7 @@ class CodeGenRunner : ExprMutator { func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol).cast(); func = (*RemoveFuncAttrFunc)(func, attr::kCodegen).cast(); builder_->UpdateFunction(gvar, func); - return create_call_dps_packed(new_func, ret_sinfo); + return create_call_dps_packed(new_func, ret_ty); } } } @@ -146,7 +145,7 @@ class CodeGenRunner : ExprMutator { new_args.push_back(VisitExpr(arg)); } - return Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + return Call(call_node->op, new_args, call_node->attrs, call_node->ty_args, call_node->span); } Expr VisitExpr_(const FunctionNode* func_node) override { diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc index 456391b033d6..e540b9bcd634 100644 --- a/src/relax/transform/specialize_primfunc_based_on_callsite.cc +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -39,8 +39,8 @@ namespace relax { using tvm::tirx::Buffer; -static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { - auto shape = tensor_sinfo->GetShape(); +static ffi::Array GetShapeFromTensorType(const TensorType& tensor_ty) { + auto shape = tensor_ty->GetShape(); TVM_FFI_ICHECK(shape.defined()); return shape.value(); } @@ -83,14 +83,14 @@ class SpecializeTIRCallArgs : ExprMutator { ffi::Map> param_map; for (size_t i = 0; i < args.size(); ++i) { - auto sinfo = GetStructInfo(args[i]); - TVM_FFI_ICHECK(sinfo->IsInstance()) + auto ty = GetType(args[i]); + TVM_FFI_ICHECK(ty->IsInstance()) << "Expected Tensor struct Info for call :" << call->op; - auto tensor_sinfo = Downcast(sinfo); - TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Shape undefined for call:" << call->args[0]; + auto tensor_ty = Downcast(ty); + TVM_FFI_ICHECK(tensor_ty->shape.defined()) << "Shape undefined for call:" << call->args[0]; ffi::String scope = "global"; - if (tensor_sinfo->vdevice.defined()) { - scope = tensor_sinfo->vdevice.value()->memory_scope; + if (tensor_ty->vdevice.defined()) { + scope = tensor_ty->vdevice.value()->memory_scope; } ffi::String name; if (args[i]->IsInstance()) { @@ -99,40 +99,40 @@ class SpecializeTIRCallArgs : ExprMutator { name = std::string({static_cast('A' + i)}); } - const Buffer& buffer = tirx::decl_buffer(GetShapeFromTensorStructInfo(tensor_sinfo), - tensor_sinfo->dtype, name, scope); + const Buffer& buffer = + tirx::decl_buffer(GetShapeFromTensorType(tensor_ty), tensor_ty->dtype, name, scope); param_map.Set(pfunc->params[i], buffer); } ffi::String scope = "global"; - auto out_sinfo = call->sinfo_args[0]; - if (out_sinfo->IsInstance()) { - auto sinfo = Downcast(out_sinfo); - if (sinfo->vdevice.defined()) { - scope = sinfo->vdevice.value()->memory_scope; + auto out_ty = call->ty_args[0]; + if (out_ty->IsInstance()) { + auto ty = Downcast(out_ty); + if (ty->vdevice.defined()) { + scope = ty->vdevice.value()->memory_scope; } const Buffer& buffer = - tirx::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + tirx::decl_buffer(GetShapeFromTensorType(ty), ty->dtype, "ret_val", scope); param_map.Set(pfunc->params[pfunc->params.size() - 1], buffer); } else { - TVM_FFI_ICHECK(out_sinfo->IsInstance()) - << "Expect output struct info of call_tir to be either TupleStructInfo or " - "TensorStructInfo, but got " - << out_sinfo; + TVM_FFI_ICHECK(out_ty->IsInstance()) + << "Expect output type of call_tir to be either TupleType or " + "TensorType, but got " + << out_ty; - const auto& tuple_sinfo = Downcast(out_sinfo); - ffi::Array sinfo_fields; + const auto& tuple_ty = Downcast(out_ty); + ffi::Array ty_fields; int index = 0; - for (const auto& si : tuple_sinfo->fields) { - TVM_FFI_ICHECK(si->IsInstance()) - << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + for (const auto& si : tuple_ty->fields) { + TVM_FFI_ICHECK(si->IsInstance()) + << "Fields of TupleType must be TensorType for call_tir " "output structinfo, but got " << si; - auto sinfo = Downcast(si); - if (sinfo->vdevice.defined()) { - scope = sinfo->vdevice.value()->memory_scope; + auto ty = Downcast(si); + if (ty->vdevice.defined()) { + scope = ty->vdevice.value()->memory_scope; } - const Buffer& buffer = tirx::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, + const Buffer& buffer = tirx::decl_buffer(GetShapeFromTensorType(ty), ty->dtype, "ret_val_" + std::to_string(index), scope); param_map.Set(pfunc->params[args.size() + index], buffer); index++; diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index b73faa39007e..7a80ebc7b191 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -754,7 +754,7 @@ class SplitMutator : public ExprMutator { tirx::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); DataType dtype = intermediate_buffer->dtype; Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs, - {TensorStructInfo(ShapeExpr(intermediate_buffer->shape), dtype)}); + {TensorType(ShapeExpr(intermediate_buffer->shape), dtype)}); Var call_var1 = builder_->Emit(call1); // emit the second call to the rest of the function ffi::Array args2; @@ -763,7 +763,7 @@ class SplitMutator : public ExprMutator { args2.push_back(GetCallTIRArgs(call->args[1])[p]); } GlobalVar gv2 = builder_->AddFunction(func2, "unfused_epilogue"); - Call call2(call_tir_op_, {gv2, Tuple(args2)}, call->attrs, call->sinfo_args); + Call call2(call_tir_op_, {gv2, Tuple(args2)}, call->attrs, call->ty_args); builder_->UpdateFunction(gv, WithoutAttr(func, "global_symbol")); return call2; } diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index c4103429c30b..c04aecdfeefc 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -299,7 +299,7 @@ class SplitLayoutRewritePreproc : public ExprMutator { // Step 5: Emit the preproc call ffi::Array call_tir_args = Downcast(call->args[1])->fields; ffi::Array preproc_args; - ffi::Array preproc_sinfo_list; + ffi::Array preproc_ty_list; for (const auto& info : rewrite_infos) { preproc_args.push_back(call_tir_args[info.buffer_index]); tirx::Buffer rewritten_buffer = info.post_rewrite_buffer; @@ -308,16 +308,16 @@ class SplitLayoutRewritePreproc : public ExprMutator { << "Currently does not support rewrite buffer with " "dynamic shape."; } - preproc_sinfo_list.push_back( - TensorStructInfo(ShapeExpr(rewritten_buffer->shape), rewritten_buffer->dtype)); + preproc_ty_list.push_back( + TensorType(ShapeExpr(rewritten_buffer->shape), rewritten_buffer->dtype)); } - StructInfo preproc_sinfo = preproc_sinfo_list.size() > 1 // - ? TupleStructInfo(preproc_sinfo_list) // - : preproc_sinfo_list[0]; + Type preproc_ty = preproc_ty_list.size() > 1 // + ? TupleType(preproc_ty_list) // + : preproc_ty_list[0]; // Step 6: Call the preproc function Expr preproc_call = - builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)}, {}, {preproc_sinfo})); + builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)}, {}, {preproc_ty})); if (rewrite_infos.size() == 1) { call_tir_args.Set(rewrite_infos[0].buffer_index, preproc_call); } else { @@ -326,7 +326,7 @@ class SplitLayoutRewritePreproc : public ExprMutator { } } Expr main_call = - builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)}, {}, call->sinfo_args)); + builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)}, {}, call->ty_args)); return main_call; } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index b69dce0155f0..95cf570f9681 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -436,7 +436,7 @@ void SetTIRVarRangeConstraints(Function func, arith::AnalyzerObj* ana, for (const ffi::String& var_name : non_negative_var_attr_raw) { non_negative_var_attr.insert(var_name); } - ffi::Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); + ffi::Array var_in_signature = TIRVarsInType(GetType(func)); for (const tirx::Var& tir_var : var_in_signature) { auto it_upper = var_upper_bound_attr.find(tir_var->name_hint); auto it_lower = var_lower_bound_attr.find(tir_var->name_hint); @@ -631,12 +631,12 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // - the shape of the tensor is known, in the form of ShapeExpr; // - the tensor has known dtype; // - no storage token was created for this call before. - const auto* sinfo = call->struct_info_.as(); - TVM_FFI_ICHECK_NOTNULL(sinfo); - const auto* shape = sinfo->shape.as(); + const auto* ty = call->ty.as(); + TVM_FFI_ICHECK_NOTNULL(ty); + const auto* shape = ty->shape.as(); TVM_FFI_ICHECK_NOTNULL(shape); - TVM_FFI_ICHECK(!sinfo->IsUnknownDtype()); - TVM_FFI_ICHECK(sinfo->dtype == Downcast(call->args[1])->value); + TVM_FFI_ICHECK(!ty->IsUnknownDtype()); + TVM_FFI_ICHECK(ty->dtype == Downcast(call->args[1])->value); TVM_FFI_ICHECK(!token_map_.count(call)); // Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic @@ -653,7 +653,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { } ffi::Optional vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); - StorageToken token(upper_bounded_shape, sinfo->dtype, storage_scope->value, vdevice); + StorageToken token(upper_bounded_shape, ty->dtype, storage_scope->value, vdevice); Tokens tokens(token); SetTokens(call, tokens); @@ -925,9 +925,9 @@ class StorageAllocationRewriter : public ExprMutator { if (it != alloc_tensor2token_.end()) { // Case 1. This `alloc_tensor` is planned for memory reuse. TVM_FFI_ICHECK_EQ(call->op, alloc_tensor_op); - const auto* sinfo = call->struct_info_.as(); - TVM_FFI_ICHECK_NOTNULL(sinfo); - TVM_FFI_ICHECK_NOTNULL(sinfo->shape.as()); + const auto* ty = call->ty.as(); + TVM_FFI_ICHECK_NOTNULL(ty); + TVM_FFI_ICHECK_NOTNULL(ty->shape.as()); PrimValue runtime_device_index = Downcast(call->args[2]); // If the token is visited for the first time, create a storage variable using @@ -951,9 +951,9 @@ class StorageAllocationRewriter : public ExprMutator { // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`. PrimValue offset = PrimValue::Int64(0); - DataType dtype = sinfo->dtype; + DataType dtype = ty->dtype; return Call(mem_alloc_tensor, - {storage_var, offset, sinfo->shape.value(), DataTypeImm(dtype), call->args[2]}, + {storage_var, offset, ty->shape.value(), DataTypeImm(dtype), call->args[2]}, Attrs()); } else if (plan_dynamic_output_ && call->op == alloc_tensor_op) { // Case 2. For a `alloc_tensor` that is not planned for memory reuse, @@ -962,30 +962,30 @@ class StorageAllocationRewriter : public ExprMutator { // estimation, we allocate a storage to its upper bound size, and // allocate a tensor out from it with the actual symbolic shape. - const auto* sinfo = call->struct_info_.as(); - TVM_FFI_ICHECK_NOTNULL(sinfo); - const auto* shape = sinfo->shape.as(); + const auto* ty = call->ty.as(); + TVM_FFI_ICHECK_NOTNULL(ty); + const auto* shape = ty->shape.as(); TVM_FFI_ICHECK_NOTNULL(shape); ffi::Array upper_bounded_shape = GetUpperBoundShape(shape->values, ana_.get(), dom_map_); if (!IsStaticShape(shape->values)) { - TVM_FFI_ICHECK(!sinfo->IsUnknownDtype()); - TVM_FFI_ICHECK_EQ(sinfo->dtype, Downcast(call->args[1])->value); + TVM_FFI_ICHECK(!ty->IsUnknownDtype()); + TVM_FFI_ICHECK_EQ(ty->dtype, Downcast(call->args[1])->value); PrimExpr bytes = upper_bounded_shape[0]; for (int i = 1; i < static_cast(upper_bounded_shape.size()); ++i) { bytes *= upper_bounded_shape[i]; } - bytes *= sinfo->dtype.bytes() * sinfo->dtype.lanes(); + bytes *= ty->dtype.bytes() * ty->dtype.lanes(); Call alloc_storage(mem_alloc_storage, {/*size=*/ShapeExpr({bytes}), /*virtual_device_index=*/Downcast(call->args[2]), /*storage_scope=*/Downcast(call->args[3]), // - /*dtype=*/DataTypeImm(sinfo->dtype)}); + /*dtype=*/DataTypeImm(ty->dtype)}); Var storage = builder_->Emit(alloc_storage, "storage"); return Call(mem_alloc_tensor, {storage, // /*offset=*/PrimValue::Int64(0), /*shape=*/ffi::GetRef(shape), // - /*dtype=*/DataTypeImm(sinfo->dtype), + /*dtype=*/DataTypeImm(ty->dtype), /*vdevice_index=*/call->args[2]}); } } diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index d9992cfe5901..ba49d42b308c 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -218,14 +218,13 @@ class DTypeDecisionCollector : public ExprVisitor { // require the i-th field rhs tuple to be the type of the lhs NType lhs_type = GetDType(binding->var); std::vector require_rhs; - const TupleStructInfoNode* sinfo = - tuple_get_item_node->tuple->struct_info_.as(); - TVM_FFI_ICHECK(sinfo != nullptr) << "TupleGetItemNode must have TupleStructInfo"; - for (size_t i = 0; i < sinfo->fields.size(); ++i) { + const TupleTypeNode* ty = tuple_get_item_node->tuple->ty.as(); + TVM_FFI_ICHECK(ty != nullptr) << "TupleGetItemNode must have TupleType"; + for (size_t i = 0; i < ty->fields.size(); ++i) { if (i == static_cast(tuple_get_item_node->index)) { require_rhs.push_back(lhs_type); } else { - require_rhs.push_back(NTypeFrom(sinfo->fields[i], unknown_)); + require_rhs.push_back(NTypeFrom(ty->fields[i], unknown_)); } } RequireArgsToType({tuple_get_item_node->tuple}, {NType(require_rhs)}); @@ -239,8 +238,8 @@ class DTypeDecisionCollector : public ExprVisitor { this->VisitBindingBlock(*it); } - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); + if (auto* ty = op->ty.as()) { + this->VisitExprDepTypeField(ffi::GetRef(ty)); } } @@ -258,8 +257,8 @@ class DTypeDecisionCollector : public ExprVisitor { this->VisitExpr(op->false_branch); this->VisitExpr(op->cond); - if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); + if (auto* ty = op->ty.as()) { + this->VisitExprDepTypeField(ffi::GetRef(ty)); } } @@ -285,15 +284,14 @@ class ToMixedPrecisionRewriter : public ExprMutator { return it->second; } else { if (fp16_input_names_.count(var->name_hint())) { - auto sinfo = GetStructInfo(var); - if (auto tensor_sinfo = sinfo.as()) { + auto ty = GetType(var); + if (auto tensor_ty = ty.as()) { VDevice vdev = VDevice(); - if (tensor_sinfo->vdevice.defined()) { - vdev = tensor_sinfo->vdevice.value(); + if (tensor_ty->vdevice.defined()) { + vdev = tensor_ty->vdevice.value(); } - TensorStructInfo fp16_sinfo(tensor_sinfo->shape.value(), DataType::Float(16), vdev, - tensor_sinfo->span); - Var fp16_var(var->vid, fp16_sinfo, var->span); + TensorType fp16_ty(tensor_ty->shape.value(), DataType::Float(16), vdev, tensor_ty->span); + Var fp16_var(var->vid, fp16_ty, var->span); var_remap_[var->vid] = fp16_var; return fp16_var; } @@ -311,7 +309,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { // Note that this function only accepts expr with nested tensor type Expr RewriteExpr(const Expr& expr, const NType& to) { auto fvisitleaf = [&](const Expr& expr, std::array to) -> Expr { - const auto* tensor = GetStructInfoAs(expr); + const auto* tensor = GetTypeAs(expr); TVM_FFI_ICHECK(tensor != nullptr) << "Only support rewriting tensor expr"; // We only rewrite the expr if the dtype is not the same as the given dtype if (NTypeEqual()(to[0], NTypeFrom(expr))) return expr; @@ -346,9 +344,9 @@ class ToMixedPrecisionRewriter : public ExprMutator { } bool AllFP16Castable(const ffi::Array& args) { - auto is_fp16 = [](StructInfo sinfo) { - if (auto tensor_sinfo = sinfo.as(); - tensor_sinfo && tensor_sinfo->dtype == DataType::Float(16)) { + auto is_fp16 = [](Type ty) { + if (auto tensor_ty = ty.as(); + tensor_ty && tensor_ty->dtype == DataType::Float(16)) { return true; } return false; @@ -391,11 +389,11 @@ class ToMixedPrecisionRewriter : public ExprMutator { }; for (const Expr& arg : args) { - auto sinfo = GetStructInfo(arg); + auto ty = GetType(arg); auto constant = arg.as(); auto tuple = arg.as(); - if (!IsNestedTensor(arg) || is_fp16(sinfo) || (constant && is_in_fp16_range(constant)) || + if (!IsNestedTensor(arg) || is_fp16(ty) || (constant && is_in_fp16_range(constant)) || (tuple && AllFP16Castable(tuple->fields))) { continue; } else { @@ -511,7 +509,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (opt_new_dtype) { auto new_dtype = opt_new_dtype.value(); new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype); - new_call.CopyOnWrite()->struct_info_ = std::nullopt; + new_call.CopyOnWrite()->ty = Type(); new_value = builder_->Normalize(Call(new_call)); @@ -535,7 +533,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } ffi::ObjectPtr new_tuple = ffi::make_object(*tuple_node); new_tuple->fields = RemapArgs(tuple_node->fields); - new_tuple->struct_info_ = std::nullopt; + new_tuple->ty = Type(); Expr new_value = builder_->Normalize(Tuple(new_tuple)); if (!binding->var->IsInstance()) { // Global var: store the tensors to the original dtype @@ -555,7 +553,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { ffi::ObjectPtr new_tuple_get_item = ffi::make_object(*tuple_get_item_node); new_tuple_get_item->tuple = RemapArgs({tuple_get_item_node->tuple})[0]; - new_tuple_get_item->struct_info_ = std::nullopt; + new_tuple_get_item->ty = Type(); Expr new_value = TupleGetItem(new_tuple_get_item); if (!binding->var->IsInstance()) { // Global var: store the tensors to the original dtype diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index f9fd6c12232a..0e341509a339 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -22,7 +22,6 @@ */ #include #include -#include #include #include #include @@ -35,7 +34,7 @@ class ToNonDFMutator : public ExprMutator { public: Var VisitVarDef(const Var& var) final { if (var.as()) { - Var new_var = Var(var->vid, GetStructInfo(var), var->span); + Var new_var = Var(var->vid, GetType(var), var->span); this->var_remap_[var->vid] = new_var; return new_var; } diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc index 1c5f7461aa75..cf5154a0a61f 100644 --- a/src/relax/transform/topological_sort.cc +++ b/src/relax/transform/topological_sort.cc @@ -24,8 +24,8 @@ #include #include #include -#include #include +#include #include #include diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_type.cc similarity index 78% rename from src/relax/transform/update_param_struct_info.cc rename to src/relax/transform/update_param_type.cc index 031a552a00e7..7d8105de0e7a 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_type.cc @@ -18,7 +18,7 @@ */ /*! - * \file tvm/relax/transform/update_param_struct_info.cc + * \file tvm/relax/transform/update_param_type.cc * \brief Mutate IRModule to accept new parameters */ @@ -38,10 +38,10 @@ namespace tvm { namespace relax { namespace { -class ParamStructInfoMutator : public ExprMutator { +class ParamTypeMutator : public ExprMutator { public: - explicit ParamStructInfoMutator(ffi::TypedFunction(Var)> sinfo_func) - : sinfo_func_(sinfo_func) {} + explicit ParamTypeMutator(ffi::TypedFunction(Var)> ty_func) + : ty_func_(ty_func) {} using ExprMutator::VisitExpr_; using ExprMutator::VisitVarDef_; @@ -50,8 +50,8 @@ class ParamStructInfoMutator : public ExprMutator { auto func = ffi::GetRef(op); auto params = op->params.Map([this](Var param) { - if (auto new_sinfo = sinfo_func_(param)) { - auto new_param = WithStructInfo(param, new_sinfo.value()); + if (auto new_ty = ty_func_(param)) { + auto new_param = WithType(param, new_ty.value()); var_remap_[param->vid] = new_param; return new_param; } else { @@ -65,14 +65,14 @@ class ParamStructInfoMutator : public ExprMutator { return ExprMutator::VisitExpr_(func.get()); } - ffi::TypedFunction(Var)> sinfo_func_; + ffi::TypedFunction(Var)> ty_func_; }; } // namespace namespace transform { -Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_func) { +Pass UpdateParamType(ffi::TypedFunction(Var)> ty_func) { auto pass_func = [=](IRModule mod, PassContext pc) { - ParamStructInfoMutator mutator(sinfo_func); + ParamTypeMutator mutator(ty_func); std::unordered_set to_remove; std::unordered_map to_add; @@ -82,7 +82,7 @@ Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> si auto updated = Downcast(mutator(func.value())); if (!updated.same_as(base_func)) { GlobalVar new_gvar(gvar->name_hint); - UpdateStructInfo(new_gvar, GetStructInfo(updated)); + UpdateType(new_gvar, GetType(updated)); to_add.insert({new_gvar, updated}); to_remove.insert(gvar); } @@ -102,12 +102,12 @@ Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> si return mod; }; - return tvm::transform::CreateModulePass(pass_func, 1, "UpdateParamStructInfo", {}); + return tvm::transform::CreateModulePass(pass_func, 1, "UpdateParamType", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.UpdateParamStructInfo", UpdateParamStructInfo); + refl::GlobalDef().def("relax.transform.UpdateParamType", UpdateParamType); } } // namespace transform diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc index a6cbb83b8c73..d0684e34f5f4 100644 --- a/src/relax/transform/update_vdevice.cc +++ b/src/relax/transform/update_vdevice.cc @@ -43,8 +43,8 @@ class VDeviceMutator : public ExprMutator { Expr VisitExpr(const Expr& expr) final { auto visited_expr = ExprMutator::VisitExpr(expr); - if (visited_expr->struct_info_.defined()) { - auto* tinfo = GetStructInfoAs(visited_expr); + if (visited_expr->ty.defined()) { + auto* tinfo = GetTypeAs(visited_expr); bool unchanged = true; if (tinfo != nullptr) { if (tinfo->vdevice.defined()) { @@ -56,11 +56,10 @@ class VDeviceMutator : public ExprMutator { } if (!unchanged) { if (tinfo->shape.defined()) { - visited_expr->struct_info_ = - TensorStructInfo(tinfo->shape.value(), tinfo->dtype, new_vdevice_, tinfo->span); + visited_expr->ty = + TensorType(tinfo->shape.value(), tinfo->dtype, new_vdevice_, tinfo->span); } else { - visited_expr->struct_info_ = - TensorStructInfo(tinfo->dtype, tinfo->ndim, new_vdevice_, tinfo->span); + visited_expr->ty = TensorType(tinfo->dtype, tinfo->ndim, new_vdevice_, tinfo->span); } } } diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc index 32f080cb5124..db082c766962 100644 --- a/src/relax/transform/utils.cc +++ b/src/relax/transform/utils.cc @@ -24,34 +24,34 @@ namespace tvm { namespace relax { -bool IsScalarTensor(const StructInfo& sinfo) { - if (!sinfo->IsInstance()) { +bool IsScalarTensor(const Type& ty) { + if (!ty->IsInstance()) { return false; } - TensorStructInfo tensor_sinfo = Downcast(sinfo); - if (!tensor_sinfo->shape.defined() || !tensor_sinfo->shape->IsInstance()) { + TensorType tensor_ty = Downcast(ty); + if (!tensor_ty->shape.defined() || !tensor_ty->shape->IsInstance()) { return false; } - return tensor_sinfo->shape.as()->values.size() == 0; + return tensor_ty->shape.as()->values.size() == 0; } -bool IsScalarTensor(const Expr& expr) { return IsScalarTensor(GetStructInfo(expr)); } +bool IsScalarTensor(const Expr& expr) { return IsScalarTensor(GetType(expr)); } -bool IsNestedTensor(const StructInfo& sinfo) { - return IsNestedTensorConditioned(sinfo, [](const TensorStructInfo& sinfo) { return true; }); +bool IsNestedTensor(const Type& ty) { + return IsNestedTensorConditioned(ty, [](const TensorType& ty) { return true; }); } -bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetStructInfo(expr)); } +bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetType(expr)); } Function ComposeFunctions(Function func_a, Function func_b) { ffi::Array bindings; - Var func_a_output("func_a_output", func_a->ret_struct_info); + Var func_a_output("func_a_output", func_a->ret_ty); bindings.push_back(VarBinding(func_a_output, func_a->body)); auto func_a_outputs = [&]() -> ffi::Array { - if (auto func_a_output_tuple = func_a->ret_struct_info.as()) { + if (auto func_a_output_tuple = func_a->ret_ty.as()) { ffi::Array outputs; for (size_t i = 0; i < func_a_output_tuple->fields.size(); i++) { outputs.push_back(TupleGetItem(func_a_output, i)); @@ -62,12 +62,12 @@ Function ComposeFunctions(Function func_a, Function func_b) { } }(); - if (func_b->params.size() == 1 && func_b->params[0]->struct_info_.as()) { + if (func_b->params.size() == 1 && func_b->params[0]->ty.as()) { // Special case where the output of the first function is a tuple // that should be provided as-is to the second function, and // should not be unpacked into individual elements. auto param = func_b->params[0]; - bindings.push_back(MatchCast(param, func_a_output, GetStructInfo(param))); + bindings.push_back(MatchCast(param, func_a_output, GetType(param))); } else { TVM_FFI_CHECK_EQ(func_a_outputs.size(), func_b->params.size(), ValueError) << "Cannot compose functions together. " @@ -75,13 +75,13 @@ Function ComposeFunctions(Function func_a, Function func_b) { << "but second function expects " << func_b->params.size() << " parameters as input"; for (size_t i = 0; i < func_a_outputs.size(); i++) { auto param = func_b->params[i]; - bindings.push_back(MatchCast(param, func_a_outputs[i], GetStructInfo(param))); + bindings.push_back(MatchCast(param, func_a_outputs[i], GetType(param))); } } auto new_body = SeqExpr({BindingBlock(bindings)}, func_b->body); - auto new_function = Function(func_a->params, new_body, func_b->ret_struct_info, + auto new_function = Function(func_a->params, new_body, func_b->ret_ty, func_a->is_pure && func_b->is_pure, func_a->attrs); new_function = CopyWithNewVars(new_function); diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index dcd61df174d1..95c3fcd6f0b1 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -146,12 +146,12 @@ IRModule MakeGroupedFunctions( bool lift_constants = true, const ffi::Array& entry_function_names = {}); /*! - * \brief Check if the given StructInfo is a scalar tensor. The sinfo should be an instance of - * TensorStructInfo; its shape must be ShapeExpr. - * \param sinfo The StructInfo to be checked. - * \return true if the given StructInfo is a scalar tensor. + * \brief Check if the given Type is a scalar tensor. The ty should be an instance of + * TensorType; its shape must be ShapeExpr. + * \param ty The Type to be checked. + * \return true if the given Type is a scalar tensor. */ -bool IsScalarTensor(const StructInfo& sinfo); +bool IsScalarTensor(const Type& ty); /*! * \brief Check if the given expr is a scalar tensor. Now the shape of the tensor expr must be @@ -162,32 +162,32 @@ bool IsScalarTensor(const StructInfo& sinfo); bool IsScalarTensor(const Expr& expr); /*! - * \brief Check if the given StructInfo is a nested tensor StructInfo satisfying the given + * \brief Check if the given Type is a nested tensor Type satisfying the given * condition f_condition. - * \param sinfo The StructInfo to be checked. - * \param f_condition The condition function for each leaf StructInfo with signature - * `bool f_condition(TensorStructInfo)`. + * \param ty The Type to be checked. + * \param f_condition The condition function for each leaf Type with signature + * `bool f_condition(TensorType)`. * \tparam FType The condition function type. - * \return true if the given StructInfo is a nested tensor satisfying the given f_condition. + * \return true if the given Type is a nested tensor satisfying the given f_condition. */ template -bool IsNestedTensorConditioned(const StructInfo& sinfo, FType f_condition) { - if (const auto* tensor_sinfo = sinfo.as()) { - return f_condition(ffi::GetRef(tensor_sinfo)); - } else if (const auto* tuple_sinfo = sinfo.as()) { - return !std::any_of( - tuple_sinfo->fields.begin(), tuple_sinfo->fields.end(), - [&](const StructInfo& field) { return !IsNestedTensorConditioned(field, f_condition); }); +bool IsNestedTensorConditioned(const Type& ty, FType f_condition) { + if (const auto* tensor_ty = ty.as()) { + return f_condition(ffi::GetRef(tensor_ty)); + } else if (const auto* tuple_ty = ty.as()) { + return !std::any_of(tuple_ty->fields.begin(), tuple_ty->fields.end(), [&](const Type& field) { + return !IsNestedTensorConditioned(field, f_condition); + }); } return false; } /*! - * \brief Check if the given StructInfo is a nested tensor. - * \param sinfo The StructInfo to be checked. - * \return true if the given StructInfo is a nested tensor. + * \brief Check if the given Type is a nested tensor. + * \param ty The Type to be checked. + * \return true if the given Type is a nested tensor. */ -bool IsNestedTensor(const StructInfo& sinfo); +bool IsNestedTensor(const Type& ty); /*! * \brief Check if the given expr is a nested tensor. @@ -270,8 +270,8 @@ class SymbolicVarRenewMutator : public ExprMutator, tirx::ExprMutator { if (all_params_unchanged && body.same_as(op->body)) { return ffi::GetRef(op); } else { - auto new_ret_sinfo = this->VisitExprDepStructInfoField(op->ret_struct_info); - return Function(params, body, new_ret_sinfo, op->is_pure, op->attrs); + auto new_ret_ty = this->VisitExprDepTypeField(op->ret_ty); + return Function(params, body, new_ret_ty, op->is_pure, op->attrs); } } @@ -294,7 +294,7 @@ class FunctionCopier : public SymbolicVarRenewMutator { Var VisitVarDef_(const DataflowVarNode* var) override { Var new_var = SymbolicVarRenewMutator::VisitVarDef_(var); - Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var), new_var->span); + Var copied_var = DataflowVar(new_var->name_hint(), GetType(new_var), new_var->span); var_remap_[var->vid] = copied_var; relax_var_map_.Set(ffi::GetRef(var), copied_var); return copied_var; @@ -302,7 +302,7 @@ class FunctionCopier : public SymbolicVarRenewMutator { Var VisitVarDef_(const VarNode* var) override { Var new_var = SymbolicVarRenewMutator::VisitVarDef_(var); - Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var), new_var->span); + Var copied_var = Var(new_var->name_hint(), GetType(new_var), new_var->span); var_remap_[var->vid] = copied_var; relax_var_map_.Set(ffi::GetRef(var), copied_var); return copied_var; diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 2155824bda38..d35b32ac58e9 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -57,13 +57,12 @@ class ExprBinder : public ExprMutator { Expr body = this->VisitWithNewScope(op->body, params); - // FuncStructInfo does not depend on Expr + // FuncType does not depend on Expr if (all_params_unchanged && body.same_as(op->body)) { return ffi::GetRef(op); } else { // purity won't be affected, no need to update annotation - return Function(params, body, VisitExprDepStructInfoField(op->ret_struct_info), op->is_pure, - op->attrs); + return Function(params, body, VisitExprDepTypeField(op->ret_ty), op->is_pure, op->attrs); } } @@ -103,9 +102,8 @@ Expr Bind(const Expr& expr, const tvm::ffi::Map& binds, return ExprBinder(binds, symbolic_var_map).VisitExpr(expr); } -StructInfo Bind(const StructInfo& sinfo, - const tvm::ffi::Map& symbolic_var_map) { - return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo); +Type Bind(const Type& ty, const tvm::ffi::Map& symbolic_var_map) { + return ExprBinder({}, symbolic_var_map).VisitExprDepTypeField(ty); } tvm::ffi::Map InferSymbolicVarMap( @@ -121,25 +119,24 @@ tvm::ffi::Map InferSymbolicVarMap( } }; - auto bind_from_prim_value = [&bind_from_prim_expr](const StructInfo& var, - const StructInfo& expr) { - auto var_sinfo = var.as(); - if (!var_sinfo) return; + auto bind_from_prim_value = [&bind_from_prim_expr](const Type& var, const Type& expr) { + auto var_ty = var.as(); + if (!var_ty) return; - auto expr_sinfo = expr.as(); - if (!expr_sinfo) return; + auto expr_ty = expr.as(); + if (!expr_ty) return; - if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return; + if (!var_ty->value.defined() || !expr_ty->value.defined()) return; - bind_from_prim_expr(var_sinfo->value.value(), expr_sinfo->value.value()); + bind_from_prim_expr(var_ty->value.value(), expr_ty->value.value()); }; - auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) { - auto var_shape = var.as(); + auto bind_from_shape = [&bind_from_prim_expr](const Type& var, const Type& expr) { + auto var_shape = var.as(); if (!var_shape) return; if (!var_shape->values.defined()) return; - auto expr_shape = expr.as(); + auto expr_shape = expr.as(); if (!expr_shape) return; if (!expr_shape->values.defined()) return; @@ -151,35 +148,34 @@ tvm::ffi::Map InferSymbolicVarMap( } }; - auto bind_from_tensor = [&bind_from_shape](const StructInfo& var, const StructInfo& expr) { - auto var_tensor = var.as(); + auto bind_from_tensor = [&bind_from_shape](const Type& var, const Type& expr) { + auto var_tensor = var.as(); if (!var_tensor) return; if (!var_tensor->shape.defined()) return; - auto expr_tensor = expr.as(); + auto expr_tensor = expr.as(); if (!expr_tensor) return; if (!expr_tensor->shape.defined()) return; - bind_from_shape(GetStructInfo(var_tensor->shape.value()), - GetStructInfo(expr_tensor->shape.value())); + bind_from_shape(GetType(var_tensor->shape.value()), GetType(expr_tensor->shape.value())); }; - std::function bind_from_struct_info = nullptr; - auto bind_from_tuple = [&bind_from_struct_info](const StructInfo& var, const StructInfo& expr) { - auto var_tuple = var.as(); + std::function bind_from_ty = nullptr; + auto bind_from_tuple = [&bind_from_ty](const Type& var, const Type& expr) { + auto var_tuple = var.as(); if (!var_tuple) return; - auto expr_tuple = expr.as(); + auto expr_tuple = expr.as(); if (!expr_tuple) return; if (var_tuple->fields.size() != expr_tuple->fields.size()) return; for (size_t i = 0; i < var_tuple->fields.size(); i++) { - bind_from_struct_info(var_tuple->fields[i], expr_tuple->fields[i]); + bind_from_ty(var_tuple->fields[i], expr_tuple->fields[i]); } }; - bind_from_struct_info = [&](const StructInfo& var, const StructInfo& expr) { + bind_from_ty = [&](const Type& var, const Type& expr) { bind_from_tensor(var, expr); bind_from_shape(var, expr); bind_from_prim_value(var, expr); @@ -187,23 +183,22 @@ tvm::ffi::Map InferSymbolicVarMap( }; for (const auto& [relax_var, relax_expr] : relax_var_remap) { - auto var_sinfo = GetStructInfo(relax_var); - auto expr_sinfo = GetStructInfo(relax_expr); - bind_from_struct_info(var_sinfo, expr_sinfo); + auto var_ty = GetType(relax_var); + auto expr_ty = GetType(relax_expr); + bind_from_ty(var_ty, expr_ty); } return tir_var_remap; } -bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank, - bool permit_unknown_dtype) { +bool IsBoolType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { DataType dtype; int ndim; - if (const auto* tensor = sinfo.as()) { + if (const auto* tensor = ty.as()) { dtype = tensor->dtype; ndim = tensor->ndim; - } else if (const auto* prim = sinfo.as()) { + } else if (const auto* prim = ty.as()) { dtype = prim->dtype; ndim = 0; } else { @@ -228,9 +223,9 @@ bool IsImpureCall(const Call& call) { << "Cannot find the registered purity of this op: " << op->name; return !(purity_map[op]); } - // the StructInfo must be FuncStructInfo - auto func_struct_info = GetStructInfoAs(call->op); - return !func_struct_info->purity; + // the Type must be FuncType + auto func_ty = GetTypeAs(call->op); + return !func_ty->purity; } Expr GetBoundValue(const Binding& b) { diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 322a0a137c17..8fc18c5c0722 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -43,7 +43,7 @@ namespace vm { using tvm::runtime::Tensor; //------------------------------------------------- -// Shape/StructInfo handling. +// Shape/Type handling. //------------------------------------------------- /*! * \brief Builtin function to allocate shape heap. diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 6183630da465..1edae6f874c2 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -38,20 +38,23 @@ IRModuleFrame IRModule() { // DeclFunction lives at the IR layer because an IRModule may host // heterogeneous function kinds (e.g. relax::Function, tirx::PrimFunc). -// To derive the GlobalVar's struct_info_ without coupling the IR layer to +// To derive the GlobalVar's ty without coupling the IR layer to // any specific dialect, dispatch is keyed by the function's type-key: // each dialect registers its own handler that maps a function of that -// type to the appropriate struct_info. -inline ffi::Optional GetGlobalVarStructInfo(const BaseFunc& func) { - if (func->struct_info_.defined()) { - return func->struct_info_; +// type to the appropriate ty. +inline ffi::Optional GetGlobalVarType(const BaseFunc& func) { + if (func->ty.defined()) { + return func->ty; } // Registry: "script.ir_builder.decl_function." — per-function-kind - // handler that derives the GlobalVar struct_info from the function signature. + // handler that derives the GlobalVar ty from the function signature. // Grep hint: grep -rn 'script.ir_builder.decl_function.' src/ const std::string key = "script.ir_builder.decl_function." + func->GetTypeKey(); if (auto fn = tvm::ffi::Function::GetGlobal(key)) { - return (*fn)(func).cast>(); + ffi::Optional result = (*fn)(func).cast>(); + if (result.defined()) { + return Downcast(result.value()); + } } return std::nullopt; } @@ -62,8 +65,8 @@ GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signat << "function " << func_name << " already exists"; GlobalVar gv = GlobalVar(func_name); - if (auto sinfo = GetGlobalVarStructInfo(func_signature)) { - gv->struct_info_ = sinfo.value(); + if (auto ty = GetGlobalVarType(func_signature)) { + gv->ty = ty.value(); } else { TVM_FFI_THROW(InternalError) << "Unsupported function type: " << func_signature->GetTypeKey(); } @@ -81,8 +84,8 @@ void DefFunction(const ffi::String& func_name, const BaseFunc& func) { << "function " << func_name << " does not exist, please declare it first."; const GlobalVar& gv = (*it).second; frame->functions.Set(gv, func); - if (auto sinfo = GetGlobalVarStructInfo(func)) { - gv->struct_info_ = sinfo.value(); + if (auto ty = GetGlobalVarType(func)) { + gv->ty = ty.value(); } else { TVM_FFI_THROW(InternalError) << "Unsupported function type: " << func->GetTypeKey(); } diff --git a/src/script/printer/script_printer.cc b/src/script/printer/script_printer.cc index d595898c919e..b8a7b6cefc10 100644 --- a/src/script/printer/script_printer.cc +++ b/src/script/printer/script_printer.cc @@ -123,8 +123,8 @@ PrinterConfig::PrinterConfig(ffi::Map config_dict) { } } // Boolean dialect keys. - if (auto v = config_dict.Get("relax.show_all_struct_info")) { - n->extra_config.Set(ffi::String("relax.show_all_struct_info"), v.value()); + if (auto v = config_dict.Get("relax.show_all_ty")) { + n->extra_config.Set(ffi::String("relax.show_all_ty"), v.value()); } if (auto v = config_dict.Get("extra_config")) { auto extra = Downcast>(v.value()); diff --git a/src/tirx/ir/function.cc b/src/tirx/ir/function.cc index 273ed1ae3c99..c44c2799801c 100644 --- a/src/tirx/ir/function.cc +++ b/src/tirx/ir/function.cc @@ -23,7 +23,8 @@ */ #include #include -#include +#include +#include #include #include #include @@ -37,40 +38,40 @@ TVM_FFI_STATIC_INIT_BLOCK() { } namespace { -relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { - ffi::Array params; +tvm::Type InferType(const PrimFunc& prim_func) { + ffi::Array params; for (const auto& param : prim_func->params) { - relax::StructInfo param_sinfo = [&]() -> relax::StructInfo { + tvm::Type param_ty = [&]() -> tvm::Type { if (auto opt_buf = prim_func->buffer_map.Get(param)) { auto buf = opt_buf.value(); relax::ShapeExpr shape( buf->shape.Map([](PrimExpr dim) { return cast(DataType::Int(64), dim); })); - return relax::TensorStructInfo(shape, buf->dtype); + return relax::TensorType(shape, buf->dtype); } if (auto prim_type = param->type_annotation.as(); prim_type && prim_type->dtype.is_handle()) { - return relax::ObjectStructInfo(); + return relax::ObjectType(); } - return relax::PrimStructInfo(param->dtype); + return relax::PrimType(param->dtype); }(); - params.push_back(param_sinfo); + params.push_back(param_ty); } - relax::StructInfo ret = [&]() -> relax::StructInfo { + tvm::Type ret = [&]() -> tvm::Type { if (const auto* prim = prim_func->ret_type.as()) { - return relax::PrimStructInfo(prim->dtype); + return relax::PrimType(prim->dtype); } else if (IsVoidType(prim_func->ret_type)) { - return relax::TupleStructInfo(ffi::Array{}); + return relax::TupleType(ffi::Array{}); } else { - return relax::ObjectStructInfo(); + return relax::ObjectType(); } }(); bool purity = prim_func->body.defined() ? s_tir::IsPureFunction(prim_func) : false; - return relax::FuncStructInfo(params, ret, purity); + return relax::FuncType(params, ret, purity); } } // namespace @@ -87,11 +88,11 @@ PrimFunc::PrimFunc(ffi::Array params, Stmt body, Type ret_type, n->ret_type = std::move(ret_type); n->buffer_map = std::move(buffer_map); n->attrs = std::move(attrs); - n->struct_info_ = relax::FuncStructInfo::OpaqueFunc(); + n->ty = relax::FuncType::OpaqueFunc(); n->span = std::move(span); data_ = std::move(n); - (*this)->struct_info_ = InferStructInfo(*this); + (*this)->ty = InferType(*this); } FuncType PrimFuncNode::func_type_annotation() const { diff --git a/src/tirx/script/builder/ir.cc b/src/tirx/script/builder/ir.cc index e70cdf09dabb..a75025a0ddd1 100644 --- a/src/tirx/script/builder/ir.cc +++ b/src/tirx/script/builder/ir.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index 0d5fb48ad94f..31a833b465a7 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include #include @@ -144,9 +144,9 @@ TEST(NestedMsg, Equal) { } TEST(NestedMsg, MapAndDecompose) { - relax::Var x("x", PrimStructInfo(runtime::DataType::Int(16))); - relax::Var y("y", PrimStructInfo(runtime::DataType::Int(32))); - relax::Var z("z", PrimStructInfo(runtime::DataType::Int(64))); + relax::Var x("x", PrimType(runtime::DataType::Int(16))); + relax::Var y("y", PrimType(runtime::DataType::Int(32))); + relax::Var z("z", PrimType(runtime::DataType::Int(64))); BlockBuilder bb = BlockBuilder::Create(std::nullopt); relax::Expr t0 = bb->Normalize(Tuple({x, y})); @@ -167,16 +167,15 @@ TEST(NestedMsg, MapAndDecompose) { EXPECT_TRUE(Equal(output, expected, [](IntImm lhs, IntImm rhs) -> bool { return lhs->value == rhs->value; })); - auto output2 = - MapToNestedMsg(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg { - const auto* prim_sinfo = sinfo.as(); - if (prim_sinfo == nullptr) return std::nullopt; - int bits = prim_sinfo->dtype.bits(); - if (bits == 16) return c0; - if (bits == 32) return c1; - if (bits == 64) return c2; - return std::nullopt; - }); + auto output2 = MapToNestedMsg(GetType(t1), [&](Type ty) -> NestedMsg { + const auto* prim_ty = ty.as(); + if (prim_ty == nullptr) return std::nullopt; + int bits = prim_ty->dtype.bits(); + if (bits == 16) return c0; + if (bits == 32) return c1; + if (bits == 64) return c2; + return std::nullopt; + }); EXPECT_TRUE(Equal(output2, expected, [](IntImm lhs, IntImm rhs) -> bool { return lhs->value == rhs->value; })); @@ -200,13 +199,13 @@ TEST(NestedMsg, MapAndDecompose) { EXPECT_EQ(z_count, 1); } -TEST(NestedMsg, MapToNestedMsgBySInfo) { - auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); - auto sf1 = TupleStructInfo({sf0, sf0}); - auto sf2 = TupleStructInfo({sf0, sf0}); - auto x = relax::Var("x", TupleStructInfo({sf1, sf2, sf0})); +TEST(NestedMsg, MapToNestedMsgByType) { + auto sf0 = TensorType(DataType::Float(32), /*ndim=*/0); + auto sf1 = TupleType({sf0, sf0}); + auto sf2 = TupleType({sf0, sf0}); + auto x = relax::Var("x", TupleType({sf1, sf2, sf0})); - auto msg = MapToNestedMsgBySInfo(x, [](Expr value) { return value; }); + auto msg = MapToNestedMsgByType(x, [](Expr value) { return value; }); EXPECT_TRUE(msg.IsNested()); auto arr = msg.NestedArray(); @@ -223,8 +222,8 @@ TEST(NestedMsg, MapToNestedMsgBySInfo) { } TEST(NestedMsg, NestedMsgToExpr) { - auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); - auto sf1 = TupleStructInfo({sf0, sf0}); + auto sf0 = TensorType(DataType::Float(32), /*ndim=*/0); + auto sf1 = TupleType({sf0, sf0}); auto c0 = IntImm::Int32(0); auto c1 = IntImm::Int32(1); @@ -306,7 +305,7 @@ TEST(NestedMsg, TransformTupleLeaf) { NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}}; NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}}; - PrimStructInfo s = PrimStructInfo(runtime::DataType::Int(32)); + PrimType s = PrimType(runtime::DataType::Int(32)); relax::Var x("x", s), y("y", s), z("z", s); BlockBuilder bb = BlockBuilder::Create(std::nullopt); Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, x})})})); diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index cd762569af1f..368833f9db53 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -356,7 +356,7 @@ def create_quantize_func( assert NotImplementedError() bb = relax.BlockBuilder() # pylint: disable=invalid-name - weight_var = relax.Var("weight", relax.TensorStructInfo(weight_shape, model_dtype)) + weight_var = relax.Var("weight", relax.TensorType(weight_shape, model_dtype)) compute_scale, compute_quantize, compute_transpose = quantize_func( weight_shape, model_dtype, @@ -401,9 +401,9 @@ def create_dequantize_func( bb = relax.BlockBuilder() # pylint: disable=invalid-name packed_weight_var = relax.Var( - "weight", relax.TensorStructInfo(packed_weight_shape, storage_dtype) + "weight", relax.TensorType(packed_weight_shape, storage_dtype) ) - scale_var = relax.Var("scale", relax.TensorStructInfo(scale_shape, model_dtype)) + scale_var = relax.Var("scale", relax.TensorType(scale_shape, model_dtype)) compute_dequantize = dequantize_func( packed_weight_shape, scale_shape, @@ -488,9 +488,7 @@ def compute_quantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr): global_var = bb.add_func(quant, "quantized_weight") lv_quantized_weight = bb.emit( - relax.call_tir( - global_var, args, relax.TensorStructInfo(packed_shape, storage_dtype) - ) + relax.call_tir(global_var, args, relax.TensorType(packed_shape, storage_dtype)) ) return lv_quantized_weight @@ -539,7 +537,7 @@ def compute_dequantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr): global_var = bb.add_func(dequant, "dequantize_weight") lv_dequantized_weight = bb.emit( - relax.call_tir(global_var, args, relax.TensorStructInfo(dequant_shape, model_dtype)) + relax.call_tir(global_var, args, relax.TensorType(dequant_shape, model_dtype)) ) return lv_dequantized_weight @@ -927,7 +925,7 @@ def main( lv = R.call_tir( cls.moe_dequantize_gemv, (x, weight, astype, indptr), - out_sinfo=R.Tensor((2, spatial_size), dtype="float16"), + out_ty=R.Tensor((2, spatial_size), dtype="float16"), ) gv: R.Tensor((2, spatial_size), dtype="float16") = lv R.output(gv) diff --git a/tests/python/contrib/test_tir_triton_integration.py b/tests/python/contrib/test_tir_triton_integration.py index c424fe8b591a..91794989bb65 100644 --- a/tests/python/contrib/test_tir_triton_integration.py +++ b/tests/python/contrib/test_tir_triton_integration.py @@ -89,7 +89,7 @@ def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle) -> None def main(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")): m = T.int64() with R.dataflow(): - output = R.call_tir(Module.add, [x, y], relax.TensorStructInfo((m,), "float32")) + output = R.call_tir(Module.add, [x, y], relax.TensorType((m,), "float32")) R.output(output) return output diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py index 2b379dc71137..62ffb4ce5817 100644 --- a/tests/python/disco/test_loader.py +++ b/tests/python/disco/test_loader.py @@ -268,13 +268,13 @@ def main( "runtime.disco.ShardLoaderLoad", loader, R.shape([0]), - sinfo_args=R.Tensor((64, 64), "float32"), + ty_args=R.Tensor((64, 64), "float32"), ) lv1: R.Tensor((16, 128), "float32") = R.call_pure_packed( "runtime.disco.ShardLoaderLoad", loader, R.shape([1]), - sinfo_args=R.Tensor((16, 128), "float32"), + ty_args=R.Tensor((16, 128), "float32"), ) lv2 = R.tuple(lv0, lv1) R.output(lv2) diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index cfa7755915b0..64ab378a79e1 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -322,7 +322,7 @@ def main() -> R.Tuple(R.Tensor((1,), "int32"), R.Tensor((1,), "int32")): my_pe = R.call_tir( cls.query_pe, (), - out_sinfo=[ + out_ty=[ R.Tensor((1,), "int32"), R.Tensor((1,), "int32"), ], diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index 1f482b9ee5fc..2e84170940e1 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -241,7 +241,7 @@ def transpose(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): def main(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor((16, 8), dtype="float32"): cls = TestMod with R.dataflow(): - B = R.call_tir(cls.transpose, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32")) + B = R.call_tir(cls.transpose, (A,), out_ty=R.Tensor((16, 8), dtype="float32")) R.output(B) return B @@ -295,7 +295,7 @@ def transpose_1(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor( R.func_attr({"global_symbol": "transpose_1"}) cls = TestMod with R.dataflow(): - B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32")) + B = R.call_tir(cls.t1, (A,), out_ty=R.Tensor((16, 8), dtype="float32")) R.output(B) return B @@ -306,7 +306,7 @@ def transpose_2(A: R.Tensor((16, 8), dtype="float32")) -> R.Tensor( R.func_attr({"global_symbol": "transpose_2"}) cls = TestMod with R.dataflow(): - B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16), dtype="float32")) + B = R.call_tir(cls.t2, (A,), out_ty=R.Tensor((8, 16), dtype="float32")) R.output(B) return B diff --git a/tests/python/nightly/test_nnapi/infrastructure.py b/tests/python/nightly/test_nnapi/infrastructure.py index bf4f07431ad2..f52c6552a058 100644 --- a/tests/python/nightly/test_nnapi/infrastructure.py +++ b/tests/python/nightly/test_nnapi/infrastructure.py @@ -37,10 +37,10 @@ def reshape_matmul(mod: tvm.IRModule): def _rewriter(expr: Expr, matches: dict[DFPattern, Expr]): i0 = matches[input0] i1 = matches[input1] - if len(i0.struct_info.shape) == 2 and len(i1.struct_info.shape) == 2: - i0_shape = [1] + [*i0.struct_info.shape.values] - i1_shape = [1] + [*i1.struct_info.shape.values] - oshape = matches[pattern].struct_info.shape + if len(i0.ty.shape) == 2 and len(i1.ty.shape) == 2: + i0_shape = [1] + [*i0.ty.shape.values] + i1_shape = [1] + [*i1.ty.shape.values] + oshape = matches[pattern].ty.shape return R.reshape(R.matmul(R.reshape(i0, i0_shape), R.reshape(i1, i1_shape)), oshape) return expr @@ -59,7 +59,7 @@ def decompose_clip(mod: tvm.IRModule) -> tvm.IRModule: pattern = is_op("relax.clip")(input_pattern, min_pattern, max_pattern) def _rewriter(expr: Expr, matches: dict[DFPattern, Expr]) -> Expr: # pylint: disable=unused-argument - dtype = matches[input_pattern].struct_info.dtype + dtype = matches[input_pattern].ty.dtype return R.minimum( R.maximum( matches[input_pattern], diff --git a/tests/python/relax/backend/adreno/mod_utils.py b/tests/python/relax/backend/adreno/mod_utils.py index 3568abf3a265..ab08434c29d2 100644 --- a/tests/python/relax/backend/adreno/mod_utils.py +++ b/tests/python/relax/backend/adreno/mod_utils.py @@ -740,7 +740,7 @@ def main( lv2 = relax.call_tir( cls.dequantize, (weight, scale), - out_sinfo=R.Tensor((K, N), dtype="float16"), + out_ty=R.Tensor((K, N), dtype="float16"), ) gv: R.Tensor((1, seq_len, N), dtype="float16") = relax.op.matmul( input, lv2, out_dtype="float16" @@ -798,7 +798,7 @@ def main( lv2 = relax.call_tir( cls.dequantize, (weight, scale), - out_sinfo=R.Tensor((K, vocab_size), dtype="float16"), + out_ty=R.Tensor((K, vocab_size), dtype="float16"), ) gv: R.Tensor((1, 1, vocab_size), dtype="float16") = relax.op.matmul( input, lv2, out_dtype="float16" diff --git a/tests/python/relax/backend/adreno/test_transform_annotate_custom_scope.py b/tests/python/relax/backend/adreno/test_transform_annotate_custom_scope.py index 8d364628fbf7..f3a2b629ee6e 100644 --- a/tests/python/relax/backend/adreno/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/backend/adreno/test_transform_annotate_custom_scope.py @@ -44,31 +44,29 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re if call.op.name == "relax.call_tir": # if call.args[0].name_hint in self.scope_info: for idx, arg in enumerate(call.args[1]): - arg_sinfo = arg.struct_info - assert isinstance(arg_sinfo, relax.TensorStructInfo), ( - f"Expected TensorStructInfo but git {type(arg_sinfo)}" - ) - call_mem_scope = ( - "global" if not arg_sinfo.vdevice else arg_sinfo.vdevice.memory_scope + arg_ty = arg.ty + assert isinstance(arg_ty, relax.TensorType), ( + f"Expected TensorType but git {type(arg_ty)}" ) + call_mem_scope = "global" if not arg_ty.vdevice else arg_ty.vdevice.memory_scope assert call_mem_scope == self.scope_info[call.args[0].name_hint][0][idx], ( f"Scope mismatched for argument {idx} in {call.args[0].name_hint}" ) - if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + if isinstance(call.ty_args[0], relax.TensorType): call_mem_scope = ( "global" - if not call.sinfo_args[0].vdevice - else call.sinfo_args[0].vdevice.memory_scope + if not call.ty_args[0].vdevice + else call.ty_args[0].vdevice.memory_scope ) assert call_mem_scope == self.scope_info[call.args[0].name_hint][1][0], ( f"Scope mismatched for return scope: {call.args[0].name_hint}" ) else: - assert isinstance(call.sinfo_args[0], relax.TupleStructInfo), ( - f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + assert isinstance(call.ty_args[0], relax.TupleType), ( + f"Expected TupleType but git {type(call.ty_args[0])}" ) - for idx, sinfo in enumerate(call.sinfo_args[0].fields): - call_mem_scope = "global" if not sinfo.vdevice else sinfo.vdevice.memory_scope + for idx, ty in enumerate(call.ty_args[0].fields): + call_mem_scope = "global" if not ty.vdevice else ty.vdevice.memory_scope assert call_mem_scope == self.scope_info[call.args[0].name_hint][1][idx], ( f"Scope mismatched for return scope for {idx} in {call.args[0].name_hint}" ) diff --git a/tests/python/relax/backend/adreno/test_transform_fold_vdevice_scope_change.py b/tests/python/relax/backend/adreno/test_transform_fold_vdevice_scope_change.py index 58bcdb58d0ba..929b29ff49ae 100644 --- a/tests/python/relax/backend/adreno/test_transform_fold_vdevice_scope_change.py +++ b/tests/python/relax/backend/adreno/test_transform_fold_vdevice_scope_change.py @@ -134,14 +134,14 @@ def main( lv = R.call_tir( cls.te_layout_transform, (x,), - out_sinfo=R.Tensor( + out_ty=R.Tensor( (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" ), ) lv2 = R.call_tir( cls.max_pool2d_opencl, (lv,), - out_sinfo=R.Tensor( + out_ty=R.Tensor( (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" ), ) @@ -151,7 +151,7 @@ def main( gv2 = R.call_tir( cls.te_layout_transform2, (lv5,), - out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + out_ty=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), ) R.output(gv2) return gv2 @@ -259,21 +259,19 @@ def main( lv = R.call_tir( cls.te_layout_transform, (x,), - out_sinfo=R.Tensor( + out_ty=R.Tensor( (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" ), ) lv5 = R.call_tir( cls.max_pool2d_opencl, (lv,), - out_sinfo=R.Tensor( - (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" - ), + out_ty=R.Tensor((2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global"), ) gv2 = R.call_tir( cls.te_layout_transform2, (lv5,), - out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + out_ty=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), ) R.output(gv2) return gv2 diff --git a/tests/python/relax/backend/adreno/utils.py b/tests/python/relax/backend/adreno/utils.py index f576c202cd65..e865339ae3a5 100644 --- a/tests/python/relax/backend/adreno/utils.py +++ b/tests/python/relax/backend/adreno/utils.py @@ -201,8 +201,8 @@ def verify_results(mod, target, ref_target): inputs = [] for arg in mod["main"].params: - shape = tuple(shape_val.value for shape_val in arg.struct_info.shape.values) - inputs.append(np.random.uniform(0, 1, size=shape).astype(arg.struct_info.dtype)) + shape = tuple(shape_val.value for shape_val in arg.ty.shape.values) + inputs.append(np.random.uniform(0, 1, size=shape).astype(arg.ty.dtype)) mod_org, mod_ref = mod, mod.clone() diff --git a/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py b/tests/python/relax/distributed/test_distributed_dtensor_type.py similarity index 69% rename from tests/python/relax/distributed/test_distributed_dtensor_sinfo.py rename to tests/python/relax/distributed/test_distributed_dtensor_type.py index 4fdbdd82a138..1e67ea16c0ef 100644 --- a/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py +++ b/tests/python/relax/distributed/test_distributed_dtensor_type.py @@ -41,12 +41,12 @@ def _check_json_roundtrip(x): return xret -def test_dtensor_struct_info(): +def test_dtensor_type(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - tensor_s0 = rx.TensorStructInfo([1, n + 1, m], "float32") - tensor_s1 = rx.TensorStructInfo([1, n + 1, m], "float32") - assert tensor_s0 == tensor_s1 + tensor_ty0 = rx.TensorType([1, n + 1, m], "float32") + tensor_ty1 = rx.TensorType([1, n + 1, m], "float32") + assert tensor_ty0 == tensor_ty1 device_mesh0 = rx.distributed.DeviceMesh((2, 2), Range(0, 4)) device_mesh1 = rx.distributed.DeviceMesh((2, 2), Range(0, 4)) @@ -59,36 +59,36 @@ def test_dtensor_struct_info(): placement1 = rx.distributed.Placement([shard0, replica]) tvm.ir.assert_structural_equal(placement0, placement1) - s0 = rx.distributed.DTensorStructInfo(tensor_s0, device_mesh0, placement0) - s1 = rx.distributed.DTensorStructInfo(tensor_s1, device_mesh1, placement1) - _check_equal(s0, s1) - _check_json_roundtrip(s0) - _check_json_roundtrip(s1) - - assert s0 == s1 - tvm.ir.assert_structural_equal(s0.device_mesh, device_mesh0) - assert s0.device_mesh.shape == (2, 2) - tvm.ir.assert_structural_equal(s0.device_mesh.device_range, Range(0, 4)) - tvm.ir.assert_structural_equal(s0.placement, placement0) - assert len(s0.placement.dim_specs) == 2 - assert s0.placement.dim_specs[0] == shard0 - assert s0.placement.dim_specs[1] == replica - assert s0.tensor_sinfo == tensor_s0 + ty0 = rx.distributed.DTensorType(tensor_ty0, device_mesh0, placement0) + ty1 = rx.distributed.DTensorType(tensor_ty1, device_mesh1, placement1) + _check_equal(ty0, ty1) + _check_json_roundtrip(ty0) + _check_json_roundtrip(ty1) + + assert ty0 == ty1 + tvm.ir.assert_structural_equal(ty0.device_mesh, device_mesh0) + assert ty0.device_mesh.shape == (2, 2) + tvm.ir.assert_structural_equal(ty0.device_mesh.device_range, Range(0, 4)) + tvm.ir.assert_structural_equal(ty0.placement, placement0) + assert len(ty0.placement.dim_specs) == 2 + assert ty0.placement.dim_specs[0] == shard0 + assert ty0.placement.dim_specs[1] == replica + assert ty0.tensor_ty == tensor_ty0 # can turn into str - # str(s0) + # str(ty0) # dimension of device mesh and placement should be the same shard1 = rx.distributed.PlacementSpec.sharding(1) placement2 = rx.distributed.Placement([shard0, replica, shard1]) with pytest.raises(ValueError): - rx.distributed.DTensorStructInfo(tensor_s0, device_mesh0, placement2) + rx.distributed.DTensorType(tensor_ty0, device_mesh0, placement2) # Sharding dimension should be smaller than tensor ndim shard3 = rx.distributed.PlacementSpec.sharding(3) placement3 = rx.distributed.Placement([shard3, replica]) with pytest.raises(ValueError): - rx.distributed.DTensorStructInfo(tensor_s0, device_mesh0, placement3) + rx.distributed.DTensorType(tensor_ty0, device_mesh0, placement3) if __name__ == "__main__": diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_distir.py b/tests/python/relax/distributed/test_distributed_transform_lower_distir.py index f0b1cc1539b4..b7a89a121b61 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_distir.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_distir.py @@ -121,16 +121,16 @@ def foo( lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view( cls.matmul1, (x, weight1), - out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), + out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), ) lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view( - cls.gelu1, (lv0,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") + cls.gelu1, (lv0,), out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") ) lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1 gv: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.dist.call_tir_local_view( cls.matmul2, (lv2, weight2), - out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"), + out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "R"), ) lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.ccl.allreduce( gv, op_type="sum" @@ -162,16 +162,16 @@ def foo( lv0 = R.call_tir( MLP.get_global_var("matmul1"), (gv, gv1), - out_sinfo=R.Tensor((128, 64), dtype="float32"), + out_ty=R.Tensor((128, 64), dtype="float32"), ) lv1 = R.call_tir( - MLP.get_global_var("gelu1"), (lv0,), out_sinfo=R.Tensor((128, 64), dtype="float32") + MLP.get_global_var("gelu1"), (lv0,), out_ty=R.Tensor((128, 64), dtype="float32") ) lv2: R.Tensor((128, 64), dtype="float32") = lv1 gv_1 = R.call_tir( MLP.get_global_var("matmul2"), (lv2, gv2), - out_sinfo=R.Tensor((128, 128), dtype="float32"), + out_ty=R.Tensor((128, 128), dtype="float32"), ) lv3: R.Tensor((128, 128), dtype="float32") = R.ccl.allreduce(gv_1, op_type="sum") return lv3 @@ -303,10 +303,10 @@ def foo( lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view( cls.matmul2, (x, weight1), - out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), + out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), ) lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view( - cls.gelu1, (lv0,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") + cls.gelu1, (lv0,), out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") ) gv: R.Tuple( R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"), @@ -314,7 +314,7 @@ def foo( ) = R.dist.call_tir_local_view( cls.split11, (lv1,), - out_sinfo=[ + out_ty=[ R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"), R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"), ], @@ -325,7 +325,7 @@ def foo( gv_1: R.DTensor((64, 128), "float32", "mesh[0]", "R") = R.dist.call_tir_local_view( cls.matmul11, (lv3, weight2), - out_sinfo=R.DTensor((64, 128), "float32", "mesh[0]", "R"), + out_ty=R.DTensor((64, 128), "float32", "mesh[0]", "R"), ) lv4: R.DTensor((64, 128), "float32", "mesh[0]", "R") = R.ccl.allreduce( gv_1, op_type="sum" @@ -359,17 +359,17 @@ def foo( lv0 = R.call_tir( MLPWithTuple.get_global_var("matmul2"), (gv, gv2), - out_sinfo=R.Tensor((128, 64), dtype="float32"), + out_ty=R.Tensor((128, 64), dtype="float32"), ) lv1 = R.call_tir( MLPWithTuple.get_global_var("gelu1"), (lv0,), - out_sinfo=R.Tensor((128, 64), dtype="float32"), + out_ty=R.Tensor((128, 64), dtype="float32"), ) gv_1 = R.call_tir( MLPWithTuple.get_global_var("split11"), (lv1,), - out_sinfo=[ + out_ty=[ R.Tensor((64, 64), dtype="float32"), R.Tensor((64, 64), dtype="float32"), ], @@ -379,7 +379,7 @@ def foo( gv_1_1 = R.call_tir( MLPWithTuple.get_global_var("matmul11"), (lv3, gv4), - out_sinfo=R.Tensor((64, 128), dtype="float32"), + out_ty=R.Tensor((64, 128), dtype="float32"), ) lv4: R.Tensor((64, 128), dtype="float32") = R.ccl.allreduce(gv_1_1, op_type="sum") return lv4 diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py index 5e4169b01695..67de2259db3e 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py @@ -103,16 +103,16 @@ def foo( lv0 = R.dist.call_tir( cls.matmul, (x, weight1), - out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), + out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), ) lv1 = R.dist.call_tir( - cls.gelu, (lv0,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") + cls.gelu, (lv0,), out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") ) lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1 lv3 = R.dist.call_tir( cls.matmul, (lv2, weight2), - out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"), + out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "R"), ) return lv3 @@ -209,16 +209,16 @@ def foo( lv0: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view( cls.matmul1, (x, weight1), - out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), + out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), ) lv1: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = R.dist.call_tir_local_view( - cls.gelu1, (lv0,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") + cls.gelu1, (lv0,), out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") ) lv2: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]") = lv1 gv: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.dist.call_tir_local_view( cls.matmul2, (lv2, weight2), - out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "R"), + out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "R"), ) lv3: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.ccl.allreduce( gv, op_type="sum", in_group=False @@ -664,185 +664,185 @@ def foo( lv6 = R.dist.call_tir( cls.rms_norm, (input_tokens, rms_norm_weight), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv7 = R.dist.call_tir( cls.transpose, (linear_weight,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) lv8 = R.dist.call_tir( cls.matmul, (lv6, lv7), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) lv9 = R.dist.call_tir( cls.reshape, (lv8,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv10 = R.dist.call_tir( cls.transpose, (linear_weight1,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) lv11 = R.dist.call_tir( cls.matmul, (lv6, lv10), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) lv12 = R.dist.call_tir( cls.reshape, (lv11,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv13 = R.dist.call_tir( cls.transpose, (linear_weight2,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) lv14 = R.dist.call_tir( cls.matmul, (lv6, lv13), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) lv15 = R.dist.call_tir( cls.reshape, (lv14,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv16 = R.dist.call_tir( cls.rotary_embedding, (lv9, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), tir_vars=R.shape([256]), ) lv17 = R.dist.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), tir_vars=R.shape([256]), ) lv18 = R.dist.call_tir( cls.reshape1, (lv17,), - out_sinfo=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), ) lv19 = R.dist.call_tir( cls.reshape1, (lv15,), - out_sinfo=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), ) lv20: R.Object = kv_cache[0] lv21: R.Object = R.call_packed( "vm.builtin.distributed.attention_kv_cache_append", lv20, lv18, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed( "vm.builtin.distributed.attention_kv_cache_append", lv22, lv19, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv24: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv21, R.shape([256, 32, 128]), - sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv25: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv23, R.shape([256, 32, 128]), - sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv26 = R.dist.call_tir( cls.reshape2, (lv24,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv27 = R.dist.call_tir( cls.reshape2, (lv25,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv28 = R.dist.call_tir( cls.transpose1, (lv16,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) lv29 = R.dist.call_tir( cls.transpose1, (lv26,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) lv30 = R.dist.call_tir( cls.transpose1, (lv27,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) lv31 = R.dist.call_tir( cls.transpose2, (lv29,), - out_sinfo=R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]"), ) lv32 = R.dist.call_tir( cls.matmul1, (lv28, lv31), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv33 = R.dist.call_tir( cls.divide, (lv32, div_const), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv34 = R.dist.call_tir( cls.maximum, (lv33, maximum_const), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv35 = R.dist.call_tir( cls.minimum, (lv34, mask), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv37 = R.dist.call_tir( cls.softmax, (lv35,), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv39 = R.dist.call_tir( cls.matmul2, (lv37, lv30), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) lv40 = R.dist.call_tir( cls.transpose3, (lv39,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv41 = R.dist.call_tir( cls.reshape3, (lv40,), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) lv42 = R.dist.call_tir( cls.transpose, (linear_weight3,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"), ) lv43 = R.dist.call_tir( cls.matmul, (lv41, lv42), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv44 = R.dist.call_tir( cls.add, (input_tokens, lv43), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) gv: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = lv44 return gv @@ -1338,74 +1338,74 @@ def foo( lv6: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.dist.call_tir_local_view( cls.rms_norm, (input_tokens, rms_norm_weight), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv7: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.dist.call_tir_local_view( cls.transpose4, (linear_weight,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) lv8: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.matmul3, (lv6, lv7), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) ) lv9: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.reshape4, (lv8,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) ) lv10: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.transpose4, (linear_weight1,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) ) lv11: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.matmul3, (lv6, lv10), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) ) lv12: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.reshape4, (lv11,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) ) lv13: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.transpose4, (linear_weight2,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) ) lv14: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.matmul3, (lv6, lv13), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) ) lv15: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.reshape4, (lv14,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) ) lv16: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.rotary_embedding1, (lv9, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), tir_vars=R.shape([256]), ) ) @@ -1413,7 +1413,7 @@ def foo( R.dist.call_tir_local_view( cls.rotary_embedding1, (lv12, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), tir_vars=R.shape([256]), ) ) @@ -1421,14 +1421,14 @@ def foo( R.dist.call_tir_local_view( cls.reshape11, (lv17,), - out_sinfo=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), ) ) lv19: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.reshape11, (lv15,), - out_sinfo=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), ) ) lv20: R.Object = kv_cache[0] @@ -1436,136 +1436,136 @@ def foo( "vm.builtin.distributed.attention_kv_cache_append", lv20, lv18, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed( "vm.builtin.distributed.attention_kv_cache_append", lv22, lv19, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv24: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv21, R.shape([256, 32, 128]), - sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv25: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv23, R.shape([256, 32, 128]), - sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv26: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.reshape21, (lv24,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) ) lv27: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.reshape21, (lv25,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) ) lv28: R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.transpose11, (lv16,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) ) lv29: R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.transpose11, (lv26,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) ) lv30: R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.transpose11, (lv27,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) ) lv31: R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.transpose21, (lv29,), - out_sinfo=R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]"), ) ) lv32: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.matmul11, (lv28, lv31), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) ) lv33: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.divide1, (lv32, div_const), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) ) lv34: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.maximum1, (lv33, maximum_const), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) ) lv35: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.minimum1, (lv34, mask), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) ) lv37: R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.softmax1, (lv35,), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) ) lv39: R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]") = ( R.dist.call_tir_local_view( cls.matmul21, (lv37, lv30), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) ) lv40: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.transpose31, (lv39,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) ) lv41: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]") = ( R.dist.call_tir_local_view( cls.reshape31, (lv40,), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) ) lv42: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]") = ( R.dist.call_tir_local_view( cls.transpose5, (linear_weight3,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"), ) ) gv: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.dist.call_tir_local_view( cls.matmul4, (lv41, lv42), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv43: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.ccl.allreduce( gv, op_type="sum", in_group=False @@ -1573,7 +1573,7 @@ def foo( lv44: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = R.dist.call_tir_local_view( cls.add, (input_tokens, lv43), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) gv_1: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = lv44 return gv_1 diff --git a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py index 68ce5500cb6c..7cf3a5f0d14c 100644 --- a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py +++ b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py @@ -121,7 +121,7 @@ def foo( lv_tuple = R.call_tir( cls.split1, [lv1], - out_sinfo=[R.Tensor((64, 128), "float32"), R.Tensor((64, 128), "float32")], + out_ty=[R.Tensor((64, 128), "float32"), R.Tensor((64, 128), "float32")], ) lv2 = lv_tuple[0] lv3 = R.dist.annotate_sharding(lv2, device_mesh="mesh[0]", placement="S[1]") @@ -174,7 +174,7 @@ def foo( gv = R.dist.call_tir( cls.split1, (lv1,), - out_sinfo=[ + out_ty=[ R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"), R.DTensor((64, 128), "float32", "mesh[0]", "S[1]"), ], @@ -473,7 +473,7 @@ def foo( lv6 = R.call_tir( cls.rms_norm, (input_tokens, rms_norm_weight), - out_sinfo=R.Tensor((1, 256, 4096), dtype="float16"), + out_ty=R.Tensor((1, 256, 4096), dtype="float16"), ) lv7: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(linear_weight, axes=None) lv7_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding( @@ -512,12 +512,12 @@ def foo( lv16 = R.call_tir( cls.rotary_embedding, (lv9, cos_cached, sin_cached), - out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), + out_ty=R.Tensor((1, 256, 32, 128), dtype="float16"), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), - out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), + out_ty=R.Tensor((1, 256, 32, 128), dtype="float16"), ) lv18: R.Tensor((256, 32, 128), dtype="float16") = R.reshape( lv17, R.shape([256, 32, 128]) @@ -527,23 +527,23 @@ def foo( ) lv20: R.Object = kv_cache[0] lv21: R.Object = R.call_packed( - "vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,) + "vm.builtin.attention_kv_cache_append", lv20, lv18, ty_args=(R.Object,) ) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed( - "vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,) + "vm.builtin.attention_kv_cache_append", lv22, lv19, ty_args=(R.Object,) ) lv24: R.Tensor((256, 32, 128), dtype="float16") = R.call_packed( "vm.builtin.attention_kv_cache_view", lv21, R.shape([256, 32, 128]), - sinfo_args=(R.Tensor((256, 32, 128), dtype="float16"),), + ty_args=(R.Tensor((256, 32, 128), dtype="float16"),), ) lv25: R.Tensor((256, 32, 128), dtype="float16") = R.call_packed( "vm.builtin.attention_kv_cache_view", lv23, R.shape([256, 32, 128]), - sinfo_args=(R.Tensor((256, 32, 128), dtype="float16"),), + ty_args=(R.Tensor((256, 32, 128), dtype="float16"),), ) lv26: R.Tensor((1, 256, 32, 128), dtype="float16") = R.reshape( lv24, R.shape([1, 256, 32, 128]) @@ -678,7 +678,7 @@ def foo( lv6 = R.dist.call_tir( cls.rms_norm, (input_tokens, rms_norm_weight), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv7: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.permute_dims( linear_weight, axes=None @@ -710,12 +710,12 @@ def foo( lv16 = R.dist.call_tir( cls.rotary_embedding, (lv9, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv17 = R.dist.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv18: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape( lv17, R.shape([256, 32, 128]) @@ -728,26 +728,26 @@ def foo( "vm.builtin.distributed.attention_kv_cache_append", lv20, lv18, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed( "vm.builtin.distributed.attention_kv_cache_append", lv22, lv19, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv24: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv21, R.shape([256, 32, 128]), - sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv25: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv23, R.shape([256, 32, 128]), - sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv26: R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape( lv24, R.shape([1, 256, 32, 128]) @@ -1237,134 +1237,134 @@ def foo( lv6 = R.call_tir( cls.rms_norm, (input_tokens, rms_norm_weight), - out_sinfo=R.Tensor((1, 256, 4096), dtype="float16"), + out_ty=R.Tensor((1, 256, 4096), dtype="float16"), ) lv7 = R.call_tir( - cls.transpose, (linear_weight,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + cls.transpose, (linear_weight,), out_ty=R.Tensor((4096, 4096), dtype="float16") ) lv7_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding( lv7, device_mesh="mesh[0]", placement="S[1]" ) lv8 = R.call_tir( - cls.matmul, (lv6, lv7_copy), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16") + cls.matmul, (lv6, lv7_copy), out_ty=R.Tensor((1, 256, 4096), dtype="float16") ) lv9 = R.call_tir( - cls.reshape, (lv8,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16") + cls.reshape, (lv8,), out_ty=R.Tensor((1, 256, 32, 128), dtype="float16") ) lv10 = R.call_tir( - cls.transpose, (linear_weight1,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + cls.transpose, (linear_weight1,), out_ty=R.Tensor((4096, 4096), dtype="float16") ) lv10_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding( lv10, device_mesh="mesh[0]", placement="S[1]" ) lv11 = R.call_tir( - cls.matmul, (lv6, lv10_copy), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16") + cls.matmul, (lv6, lv10_copy), out_ty=R.Tensor((1, 256, 4096), dtype="float16") ) lv12 = R.call_tir( - cls.reshape, (lv11,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16") + cls.reshape, (lv11,), out_ty=R.Tensor((1, 256, 32, 128), dtype="float16") ) lv13 = R.call_tir( - cls.transpose, (linear_weight2,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + cls.transpose, (linear_weight2,), out_ty=R.Tensor((4096, 4096), dtype="float16") ) lv13_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding( lv13, device_mesh="mesh[0]", placement="S[1]" ) lv14 = R.call_tir( - cls.matmul, (lv6, lv13_copy), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16") + cls.matmul, (lv6, lv13_copy), out_ty=R.Tensor((1, 256, 4096), dtype="float16") ) lv15 = R.call_tir( - cls.reshape, (lv14,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16") + cls.reshape, (lv14,), out_ty=R.Tensor((1, 256, 32, 128), dtype="float16") ) lv16 = R.call_tir( cls.rotary_embedding, (lv9, cos_cached, sin_cached), - out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), + out_ty=R.Tensor((1, 256, 32, 128), dtype="float16"), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), - out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), + out_ty=R.Tensor((1, 256, 32, 128), dtype="float16"), ) lv18 = R.call_tir( - cls.reshape1, (lv17,), out_sinfo=R.Tensor((256, 32, 128), dtype="float16") + cls.reshape1, (lv17,), out_ty=R.Tensor((256, 32, 128), dtype="float16") ) lv19 = R.call_tir( - cls.reshape1, (lv15,), out_sinfo=R.Tensor((256, 32, 128), dtype="float16") + cls.reshape1, (lv15,), out_ty=R.Tensor((256, 32, 128), dtype="float16") ) lv20: R.Object = kv_cache[0] lv21: R.Object = R.call_packed( - "vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,) + "vm.builtin.attention_kv_cache_append", lv20, lv18, ty_args=(R.Object,) ) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed( - "vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,) + "vm.builtin.attention_kv_cache_append", lv22, lv19, ty_args=(R.Object,) ) lv24: R.Tensor((256, 32, 128), dtype="float16") = R.call_packed( "vm.builtin.attention_kv_cache_view", lv21, R.shape([256, 32, 128]), - sinfo_args=(R.Tensor((256, 32, 128), dtype="float16"),), + ty_args=(R.Tensor((256, 32, 128), dtype="float16"),), ) lv25: R.Tensor((256, 32, 128), dtype="float16") = R.call_packed( "vm.builtin.attention_kv_cache_view", lv23, R.shape([256, 32, 128]), - sinfo_args=(R.Tensor((256, 32, 128), dtype="float16"),), + ty_args=(R.Tensor((256, 32, 128), dtype="float16"),), ) lv26 = R.call_tir( - cls.reshape2, (lv24,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16") + cls.reshape2, (lv24,), out_ty=R.Tensor((1, 256, 32, 128), dtype="float16") ) lv27 = R.call_tir( - cls.reshape2, (lv25,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16") + cls.reshape2, (lv25,), out_ty=R.Tensor((1, 256, 32, 128), dtype="float16") ) lv28 = R.call_tir( - cls.transpose1, (lv16,), out_sinfo=R.Tensor((1, 32, 256, 128), dtype="float16") + cls.transpose1, (lv16,), out_ty=R.Tensor((1, 32, 256, 128), dtype="float16") ) lv29 = R.call_tir( - cls.transpose1, (lv26,), out_sinfo=R.Tensor((1, 32, 256, 128), dtype="float16") + cls.transpose1, (lv26,), out_ty=R.Tensor((1, 32, 256, 128), dtype="float16") ) lv30 = R.call_tir( - cls.transpose1, (lv27,), out_sinfo=R.Tensor((1, 32, 256, 128), dtype="float16") + cls.transpose1, (lv27,), out_ty=R.Tensor((1, 32, 256, 128), dtype="float16") ) lv31 = R.call_tir( - cls.transpose2, (lv29,), out_sinfo=R.Tensor((1, 32, 128, 256), dtype="float16") + cls.transpose2, (lv29,), out_ty=R.Tensor((1, 32, 128, 256), dtype="float16") ) lv32 = R.call_tir( - cls.matmul1, (lv28, lv31), out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16") + cls.matmul1, (lv28, lv31), out_ty=R.Tensor((1, 32, 256, 256), dtype="float16") ) lv33 = R.call_tir( cls.divide, (lv32, div_const), - out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16"), + out_ty=R.Tensor((1, 32, 256, 256), dtype="float16"), ) lv34 = R.call_tir( cls.maximum, (lv33, maximum_const), - out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16"), + out_ty=R.Tensor((1, 32, 256, 256), dtype="float16"), ) lv35 = R.call_tir( - cls.minimum, (lv34, mask), out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16") + cls.minimum, (lv34, mask), out_ty=R.Tensor((1, 32, 256, 256), dtype="float16") ) lv37 = R.call_tir( - cls.softmax, (lv35,), out_sinfo=R.Tensor((1, 32, 256, 256), dtype="float16") + cls.softmax, (lv35,), out_ty=R.Tensor((1, 32, 256, 256), dtype="float16") ) lv39 = R.call_tir( - cls.matmul2, (lv37, lv30), out_sinfo=R.Tensor((1, 32, 256, 128), dtype="float16") + cls.matmul2, (lv37, lv30), out_ty=R.Tensor((1, 32, 256, 128), dtype="float16") ) lv40 = R.call_tir( - cls.transpose3, (lv39,), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16") + cls.transpose3, (lv39,), out_ty=R.Tensor((1, 256, 32, 128), dtype="float16") ) lv41 = R.call_tir( - cls.reshape3, (lv40,), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16") + cls.reshape3, (lv40,), out_ty=R.Tensor((1, 256, 4096), dtype="float16") ) lv42 = R.call_tir( - cls.transpose, (linear_weight3,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + cls.transpose, (linear_weight3,), out_ty=R.Tensor((4096, 4096), dtype="float16") ) lv43 = R.call_tir( - cls.matmul, (lv41, lv42), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16") + cls.matmul, (lv41, lv42), out_ty=R.Tensor((1, 256, 4096), dtype="float16") ) lv44 = R.call_tir( - cls.add, (input_tokens, lv43), out_sinfo=R.Tensor((1, 256, 4096), dtype="float16") + cls.add, (input_tokens, lv43), out_ty=R.Tensor((1, 256, 4096), dtype="float16") ) gv: R.Tensor((1, 256, 4096), dtype="float16") = lv44 return gv @@ -1397,183 +1397,183 @@ def foo( lv6 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("rms_norm"), (input_tokens, rms_norm_weight), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv7 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose"), (linear_weight,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) lv8 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("matmul"), (lv6, lv7), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) lv9 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape"), (lv8,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv10 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose"), (linear_weight1,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) lv11 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("matmul"), (lv6, lv10), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) lv12 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape"), (lv11,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv13 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose"), (linear_weight2,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]"), ) lv14 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("matmul"), (lv6, lv13), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) lv15 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape"), (lv14,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv16 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("rotary_embedding"), (lv9, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv17 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("rotary_embedding"), (lv12, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv18 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape1"), (lv17,), - out_sinfo=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), ) lv19 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape1"), (lv15,), - out_sinfo=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"), ) lv20: R.Object = kv_cache[0] lv21: R.Object = R.call_packed( "vm.builtin.distributed.attention_kv_cache_append", lv20, lv18, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed( "vm.builtin.distributed.attention_kv_cache_append", lv22, lv19, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv24: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv21, R.shape([256, 32, 128]), - sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv25: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv23, R.shape([256, 32, 128]), - sinfo_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv26 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape2"), (lv24,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv27 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape2"), (lv25,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv28 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose1"), (lv16,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) lv29 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose1"), (lv26,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) lv30 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose1"), (lv27,), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) lv31 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose2"), (lv29,), - out_sinfo=R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 128, 256), "float16", "mesh[0]", "S[1]"), ) lv32 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("matmul1"), (lv28, lv31), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv33 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("divide"), (lv32, div_const), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv34 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("maximum"), (lv33, maximum_const), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv35 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("minimum"), (lv34, mask), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv37 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("softmax"), (lv35,), - out_sinfo=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 256), "float16", "mesh[0]", "S[1]"), ) lv39 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("matmul2"), (lv37, lv30), - out_sinfo=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), + out_ty=R.DTensor((1, 32, 256, 128), "float16", "mesh[0]", "S[1]"), ) lv40 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose3"), (lv39,), - out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), ) lv41 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape3"), (lv40,), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "S[2]"), ) lv42 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("transpose"), (linear_weight3,), - out_sinfo=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"), + out_ty=R.DTensor((4096, 4096), "float16", "mesh[0]", "S[0]"), ) lv43 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("matmul"), (lv41, lv42), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) lv44 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("add"), (input_tokens, lv43), - out_sinfo=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R"), ) gv: R.DTensor((1, 256, 4096), "float16", "mesh[0]", "R") = lv44 return gv @@ -1687,7 +1687,7 @@ def foo( lv6 = R.call_tir( cls.rms_norm, (input_tokens, rms_norm_weight), - out_sinfo=R.Tensor((1, n, 4096), dtype="float16"), + out_ty=R.Tensor((1, n, 4096), dtype="float16"), ) lv7: R.Tensor((4096, 4096), dtype="float16") = R.permute_dims(linear_weight, axes=None) lv7_copy: R.Tensor((4096, 4096), dtype="float16") = R.dist.annotate_sharding( @@ -1724,36 +1724,36 @@ def foo( lv16 = R.call_tir( cls.rotary_embedding, (lv9, cos_cached, sin_cached), - out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), + out_ty=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), - out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), + out_ty=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]), ) lv18: R.Tensor((n, 32, 128), dtype="float16") = R.reshape(lv17, R.shape([n, 32, 128])) lv19: R.Tensor((n, 32, 128), dtype="float16") = R.reshape(lv15, R.shape([n, 32, 128])) lv20: R.Object = kv_cache[0] lv21: R.Object = R.call_packed( - "vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,) + "vm.builtin.attention_kv_cache_append", lv20, lv18, ty_args=(R.Object,) ) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed( - "vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,) + "vm.builtin.attention_kv_cache_append", lv22, lv19, ty_args=(R.Object,) ) lv24: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed( "vm.builtin.attention_kv_cache_view", lv21, R.shape([m, 32, 128]), - sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),), + ty_args=(R.Tensor((m, 32, 128), dtype="float16"),), ) lv25: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed( "vm.builtin.attention_kv_cache_view", lv23, R.shape([m, 32, 128]), - sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),), + ty_args=(R.Tensor((m, 32, 128), dtype="float16"),), ) lv26: R.Tensor((1, m, 32, 128), dtype="float16") = R.reshape( lv24, R.shape([1, m, 32, 128]) @@ -1895,7 +1895,7 @@ def foo( lv6 = R.dist.call_tir( cls.rms_norm, (input_tokens, rms_norm_weight), - out_sinfo=R.DTensor((1, n, 4096), "float16", "mesh[0]", "R"), + out_ty=R.DTensor((1, n, 4096), "float16", "mesh[0]", "R"), ) lv7: R.DTensor((4096, 4096), "float16", "mesh[0]", "S[1]") = R.permute_dims( linear_weight, axes=None @@ -1927,13 +1927,13 @@ def foo( lv16 = R.dist.call_tir( cls.rotary_embedding, (lv9, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]"), tir_vars=R.shape([m]), ) lv17 = R.dist.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), - out_sinfo=R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]"), + out_ty=R.DTensor((1, n, 32, 128), "float16", "mesh[0]", "S[2]"), tir_vars=R.shape([m]), ) lv18: R.DTensor((n, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape( @@ -1947,26 +1947,26 @@ def foo( "vm.builtin.distributed.attention_kv_cache_append", lv20, lv18, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed( "vm.builtin.distributed.attention_kv_cache_append", lv22, lv19, - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) lv24: R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv21, R.shape([m, 32, 128]), - sinfo_args=(R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv25: R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]") = R.call_packed( "vm.builtin.distributed.attention_kv_cache_view", lv23, R.shape([m, 32, 128]), - sinfo_args=(R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]"),), + ty_args=(R.DTensor((m, 32, 128), "float16", "mesh[0]", "S[1]"),), ) lv26: R.DTensor((1, m, 32, 128), "float16", "mesh[0]", "S[2]") = R.reshape( lv24, R.shape([1, m, 32, 128]) diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_parser.py b/tests/python/relax/distributed/test_distributed_tvmscript_parser.py index d80ad73c4d59..4a9a4e70d7b8 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_parser.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_parser.py @@ -84,10 +84,10 @@ def foo( foo_func = TestModule["foo"] params = foo_func.params assert len(params) == 1 - assert params[0].struct_info == R.DTensor( + assert params[0].ty == R.DTensor( (128, 128), "float32", device_mesh_list[0], placement="S[0], R" ) - assert foo_func.ret_struct_info == R.DTensor( + assert foo_func.ret_ty == R.DTensor( (128, 128), "float32", device_mesh_list[0], placement="S[0], R" ) assert isinstance(foo_func.body, SeqExpr) @@ -95,7 +95,7 @@ def foo( assert isinstance(foo_func.body.blocks[0].bindings[0], VarBinding) value = foo_func.body.blocks[0].bindings[0].value assert isinstance(value, Call) - assert value.sinfo_args[0] == R.DTensor( + assert value.ty_args[0] == R.DTensor( (128, 128), "float32", device_mesh_list[0], placement="S[0], R" ) _check(TestModule) @@ -181,9 +181,7 @@ def foo( shape=(128, 128), dtype="float32", device_mesh="mesh[0]", placement="S[0], R" ), ) - gv1 = R.add( - gv0, R.dist.const(1.0, struct_info=R.DTensor((), "float32", "mesh[0]", "R, R")) - ) + gv1 = R.add(gv0, R.dist.const(1.0, ty=R.DTensor((), "float32", "mesh[0]", "R, R"))) return gv1 _check(TestModule) diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index 5a6c2a5802d4..e4f42646d46d 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -18,8 +18,8 @@ import tvm.testing from tvm.ir import Range -from tvm.relax import TensorStructInfo -from tvm.relax.distributed import DeviceMesh, DTensorStructInfo, Placement +from tvm.relax import TensorType +from tvm.relax.distributed import DeviceMesh, DTensorType, Placement from tvm.script.parser import ir as I from tvm.script.parser import relax as R from tvm.script.parser import tirx as T @@ -35,9 +35,7 @@ def _assert_print(obj, expected): def test_constant(): constant = R.dist.const( 1, - struct_info=R.DTensor( - (), "float32", device_mesh=DeviceMesh((2, 2), Range(0, 4)), placement="R, R" - ), + ty=R.DTensor((), "float32", device_mesh=DeviceMesh((2, 2), Range(0, 4)), placement="R, R"), ) assert ( constant.__str__() @@ -45,28 +43,22 @@ def test_constant(): ) -def test_dtensor_struct_info(): - tensor_sinfo1 = TensorStructInfo((32, 32), "float32") - tensor_sinfo2 = TensorStructInfo((32, 32), "void") - obj0 = DTensorStructInfo( - tensor_sinfo1, DeviceMesh((2, 2), Range(0, 4)), Placement.from_text("S[1], R") - ) +def test_dtensor_type(): + tensor_ty1 = TensorType((32, 32), "float32") + tensor_ty2 = TensorType((32, 32), "void") + obj0 = DTensorType(tensor_ty1, DeviceMesh((2, 2), Range(0, 4)), Placement.from_text("S[1], R")) assert ( obj0.__str__() == """R.DTensor((32, 32), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[1], R")""" ) - obj1 = DTensorStructInfo( - tensor_sinfo2, DeviceMesh((2, 2), Range(0, 4)), Placement.from_text("S[1], R") - ) + obj1 = DTensorType(tensor_ty2, DeviceMesh((2, 2), Range(0, 4)), Placement.from_text("S[1], R")) assert ( obj1.__str__() == """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), R.Range(0, 4)), placement="S[1], R")""" ) - obj2 = DTensorStructInfo( - tensor_sinfo2, DeviceMesh((2, 2), [0, 1, 2, 3]), Placement.from_text("S[1], R") - ) + obj2 = DTensorType(tensor_ty2, DeviceMesh((2, 2), [0, 1, 2, 3]), Placement.from_text("S[1], R")) assert ( obj2.__str__() == """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), [0, 1, 2, 3]), placement="S[1], R")""" @@ -118,7 +110,7 @@ def test_func(): @R.function def foo(x: R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) -> R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R"): - gv0 = R.dist.call_tir(tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) + gv0 = R.dist.call_tir(tir_func, (x,), out_ty=R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) return gv0 """, ) @@ -151,7 +143,7 @@ def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer(( @R.function def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) -> R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R"): cls = Module - gv0 = R.dist.call_tir(cls.tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) + gv0 = R.dist.call_tir(cls.tir_func, (x,), out_ty=R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) return gv0 """, ) diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 56776323fc87..7be99442aced 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -132,7 +132,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") ) R.output(lv0) - z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, ty_args=(R.Tensor((32, 32), "float32"))) return z optimized = remove_all_unused(IdentityUnused["main"]) @@ -144,7 +144,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x R.output(lv0) - z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, ty_args=(R.Tensor((32, 32), "float32"))) return z tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) @@ -186,7 +186,7 @@ def test_binding_block_keep_impure_without_dataflow(): @R.function(private=True, pure=False) def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x - y = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + y = R.call_packed("vm.builtin.copy", lv0, ty_args=(R.Tensor((32, 32), "float32"))) return y expected = before @@ -215,7 +215,7 @@ def test_binding_block_keep_pure_func_used_only_for_impure(): def before(x: R.Tensor((32, 32), "int32")): y = x * R.const(2) z = R.call_packed( - "function_maybe_with_side_effects", y, sinfo_args=(R.Tensor((32, 32), "int32")) + "function_maybe_with_side_effects", y, ty_args=(R.Tensor((32, 32), "int32")) ) return R.tuple() @@ -236,7 +236,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: def internal_unused_func(A: R.Tensor((32, 32), "float32")) -> R.Tensor: return A - z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, ty_args=(R.Tensor((32, 32), "float32"))) return z optimized = remove_all_unused(IdentityUnused["main"]) @@ -246,7 +246,7 @@ class GroundTruth: @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x - z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, ty_args=(R.Tensor((32, 32), "float32"))) return z tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) @@ -260,7 +260,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x R.output(lv0) - z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, ty_args=(R.Tensor((32, 32), "float32"))) return lv0 optimized = remove_all_unused(IdentityUnused["main"]) @@ -273,7 +273,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x R.output(lv0) # This might bring side effect so cannot be removed. - z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", lv0, ty_args=(R.Tensor((32, 32), "float32"))) return lv0 tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) @@ -284,7 +284,7 @@ def test_edge_binding_block_fake_unused_remove_all_unused(): class IdentityUnused: @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): - z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32, 32), "float32"))) + z = R.call_packed("vm.builtin.copy", x, ty_args=(R.Tensor((32, 32), "float32"))) return x optimized = remove_all_unused(IdentityUnused["main"]) @@ -301,7 +301,7 @@ def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(dtype="int32", ndim=3): k = T.int64() with R.dataflow(): lv: R.Shape(ndim=3) = R.call_pure_packed( - "vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape(ndim=3),) + "vm.builtin.tensor_to_shape", x, ty_args=(R.Shape(ndim=3),) ) lv1: R.Shape([m, n, k]) = R.match_cast(lv, R.Shape([m, n, k])) gv: R.Tensor((m, n, k), dtype="int32") = R.full( @@ -364,14 +364,14 @@ def test_retain_impure_calls_unused_in_binding_block(): @R.function(pure=False) def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x - unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32")) + unused0 = R.call_packed("my_impure_call", x, ty_args=R.Tensor((32, 32), dtype="float32")) unused1 = R.call_dps_packed("my_unused_call", (lv0,), R.Tensor((32, 32), dtype="float32")) return lv0 @R.function(pure=False) def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x - unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32")) + unused0 = R.call_packed("my_impure_call", x, ty_args=R.Tensor((32, 32), dtype="float32")) return lv0 after = remove_all_unused(before.body) @@ -508,12 +508,12 @@ def test_free_vars(): inner = rx.Function( [z], rx.op.add(x, rx.op.add(y, z)), - ret_struct_info=R.Tensor(ndim=-1), + ret_ty=R.Tensor(ndim=-1), ) outer = rx.Function( [x, y], rx.Call(inner, [y]), - ret_struct_info=R.Tensor(ndim=-1), + ret_ty=R.Tensor(ndim=-1), ) assert len(free_vars(outer)) == 0 assert var_name_set(free_vars(inner)) == {"x", "y"} diff --git a/tests/python/relax/test_analysis_contains_impure_call.py b/tests/python/relax/test_analysis_contains_impure_call.py index 51de60544d42..c345991e0c6f 100644 --- a/tests/python/relax/test_analysis_contains_impure_call.py +++ b/tests/python/relax/test_analysis_contains_impure_call.py @@ -91,12 +91,12 @@ def recursive_impure() -> R.Object: body.blocks[0].bindings[-1], ] # Note: we construct the function in this way so that we keep the old vars - # with their current StructInfo. That would get fixed during normalization. + # with their current Type. That would get fixed during normalization. # However, this situation is meant to correspond to an intermediate state # that might arise within a pass. new_body = rx.SeqExpr([rx.BindingBlock(new_bindings)], body.body) - # if we didn't ignore the recursive call, the fact the var's StructInfo + # if we didn't ignore the recursive call, the fact the var's Type # calls it impure would throw it off assert not contains_impure_call(new_body, own_name=own_name) assert contains_impure_call(new_body) diff --git a/tests/python/relax/test_analysis_estimate_memory_usage.py b/tests/python/relax/test_analysis_estimate_memory_usage.py index 977644ff8af7..5bdbbc5f0b60 100644 --- a/tests/python/relax/test_analysis_estimate_memory_usage.py +++ b/tests/python/relax/test_analysis_estimate_memory_usage.py @@ -80,7 +80,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32 _: R.Tuple() = cls.exp(x, alloc) lv: R.Tensor((2, 4), dtype="float32") = alloc lv1: R.Tensor((8,), dtype="float32") = R.call_packed( - "vm.builtin.reshape", lv, R.shape([8]), sinfo_args=[R.Tensor((8,), dtype="float32")] + "vm.builtin.reshape", lv, R.shape([8]), ty_args=[R.Tensor((8,), dtype="float32")] ) storage1: R.Object = R.memory.alloc_storage( R.shape([40]), virtual_device_index=0, storage_scope="global", dtype="float32" diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_type_analysis.py similarity index 58% rename from tests/python/relax/test_analysis_struct_info_analysis.py rename to tests/python/relax/test_analysis_type_analysis.py index 4d25f9f504da..20ccb4a0e7fa 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_type_analysis.py @@ -16,7 +16,7 @@ # under the License. # ruff: noqa: E731, F401, F841 -"""Tests analysis functions of struct info""" +"""Tests Relax dependent type analysis functions.""" import pytest import tvm_ffi @@ -31,20 +31,20 @@ def test_get_static_type_basic(): # object - s0 = rx.ObjectStructInfo() + s0 = rx.ObjectType() tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0), rx.ObjectType()) # prim - s1 = rx.PrimStructInfo("float32") - tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), tvm.ir.PrimType("float32")) + s1 = rx.PrimType("float32") + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), rx.PrimType("float32")) def test_get_static_type_shape(): # shape n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - s2 = rx.ShapeStructInfo([1, n + 1, m]) - s3 = rx.ShapeStructInfo(ndim=2) + s2 = rx.ShapeType([1, n + 1, m]) + s3 = rx.ShapeType(ndim=2) tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3)) @@ -53,7 +53,7 @@ def test_get_static_type_shape(): def test_get_static_type_tensor(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + s4 = rx.TensorType([1, n + 1, m], "int64") tvm.ir.assert_structural_equal( rx.analysis.get_static_type(s4), rx.TensorType(ndim=3, dtype="int64") @@ -63,11 +63,11 @@ def test_get_static_type_tensor(): def test_get_static_type_tuple(): # tuple n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - s0 = rx.ObjectStructInfo() - s2 = rx.ShapeStructInfo([1, n + 1, m]) - s4 = rx.TensorStructInfo([1, n + 1, m], "int64") - t0 = rx.TupleStructInfo([s4, s0]) - t1 = rx.TupleStructInfo([t0, s2]) + s0 = rx.ObjectType() + s2 = rx.ShapeType([1, n + 1, m]) + s4 = rx.TensorType([1, n + 1, m], "int64") + t0 = rx.TupleType([s4, s0]) + t1 = rx.TupleType([t0, s2]) tvm.ir.assert_structural_equal( rx.analysis.get_static_type(t1), @@ -84,10 +84,10 @@ def test_get_static_type_func(): # tuple def fn_info(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.TensorStructInfo([c, n, m], "float32") - y = rx.TensorStructInfo([c, n, 1], "float32") - z = rx.TensorStructInfo([c, n], "float32") - return rx.FuncStructInfo([x, y], z) + x = rx.TensorType([c, n, m], "float32") + y = rx.TensorType([c, n, 1], "float32") + z = rx.TensorType([c, n], "float32") + return rx.FuncType([x, y], z) def fn_type(): x = rx.TensorType(ndim=3, dtype="float32") @@ -101,46 +101,44 @@ def fn_type(): def test_erase_to_well_defined_basic(): - s0 = rx.ObjectStructInfo() + s0 = rx.ObjectType() tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0) # prim - s1 = rx.PrimStructInfo("float32") + s1 = rx.PrimType("float32") tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1) def test_erase_to_well_defined_shape(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - s2 = rx.ShapeStructInfo([1, n + 1, m]) - s3 = rx.ShapeStructInfo(ndim=2) + s2 = rx.ShapeType([1, n + 1, m]) + s3 = rx.ShapeType(ndim=2) # have undefined - tvm.ir.assert_structural_equal( - rx.analysis.erase_to_well_defined(s2), rx.ShapeStructInfo(ndim=3) - ) + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2), rx.ShapeType(ndim=3)) # all defined tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2, {n: n, m: m}), s2) # replacement tvm.ir.assert_structural_equal( - rx.analysis.erase_to_well_defined(s2, {n: 2, m: m + 1}), rx.ShapeStructInfo([1, 3, m + 1]) + rx.analysis.erase_to_well_defined(s2, {n: 2, m: m + 1}), rx.ShapeType([1, 3, m + 1]) ) # partial defined tvm.ir.assert_structural_equal( - rx.analysis.erase_to_well_defined(s2, {n: n}), rx.ShapeStructInfo(ndim=3) + rx.analysis.erase_to_well_defined(s2, {n: n}), rx.ShapeType(ndim=3) ) def test_erase_to_well_defined_tensor(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) - s0 = rx.TensorStructInfo(rshape, dtype="int32") + rshape = rx.Var("shape", rx.ShapeType(ndim=2)) + s0 = rx.TensorType(rshape, dtype="int32") # undefined tvm.ir.assert_structural_equal( rx.analysis.erase_to_well_defined(s0, None, None), - rx.TensorStructInfo(ndim=2, dtype="int32"), + rx.TensorType(ndim=2, dtype="int32"), ) # defined @@ -150,44 +148,42 @@ def test_erase_to_well_defined_tensor(): tvm.ir.assert_structural_equal( rx.analysis.erase_to_well_defined(s0, None, {rshape: rx.ShapeExpr([1, 2])}), - rx.TensorStructInfo([1, 2], dtype="int32"), + rx.TensorType([1, 2], dtype="int32"), ) - s1 = rx.TensorStructInfo([m + 1, n], dtype="float32") + s1 = rx.TensorType([m + 1, n], dtype="float32") tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1, {n: n, m: m}), s1) tvm.ir.assert_structural_equal( rx.analysis.erase_to_well_defined(s1, {n: 2, m: 3}), - rx.TensorStructInfo([4, 2], dtype="float32"), + rx.TensorType([4, 2], dtype="float32"), ) tvm.ir.assert_structural_equal( rx.analysis.erase_to_well_defined(s1, {m: m}, {rshape: rshape}), - rx.TensorStructInfo(ndim=2, dtype="float32"), + rx.TensorType(ndim=2, dtype="float32"), ) - s2 = rx.TensorStructInfo([1, 2], dtype="float32") + s2 = rx.TensorType([1, 2], dtype="float32") tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2), s2) def test_erase_to_well_defined_tuple(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - s0 = rx.ObjectStructInfo() - s2 = rx.ShapeStructInfo([1, m]) - s4 = rx.TensorStructInfo([1, n + 1, m], "int64") - t0 = rx.TupleStructInfo([s4, s0]) - t1 = rx.TupleStructInfo([t0, s2]) + s0 = rx.ObjectType() + s2 = rx.ShapeType([1, m]) + s4 = rx.TensorType([1, n + 1, m], "int64") + t0 = rx.TupleType([s4, s0]) + t1 = rx.TupleType([t0, s2]) tvm.ir.assert_structural_equal( rx.analysis.erase_to_well_defined(t1, {m: m + 1}), - rx.TupleStructInfo( + rx.TupleType( [ - rx.TupleStructInfo( - [rx.TensorStructInfo(ndim=3, dtype="int64"), rx.ObjectStructInfo()] - ), - rx.ShapeStructInfo([1, m + 1]), + rx.TupleType([rx.TensorType(ndim=3, dtype="int64"), rx.ObjectType()]), + rx.ShapeType([1, m + 1]), ] ), ) @@ -196,10 +192,10 @@ def test_erase_to_well_defined_tuple(): def test_erase_to_well_defined_func(): def fn_info(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.TensorStructInfo([c, n, m], "float32") - y = rx.TensorStructInfo([c, n, 1], "float32") - z = rx.TensorStructInfo([c, n], "float32") - return rx.FuncStructInfo([x, y], z) + x = rx.TensorType([c, n, m], "float32") + y = rx.TensorType([c, n, 1], "float32") + z = rx.TensorType([c, n], "float32") + return rx.FuncType([x, y], z) f0 = fn_info(1) @@ -208,18 +204,18 @@ def fn_info(c): def test_base_check(): BR = rx.analysis.BaseCheckResult - bcheck = rx.analysis.struct_info_base_check + bcheck = rx.analysis.type_base_check n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - obj0 = rx.ObjectStructInfo() - prim0 = rx.PrimStructInfo("int32") - prim1 = rx.PrimStructInfo("float32") + obj0 = rx.ObjectType() + prim0 = rx.PrimType("int32") + prim1 = rx.PrimType("float32") - shape0 = rx.ShapeStructInfo(ndim=-1) - shape1 = rx.ShapeStructInfo(ndim=2) - shape2 = rx.ShapeStructInfo(ndim=3) - shape3 = rx.ShapeStructInfo([1, 2, 3]) - shape4 = rx.ShapeStructInfo([1, n, 3]) + shape0 = rx.ShapeType(ndim=-1) + shape1 = rx.ShapeType(ndim=2) + shape2 = rx.ShapeType(ndim=3) + shape3 = rx.ShapeType([1, 2, 3]) + shape4 = rx.ShapeType([1, n, 3]) vdevice0 = ir.VDevice() vdevice1 = ir.VDevice("llvm") @@ -227,23 +223,23 @@ def test_base_check(): vdevice3 = ir.VDevice("cuda", 2) vdevice4 = ir.VDevice("cuda", 0, "") - tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") - tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") - tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") - tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") - tensor4 = rx.TensorStructInfo([n, m], "int32") - tensor5 = rx.TensorStructInfo([n, m, 1], "int32") - tensor6 = rx.TensorStructInfo([n, m, 2], "int32") - tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0) - tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1) - tensor9 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice2) - tensor10 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice3) - tensor11 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice4) - tensor12 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0) - tensor13 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1) - tensor14 = rx.TensorStructInfo([n, m, 2], "int32", vdevice2) - tensor15 = rx.TensorStructInfo([n, m, 2], "int32", vdevice3) - tensor16 = rx.TensorStructInfo([n, m, 2], "int32", vdevice4) + tensor0 = rx.TensorType(ndim=-1, dtype="int32") + tensor1 = rx.TensorType(ndim=-1, dtype="float32") + tensor2 = rx.TensorType(ndim=2, dtype="int32") + tensor3 = rx.TensorType(ndim=2, dtype="float32") + tensor4 = rx.TensorType([n, m], "int32") + tensor5 = rx.TensorType([n, m, 1], "int32") + tensor6 = rx.TensorType([n, m, 2], "int32") + tensor7 = rx.TensorType(ndim=2, dtype="float32", vdevice=vdevice0) + tensor8 = rx.TensorType(ndim=2, dtype="float32", vdevice=vdevice1) + tensor9 = rx.TensorType(ndim=2, dtype="float32", vdevice=vdevice2) + tensor10 = rx.TensorType(ndim=2, dtype="float32", vdevice=vdevice3) + tensor11 = rx.TensorType(ndim=2, dtype="float32", vdevice=vdevice4) + tensor12 = rx.TensorType([n, m, 2], "int32", vdevice0) + tensor13 = rx.TensorType([n, m, 2], "int32", vdevice1) + tensor14 = rx.TensorType([n, m, 2], "int32", vdevice2) + tensor15 = rx.TensorType([n, m, 2], "int32", vdevice3) + tensor16 = rx.TensorType([n, m, 2], "int32", vdevice4) # obj assert bcheck(obj0, prim0) == BR.PASS @@ -277,7 +273,7 @@ def test_base_check(): # shape mismatch assert bcheck(shape3, shape4) == BR.FAIL_L2 - assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3])) + assert shape4.is_base_of(rx.ShapeType([1, n, 3])) # tensor assert bcheck(tensor0, obj0) == BR.FAIL_L1 @@ -305,7 +301,7 @@ def test_base_check(): assert bcheck(tensor5, tensor6) == BR.FAIL_L0 # match - assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32")) + assert tensor0.is_base_of(rx.TensorType(ndim=-1, dtype="int32")) assert tensor0.is_base_of(tensor2) assert tensor0.is_base_of(tensor4) assert tensor0.is_base_of(tensor5) @@ -315,58 +311,58 @@ def test_base_check(): assert tensor3.is_base_of(tensor8) assert tensor6.is_base_of(tensor12) assert tensor6.is_base_of(tensor13) - assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32")) + assert tensor4.is_base_of(rx.TensorType([n, m], dtype="int32")) # tuple - t0 = rx.TupleStructInfo([obj0, tensor0]) - t1 = rx.TupleStructInfo([prim0, tensor4]) - t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) - t3 = rx.TupleStructInfo([tensor0, obj0]) + t0 = rx.TupleType([obj0, tensor0]) + t1 = rx.TupleType([prim0, tensor4]) + t2 = rx.TupleType([obj0, tensor0, obj0]) + t3 = rx.TupleType([tensor0, obj0]) assert t0.is_base_of(t1) assert bcheck(t0, t2) == BR.FAIL_L0 assert bcheck(t0, t3) == BR.FAIL_L1 - assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1])) - assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1 + assert rx.TupleType([t0, t1]).is_base_of(rx.TupleType([t1, t1])) + assert bcheck(rx.TupleType([t0, t1]), rx.TupleType([t1, t0])) == BR.FAIL_L1 def fn_info_shape(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.TensorStructInfo([c, n, m], "float32") - y = rx.TensorStructInfo([c, n, 1], "float32") - z = rx.TensorStructInfo([c, n], "float32") - return rx.FuncStructInfo([x, y], z) + x = rx.TensorType([c, n, m], "float32") + y = rx.TensorType([c, n, 1], "float32") + z = rx.TensorType([c, n], "float32") + return rx.FuncType([x, y], z) def fn_info_erased(): - x = rx.TensorStructInfo(ndim=3, dtype="float32") - y = rx.TensorStructInfo(ndim=3, dtype="float32") - z = rx.TensorStructInfo(ndim=2, dtype="float32") - return rx.FuncStructInfo([x, y], z) + x = rx.TensorType(ndim=3, dtype="float32") + y = rx.TensorType(ndim=3, dtype="float32") + z = rx.TensorType(ndim=2, dtype="float32") + return rx.FuncType([x, y], z) assert fn_info_shape(1).is_base_of(fn_info_shape(1)) assert fn_info_erased().is_base_of(fn_info_shape(1)) assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2 - fopaque = rx.FuncStructInfo.opaque_func() + fopaque = rx.FuncType.opaque_func() assert fopaque.is_base_of(fn_info_shape(1)) -def _check_derive(ctx, finfo, args_sinfo, ret): +def _check_derive(ctx, finfo, args_ty, ret): gv = rx.GlobalVar("test") - rx.expr._update_struct_info(gv, finfo) + rx.expr._update_type(gv, finfo) args = [] - for i, sinfo in enumerate(args_sinfo): - arg = rx.Var(f"arg{i}", sinfo) + for i, ty in enumerate(args_ty): + arg = rx.Var(f"arg{i}", ty) args.append(arg) call = rx.Call(gv, args) - derived_ret = rx.analysis.derive_call_ret_struct_info(finfo, call, ctx) + derived_ret = rx.analysis.derive_call_ret_type(finfo, call, ctx) tvm.ir.assert_structural_equal(ret, derived_ret) -def test_derive_call_ret_struct_info(): - obj0 = rx.ObjectStructInfo() - prim0 = rx.PrimStructInfo("float32") +def test_derive_call_ret_type(): + obj0 = rx.ObjectType() + prim0 = rx.PrimType("float32") n, m = tirx.Var("n0", "int64"), tirx.Var("m0", "int64") bb = rx.BlockBuilder() @@ -375,23 +371,23 @@ def test_derive_call_ret_struct_info(): def func0(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.TensorStructInfo([n, m], "float32") - z = rx.TensorStructInfo([m + c, n], "float32") - return rx.FuncStructInfo([x], z) + x = rx.TensorType([n, m], "float32") + z = rx.TensorType([m + c, n], "float32") + return rx.FuncType([x], z) # Tensor => Tensor _check_derive( bb, func0(1), - [rx.TensorStructInfo([10, 11], "float32")], - rx.TensorStructInfo([12, 10], "float32"), + [rx.TensorType([10, 11], "float32")], + rx.TensorType([12, 10], "float32"), ) _check_derive( bb, func0(2), - [rx.TensorStructInfo([n, m], "float32")], - rx.TensorStructInfo([m + 2, n], "float32"), + [rx.TensorType([n, m], "float32")], + rx.TensorType([m + 2, n], "float32"), ) # passing in information that cannot deduce n, m @@ -400,8 +396,8 @@ def func0(c): _check_derive( bb, func0(2), - [rx.TensorStructInfo(ndim=2, dtype="float32")], - rx.TensorStructInfo(ndim=2, dtype="float32"), + [rx.TensorType(ndim=2, dtype="float32")], + rx.TensorType(ndim=2, dtype="float32"), ) # Error: wrong number of arguments @@ -409,8 +405,8 @@ def func0(c): _check_derive( bb, func0(2), - [rx.TensorStructInfo(ndim=2, dtype="float32"), obj0], - rx.TensorStructInfo(ndim=2, dtype="float32"), + [rx.TensorType(ndim=2, dtype="float32"), obj0], + rx.TensorType(ndim=2, dtype="float32"), ) # Error:type mismatch @@ -422,65 +418,65 @@ def func0(c): def func1(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.TensorStructInfo([n, m], "float32", vdev) - z = rx.TensorStructInfo([m + c, n], "float32", vdev) - return rx.FuncStructInfo([x], z) + x = rx.TensorType([n, m], "float32", vdev) + z = rx.TensorType([m + c, n], "float32", vdev) + return rx.FuncType([x], z) _check_derive( bb, func1(1), - [rx.TensorStructInfo([10, 11], "float32", vdev)], - rx.TensorStructInfo([12, 10], "float32", vdev), + [rx.TensorType([10, 11], "float32", vdev)], + rx.TensorType([12, 10], "float32", vdev), ) # opaque derivation - fopaque0 = lambda: rx.FuncStructInfo.opaque_func() - fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + fopaque0 = lambda: rx.FuncType.opaque_func() + fopaque1 = lambda: rx.FuncType.opaque_func(ret=prim0) _check_derive(bb, fopaque0(), [obj0, prim0], obj0) _check_derive(bb, fopaque1(), [obj0, prim0], prim0) # recursive tuple derivation def func_tuple0(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x0 = rx.TensorStructInfo([n, c], "float32") - x1 = rx.TensorStructInfo([n + c, m], "float32") - z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) - return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + x0 = rx.TensorType([n, c], "float32") + x1 = rx.TensorType([n + c, m], "float32") + z = rx.TupleType([rx.TensorType([m, n], "float32")]) + return rx.FuncType([rx.TupleType([x0, x1])], z) _check_derive( bb, func_tuple0(2), [ - rx.TupleStructInfo( + rx.TupleType( [ - rx.TensorStructInfo([n, 2], "float32"), - rx.TensorStructInfo([n + 2, 10], "float32"), + rx.TensorType([n, 2], "float32"), + rx.TensorType([n + 2, 10], "float32"), ] ) ], - rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + rx.TupleType([rx.TensorType([10, n], "float32")]), ) def func_tuple1(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x0 = rx.TensorStructInfo([n, m], "float32") - x1 = rx.TensorStructInfo([n + c, c], "float32") - z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) - return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + x0 = rx.TensorType([n, m], "float32") + x1 = rx.TensorType([n + c, c], "float32") + z = rx.TupleType([rx.TensorType([m, n], "float32")]) + return rx.FuncType([rx.TupleType([x0, x1])], z) # Still OK, to pass erased tensor into n+2, n is captured by other argument. _check_derive( bb, func_tuple1(4), [ - rx.TupleStructInfo( + rx.TupleType( [ - rx.TensorStructInfo([n, 4], "float32"), - rx.TensorStructInfo(ndim=2, dtype="float32"), + rx.TensorType([n, 4], "float32"), + rx.TensorType(ndim=2, dtype="float32"), ] ) ], - rx.TupleStructInfo([rx.TensorStructInfo([4, n], "float32")]), + rx.TupleType([rx.TensorType([4, n], "float32")]), ) # tuple length mismatch is not causes an error @@ -488,62 +484,62 @@ def func_tuple1(c): _check_derive( bb, func_tuple0(4), - [rx.TupleStructInfo([rx.TensorStructInfo([n, 4], "float32")])], - rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + [rx.TupleType([rx.TensorType([n, 4], "float32")])], + rx.TupleType([rx.TensorType([10, n], "float32")]), ) # mixed shape types def func_shape_mixed(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x0 = rx.ShapeStructInfo([n, m]) + x0 = rx.ShapeType([n, m]) f0 = func_tuple0(c) - z = rx.ShapeStructInfo([m + n, c]) - return rx.FuncStructInfo([x0, f0], z) + z = rx.ShapeType([m + n, c]) + return rx.FuncType([x0, f0], z) _check_derive( bb, func_shape_mixed(3), [ - rx.ShapeStructInfo([10, 20]), + rx.ShapeType([10, 20]), # have to specify purity because an impure function cannot be passed # where a pure one is expected - rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2), purity=True), + rx.FuncType.opaque_func(ret=rx.ShapeType(ndim=2), purity=True), ], - rx.ShapeStructInfo([30, 3]), + rx.ShapeType([30, 3]), ) def _check_lca(lhs, rhs, target): - tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target) - tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target) + tvm.ir.assert_structural_equal(rx.analysis.type_lca(lhs, rhs), target) + tvm.ir.assert_structural_equal(rx.analysis.type_lca(rhs, lhs), target) -def test_struct_info_lca(): +def test_type_lca(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - obj0 = rx.ObjectStructInfo() - prim0 = rx.PrimStructInfo("int32") - prim1 = rx.PrimStructInfo("float32") + obj0 = rx.ObjectType() + prim0 = rx.PrimType("int32") + prim1 = rx.PrimType("float32") vdevice0 = ir.VDevice("llvm") vdevice1 = ir.VDevice("cuda", 0) - shape0 = rx.ShapeStructInfo(ndim=-1) - shape1 = rx.ShapeStructInfo(ndim=2) - shape2 = rx.ShapeStructInfo(ndim=3) - shape3 = rx.ShapeStructInfo([1, 2, 3]) - shape4 = rx.ShapeStructInfo([1, n, 3]) - - tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") - tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") - tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") - tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") - tensor4 = rx.TensorStructInfo([n, m], "int32") - tensor5 = rx.TensorStructInfo([n, m, 1], "int32") - tensor6 = rx.TensorStructInfo([n, m, 2], "int32") - tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0) - tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1) - tensor9 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0) - tensor10 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1) + shape0 = rx.ShapeType(ndim=-1) + shape1 = rx.ShapeType(ndim=2) + shape2 = rx.ShapeType(ndim=3) + shape3 = rx.ShapeType([1, 2, 3]) + shape4 = rx.ShapeType([1, n, 3]) + + tensor0 = rx.TensorType(ndim=-1, dtype="int32") + tensor1 = rx.TensorType(ndim=-1, dtype="float32") + tensor2 = rx.TensorType(ndim=2, dtype="int32") + tensor3 = rx.TensorType(ndim=2, dtype="float32") + tensor4 = rx.TensorType([n, m], "int32") + tensor5 = rx.TensorType([n, m, 1], "int32") + tensor6 = rx.TensorType([n, m, 2], "int32") + tensor7 = rx.TensorType(ndim=2, dtype="float32", vdevice=vdevice0) + tensor8 = rx.TensorType(ndim=2, dtype="float32", vdevice=vdevice1) + tensor9 = rx.TensorType([n, m, 2], "int32", vdevice0) + tensor10 = rx.TensorType([n, m, 2], "int32", vdevice1) # obj _check_lca(obj0, prim0, obj0) @@ -557,11 +553,11 @@ def test_struct_info_lca(): _check_lca(shape2, shape3, shape2) _check_lca(shape3, shape4, shape2) - _check_lca(shape4, rx.ShapeStructInfo([1, n, 3]), shape4) + _check_lca(shape4, rx.ShapeType([1, n, 3]), shape4) # tensor _check_lca(tensor0, prim0, obj0) - _check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None)) + _check_lca(tensor0, tensor1, rx.TensorType(ndim=-1, dtype=None)) _check_lca(tensor0, tensor2, tensor0) _check_lca(tensor0, tensor4, tensor0) _check_lca(tensor0, tensor4, tensor0) @@ -573,46 +569,44 @@ def test_struct_info_lca(): _check_lca(tensor6, tensor10, tensor6) _check_lca(tensor2, tensor4, tensor2) - _check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32")) - _check_lca(tensor4, tensor5, rx.TensorStructInfo(ndim=-1, dtype="int32")) - _check_lca(tensor4, rx.TensorStructInfo([n, m], dtype="int32"), tensor4) + _check_lca(tensor5, tensor6, rx.TensorType(ndim=3, dtype="int32")) + _check_lca(tensor4, tensor5, rx.TensorType(ndim=-1, dtype="int32")) + _check_lca(tensor4, rx.TensorType([n, m], dtype="int32"), tensor4) # tuple - t0 = rx.TupleStructInfo([obj0, tensor0]) - t1 = rx.TupleStructInfo([prim0, tensor4]) - t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) - t3 = rx.TupleStructInfo([tensor0, obj0]) + t0 = rx.TupleType([obj0, tensor0]) + t1 = rx.TupleType([prim0, tensor4]) + t2 = rx.TupleType([obj0, tensor0, obj0]) + t3 = rx.TupleType([tensor0, obj0]) _check_lca(t0, t1, t0) _check_lca(t0, t2, obj0) - _check_lca(t0, t3, rx.TupleStructInfo([obj0, obj0])) + _check_lca(t0, t3, rx.TupleType([obj0, obj0])) - t5 = rx.TupleStructInfo([t0, t1]) - t6 = rx.TupleStructInfo([t1, t2]) + t5 = rx.TupleType([t0, t1]) + t6 = rx.TupleType([t1, t2]) - _check_lca(t5, t6, rx.TupleStructInfo([t0, obj0])) + _check_lca(t5, t6, rx.TupleType([t0, obj0])) - t7 = rx.TupleStructInfo([]) - _check_lca(t7, rx.TupleStructInfo([]), t7) + t7 = rx.TupleType([]) + _check_lca(t7, rx.TupleType([]), t7) def fn_info_shape(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.TensorStructInfo([c, n, m], "float32") - y = rx.TensorStructInfo([c, n, 1], "float32") - z = rx.TensorStructInfo([c, n], "float32") - return rx.FuncStructInfo([x, y], z) + x = rx.TensorType([c, n, m], "float32") + y = rx.TensorType([c, n, 1], "float32") + z = rx.TensorType([c, n], "float32") + return rx.FuncType([x, y], z) def fn_info_erased(): - x = rx.TensorStructInfo(ndim=3, dtype="float32") - y = rx.TensorStructInfo(ndim=3, dtype="float32") - z = rx.TensorStructInfo(ndim=2, dtype="float32") - return rx.FuncStructInfo([x, y], z) - - fopaque0 = lambda: rx.FuncStructInfo.opaque_func() - fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) - fopaque2 = lambda: rx.FuncStructInfo.opaque_func( - ret=rx.TensorStructInfo(ndim=2, dtype="float32") - ) + x = rx.TensorType(ndim=3, dtype="float32") + y = rx.TensorType(ndim=3, dtype="float32") + z = rx.TensorType(ndim=2, dtype="float32") + return rx.FuncType([x, y], z) + + fopaque0 = lambda: rx.FuncType.opaque_func() + fopaque1 = lambda: rx.FuncType.opaque_func(ret=prim0) + fopaque2 = lambda: rx.FuncType.opaque_func(ret=rx.TensorType(ndim=2, dtype="float32")) _check_lca(fn_info_shape(1), fn_info_shape(2), fn_info_erased()) _check_lca(fn_info_shape(2), fn_info_shape(2), fn_info_shape(2)) @@ -639,7 +633,7 @@ def _generate_prim_test_cases(): ] for dtype in dtypes: - # LCA of a PrimStructInfo with itself yields itself + # LCA of a PrimType with itself yields itself yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype)) # The LCA of two values, each statically known to be the same @@ -695,20 +689,20 @@ def _generate_prim_test_cases(): @pytest.mark.parametrize("test_case", list(_generate_prim_test_cases())) -def test_prim_struct_info_lca(test_case): - def _normalize_sinfo(sinfo): - if isinstance(sinfo, tvm.relax.StructInfo): - return sinfo - elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy): - return sinfo.as_struct_info() - elif callable(sinfo): - return sinfo() +def test_prim_type_lca(test_case): + def _normalize_ty(ty): + if isinstance(ty, tvm.relax.Type): + return ty + elif isinstance(ty, tvm.script.parser.relax.entry.TypeProxy): + return ty.as_ty() + elif callable(ty): + return ty() else: - raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo") + raise TypeError(f"Cannot normalize {type(ty)} to Type") - lhs, rhs, expected = map(_normalize_sinfo, test_case) + lhs, rhs, expected = map(_normalize_ty, test_case) - lca = rx.analysis.struct_info_lca(lhs, rhs) + lca = rx.analysis.type_lca(lhs, rhs) assert tvm_ffi.structural_equal(lca, expected), ( f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found {lca}" ) @@ -716,14 +710,14 @@ def _normalize_sinfo(sinfo): def _generate_tir_var_test_cases(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - shape0 = rx.ShapeStructInfo([1, n, 3]) - shape1 = rx.ShapeStructInfo([1, 2 * n, n, m]) - shape2 = rx.ShapeStructInfo([1, 2 * n, m]) - tensor0 = rx.TensorStructInfo([1, n, 3], "int32") - tensor1 = rx.TensorStructInfo([1, 2 * n, n, m], "int32") - tensor2 = rx.TensorStructInfo([1, 2 * n, m], "int32") - func = rx.FuncStructInfo( - [rx.TensorStructInfo([1, 2 * n, n, m], "int32")], rx.TensorStructInfo([1, n, 3], "int32") + shape0 = rx.ShapeType([1, n, 3]) + shape1 = rx.ShapeType([1, 2 * n, n, m]) + shape2 = rx.ShapeType([1, 2 * n, m]) + tensor0 = rx.TensorType([1, n, 3], "int32") + tensor1 = rx.TensorType([1, 2 * n, n, m], "int32") + tensor2 = rx.TensorType([1, 2 * n, m], "int32") + func = rx.FuncType( + [rx.TensorType([1, 2 * n, n, m], "int32")], rx.TensorType([1, n, 3], "int32") ) yield shape0, [n], [n] @@ -738,16 +732,14 @@ def _generate_tir_var_test_cases(): tir_var_test_case = tvm.testing.parameter(*_generate_tir_var_test_cases()) -def test_tir_vars_in_struct_info(tir_var_test_case): - sinfo, _vars_definable, vars_used = tir_var_test_case - tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(sinfo), vars_used) +def test_tir_vars_in_type(tir_var_test_case): + ty, _vars_definable, vars_used = tir_var_test_case + tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_type(ty), vars_used) -def test_definable_tir_vars_in_struct_info(tir_var_test_case): - sinfo, vars_definable, _vars_used = tir_var_test_case - tvm.ir.assert_structural_equal( - rx.analysis.definable_tir_vars_in_struct_info(sinfo), vars_definable - ) +def test_definable_tir_vars_in_type(tir_var_test_case): + ty, vars_definable, _vars_used = tir_var_test_case + tvm.ir.assert_structural_equal(rx.analysis.definable_tir_vars_in_type(ty), vars_definable) def test_collect_symbolic_var_from_tensor_shape(): @@ -759,10 +751,10 @@ def test_collect_symbolic_var_from_tensor_shape(): tirx.Var("p", "int64"), ) bb = rx.BlockBuilder() - x = rx.Var("x", rx.TensorStructInfo([m, m + n], "float32")) + x = rx.Var("x", rx.TensorType([m, m + n], "float32")) with bb.function("main", [x]): - v0 = bb.match_cast(x, rx.TensorStructInfo([m, k], "float32")) - v1 = bb.emit(rx.call_dps_packed("test", x, rx.TensorStructInfo([p, q], "float32"))) + v0 = bb.match_cast(x, rx.TensorType([m, k], "float32")) + v1 = bb.emit(rx.call_dps_packed("test", x, rx.TensorType([p, q], "float32"))) bb.emit_func_output(rx.const(1)) func = bb.get()["main"] @@ -781,16 +773,16 @@ def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order): tir_m = tirx.Var("m", "int64") bb = rx.BlockBuilder() - arg = rx.Var("arg", rx.TensorStructInfo([tir_n * tir_m])) + arg = rx.Var("arg", rx.TensorType([tir_n * tir_m])) if param_type == "shape_expr": extra_params = [ - rx.Var("shape_expr", rx.ShapeStructInfo([tir_n, tir_m])), + rx.Var("shape_expr", rx.ShapeType([tir_n, tir_m])), ] elif param_type == "prim_value": extra_params = [ - rx.Var("n", rx.PrimStructInfo(value=tir_n)), - rx.Var("m", rx.PrimStructInfo(value=tir_m)), + rx.Var("n", rx.PrimType(value=tir_n)), + rx.Var("m", rx.PrimType(value=tir_m)), ] else: raise ValueError(f"Unknown param_type: {param_type}") @@ -823,34 +815,34 @@ def func( ): return R.tuple() - M, N = list(func.params[2].struct_info.values) + M, N = list(func.params[2].ty.values) # Expressions are de-duplicated, in order of their first appearance tvm.ir.assert_structural_equal( - rx.analysis.collect_non_negative_expressions(func.struct_info), + rx.analysis.collect_non_negative_expressions(func.ty), [M, N - 2, N, M + 2], ) # Tensor shapes can imply that their shapes are non-negative tvm.ir.assert_structural_equal( - rx.analysis.collect_non_negative_expressions(func.params[0].struct_info), + rx.analysis.collect_non_negative_expressions(func.params[0].ty), [M, N - 2], ) tvm.ir.assert_structural_equal( - rx.analysis.collect_non_negative_expressions(func.params[1].struct_info), + rx.analysis.collect_non_negative_expressions(func.params[1].ty), [N, M + 2], ) # ShapeExpr values can imply that their contents are non-negative tvm.ir.assert_structural_equal( - rx.analysis.collect_non_negative_expressions(func.params[2].struct_info), + rx.analysis.collect_non_negative_expressions(func.params[2].ty), [M, N], ) # PrimValue instances may contain negative values, and do not # imply that their contents are non-negative. tvm.ir.assert_structural_equal( - rx.analysis.collect_non_negative_expressions(func.params[3].struct_info), + rx.analysis.collect_non_negative_expressions(func.params[3].ty), [], ) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 06af93869cdb..a123dfe75e29 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -50,7 +50,7 @@ def test_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) # Error: Var gv0 is defined more than once gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) @@ -60,7 +60,7 @@ def test_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_dataflow_var(): @@ -72,7 +72,7 @@ def test_dataflow_var(): blocks = [rx.DataflowBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) # Error: DataflowVar gv0 is defined more than once lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) @@ -82,7 +82,7 @@ def test_dataflow_var(): blocks = [rx.DataflowBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) # Error: DataflowVar lv0 is defined outside DataflowBlock lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) @@ -91,7 +91,7 @@ def test_dataflow_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) # Error: DataflowVar lv0 is used outside DataflowBlock lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) @@ -101,7 +101,7 @@ def test_dataflow_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_param_var(): @@ -116,7 +116,7 @@ def test_param_var(): gv0 = bb.emit(rx.op.add(v2, v1)) bb.emit_func_output(gv0) mod = bb.get() - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_global_var(): @@ -131,7 +131,7 @@ def test_global_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_symbolic_var(): @@ -143,7 +143,7 @@ def test_symbolic_var(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_symbolic_var_across_functions(): @@ -157,13 +157,11 @@ def test_symbolic_var_across_functions(): with bb.function("func2", [v1]): bb.emit_func_output(v1) mod = bb.get() - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_symbolic_var_invalid_type(): - with pytest.raises( - RuntimeError, match="the value in ShapeStructInfo can only have dtype of int64" - ): + with pytest.raises(RuntimeError, match="the value in ShapeType can only have dtype of int64"): dim = tirx.Var("dim", "float32") y = rx.Var("y", R.Tensor([dim], "float32")) gv0 = rx.Var("gv0", R.Tensor([dim], "float32")) @@ -172,7 +170,7 @@ def test_symbolic_var_invalid_type(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks, [y]) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_seq_expr(): @@ -189,22 +187,22 @@ def test_seq_expr(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_recursive(): - scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32") - gv0 = rx.Var("gv0", scalar_struct_info) - f = rx.Var("f", rx.FuncStructInfo([scalar_struct_info], scalar_struct_info)) - ipt = rx.Var("ipt", scalar_struct_info) - x0 = rx.Var("x0", scalar_struct_info) - x1 = rx.Var("x1", scalar_struct_info) - x2 = rx.Var("x2", scalar_struct_info) - y = rx.Var("y", scalar_struct_info) + scalar_ty = rx.TensorType(shape=[], dtype="int32") + gv0 = rx.Var("gv0", scalar_ty) + f = rx.Var("f", rx.FuncType([scalar_ty], scalar_ty)) + ipt = rx.Var("ipt", scalar_ty) + x0 = rx.Var("x0", scalar_ty) + x1 = rx.Var("x1", scalar_ty) + x2 = rx.Var("x2", scalar_ty) + y = rx.Var("y", scalar_ty) inner_block = rx.BindingBlock( [rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, rx.Call(f, [x0]))] ) - inner_func = rx.Function([ipt], rx.SeqExpr([inner_block], y), scalar_struct_info) + inner_func = rx.Function([ipt], rx.SeqExpr([inner_block], y), scalar_ty) outer_block = rx.BindingBlock( [ rx.VarBinding(f, inner_func), @@ -213,7 +211,7 @@ def test_recursive(): rx.VarBinding(gv0, x2), ] ) - func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info) + func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_ty) mod = tvm.IRModule.from_expr(func) normalized = rx.transform.Normalize()(mod) rx.analysis.well_formed(normalized) @@ -248,7 +246,7 @@ def test_if(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=True) + assert not rx.analysis.check_well_formed(mod, check_ty=True) def test_if_non_seq_body(): @@ -266,7 +264,7 @@ def test_if_non_seq_body(): ] func = build_function(blocks) mod = tvm.IRModule.from_expr(func) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) # on the other hand, if they're wrapped in a seq node, it's fine seq = rx.SeqExpr([], x) @@ -283,9 +281,9 @@ def test_if_non_seq_body(): ] new_func = build_function(new_blocks) new_mod = tvm.IRModule.from_expr(new_func) - # apply normalization to fill in struct_info_ + # apply normalization to fill in ty normalized = rx.transform.Normalize()(new_mod) - rx.analysis.well_formed(normalized, check_struct_info=True) + rx.analysis.well_formed(normalized, check_ty=True) def test_if_complex_condition(): @@ -305,7 +303,7 @@ def test_if_complex_condition(): ] func = build_function(blocks) mod = tvm.IRModule.from_expr(func) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) cond_var = rx.Var("q", R.Tensor([], "bool")) new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x)) @@ -322,30 +320,28 @@ def test_if_complex_condition(): ] func = build_function(blocks) mod = tvm.IRModule.from_expr(func) - # apply normalization to fill in struct_info_ + # apply normalization to fill in ty normalized = rx.transform.Normalize()(mod) - rx.analysis.well_formed(normalized, check_struct_info=True) + rx.analysis.well_formed(normalized, check_ty=True) def test_tuple_get_item_nested(): # Error: The tuple value in tuple get item must be a leaf expression - nested_tup = rx.Var( - "t", rx.TupleStructInfo([rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])]) - ) + nested_tup = rx.Var("t", rx.TupleType([rx.TupleType([rx.TensorType([], "int32")])])) double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0) ret_var = rx.Var("r", R.Tensor([], "int32")) f = rx.Function( [nested_tup], rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], ret_var), - ret_struct_info=R.Tensor(ndim=0, dtype="int32"), + ret_ty=R.Tensor(ndim=0, dtype="int32"), ) f = f.with_attr("global_symbol", "f") mod = tvm.IRModule.from_expr(f) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) # okay with an intermediate binding first_idx = rx.TupleGetItem(nested_tup, 0) - idx_var = rx.Var("v", rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])) + idx_var = rx.Var("v", rx.TupleType([rx.TensorType([], "int32")])) second_idx = rx.TupleGetItem(idx_var, 0) new_f = rx.Function( [nested_tup], @@ -357,13 +353,13 @@ def test_tuple_get_item_nested(): ], ret_var, ), - ret_struct_info=R.Tensor(ndim=0, dtype="int32"), + ret_ty=R.Tensor(ndim=0, dtype="int32"), ) new_f = new_f.with_attr("global_symbol", "new_f") mod = tvm.IRModule.from_expr(new_f) # normalize in order to fill in checked type normalized = rx.transform.Normalize()(mod) - rx.analysis.well_formed(normalized, check_struct_info=True) + rx.analysis.well_formed(normalized, check_ty=True) def test_complex_seq_body(): @@ -376,7 +372,7 @@ def test_complex_seq_body(): R.Tensor(ndim=0, dtype="int32"), ).with_attr("global_symbol", "foo") mod = tvm.IRModule.from_expr(func) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) # but if the result is bound, then it's okay z = rx.Var("z", R.Tensor([], "int32")) @@ -400,7 +396,7 @@ def test_complex_seq_body(): new_mod = tvm.IRModule.from_expr(new_func) # normalize in order to fill in checked type normalized = rx.transform.Normalize()(new_mod) - rx.analysis.well_formed(normalized, check_struct_info=True) + rx.analysis.well_formed(normalized, check_ty=True) def test_inline_prim_func(): @@ -436,7 +432,7 @@ def test_inline_prim_func(): R.Tensor(ndim=0, dtype="int32"), ).with_attr("global_symbol", "foo") new_mod = tvm.IRModule.from_expr(new_func) - assert not rx.analysis.check_well_formed(new_mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(new_mod, check_ty=False) def test_ANF(): @@ -447,7 +443,7 @@ def test_ANF(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) # Error: Call Node in Tuple gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) @@ -455,7 +451,7 @@ def test_ANF(): blocks = [rx.BindingBlock(bindings)] func = build_function(blocks) mod = tvm.IRModule({rx.GlobalVar("foo"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_global_var_vs_gsymbol(): @@ -469,19 +465,19 @@ def test_global_var_vs_gsymbol(): R.Tensor(ndim=2, dtype="float32"), ).with_attr("global_symbol", "main1") mod = tvm.IRModule({rx.GlobalVar("main"): func}) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) def test_nested_dataflow(): - scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32") - gv0 = rx.Var("gv0", scalar_struct_info) - f = rx.DataflowVar("f", rx.FuncStructInfo([], scalar_struct_info)) - x0 = rx.DataflowVar("x0", scalar_struct_info) - x1 = rx.DataflowVar("x1", scalar_struct_info) - x2 = rx.DataflowVar("x2", scalar_struct_info) - y = rx.Var("y", scalar_struct_info) + scalar_ty = rx.TensorType(shape=[], dtype="int32") + gv0 = rx.Var("gv0", scalar_ty) + f = rx.DataflowVar("f", rx.FuncType([], scalar_ty)) + x0 = rx.DataflowVar("x0", scalar_ty) + x1 = rx.DataflowVar("x1", scalar_ty) + x2 = rx.DataflowVar("x2", scalar_ty) + y = rx.Var("y", scalar_ty) inner_block = rx.DataflowBlock([rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, x0)]) - inner_func = rx.Function([], rx.SeqExpr([inner_block], y), scalar_struct_info) + inner_func = rx.Function([], rx.SeqExpr([inner_block], y), scalar_ty) outer_block = rx.DataflowBlock( [ rx.VarBinding(x1, rx.const(1, "int32")), @@ -490,45 +486,45 @@ def test_nested_dataflow(): rx.VarBinding(gv0, x2), ] ) - func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info) + func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_ty) mod = tvm.IRModule.from_expr(func) normalized = rx.transform.Normalize()(mod) rx.analysis.well_formed(normalized) -def test_sinfo_args_tir_var_used_before_define_call_packed(): +def test_ty_args_tir_var_used_before_define_call_packed(): # Error: Symbolic Var m1, n1 are not defined m1 = tirx.Var("m1", "int64") n1 = tirx.Var("n1", "int64") - call = R.call_packed("my_func", x, sinfo_args=R.Tensor((m1, n1), "float32")) + call = R.call_packed("my_func", x, ty_args=R.Tensor((m1, n1), "float32")) func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) -def test_sinfo_args_tir_var_used_before_define_call_tir(): +def test_ty_args_tir_var_used_before_define_call_tir(): # Error: Symbolic Var m1, n1 are not defined m1 = tirx.Var("m1", "int64") n1 = tirx.Var("n1", "int64") - call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) + call = R.call_dps_packed("my_func", x, out_ty=R.Tensor((m1, n1), "float32")) func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) - assert not rx.analysis.check_well_formed(mod, check_struct_info=False) + assert not rx.analysis.check_well_formed(mod, check_ty=False) -def test_sinfo_erase_to_well_formed(): - # Error: The return sinfo contains undefined symbolic vars +def test_ty_erase_to_well_formed(): + # Error: The return ty contains undefined symbolic vars """ @R.function def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtype="float32"): m = T.int64() n = T.int64() - gv = R.call_dps_packed("my_func", (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + gv = R.call_dps_packed("my_func", (x,), out_ty=R.Tensor((m, n), dtype="float32")) return gv """ m1 = tirx.Var("m1", "int64") n1 = tirx.Var("n1", "int64") - call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m, n), "float32")) + call = R.call_dps_packed("my_func", x, out_ty=R.Tensor((m, n), "float32")) blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])] seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr( @@ -538,7 +534,7 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtyp assert not rx.analysis.check_well_formed(mod) -def test_func_sinfo_well_formed(): +def test_func_ty_well_formed(): @R.function def foo(): @R.function @@ -553,8 +549,8 @@ def local(x: R.Tensor(["m", "n"], "float32")): def test_conditional_in_dataflow_block(): # error: not allowed to have a conditional inside a dataflow block - x = rx.Var("x", rx.TensorStructInfo([], dtype="int32")) - y = rx.Var("y", rx.TensorStructInfo([], dtype="int32")) + x = rx.Var("x", rx.TensorType([], dtype="int32")) + y = rx.Var("y", rx.TensorType([], dtype="int32")) block = rx.DataflowBlock([rx.VarBinding(y, rx.If(rx.const(True, dtype="bool"), x, x))]) func = rx.Function([x], rx.SeqExpr([block], y), R.Tensor((), dtype="int32")).with_attr( "global_symbol", "foo" @@ -713,7 +709,7 @@ def test_call_tir_with_matching_arguments(): class Module: @R.function def main(A: R.Tensor([16], "float16")): - B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + B = R.call_tir(Module.add_one, A, out_ty=R.Tensor([16], "float16")) return B @T.prim_func(s_tir=True) @@ -738,7 +734,7 @@ def test_call_tir_input_ndim(): class Module: @R.function def main(A: R.Tensor([4, 4], "float16")): - B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + B = R.call_tir(Module.add_one, A, out_ty=R.Tensor([16], "float16")) return B @T.prim_func(s_tir=True) @@ -762,7 +758,7 @@ def test_call_tir_output_ndim(): class Module: @R.function def main(A: R.Tensor([16], "float16")): - B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([4, 4], "float16")) + B = R.call_tir(Module.add_one, A, out_ty=R.Tensor([4, 4], "float16")) return B @T.prim_func(s_tir=True) @@ -787,7 +783,7 @@ def test_call_tir_input_shape(): class Module: @R.function def main(A: R.Tensor([32], "float16")): - B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + B = R.call_tir(Module.add_one, A, out_ty=R.Tensor([16], "float16")) return B @T.prim_func(s_tir=True) @@ -811,7 +807,7 @@ def test_call_tir_output_shape(): class Module: @R.function def main(A: R.Tensor([16], "float16")): - B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([32], "float16")) + B = R.call_tir(Module.add_one, A, out_ty=R.Tensor([32], "float16")) return B @T.prim_func(s_tir=True) @@ -837,7 +833,7 @@ def test_call_tir_input_dtype(): class Module: @R.function def main(A: R.Tensor([16], "float32")): - B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + B = R.call_tir(Module.add_one, A, out_ty=R.Tensor([16], "float16")) return B @T.prim_func(s_tir=True) @@ -863,7 +859,7 @@ def test_call_tir_output_dtype(): class Module: @R.function def main(A: R.Tensor([16], "float16")): - B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float32")) + B = R.call_tir(Module.add_one, A, out_ty=R.Tensor([16], "float32")) return B @T.prim_func(s_tir=True) @@ -881,7 +877,7 @@ def test_call_tir_with_correct_dynamic_output_shape(): Here, the input arguments to the `reshape` function are not sufficient to infer the shape of the outputs. This is legal, - since the output shape is determined by the `out_sinfo` parameter. + since the output shape is determined by the `out_ty` parameter. Inability to verify the output shape does not mean that the output shape is invalid. @@ -892,7 +888,7 @@ def test_call_tir_with_correct_dynamic_output_shape(): class Module: @R.function def main(A: R.Tensor([16], "float16")): - B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 8], "float16")) + B = R.call_tir(Module.reshape, A, out_ty=R.Tensor([2, 8], "float16")) return B @T.prim_func(s_tir=True) @@ -925,7 +921,7 @@ def test_call_tir_with_incorrect_dynamic_output_shape(): class Module: @R.function def main(A: R.Tensor([16], "float16")): - B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([16, 16], "float16")) + B = R.call_tir(Module.reshape, A, out_ty=R.Tensor([16, 16], "float16")) return B @T.prim_func(s_tir=True) @@ -960,7 +956,7 @@ def test_call_tir_incorrect_dimensionality_of_output_shape(): class Module: @R.function def main(A: R.Tensor([16], "float16")): - B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 4, 2], "float16")) + B = R.call_tir(Module.reshape, A, out_ty=R.Tensor([2, 4, 2], "float16")) return B @T.prim_func(s_tir=True) @@ -983,11 +979,11 @@ def test_call_tir_output_shape_with_mixed_static_and_dynamic(): Here, the input arguments to the `reshape` function are not sufficient to infer the shape of the outputs. This is legal, - since the output shape is taken from the `out_sinfo` parameter. + since the output shape is taken from the `out_ty` parameter. Identifying this failure mode is not yet supported in the current implementation. This is because the output is inferred as - `R.Tensor(ndim=3, dtype="float16")`, and the explicit `out_sinfo` + `R.Tensor(ndim=3, dtype="float16")`, and the explicit `out_ty` is a 3-d tensor. The mismatch in the first dimension is not yet counted, because the entire tensor shape is removed by `EraseToWellDefined`. @@ -998,7 +994,7 @@ def test_call_tir_output_shape_with_mixed_static_and_dynamic(): class Module: @R.function def main(A: R.Tensor([256], "float16")): - B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([8, 16, 2], "float16")) + B = R.call_tir(Module.reshape, A, out_ty=R.Tensor([8, 16, 2], "float16")) return B @T.prim_func(s_tir=True) @@ -1022,7 +1018,7 @@ def test_call_tir_with_correct_inferred_dynamic_output_shape(): TIR buffer. Even though it is dynamic, the input shapes are sufficient to infer that `M==8` and `N==4`. As a result, the output shape of `[M*N]` can be inferred to be `[32]`, and the - shape specified in `out_sinfo` can be validated. + shape specified in `out_ty` can be validated. """ @@ -1030,7 +1026,7 @@ def test_call_tir_with_correct_inferred_dynamic_output_shape(): class Module: @R.function def main(A: R.Tensor([8, 4], "float16")): - B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([32], "float16")) + B = R.call_tir(Module.flatten, A, out_ty=R.Tensor([32], "float16")) return B @T.prim_func(s_tir=True) @@ -1055,7 +1051,7 @@ def test_call_tir_with_incorrect_inferred_dynamic_output_shape(): TIR buffer. Even though it is dynamic, the input shapes are sufficient to infer that `M==8` and `N==4`. As a result, the output shape of `[M*N]` can be inferred to be `[32]`, and the - shape specified in `out_sinfo` can be validated. + shape specified in `out_ty` can be validated. This unit test is identical to the above test `test_call_tir_with_correct_inferred_dynamic_output_shape`, except @@ -1068,7 +1064,7 @@ def test_call_tir_with_incorrect_inferred_dynamic_output_shape(): class Module: @R.function def main(A: R.Tensor([8, 4], "float16")): - B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([64], "float16")) + B = R.call_tir(Module.flatten, A, out_ty=R.Tensor([64], "float16")) return B @T.prim_func(s_tir=True) @@ -1090,7 +1086,7 @@ def test_call_tir_with_dtensor_arguments(): """R.call_tir and R.dist.call_tir share the same operation Both `R.call_tir` and `R.dist.call_tir` produce the same - "relax.call_tir" operation, differing only in the StructInfo of + "relax.call_tir" operation, differing only in the Type of their arguments. Normalization of "relax.call_tir" must handle `R.DTensor` arguments. @@ -1106,7 +1102,7 @@ class Module: @R.function def main(A: R.dist.DTensor([8, 4], "float16", "mesh[0]", "S[0]")): B = R.dist.call_tir( - Module.flatten, A, out_sinfo=R.dist.DTensor([64], "float16", "mesh[0]", "S[0]") + Module.flatten, A, out_ty=R.dist.DTensor([64], "float16", "mesh[0]", "S[0]") ) return B @@ -1136,7 +1132,7 @@ def main(A: R.Tensor([16], "float16")): Module.add_one, A, inplace_indices=[0], - out_sinfo=R.Tensor([16], "float16"), + out_ty=R.Tensor([16], "float16"), ) return B @@ -1161,7 +1157,7 @@ def main(A: R.Tensor([16], "float16")): Module.add_one, A, inplace_indices=[0], - out_sinfo=R.Tensor([32], "float16"), + out_ty=R.Tensor([32], "float16"), ) return B @@ -1186,7 +1182,7 @@ def main(A: R.Tensor([16], "float16"), B: R.Tensor([32], "float16")): Module.add_one, (A, B), inplace_indices=[-1, 1], - out_sinfo=[ + out_ty=[ R.Tensor([16], "float16"), R.Tensor([32], "float16"), ], @@ -1212,11 +1208,11 @@ def add_one( rx.analysis.well_formed(Module) -def test_var_binding_must_have_compatible_struct_info(): +def test_var_binding_must_have_compatible_ty(): """Variables must accurately describe their contents - To be well-formed, the inferred struct info must not conflict with - the StructInfo annotations. + To be well-formed, the inferred type must not conflict with + the Type annotations. """ @@ -1237,17 +1233,17 @@ def test_var_binding_must_have_compatible_struct_info(): var = tvm.relax.Var("B", R.Tensor(shape=[128, 32], dtype="int32")) binding = tvm.relax.VarBinding(var, param) body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var) - tvm.relax.expr._update_struct_info(body, var.struct_info) + tvm.relax.expr._update_type(body, var.ty) main = tvm.relax.Function([param], body) assert not rx.analysis.check_well_formed(main) -def test_var_binding_may_have_less_constrained_struct_info(): - """StructInfo of variable may be less specific than expression +def test_var_binding_may_have_less_constrained_ty(): + """Type of variable may be less specific than expression - The StructInfo annotation of a variable is not required to be an - exact match to the expression's StructInfo, and may provide less + The Type annotation of a variable is not required to be an + exact match to the expression's Type, and may provide less specific information than the inference would provide. """ @@ -1261,17 +1257,17 @@ def main( B: R.Object = R.add(A, A) return B - assert isinstance( - Module["main"].body.blocks[0].bindings[0].var.struct_info, tvm.relax.ObjectStructInfo - ), "Validity of this test requires a variable with R.Object struct info" + assert isinstance(Module["main"].body.blocks[0].bindings[0].var.ty, tvm.relax.ObjectType), ( + "Validity of this test requires a variable with R.Object type" + ) rx.analysis.well_formed(Module) -def test_var_binding_with_incomplete_struct_info_must_be_consistent(): - """StructInfo of variable must be accurate +def test_var_binding_with_incomplete_ty_must_be_consistent(): + """Type of variable must be accurate - Even though StructInfo annotation may be less specific, the + Even though Type annotation may be less specific, the information that they do contain must be correct. """ @@ -1293,16 +1289,16 @@ def test_var_binding_with_incomplete_struct_info_must_be_consistent(): var = tvm.relax.Var("B", R.Tensor(ndim=3, dtype="int32")) binding = tvm.relax.VarBinding(var, param) body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var) - tvm.relax.expr._update_struct_info(body, var.struct_info) + tvm.relax.expr._update_type(body, var.ty) main = tvm.relax.Function([param], body) assert not rx.analysis.check_well_formed(main) -def test_incomplete_struct_info_must_be_consistent(): - """StructInfo annotations must be accurate +def test_incomplete_ty_must_be_consistent(): + """Type annotations must be accurate - Even though StructInfo annotation may be less specific, the + Even though Type annotation may be less specific, the information that they do contain must be correct. """ @@ -1320,11 +1316,11 @@ def main( assert not rx.analysis.check_well_formed(Module) -def test_struct_info_annotations_must_be_correct(): - """StructInfo annotations must be correct +def test_ty_annotations_must_be_correct(): + """Type annotations must be correct - To be well-formed, the inferred struct info must not conflict with - the StructInfo annotations. + To be well-formed, the inferred type must not conflict with + the Type annotations. """ @@ -1341,11 +1337,11 @@ def main( assert not rx.analysis.check_well_formed(Module) -def test_struct_info_may_be_incomplete(): - """StructInfo annotations may be less specific +def test_ty_may_be_incomplete(): + """Type annotations may be less specific - The StructInfo annotations are not required to be an exact match - to the inferred StructInfo, and may provide less specific + The Type annotations are not required to be an exact match + to the inferred Type, and may provide less specific information than the inference would provide. """ @@ -1363,10 +1359,10 @@ def main( rx.analysis.well_formed(Module) -def test_incomplete_struct_info_must_be_consistent(): - """StructInfo annotations must be accurate +def test_incomplete_ty_must_be_consistent(): + """Type annotations must be accurate - Even though StructInfo annotation may be less specific, the + Even though Type annotation may be less specific, the information that they do contain must be correct. """ diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 25a0d8ec55d0..49c52ef888d8 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -29,8 +29,8 @@ from tvm.script import relax as R from tvm.script import tirx as T -# Overload dump_ast to test both struct info and type annotations -dump_ast = partial(dump_ast, include_struct_info_annotations=True) +# Overload dump_ast to test both type and type annotations +dump_ast = partial(dump_ast, include_ty_annotations=True) def strip_whitespace(text: str) -> str: @@ -42,7 +42,7 @@ def strip_whitespace(text: str) -> str: def normalize(func: rx.Function) -> rx.Function: """ - Normalize the expr to fill in the struct_info fields everywhere + Normalize the expr to fill in the ty fields everywhere """ # using a default mutator to use the BlockBuilder's normalizer, @@ -80,12 +80,12 @@ def test_var() -> None: assert v0_str == 'Var(name_hint="v0")' v1 = rx.Var("v1", R.Tensor([54, 96], "float32")) - v1_no_annos = dump_ast(v1, include_struct_info_annotations=False) + v1_no_annos = dump_ast(v1, include_ty_annotations=False) assert v1_no_annos == 'Var(name_hint="v1")' v1_annos = dump_ast(v1) assert v1_annos != v1_no_annos assert "PrimExpr" in v1_annos - assert "struct_info" in v1_annos + assert "ty" in v1_annos def test_dataflow_var() -> None: @@ -94,12 +94,12 @@ def test_dataflow_var() -> None: assert v0_str == 'DataflowVar(name_hint="v0")' v1 = rx.DataflowVar("v1", R.Tensor([54, 96], "float16")) - v1_no_annos = dump_ast(v1, include_struct_info_annotations=False) + v1_no_annos = dump_ast(v1, include_ty_annotations=False) assert v1_no_annos == 'DataflowVar(name_hint="v1")' v1_annos = dump_ast(v1) assert v1_annos != v1_no_annos assert "PrimExpr" in v1_annos - assert "struct_info" in v1_annos + assert "ty" in v1_annos def test_match_cast() -> None: @@ -126,14 +126,14 @@ def test_match_cast() -> None: assert b1_str.startswith("MatchCast(") assert "PrimExpr(value=`m" in b1_str assert "PrimExpr(value=`n" in b1_str - assert b1_str != dump_ast(b1, include_struct_info_annotations=False) + assert b1_str != dump_ast(b1, include_ty_annotations=False) def test_var_binding() -> None: v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) b0 = rx.VarBinding(v0, val) - b0_str = dump_ast(b0, include_struct_info_annotations=False) + b0_str = dump_ast(b0, include_ty_annotations=False) assert b0_str.startswith("VarBinding(") assert 'var=Var(name_hint="v0")' in b0_str assert "value=" in b0_str @@ -217,7 +217,7 @@ def test_func(): assert func_str.startswith("Function(") assert "params=" in func_str assert "body=" in func_str - assert "ret_struct_info=" in func_str + assert "ret_ty=" in func_str assert "is_pure=" in func_str assert "attrs=" in func_str assert '"global_symbol": "func"' in func_str @@ -275,27 +275,31 @@ def test_types(): func_type = rx.FuncType([tensor_type], unit_type) assert_fields( "FuncType", - {"arg_types": "[TensorType(ndim=2, dtype=int32)]", "ret_type": "TupleType(fields=[])"}, + { + "params": "[TensorType(ndim=2, dtype=int32)]", + "ret": "TupleType(fields=[])", + "purity": "True", + }, printer.visit_type_(func_type), ) -def test_struct_info(): +def test_ty(): printer = ASTPrinter() - assert printer.visit_struct_info_(rx.ObjectStructInfo()) == "ObjectStructInfo()" + assert printer.visit_ty_(rx.ObjectType()) == "ObjectType()" - assert printer.visit_struct_info_(rx.PrimStructInfo("int32")) == "PrimStructInfo(dtype=int32)" + assert printer.visit_ty_(rx.PrimType("int32")) == "PrimType(dtype=int32)" # empty shape - empty_ssi = rx.ShapeStructInfo() - assert printer.visit_struct_info_(empty_ssi) == "ShapeStructInfo(ndim=-1)" + empty_ssi = rx.ShapeType() + assert printer.visit_ty_(empty_ssi) == "ShapeType(ndim=-1)" # include some dimensions - shape_info = rx.ShapeStructInfo([tirx.IntImm("int64", 1), tirx.IntImm("int64", 2)]) - assert strip_whitespace(printer.visit_struct_info_(shape_info)) == strip_whitespace( + shape_info = rx.ShapeType([tirx.IntImm("int64", 1), tirx.IntImm("int64", 2)]) + assert strip_whitespace(printer.visit_ty_(shape_info)) == strip_whitespace( """ - ShapeStructInfo( + ShapeType( ndim=2, values=[ PrimExpr(value=`T.int64(1)`), @@ -305,44 +309,41 @@ def test_struct_info(): """ ) - # tensor struct info - default_tsi = rx.TensorStructInfo() - assert ( - strip_whitespace(printer.visit_struct_info_(default_tsi)) - == "TensorStructInfo(dtype=float32,ndim=-1)" - ) + # tensor type + default_tsi = rx.TensorType() + assert strip_whitespace(printer.visit_ty_(default_tsi)) == "TensorType(dtype=float32,ndim=-1)" # use a var as the shape - x = rx.Var("x", struct_info=rx.ShapeStructInfo(values=[])) - var_tsi = rx.TensorStructInfo(shape=x, dtype="int32") - assert strip_whitespace(printer.visit_struct_info_(var_tsi)) == strip_whitespace( + x = rx.Var("x", ty=rx.ShapeType(values=[])) + var_tsi = rx.TensorType(shape=x, dtype="int32") + assert strip_whitespace(printer.visit_ty_(var_tsi)) == strip_whitespace( """ - TensorStructInfo( + TensorType( dtype=int32, shape=Var( name_hint="x", - struct_info=ShapeStructInfo(ndim=0, values=[]) + ty=ShapeType(ndim=0, values=[]) ) ) """ ) - empty_tuple = rx.TupleStructInfo([]) - assert printer.visit_struct_info_(empty_tuple) == "TupleStructInfo(fields=[])" + empty_tuple = rx.TupleType([]) + assert printer.visit_ty_(empty_tuple) == "TupleType(fields=[])" - tuple_of_shape = rx.TupleStructInfo([empty_ssi]) - assert strip_whitespace(printer.visit_struct_info_(tuple_of_shape)) == strip_whitespace( + tuple_of_shape = rx.TupleType([empty_ssi]) + assert strip_whitespace(printer.visit_ty_(tuple_of_shape)) == strip_whitespace( """ - TupleStructInfo(fields=[ - ShapeStructInfo(ndim=-1) + TupleType(fields=[ + ShapeType(ndim=-1) ]) """ ) - simple_func = rx.FuncStructInfo([], rx.ObjectStructInfo()) + simple_func = rx.FuncType([], rx.ObjectType()) assert ( - strip_whitespace(printer.visit_struct_info_(simple_func)) - == "FuncStructInfo(params=[],ret=ObjectStructInfo(),purity=True)" + strip_whitespace(printer.visit_ty_(simple_func)) + == "FuncType(params=[],ret=ObjectType(),purity=True)" ) @@ -361,7 +362,7 @@ def f( t = R.add(w, z) sh: R.Shape = R.shape_of(t) o: R.Object = R.call_packed( - "contrib.tensor_array_stack", x, y, sinfo_args=R.Object(), test_attr=True + "contrib.tensor_array_stack", x, y, ty_args=R.Object(), test_attr=True ) return o @@ -369,13 +370,13 @@ def f( f_str = strip_whitespace( dump_ast( f, - include_struct_info_annotations=False, + include_ty_annotations=False, include_call_attrs=True, ) ) # the function has an annotated return type - assert "ret_struct_info=ObjectStructInfo()" in f_str + assert "ret_ty=ObjectType()" in f_str # the purity attribute is set to false assert "is_pure=False" @@ -383,7 +384,7 @@ def f( extern_call = f.body.blocks[0].bindings[-1].value extern_call_text = dump_ast( extern_call, - include_struct_info_annotations=False, + include_ty_annotations=False, include_call_attrs=True, ) assert strip_whitespace(extern_call_text) in f_str @@ -392,7 +393,7 @@ def f( { "op": 'ExternFunc(global_symbol="contrib.tensor_array_stack")', "args": '[Var(name_hint="x"), Var(name_hint="y")]', - "sinfo_args": "[ObjectStructInfo()]", + "ty_args": "[ObjectType()]", "attrs": '{"test_attr": True}', }, extern_call_text, @@ -402,7 +403,7 @@ def f( op_call = f.body.blocks[0].bindings[0].value op_call_text = dump_ast( op_call, - include_struct_info_annotations=False, + include_ty_annotations=False, include_call_attrs=True, ) assert strip_whitespace(op_call_text) in f_str @@ -462,7 +463,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): foo_str = strip_whitespace( dump_ast( foo, - include_struct_info_annotations=False, + include_ty_annotations=False, include_call_attrs=False, ) ) @@ -473,7 +474,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): tir_call = foo.body.blocks[0].bindings[0].value tir_call_text = dump_ast( tir_call, - include_struct_info_annotations=False, + include_ty_annotations=False, include_call_attrs=False, ) assert_fields( @@ -484,8 +485,8 @@ def foo(x: R.Tensor(("m", "n"), "float32")): GlobalVar(name_hint="addone"), Tuple(fields=[Var(name_hint="x")]) ]""", - "sinfo_args": """[ - TensorStructInfo( + "ty_args": """[ + TensorType( dtype=float32, shape=ShapeExpr( values=[ @@ -511,7 +512,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): foo_str = strip_whitespace( dump_ast( foo, - include_struct_info_annotations=False, + include_ty_annotations=False, include_call_attrs=False, ) ) @@ -522,7 +523,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): tir_call = foo.body.blocks[0].bindings[0].value tir_call_text = dump_ast( tir_call, - include_struct_info_annotations=False, + include_ty_annotations=False, include_call_attrs=False, ) assert_fields( @@ -533,8 +534,8 @@ def foo(x: R.Tensor(("m", "n"), "float32")): ExternFunc(global_symbol="test.op.identity"), Tuple(fields=[Var(name_hint="x")]) ]""", - "sinfo_args": """[ - TensorStructInfo( + "ty_args": """[ + TensorType( dtype=float32, shape=ShapeExpr( values=[ @@ -558,7 +559,7 @@ def foo(x: R.Tensor): foo_str = strip_whitespace( dump_ast( foo, - include_struct_info_annotations=False, + include_ty_annotations=False, ) ) assert 'Op(name="relax.unique")' in foo_str @@ -574,14 +575,14 @@ def bar(x: R.Tensor): bar_str = strip_whitespace( dump_ast( bar, - include_struct_info_annotations=False, + include_ty_annotations=False, ) ) # the format string is a StringImm argument assert 'StringImm(value="{}")' in bar_str -def test_print_struct_info_annotation_non_var(): +def test_print_ty_annotation_non_var(): @R.function def f() -> R.Tensor: return R.const([1, 2]) @@ -589,13 +590,13 @@ def f() -> R.Tensor: body = normalize(f).body body_str = strip_whitespace(dump_ast(body)) # the constant has a shape of (2,) - struct_info = strip_whitespace( + ty = strip_whitespace( """ - struct_info=TensorStructInfo( + ty=TensorType( dtype=int32, shape=ShapeExpr( values=[PrimExpr(value=`T.int64(2)`)], - struct_info=ShapeStructInfo( + ty=ShapeType( ndim=1, values=[PrimExpr(value=`T.int64(2)`)] ) @@ -603,7 +604,7 @@ def f() -> R.Tensor: ) """ ) - assert struct_info in body_str + assert ty in body_str def test_print_type_annotation_non_var(): @@ -656,7 +657,7 @@ def test_prim_value(): """ PrimValue( value=PrimExpr(value=`T.int64(1)`), - struct_info=PrimStructInfo(dtype=int64) + ty=PrimType(dtype=int64) ) """ ) @@ -669,7 +670,7 @@ def test_string_imm(): """ StringImm( value="test", - struct_info=ObjectStructInfo() + ty=ObjectType() ) """ ) @@ -682,7 +683,7 @@ def test_datatype_imm(): """ DataTypeImm( value=int32, - struct_info=ObjectStructInfo() + ty=ObjectType() ) """ ) diff --git a/tests/python/relax/test_backend_dispatch_sampling.py b/tests/python/relax/test_backend_dispatch_sampling.py index c1fe0dbd0c12..28778e3c52a2 100644 --- a/tests/python/relax/test_backend_dispatch_sampling.py +++ b/tests/python/relax/test_backend_dispatch_sampling.py @@ -70,7 +70,7 @@ def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1) cls = Expected with R.dataflow(): lv: R.Tensor((3, 5), dtype="float32") = R.cumsum(prob, axis=1, dtype="float32", exclusive=0) - gv = R.call_tir(cls.get_sample_index, (lv, uniform_sample, sample_indices), out_sinfo=R.Tensor((6, 1), dtype="int64")) + gv = R.call_tir(cls.get_sample_index, (lv, uniform_sample, sample_indices), out_ty=R.Tensor((6, 1), dtype="int64")) R.output(gv) return gv # fmt: on @@ -188,7 +188,7 @@ def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handl def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1), dtype="float32"), sample_indices: R.Tensor((6, 1), dtype="int64")) -> R.Tensor((6, 1), dtype="int64"): cls = Expected with R.dataflow(): - gv = R.call_tir(cls.parallel_sampling_from_prob, (prob, uniform_sample, sample_indices), out_sinfo=R.Tensor((6, 1), dtype="int64")) + gv = R.call_tir(cls.parallel_sampling_from_prob, (prob, uniform_sample, sample_indices), out_ty=R.Tensor((6, 1), dtype="int64")) R.output(gv) return gv # fmt: on diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index ce89852b9040..045468e57a9e 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -49,8 +49,8 @@ class Expected: def main(x: R.Shape([1, 2]), y: R.Shape): R.func_attr({"relax.force_pure": True}) shape_heap = R.null_value() - _ = R.call_packed("vm.builtin.check_shape_info", x, 2, "", sinfo_args=[R.Tuple()]) - _ = R.call_packed("vm.builtin.check_shape_info", y, -1, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", x, 2, "", ty_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", y, -1, "", ty_args=[R.Tuple()]) _ = R.call_packed( "vm.builtin.match_shape", x, @@ -61,7 +61,7 @@ def main(x: R.Shape([1, 2]), y: R.Shape): MS.ASSERT_EQUAL_TO_IMM, 2, "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) return x @@ -92,8 +92,8 @@ class Expected: def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): R.func_attr({"relax.force_pure": True}) shape_heap = R.null_value() - _ = R.call_packed("vm.builtin.check_func_info", f, "", sinfo_args=[R.Tuple()]) - _ = R.call_packed("vm.builtin.check_shape_info", y, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_func_info", f, "", ty_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", y, 2, "", ty_args=[R.Tuple()]) _ = R.call_packed( "vm.builtin.match_shape", y, @@ -104,7 +104,7 @@ def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): MS.ASSERT_EQUAL_TO_IMM, 2, "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) return y @@ -137,7 +137,7 @@ def main(x: R.Tensor(["n", 2, "m"], "float32")): shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(2)], - sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ty_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_packed( "vm.builtin.check_tensor_info", @@ -145,7 +145,7 @@ def main(x: R.Tensor(["n", 2, "m"], "float32")): 3, R.dtype("float32"), "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) _ = R.call_packed( "vm.builtin.match_shape", @@ -159,7 +159,7 @@ def main(x: R.Tensor(["n", 2, "m"], "float32")): MS.STORE_TO_HEAP, sindex["m"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) return x @@ -208,7 +208,7 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)) -> shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(4)], - sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ty_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_packed( "vm.builtin.check_tensor_info", @@ -216,10 +216,10 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)) -> 2, R.dtype("float32"), "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) _ = R.call_packed( - "vm.builtin.check_tensor_info", y, 3, R.dtype(""), "", sinfo_args=[R.Tuple()] + "vm.builtin.check_tensor_info", y, 3, R.dtype(""), "", ty_args=[R.Tuple()] ) _ = R.call_packed( "vm.builtin.match_shape", @@ -231,7 +231,7 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)) -> MS.STORE_TO_HEAP, sindex["m"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) _ = R.call_packed( "vm.builtin.match_shape", @@ -245,7 +245,7 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)) -> MS.NO_OP, 0, "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) _ = cls.shape_func(shape_heap) # extra assertion on y's shape after shape computation @@ -261,7 +261,7 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)) -> MS.ASSERT_EQUAL_TO_LOAD, sindex["k+1"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) # construct shape value for return @@ -275,7 +275,7 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)) -> sindex["m"], MK.USE_IMM, 2, - sinfo_args=[R.Shape(ndim=3)], + ty_args=[R.Shape(ndim=3)], ) return s @@ -314,10 +314,10 @@ def main( shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(3)], - sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ty_args=[R.Tensor(ndim=1, dtype="int64")], ) # recursively unpack tuple for static info check - _ = R.call_packed("vm.builtin.check_tuple_info", x, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_tuple_info", x, 2, "", ty_args=[R.Tuple()]) t0 = x[0] _ = R.call_packed( "vm.builtin.check_tensor_info", @@ -325,12 +325,12 @@ def main( 2, R.dtype("float32"), "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) t1 = x[1] - _ = R.call_packed("vm.builtin.check_tuple_info", t1, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_tuple_info", t1, 2, "", ty_args=[R.Tuple()]) t1x0 = t1[0] - _ = R.call_packed("vm.builtin.check_shape_info", t1x0, -1, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", t1x0, -1, "", ty_args=[R.Tuple()]) t1x1 = t1[1] _ = R.call_packed( "vm.builtin.check_tensor_info", @@ -338,7 +338,7 @@ def main( 2, R.dtype("int32"), "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) # match shape checks. _ = R.call_packed( @@ -351,7 +351,7 @@ def main( MS.STORE_TO_HEAP, sindex["m"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) _ = R.call_packed( "vm.builtin.match_shape", @@ -363,7 +363,7 @@ def main( MS.STORE_TO_HEAP, sindex["k"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) return x @@ -374,7 +374,7 @@ def main( def test_return_match_check(): - """Test when return body is not same as ret_struct_info, runtime match check needed.""" + """Test when return body is not same as ret_ty, runtime match check needed.""" MS = MatchShapeCode @tvm.script.ir_module @@ -402,10 +402,10 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Object) -> R.Tuple( shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(2)], - sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ty_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_packed( - "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", ty_args=[R.Tuple()] ) _ = R.call_packed( "vm.builtin.match_shape", @@ -417,11 +417,11 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Object) -> R.Tuple( MS.STORE_TO_HEAP, sindex["m"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) - _ = R.call_packed("vm.builtin.check_tuple_info", y, 1, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_tuple_info", y, 1, "", ty_args=[R.Tuple()]) # emit runtime function call since y do not have the right type. - y1 = R.call_packed("vm.builtin.tuple_getitem", y, 0, sinfo_args=[R.Object]) + y1 = R.call_packed("vm.builtin.tuple_getitem", y, 0, ty_args=[R.Object]) # run check _ = R.call_packed( "vm.builtin.check_tensor_info", @@ -429,7 +429,7 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Object) -> R.Tuple( 2, R.dtype("float32"), "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) # shape check _ = R.call_packed( @@ -442,7 +442,7 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Object) -> R.Tuple( MS.ASSERT_EQUAL_TO_LOAD, sindex["m"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) return y @@ -456,7 +456,7 @@ def main(x: R.Tensor(["n", "m"], "float32"), y: R.Object) -> R.Tuple( def test_return_match_check_with_new_expr(): """Like test_return_match_check, but requires a computation - When return body is not same as ret_struct_info, a runtime match + When return body is not same as ret_ty, a runtime match check is required. This match check may require a symbolic expression to be computed. """ @@ -467,7 +467,7 @@ class Before: @R.function def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], "float32"): R.func_attr({"relax.force_pure": True}) - out = R.call_packed("flatten_matrix", x, sinfo_args=R.Object) + out = R.call_packed("flatten_matrix", x, ty_args=R.Object) return out # slot assignment: @@ -484,10 +484,10 @@ def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], "float32"): shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(2)], - sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ty_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_packed( - "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", ty_args=[R.Tuple()] ) _ = R.call_packed( "vm.builtin.match_shape", @@ -499,19 +499,19 @@ def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], "float32"): MS.ASSERT_EQUAL_TO_LOAD, sindex["n"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) _ = Expected.shape_func(shape_heap) - out = R.call_packed("flatten_matrix", x, sinfo_args=R.Object) + out = R.call_packed("flatten_matrix", x, ty_args=R.Object) _ = R.call_packed( "vm.builtin.check_tensor_info", out, 1, R.dtype("float32"), "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) _ = R.call_packed( "vm.builtin.match_shape", @@ -521,7 +521,7 @@ def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], "float32"): MS.ASSERT_EQUAL_TO_LOAD, sindex["n * n"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) return out @@ -577,7 +577,7 @@ def fn1(A: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype= shape_heap: R.Tensor(dtype="int64", ndim=1) = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", (R.prim_value(2),), - sinfo_args=(R.Tensor(dtype="int64", ndim=1),), + ty_args=(R.Tensor(dtype="int64", ndim=1),), ) _: R.Tuple = R.call_packed( "vm.builtin.check_tensor_info", @@ -585,7 +585,7 @@ def fn1(A: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype= R.prim_value(2), R.dtype("float32"), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) _1: R.Tuple = R.call_packed( "vm.builtin.match_shape", @@ -597,7 +597,7 @@ def fn1(A: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype= MS.STORE_TO_HEAP, sindex_fn1["n"], R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) return A @@ -609,7 +609,7 @@ def fn2(A: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(("n", "m"), dtype= shape_heap: R.Tensor(dtype="int64", ndim=1) = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", (R.prim_value(2),), - sinfo_args=(R.Tensor(dtype="int64", ndim=1),), + ty_args=(R.Tensor(dtype="int64", ndim=1),), ) _2: R.Tuple = R.call_packed( "vm.builtin.check_tensor_info", @@ -617,7 +617,7 @@ def fn2(A: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(("n", "m"), dtype= R.prim_value(2), R.dtype("float32"), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) _3: R.Tuple = R.call_packed( "vm.builtin.match_shape", @@ -629,7 +629,7 @@ def fn2(A: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(("n", "m"), dtype= MS.STORE_TO_HEAP, sindex_fn2["m"], R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) return A @@ -669,7 +669,7 @@ def main_transform_params(params: R.Tuple(R.Tensor((16, 16), dtype="float32"))) params, R.prim_value(1), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) gv: R.Tensor((16, 16), dtype="float32") = params[0] _1: R.Tuple = R.call_packed( @@ -678,7 +678,7 @@ def main_transform_params(params: R.Tuple(R.Tensor((16, 16), dtype="float32"))) R.prim_value(2), R.dtype("float32"), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) _2: R.Tuple = R.call_packed( "vm.builtin.match_shape", @@ -690,7 +690,7 @@ def main_transform_params(params: R.Tuple(R.Tensor((16, 16), dtype="float32"))) MS.ASSERT_EQUAL_TO_IMM, R.prim_value(16), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) return params @@ -706,7 +706,7 @@ def main( R.prim_value(2), R.dtype("float32"), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) _1: R.Tuple = R.call_packed( "vm.builtin.match_shape", @@ -718,7 +718,7 @@ def main( MS.ASSERT_EQUAL_TO_IMM, R.prim_value(16), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) return (x, param_0) @@ -756,7 +756,7 @@ def main( shape_heap: R.Tensor(dtype="int64", ndim=1) = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", (R.prim_value(1),), - sinfo_args=(R.Tensor(dtype="int64", ndim=1),), + ty_args=(R.Tensor(dtype="int64", ndim=1),), ) _: R.Tuple = R.call_packed( "vm.builtin.check_tensor_info", @@ -764,14 +764,14 @@ def main( R.prim_value(2), R.dtype("float32"), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) _1: R.Tuple = R.call_packed( "vm.builtin.check_tuple_info", params, R.prim_value(2), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) _param_1: R.Tensor((n,), dtype="float32") = params[1] _2: R.Tuple = R.call_packed( @@ -780,7 +780,7 @@ def main( R.prim_value(1), R.dtype("float32"), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) _3: R.Tuple = R.call_packed( "vm.builtin.match_shape", @@ -792,7 +792,7 @@ def main( MS.ASSERT_EQUAL_TO_IMM, R.prim_value(16), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) _4: R.Tuple = R.call_packed( "vm.builtin.match_shape", @@ -802,7 +802,7 @@ def main( R.prim_value(1), R.prim_value(0), R.str(""), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) param_0 = params[0] @@ -842,14 +842,14 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [2], - sinfo_args=(R.Tensor(dtype="int64", ndim=1),), + ty_args=(R.Tensor(dtype="int64", ndim=1),), ) _ = R.call_packed( "vm.builtin.check_prim_value_info", arg_prim_value, R.dtype("int64"), "", - sinfo_args=[R.Tuple], + ty_args=[R.Tuple], ) _ = R.call_packed( "vm.builtin.match_prim_value", @@ -858,7 +858,7 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): MatchShapeCode.STORE_TO_HEAP, 0, "", - sinfo_args=[R.Tuple], + ty_args=[R.Tuple], ) shape = R.call_packed( "vm.builtin.make_shape", @@ -866,7 +866,7 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): 1, MakeShapeCode.LOAD_SHAPE, 0, - sinfo_args=[R.Shape(ndim=1)], + ty_args=[R.Shape(ndim=1)], ) _ = R.call_packed( "vm.builtin.match_shape", @@ -876,7 +876,7 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): MatchShapeCode.STORE_TO_HEAP, 1, "", - sinfo_args=[R.Tuple], + ty_args=[R.Tuple], ) m = T.int64() @@ -886,7 +886,7 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): shape_heap, MakeShapeCode.LOAD_SHAPE, 1, - sinfo_args=[R.Prim(value=m)], + ty_args=[R.Prim(value=m)], ) return gv diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 2b34980a24f0..e65befd70027 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -35,7 +35,7 @@ def add(self, x, y): """Simple addition function.""" x_tvm = self._convert_pytorch_to_tvm(x) y_tvm = self._convert_pytorch_to_tvm(y) - result = self.call_tir(self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), "float32")) + result = self.call_tir(self.add_tir, [x_tvm, y_tvm], out_ty=R.Tensor((5,), "float32")) return self._convert_tvm_to_pytorch(result) @I.pyfunc @@ -43,9 +43,7 @@ def multiply(self, x, y): """Simple multiplication function.""" x_tvm = self._convert_pytorch_to_tvm(x) y_tvm = self._convert_pytorch_to_tvm(y) - result = self.call_tir( - self.multiply_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), "float32") - ) + result = self.call_tir(self.multiply_tir, [x_tvm, y_tvm], out_ty=R.Tensor((5,), "float32")) return self._convert_tvm_to_pytorch(result) @T.prim_func(s_tir=True) @@ -91,16 +89,16 @@ def ml_pipeline(self, input_data, model_params): # Run ML inference features = self.call_tir( - self.extract_features, [tvm_data], out_sinfo=R.Tensor((10,), "float32") + self.extract_features, [tvm_data], out_ty=R.Tensor((10,), "float32") ) predictions = self.call_tir( - self.ml_inference, [features, tvm_params], out_sinfo=R.Tensor((5,), "float32") + self.ml_inference, [features, tvm_params], out_ty=R.Tensor((5,), "float32") ) # Post-process results final_result = self.call_tir( - self.post_process, [predictions], out_sinfo=R.Tensor((5,), "float32") + self.post_process, [predictions], out_ty=R.Tensor((5,), "float32") ) return self._convert_tvm_to_pytorch(final_result) @@ -123,7 +121,7 @@ def data_preprocessing(self, raw_data): # Convert and return tvm_processed = self._convert_pytorch_to_tvm(processed) result = self.call_tir( - self.normalize_data, [tvm_processed], out_sinfo=R.Tensor((10,), "float32") + self.normalize_data, [tvm_processed], out_ty=R.Tensor((10,), "float32") ) return self._convert_tvm_to_pytorch(result) @@ -240,7 +238,7 @@ def vectorized_operation(self, x, y): x_tvm = self._convert_pytorch_to_tvm(x) y_tvm = self._convert_pytorch_to_tvm(y) result = self.call_tir( - self.vectorized_add, [x_tvm, y_tvm], out_sinfo=R.Tensor((10,), "float32") + self.vectorized_add, [x_tvm, y_tvm], out_ty=R.Tensor((10,), "float32") ) return self._convert_tvm_to_pytorch(result) @@ -317,7 +315,7 @@ def sklearn_integration(self, input_data, scaler_params): result = self.call_tir( self.final_transform, [tvm_data], - out_sinfo=R.Tensor((reduced_data.shape[0], 10), "float32"), + out_ty=R.Tensor((reduced_data.shape[0], 10), "float32"), ) return self._convert_tvm_to_pytorch(result) @@ -719,15 +717,15 @@ def test_call_py_func_with_base_py_module(): import numpy as np import torch - from tvm.relax import TensorStructInfo, Var + from tvm.relax import TensorType, Var from tvm.relax.expr import StringImm from tvm.relax.op import call_py_func # Test 1: Operator creation and basic properties - x = Var("x", TensorStructInfo((5,), "float32")) - y = Var("y", TensorStructInfo((5,), "float32")) + x = Var("x", TensorType((5,), "float32")) + y = Var("y", TensorType((5,), "float32")) - call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) + call_expr = call_py_func(StringImm("test_func"), (x, y), out_ty=R.Tensor((5,), "float32")) assert call_expr.op.name == "relax.call_py_func" assert call_expr.args[0].value == "test_func" @@ -737,8 +735,8 @@ def test_call_py_func_with_base_py_module(): try: call_py_func( "invalid", - (Var("x", TensorStructInfo((5,), "float32")),), - out_sinfo=R.Tensor((5,), "float32"), + (Var("x", TensorType((5,), "float32")),), + out_ty=R.Tensor((5,), "float32"), ) assert False, "Should raise type error" except Exception as e: @@ -749,7 +747,7 @@ def test_call_py_func_with_base_py_module(): class ValidationTestModule(BasePyModule): @R.function def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): - result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32")) + result = R.call_py_func("non_existent_func", (x,), out_ty=R.Tensor((5,), "float32")) return result device = tvm.cpu() @@ -775,9 +773,9 @@ def torch_softmax(self, x, dim=0): @R.function def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): - relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")) + relu_result = R.call_py_func("torch_relu", (x,), out_ty=R.Tensor((10,), "float32")) final_result = R.call_py_func( - "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") + "torch_softmax", (relu_result,), out_ty=R.Tensor((10,), "float32") ) return final_result diff --git a/tests/python/relax/test_base_py_module_symbolic_shape.py b/tests/python/relax/test_base_py_module_symbolic_shape.py index cb16083c6e8d..635804645f99 100644 --- a/tests/python/relax/test_base_py_module_symbolic_shape.py +++ b/tests/python/relax/test_base_py_module_symbolic_shape.py @@ -107,9 +107,9 @@ def test_base_py_module_tir_symbolic_end_to_end(): b = np.random.randn(5).astype("float32") n = tirx.Var("n", "int64") - out_sinfo = relax.TensorStructInfo((n,), "float32") + out_ty = relax.TensorType((n,), "float32") - out = bpm.call_tir("add_tir", [a, b], out_sinfo) + out = bpm.call_tir("add_tir", [a, b], out_ty) out_np = out if isinstance(out, np.ndarray) else out.numpy() tvm.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6) @@ -232,8 +232,8 @@ def test_base_py_module_multiple_symbolic_dims(): # Test TIR function with multiple symbolic dims # Use concrete shapes for TIR function to avoid constraint issues - out_sinfo = relax.TensorStructInfo((2, 4), "float32") - out_tir = bpm.call_tir("matmul_tir", [a, b], out_sinfo) + out_ty = relax.TensorType((2, 4), "float32") + out_tir = bpm.call_tir("matmul_tir", [a, b], out_ty) out_tir_np = out_tir if isinstance(out_tir, np.ndarray) else out_tir.numpy() tvm.testing.assert_allclose(out_tir_np, expected, rtol=1e-6, atol=1e-6) @@ -257,9 +257,9 @@ def test_add_packed(a, b, out): b = np.random.randn(5).astype("float32") n = tirx.Var("n", "int64") - out_sinfo = relax.TensorStructInfo((n,), "float32") + out_ty = relax.TensorType((n,), "float32") - out = bpm.call_dps_packed("test_add_packed", [a, b], out_sinfo) + out = bpm.call_dps_packed("test_add_packed", [a, b], out_ty) out_np = out if isinstance(out, np.ndarray) else out.numpy() tvm.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6) @@ -285,9 +285,9 @@ def test_matmul_packed(a, b, out): a = np.random.randn(2, 3).astype("float32") b = np.random.randn(3, 4).astype("float32") - out_sinfo = relax.TensorStructInfo((2, 4), "float32") + out_ty = relax.TensorType((2, 4), "float32") - out = bpm.call_dps_packed("test_matmul_packed", [a, b], out_sinfo) + out = bpm.call_dps_packed("test_matmul_packed", [a, b], out_ty) out_np = out if isinstance(out, np.ndarray) else out.numpy() expected = np.matmul(a, b) tvm.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) @@ -318,9 +318,9 @@ def test_add_scalar_packed(x, scalar, out): scalar = 2.5 n = tirx.Var("n", "int64") - out_sinfo = relax.TensorStructInfo((n,), "float32") + out_ty = relax.TensorType((n,), "float32") - out = bpm.call_dps_packed("test_add_scalar_packed", [x, scalar], out_sinfo) + out = bpm.call_dps_packed("test_add_scalar_packed", [x, scalar], out_ty) out_np = out if isinstance(out, np.ndarray) else out.numpy() expected = x + scalar tvm.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) diff --git a/tests/python/relax/test_bind_symbolic_vars.py b/tests/python/relax/test_bind_symbolic_vars.py index 9f36e46a38e9..90fe4864c774 100644 --- a/tests/python/relax/test_bind_symbolic_vars.py +++ b/tests/python/relax/test_bind_symbolic_vars.py @@ -43,8 +43,8 @@ def expected(A: R.Tensor((128, 64)), B: R.Tensor((64, 32))) -> R.Tensor((128, 32 return R.matmul(A, B) if replace_by_tir_var: - M, K = before.params[0].struct_info.shape - _, N = before.params[1].struct_info.shape + M, K = before.params[0].ty.shape + _, N = before.params[1].ty.shape symbolic_var_map = {M: 128, K: 64, N: 32} else: symbolic_var_map = {"M": 128, "K": 64, "N": 32} @@ -128,7 +128,7 @@ def test_error_with_multiple_definitions(): def func(A: R.Tensor(["M", "N"])): return A - tir_var = func.params[0].struct_info.shape[0] + tir_var = func.params[0].ty.shape[0] symbolic_var_map = {tir_var: 0, "M": 0} with pytest.raises(RuntimeError): @@ -166,19 +166,19 @@ def expected(A: R.Tensor(["outside_var * 2", "outside_var"])): def test_bind_symbolic_vars_in_tensor_shape(): - """The bound variable should be replaced when appearing in struct info""" + """The bound variable should be replaced when appearing in type""" @R.function(private=True) def before(A: R.Tensor(["M", "N"])): M = T.int64() N = T.int64() - B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N])) + B = R.call_dps_packed("dummy_func", [A], out_ty=R.Tensor([2 * M * N])) return B @R.function(private=True) def expected(A: R.Tensor(["M", 16])): M = T.int64() - B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32])) + B = R.call_dps_packed("dummy_func", [A], out_ty=R.Tensor([M * 32])) return B after = before.bind_symbolic_vars({"N": 16}) @@ -192,12 +192,12 @@ def test_bind_symbolic_vars_in_shape_expr(): def before(A: R.Tensor(["M * N"]), x: R.Shape(["M", "N"])): M = T.int64() N = T.int64() - B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N])) + B = R.call_dps_packed("dummy_func", [A], out_ty=R.Tensor([2 * M * N])) return B @R.function(private=True) def expected(A: R.Tensor(["M * 16"]), x: R.Shape(["M", 16])): - B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32])) + B = R.call_dps_packed("dummy_func", [A], out_ty=R.Tensor([M * 32])) return B after = before.bind_symbolic_vars({"N": 16}) @@ -228,12 +228,12 @@ def test_bind_defining_of_symbolic_vars_in_prim_value(): def before(A: R.Tensor(["M * N"]), x: R.Prim(value="M"), y: R.Prim(value="N")): M = T.int64() N = T.int64() - B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N])) + B = R.call_dps_packed("dummy_func", [A], out_ty=R.Tensor([2 * M * N])) return B @R.function(private=True) def expected(A: R.Tensor(["M * 16"]), x: R.Prim(value="M"), y: R.Prim(value=16)): - B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32])) + B = R.call_dps_packed("dummy_func", [A], out_ty=R.Tensor([M * 32])) return B after = before.bind_symbolic_vars({"N": 16}) @@ -258,12 +258,12 @@ def test_bind_usage_of_symbolic_vars_in_prim_value(): def before(A: R.Tensor(["M", "N"]), x: R.Prim(value="M*N")): M = T.int64() N = T.int64() - B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N])) + B = R.call_dps_packed("dummy_func", [A], out_ty=R.Tensor([2 * M * N])) return B @R.function(private=True) def expected(A: R.Tensor([16, 16]), x: R.Prim(value=256)): - B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([512])) + B = R.call_dps_packed("dummy_func", [A], out_ty=R.Tensor([512])) return B after = before.bind_symbolic_vars({"M": 16, "N": 16}) diff --git a/tests/python/relax/test_blockbuilder_core.py b/tests/python/relax/test_blockbuilder_core.py index c89597bbefac..2d2eb95ec859 100644 --- a/tests/python/relax/test_blockbuilder_core.py +++ b/tests/python/relax/test_blockbuilder_core.py @@ -43,8 +43,8 @@ def nop(): def test_block_builder(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() bb._begin_binding_block() @@ -68,8 +68,8 @@ def test_block_builder(): def test_emit_with_name(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() bb._begin_dataflow_block() @@ -84,8 +84,8 @@ def test_emit_with_name(): def test_function_single_block(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() with bb.function("func", [x, y]): @@ -102,7 +102,7 @@ def test_function_single_block(): assert func.params[0] == x assert func.params[1] == y assert func.body.body == gv0 - assert_structural_equal(gv0.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert_structural_equal(gv0.ty, rx.TensorType([m, n], "float16")) assert len(func.body.blocks) == 1 assert len(func.body.blocks[0].bindings) == 3 @@ -110,8 +110,8 @@ def test_function_single_block(): def test_function_multi_blocks(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() with bb.function("func", [x, y]): @@ -130,7 +130,7 @@ def test_function_multi_blocks(): func = bb.finalize()["func"] - assert_structural_equal(gv2.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert_structural_equal(gv2.ty, rx.TensorType([m, n], "float16")) assert func.params[0] == x assert func.params[1] == y assert func.body.body == gv2 @@ -145,8 +145,8 @@ def test_multi_functions(): m_1 = tirx.Var("m", "int64") n_1 = tirx.Var("n", "int64") - x_1 = rx.Var("x", rx.TensorStructInfo([m_1, n_1], "float16")) - y_1 = rx.Var("y", rx.TensorStructInfo([n_1], "float16")) + x_1 = rx.Var("x", rx.TensorType([m_1, n_1], "float16")) + y_1 = rx.Var("y", rx.TensorType([n_1], "float16")) with bb.function("func1", [x_1, y_1]): with bb.dataflow(): @@ -157,8 +157,8 @@ def test_multi_functions(): m_2 = tirx.Var("m", "int64") n_2 = tirx.Var("n", "int64") - x_2 = rx.Var("x", rx.TensorStructInfo([m_2, n_2], "float16")) - y_2 = rx.Var("y", rx.TensorStructInfo([n_2], "float16")) + x_2 = rx.Var("x", rx.TensorType([m_2, n_2], "float16")) + y_2 = rx.Var("y", rx.TensorType([n_2], "float16")) with bb.function("func2", [x_2, y_2]): with bb.dataflow(): @@ -183,56 +183,56 @@ def test_binary_shape_type_deduction(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") k = tirx.Var("k", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, 1], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) - z = rx.Var("z", rx.TensorStructInfo([5], "float16")) - w = rx.Var("w", rx.TensorStructInfo([k], "float16")) + x = rx.Var("x", rx.TensorType([m, 1], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) + z = rx.Var("z", rx.TensorType([5], "float16")) + w = rx.Var("w", rx.TensorType([k], "float16")) bb = rx.BlockBuilder() with bb.function("func", [x, y, z, w]): with bb.dataflow(): lv0 = bb.emit(rx.op.add(x, y)) - assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert_structural_equal(lv0.ty, rx.TensorType([m, n], "float16")) lv1 = bb.emit(rx.op.multiply(x, z)) - assert_structural_equal(lv1.struct_info, rx.TensorStructInfo([m, 5], "float16")) + assert_structural_equal(lv1.ty, rx.TensorType([m, 5], "float16")) lv2 = bb.emit(rx.op.multiply(z, w)) - assert isinstance(lv2.struct_info, rx.TensorStructInfo) - assert lv2.struct_info.ndim == 1 - assert lv2.struct_info.dtype == "float16" + assert isinstance(lv2.ty, rx.TensorType) + assert lv2.ty.ndim == 1 + assert lv2.ty.dtype == "float16" lv3 = bb.emit(rx.op.multiply(y, w)) - assert isinstance(lv3.struct_info, rx.TensorStructInfo) - assert lv3.struct_info.ndim == 1 - assert lv3.struct_info.dtype == "float16" + assert isinstance(lv3.ty, rx.TensorType) + assert lv3.ty.ndim == 1 + assert lv3.ty.dtype == "float16" gv0 = bb.emit_output(lv3) bb.emit_func_output(gv0) - assert isinstance(gv0.struct_info, rx.TensorStructInfo) - assert gv0.struct_info.ndim == 1 - assert gv0.struct_info.dtype == "float16" + assert isinstance(gv0.ty, rx.TensorType) + assert gv0.ty.ndim == 1 + assert gv0.ty.dtype == "float16" def test_emit_match_cast(): m = tirx.Var("m", dtype="int64") n = tirx.Var("n", dtype="int64") - x = rx.Var("tensor_value", rx.TensorStructInfo(dtype="float32", ndim=-1)) - y = rx.Var("shape_value", rx.ShapeStructInfo([16, 8])) + x = rx.Var("tensor_value", rx.TensorType(dtype="float32", ndim=-1)) + y = rx.Var("shape_value", rx.ShapeType([16, 8])) bb = rx.BlockBuilder() with bb.function("func", [x, y]): with bb.dataflow(): # lv0: Tensor((m, n), "float32") = # match_cast(x: Tensor(_, "float32"], [m, n)) - lv0 = bb.match_cast(x, rx.TensorStructInfo([m, n], "float32")) + lv0 = bb.match_cast(x, rx.TensorType([m, n], "float32")) assert isinstance(lv0, rx.DataflowVar) - assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32")) + assert_structural_equal(lv0.ty, rx.TensorType([m, n], "float32")) - # lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n])) - lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]), "var_name") - assert lv1.struct_info == rx.ShapeStructInfo([m, n]) + # lv1: Shape = match_cast(shape, rx.ShapeType([m, n])) + lv1 = bb.match_cast(y, rx.ShapeType([m, n]), "var_name") + assert lv1.ty == rx.ShapeType([m, n]) gv0 = bb.emit_output(lv1) bb.emit_func_output(gv0) @@ -243,11 +243,11 @@ def test_emit_match_cast(): assert isinstance(b1, rx.MatchCast) assert b0.value == x - assert b0.struct_info == rx.TensorStructInfo([m, n], "float32") + assert b0.ty == rx.TensorType([m, n], "float32") assert b0.var == lv0 assert b1.value == y - assert b1.struct_info == rx.ShapeStructInfo([m, n]) + assert b1.ty == rx.ShapeType([m, n]) assert b1.var == lv1 assert b1.var.name_hint == "var_name" @@ -255,10 +255,10 @@ def test_emit_match_cast(): def test_emit_match_cast_binding_in_dataflow_block(): bb = rx.BlockBuilder() - x = rx.Var("x", rx.TensorStructInfo(dtype="float32", ndim=-1)) + x = rx.Var("x", rx.TensorType(dtype="float32", ndim=-1)) m = tirx.Var("m", dtype="int64") - gv = rx.Var("gv", rx.TensorStructInfo(dtype="float32", ndim=-1)) - match_cast = rx.MatchCast(gv, x, rx.TensorStructInfo((m,), "float32")) + gv = rx.Var("gv", rx.TensorType(dtype="float32", ndim=-1)) + match_cast = rx.MatchCast(gv, x, rx.TensorType((m,), "float32")) with bb.function("main", [x]): with bb.dataflow(): @@ -272,8 +272,8 @@ def test_emit_match_cast_binding_in_dataflow_block(): assert isinstance(b0, rx.MatchCast) assert b0.value == x - assert isinstance(b0.struct_info, rx.TensorStructInfo) - assert b0.struct_info.shape[0] == m + assert isinstance(b0.ty, rx.TensorType) + assert b0.ty.shape[0] == m assert b0.var == gv @@ -281,8 +281,8 @@ def test_normalize(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() # Call node @@ -298,47 +298,47 @@ def test_normalize(): # Tuple node tuple_1 = rx.Tuple([x, y]) bb.normalize(tuple_1) - assert isinstance(tuple_1.struct_info, rx.TupleStructInfo) - assert isinstance(tuple_1.struct_info.fields[0], rx.TensorStructInfo) - assert isinstance(tuple_1.struct_info.fields[1], rx.TensorStructInfo) + assert isinstance(tuple_1.ty, rx.TupleType) + assert isinstance(tuple_1.ty.fields[0], rx.TensorType) + assert isinstance(tuple_1.ty.fields[1], rx.TensorType) # Nested Tuple tuple_2 = rx.Tuple([x, rx.Tuple([x, y])]) bb.normalize(tuple_2) - assert isinstance(tuple_2.struct_info, rx.TupleStructInfo) - assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo) - assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo) - assert isinstance(tuple_2.struct_info.fields[1].fields[0], rx.TensorStructInfo) - assert isinstance(tuple_2.struct_info.fields[1].fields[1], rx.TensorStructInfo) + assert isinstance(tuple_2.ty, rx.TupleType) + assert isinstance(tuple_2.ty.fields[0], rx.TensorType) + assert isinstance(tuple_2.ty.fields[1], rx.TupleType) + assert isinstance(tuple_2.ty.fields[1].fields[0], rx.TensorType) + assert isinstance(tuple_2.ty.fields[1].fields[1], rx.TensorType) def test_tuple_indexing(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - shape_x = rx.TensorStructInfo([m, n], "float16") - shape_y = rx.TensorStructInfo([n], "float16") - relax_tuple = rx.Var("relax_tuple", rx.TupleStructInfo([shape_x, shape_y])) + shape_x = rx.TensorType([m, n], "float16") + shape_y = rx.TensorType([n], "float16") + relax_tuple = rx.Var("relax_tuple", rx.TupleType([shape_x, shape_y])) - assert isinstance(relax_tuple.struct_info, rx.TupleStructInfo) - assert isinstance(relax_tuple.struct_info.fields[0], rx.TensorStructInfo) - assert isinstance(relax_tuple.struct_info.fields[1], rx.TensorStructInfo) + assert isinstance(relax_tuple.ty, rx.TupleType) + assert isinstance(relax_tuple.ty.fields[0], rx.TensorType) + assert isinstance(relax_tuple.ty.fields[1], rx.TensorType) - # TupleGetItem will initialize struct info from the - # TupleStructInfo, if present. + # TupleGetItem will initialize type from the + # TupleType, if present. x = relax_tuple[0] - tvm.ir.assert_structural_equal(x.struct_info, shape_x) + tvm.ir.assert_structural_equal(x.ty, shape_x) y = relax_tuple[1] - tvm.ir.assert_structural_equal(y.struct_info, shape_y) + tvm.ir.assert_structural_equal(y.ty, shape_y) # Tuple unpacking produces TupleGetItem structs x_unpack, y_unpack = relax_tuple tvm.ir.assert_structural_equal(x, x_unpack) tvm.ir.assert_structural_equal(y, y_unpack) - # When TupleStructInfo is available, tuple unpacking fails immediately + # When TupleType is available, tuple unpacking fails immediately # for incorrect number of arguments. with pytest.raises(ValueError): x_unpack, y_unpack, z_unpack = relax_tuple @@ -347,9 +347,9 @@ def test_tuple_indexing(): def test_call_te(): bb = rx.BlockBuilder() n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) - y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) - z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) + x = rx.Var("x", rx.TensorType([n, m], "float32")) + y = rx.Var("y", rx.TensorType([n, m], "float32")) + z = rx.Var("z", rx.TensorType([n, m], "float32")) def te_func(args, args_dict, msg): A, B = args @@ -394,8 +394,8 @@ def test_call_te_unique_tensor_name(): def test_call_te_with_unsupported_shape_arg(): bb = rx.BlockBuilder() - x = rx.Var("x", rx.TensorStructInfo((200,), "float32")) - s = rx.Var("s", rx.ShapeStructInfo((200,))) + x = rx.Var("x", rx.TensorType((200,), "float32")) + s = rx.Var("s", rx.ShapeType((200,))) with pytest.raises(AssertionError): with bb.function("rx_func", [x]): @@ -406,9 +406,9 @@ def test_call_te_with_unsupported_shape_arg(): def test_emit_te(): bb = rx.BlockBuilder() n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) - y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) - z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) + x = rx.Var("x", rx.TensorType([n, m], "float32")) + y = rx.Var("y", rx.TensorType([n, m], "float32")) + z = rx.Var("z", rx.TensorType([n, m], "float32")) def te_func(args, args_dict, msg): A, B = args @@ -454,9 +454,9 @@ def get_tir_func(): def test_emit_te_multiple(): bb = rx.BlockBuilder() n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) - y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) - z = rx.Var("z", rx.TensorStructInfo([128, m], "float32")) + x = rx.Var("x", rx.TensorType([n, m], "float32")) + y = rx.Var("y", rx.TensorType([n, m], "float32")) + z = rx.Var("z", rx.TensorType([128, m], "float32")) def te_func(A): B = te.compute((128, 128), lambda i, j: A[i, j] + 1) @@ -486,7 +486,7 @@ def te_func(A): def test_emit_te_multiple_output(): bb = rx.BlockBuilder() n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + x = rx.Var("x", rx.TensorType([n, m], "float32")) def te_func(A): B0, B1 = te.compute((n, m), lambda i, j: (A[i, j] + 1, A[i, j] * 2), name="B") @@ -503,17 +503,17 @@ def te_func(A): assert rx_func.params[0] == x call_node = rx_func.body.blocks[0].bindings[0].value assert call_node.args[0].name_hint == "te_func" - assert isinstance(call_node.sinfo_args[0], rx.TupleStructInfo) - assert len(call_node.sinfo_args[0].fields) == 2 - assert isinstance(call_node.sinfo_args[0].fields[0].shape, rx.ShapeExpr) - assert isinstance(call_node.sinfo_args[0].fields[1].shape, rx.ShapeExpr) + assert isinstance(call_node.ty_args[0], rx.TupleType) + assert len(call_node.ty_args[0].fields) == 2 + assert isinstance(call_node.ty_args[0].fields[0].shape, rx.ShapeExpr) + assert isinstance(call_node.ty_args[0].fields[1].shape, rx.ShapeExpr) def test_emit_te_extern(): bb = rx.BlockBuilder() n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) - y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) + x = rx.Var("x", rx.TensorType([n, m], "float32")) + y = rx.Var("y", rx.TensorType([m, n], "float32")) with bb.function("rx_cblas_matmul", [x, y]): out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) @@ -532,8 +532,8 @@ def test_emit_te_extern(): assert call_node.args[0].name_hint == "matmul" assert call_node.args[1][0] == x assert call_node.args[1][1] == y - assert call_node.sinfo_args[0].shape[0] == n - assert call_node.sinfo_args[0].shape[1] == n + assert call_node.ty_args[0].shape[0] == n + assert call_node.ty_args[0].shape[1] == n def test_emit_te_prim_value(): @@ -561,8 +561,8 @@ def test_emit_te_prim_value(): def test_nested_function_fail(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() with pytest.raises(RuntimeError): @@ -576,8 +576,8 @@ def test_nested_function_fail(): def test_emit_func_output_twice_fail(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() with pytest.raises(RuntimeError): @@ -590,8 +590,8 @@ def test_emit_func_output_twice_fail(): def test_func_params_twice_fail(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() with pytest.raises(RuntimeError): @@ -603,8 +603,8 @@ def test_func_params_twice_fail(): def test_no_func_params_fail(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) - y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + x = rx.Var("x", rx.TensorType([m, n], "float16")) + y = rx.Var("y", rx.TensorType([n], "float16")) bb = rx.BlockBuilder() with pytest.raises(RuntimeError): @@ -617,8 +617,8 @@ def test_block_builder_scope_recovery(): bb = rx.BlockBuilder() n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) - y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) + x = rx.Var("x", rx.TensorType([n, m], "float32")) + y = rx.Var("y", rx.TensorType([m, n], "float32")) with pytest.raises(RuntimeError): # this line fails @@ -643,10 +643,10 @@ def make_function(emit_nested_tuple: bool): n_sym = tirx.Var("n", "int64") m_sym = tirx.Var("m", "int64") - n = rx.Var("n", rx.PrimStructInfo(value=n_sym)) - m = rx.Var("m", rx.PrimStructInfo(value=m_sym)) - x = rx.Var("x", rx.TensorStructInfo([n_sym, m_sym], "float32")) - y = rx.Var("y", rx.TensorStructInfo([m_sym, n_sym], "float32")) + n = rx.Var("n", rx.PrimType(value=n_sym)) + m = rx.Var("m", rx.PrimType(value=m_sym)) + x = rx.Var("x", rx.TensorType([n_sym, m_sym], "float32")) + y = rx.Var("y", rx.TensorType([m_sym, n_sym], "float32")) with bb.function("func", [n, m, x, y]): scalars = (n, m) @@ -902,7 +902,7 @@ def test_error_when_unwrapping_dataflowvar(): """ bb = rx.BlockBuilder() - lhs = rx.Var("a", rx.TensorStructInfo(shape=[], dtype="int64")) + lhs = rx.Var("a", rx.TensorType(shape=[], dtype="int64")) with bb.function("func", [lhs]): rhs = rx.const(2, "int64") diff --git a/tests/python/relax/test_blockbuilder_emit_te.py b/tests/python/relax/test_blockbuilder_emit_te.py index f314f45aaf62..0ca90e5a8b4d 100644 --- a/tests/python/relax/test_blockbuilder_emit_te.py +++ b/tests/python/relax/test_blockbuilder_emit_te.py @@ -66,7 +66,7 @@ def main(x: R.Tensor((10,), dtype="float32"), y: R.Shape(["m"])) -> R.Tensor( gv = R.call_tir( cls.te_func, (x,), - out_sinfo=R.Tensor((10,), dtype="float32"), + out_ty=R.Tensor((10,), dtype="float32"), tir_vars=R.shape([m]), ) return gv @@ -120,7 +120,7 @@ def main( cls.te_slice, A, tir_vars=[row_index], - out_sinfo=R.Tensor([16], "float32"), + out_ty=R.Tensor([16], "float32"), ) return gv diff --git a/tests/python/relax/test_codegen_coreml.py b/tests/python/relax/test_codegen_coreml.py index 95a3b32c8c24..4c07751bc76a 100644 --- a/tests/python/relax/test_codegen_coreml.py +++ b/tests/python/relax/test_codegen_coreml.py @@ -46,8 +46,8 @@ def _has_xcode(): def test_partition_for_coreml_uses_current_relax_passes(): from tvm.relax.backend.metal.coreml import partition_for_coreml - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) - y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) + y = relax.Var("y", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x, y]): with bb.dataflow(): @@ -91,8 +91,8 @@ def verify(mod, inputs): @requires_coreml_runtime def test_add(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) - y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) + y = relax.Var("y", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x, y]): with bb.dataflow(): @@ -107,7 +107,7 @@ def test_add(): @requires_coreml_runtime def test_add_const(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) y = relax.const(np.ones([10, 10]), "float32") bb = relax.BlockBuilder() with bb.function("main", [x]): @@ -122,8 +122,8 @@ def test_add_const(): @requires_coreml_runtime def test_multiply(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) - y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) + y = relax.Var("y", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x, y]): with bb.dataflow(): @@ -139,7 +139,7 @@ def test_multiply(): @requires_coreml_runtime def test_matmul(): - x = relax.Var("x", relax.TensorStructInfo([8, 10], "float32")) + x = relax.Var("x", relax.TensorType([8, 10], "float32")) y = relax.Constant(tvm.runtime.tensor(np.random.rand(10, 8).astype("float32"), dev)) bb = relax.BlockBuilder() with bb.function("main", [x]): @@ -152,8 +152,8 @@ def test_matmul(): x_data = tvm.runtime.tensor(np.random.rand(8, 10).astype("float32"), dev) verify(mod, [x_data]) - x = relax.Var("x", relax.TensorStructInfo([8, 10], "float32")) - y = relax.Var("y", relax.TensorStructInfo([10, 8], "float32")) + x = relax.Var("x", relax.TensorType([8, 10], "float32")) + y = relax.Var("y", relax.TensorType([10, 8], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x, y]): with bb.dataflow(): @@ -169,7 +169,7 @@ def test_matmul(): @requires_coreml_runtime def test_clip(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x]): @@ -182,7 +182,7 @@ def test_clip(): x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data]) - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x]): @@ -200,7 +200,7 @@ def test_clip(): @requires_coreml_runtime def test_expand_dims(): def get_mod(axis): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x]): with bb.dataflow(): @@ -216,7 +216,7 @@ def get_mod(axis): @requires_coreml_runtime def test_relu(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x]): with bb.dataflow(): @@ -231,7 +231,7 @@ def test_relu(): @requires_coreml_runtime def test_batch_flatten(): - x = relax.Var("x", relax.TensorStructInfo([10, 10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x]): with bb.dataflow(): @@ -246,7 +246,7 @@ def test_batch_flatten(): @requires_coreml_runtime def test_softmax(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x]): with bb.dataflow(): @@ -261,7 +261,7 @@ def test_softmax(): @requires_coreml_runtime def test_conv2d(): - x = relax.Var("x", relax.TensorStructInfo([1, 3, 224, 224], "float32")) + x = relax.Var("x", relax.TensorType([1, 3, 224, 224], "float32")) w = relax.const(np.zeros((16, 3, 3, 3), dtype="float32")) bb = relax.BlockBuilder() with bb.function("main", [x]): @@ -276,7 +276,7 @@ def test_conv2d(): @requires_coreml_runtime def test_global_avg_pool2d(): - x = relax.Var("x", relax.TensorStructInfo([1, 1, 10, 10], "float32")) + x = relax.Var("x", relax.TensorType([1, 1, 10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x]): with bb.dataflow(): @@ -290,8 +290,8 @@ def test_global_avg_pool2d(): @requires_coreml_runtime def test_subgraph1(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) - y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) + y = relax.Var("y", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x, y]): with bb.dataflow(): @@ -307,8 +307,8 @@ def test_subgraph1(): @requires_coreml_runtime def test_subgraph2(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) - y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) + y = relax.Var("y", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x, y]): with bb.dataflow(): diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index cc6fd499fa14..865461d0c953 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -1376,7 +1376,7 @@ def main_bias( lv = R.call_tir( cls.encode, (y,), - out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], + out_ty=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], ) lv1 = lv[0] lv2 = R.call_pure_packed( @@ -1384,11 +1384,11 @@ def main_bias( lv1, 80, True, - sinfo_args=(R.Tensor((64, 64), dtype="int8"),), + ty_args=(R.Tensor((64, 64), dtype="int8"),), ) lv3: R.Tensor((128,), dtype="float16") = lv[1] lv6 = R.call_tir( - cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), dtype="float16") + cls.decode, (lv2, lv3), out_ty=R.Tensor((64, 128), dtype="float16") ) lv1_1: R.Tensor((64, 128), dtype="float16") = R.matmul(x, lv6, out_dtype="float16") lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, bias) @@ -1407,7 +1407,7 @@ def main_cast_bias( lv = R.call_tir( cls.encode, (y,), - out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], + out_ty=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], ) lv1 = lv[0] lv2 = R.call_pure_packed( @@ -1415,11 +1415,11 @@ def main_cast_bias( lv1, 80, True, - sinfo_args=(R.Tensor((64, 64), dtype="int8"),), + ty_args=(R.Tensor((64, 64), dtype="int8"),), ) lv3: R.Tensor((128,), dtype="float16") = lv[1] lv6 = R.call_tir( - cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), dtype="float16") + cls.decode, (lv2, lv3), out_ty=R.Tensor((64, 128), dtype="float16") ) lv1_1: R.Tensor((64, 128), dtype="float32") = R.matmul(x, lv6, out_dtype="float32") cast: R.Tensor((64, 128), dtype="float16") = R.astype(lv1_1, dtype="float16") @@ -1440,7 +1440,7 @@ def main_residual( lv = R.call_tir( cls.encode, (y,), - out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], + out_ty=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], ) lv1 = lv[0] lv2 = R.call_pure_packed( @@ -1448,11 +1448,11 @@ def main_residual( lv1, 80, True, - sinfo_args=(R.Tensor((64, 64), dtype="int8"),), + ty_args=(R.Tensor((64, 64), dtype="int8"),), ) lv3: R.Tensor((128,), dtype="float16") = lv[1] lv6 = R.call_tir( - cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), dtype="float16") + cls.decode, (lv2, lv3), out_ty=R.Tensor((64, 128), dtype="float16") ) lv1_1: R.Tensor((64, 128), dtype="float16") = R.matmul(x, lv6, out_dtype="float16") lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, bias) @@ -1595,7 +1595,7 @@ def main( lv = R.call_tir( cls.encode, (y,), - out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((64,), dtype="float16")], + out_ty=[R.Tensor((64, 64), dtype="int8"), R.Tensor((64,), dtype="float16")], ) lv1: R.Tensor((64, 64), dtype="int8") = lv[0] lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed( @@ -1603,14 +1603,12 @@ def main( lv1, R.prim_value(80), R.prim_value(0), - sinfo_args=(R.Tensor((64, 64), dtype="int8"),), + ty_args=(R.Tensor((64, 64), dtype="int8"),), ) lv3: R.Tensor((64,), dtype="float16") = lv[1] lv4: R.Tensor((64, 64), dtype="int8") = R.builtin.stop_lift_params(lv2) lv5: R.Tensor((64,), dtype="float16") = R.builtin.stop_lift_params(lv3) - lv6 = R.call_tir( - cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64), dtype="float16") - ) + lv6 = R.call_tir(cls.decode, (lv4, lv5), out_ty=R.Tensor((64, 64), dtype="float16")) lv1_1: R.Tensor((64, 64), dtype="float16") = R.matmul(x, lv6, out_dtype="float16") lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, bias) lv2_2: R.Tensor((64, 128), dtype="float16") = R.nn.gelu(lv2_1) @@ -1710,7 +1708,7 @@ def main( cls = Module with R.dataflow(): lv = R.call_tir( - cls.rms_norm, (input, weight), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16") + cls.rms_norm, (input, weight), out_ty=R.Tensor((1, 1, 4096), dtype="float16") ) R.output(lv) return lv @@ -1874,7 +1872,7 @@ def main( lv = R.call_tir( cls.encode, (y,), - out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((64,), dtype="float16")], + out_ty=[R.Tensor((64, 64), dtype="int8"), R.Tensor((64,), dtype="float16")], ) lv1: R.Tensor((64, 64), dtype="int8") = lv[0] lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed( @@ -1882,14 +1880,12 @@ def main( lv1, R.prim_value(80), R.prim_value(0), - sinfo_args=(R.Tensor((64, 64), dtype="int8"),), + ty_args=(R.Tensor((64, 64), dtype="int8"),), ) lv3: R.Tensor((64,), dtype="float16") = lv[1] lv4: R.Tensor((64, 64), dtype="int8") = R.builtin.stop_lift_params(lv2) lv5: R.Tensor((64,), dtype="float16") = R.builtin.stop_lift_params(lv3) - lv6 = R.call_tir( - cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64), dtype="float16") - ) + lv6 = R.call_tir(cls.decode, (lv4, lv5), out_ty=R.Tensor((64, 64), dtype="float16")) lv1_1: R.Tensor((b, 64, 64), dtype="float16") = R.matmul( x, lv6, out_dtype="float16" ) @@ -2026,7 +2022,7 @@ def main( lv = R.call_tir( cls.encode, (y,), - out_sinfo=[ + out_ty=[ R.Tensor((128, 128), dtype="int8"), R.Tensor((2, 128), dtype="float16"), ], @@ -2037,13 +2033,13 @@ def main( lv1, R.prim_value(80), R.prim_value(0), - sinfo_args=(R.Tensor((128, 128), dtype="int8"),), + ty_args=(R.Tensor((128, 128), dtype="int8"),), ) lv3: R.Tensor((2, 128), dtype="float16") = lv[1] lv4: R.Tensor((128, 128), dtype="int8") = R.builtin.stop_lift_params(lv2) lv5: R.Tensor((2, 128), dtype="float16") = R.builtin.stop_lift_params(lv3) lv6 = R.call_tir( - cls.decode, (lv4, lv5), out_sinfo=R.Tensor((128, 128), dtype="float16") + cls.decode, (lv4, lv5), out_ty=R.Tensor((128, 128), dtype="float16") ) lv1_1: R.Tensor((b, 128, 128), dtype="float16") = R.matmul( x, lv6, out_dtype="float16" @@ -2231,7 +2227,7 @@ def main( # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. # https://github.com/apache/tvm/issues/15851 cumsum = R.call_dps_packed( - "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + "tvm.contrib.thrust.sum_scan", seq_lens, out_ty=seq_lens.ty ) max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0") seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) @@ -2284,7 +2280,7 @@ def main( # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. # https://github.com/apache/tvm/issues/15851 cumsum = R.call_dps_packed( - "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + "tvm.contrib.thrust.sum_scan", seq_lens, out_ty=seq_lens.ty ) max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0") seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) @@ -2379,7 +2375,7 @@ def main( # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. # https://github.com/apache/tvm/issues/15851 cumsum = R.call_dps_packed( - "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + "tvm.contrib.thrust.sum_scan", seq_lens, out_ty=seq_lens.ty ) max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0") seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py index 1afe4cc17438..d51fae0f5a5e 100644 --- a/tests/python/relax/test_codegen_tensorrt.py +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -483,7 +483,7 @@ def test_tensorrt_layout_transform(): # strings); the codegen translates a pure-permutation index map into a transpose. Built with the # BlockBuilder because the index_map lambda cannot be expressed in TVMScript. bb = relax.BlockBuilder() - data = relax.Var("data", relax.TensorStructInfo((1, 4, 8, 8), "float32")) + data = relax.Var("data", relax.TensorType((1, 4, 8, 8), "float32")) with bb.function("main", [data]): with bb.dataflow(): out = bb.emit( diff --git a/tests/python/relax/test_contrib_vllm.py b/tests/python/relax/test_contrib_vllm.py index fc97859ee908..d0ced83764c5 100644 --- a/tests/python/relax/test_contrib_vllm.py +++ b/tests/python/relax/test_contrib_vllm.py @@ -95,7 +95,7 @@ def main( 16, max_len, ], - out_sinfo=query.struct_info, + out_ty=query.ty, ) R.output(out) return out @@ -140,7 +140,7 @@ def main( max_logits, tmp_out, ], - out_sinfo=query.struct_info, + out_ty=query.ty, ) R.output(out) return out @@ -367,7 +367,7 @@ def main( key_cache, value_cache, slot_mapping, - sinfo_args=[key_cache.struct_info, value_cache.struct_info], + ty_args=[key_cache.ty, value_cache.ty], ) out = (kv[0], kv[1]) R.output(out) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 1b23e1448242..1184cd5b402f 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -205,11 +205,11 @@ def tir_id2(x: T.handle, y: T.handle, z: T.handle) -> None: def main(x: R.Tensor((10, 10), "int32")) -> R.Tensor((10, 10), "int32"): with R.dataflow(): cls = AliasCallTir - y = R.call_tir(cls.tir_id, (x,), out_sinfo=R.Tensor((10, 10), "int32")) + y = R.call_tir(cls.tir_id, (x,), out_ty=R.Tensor((10, 10), "int32")) t = R.call_tir( cls.tir_id2, (y,), - out_sinfo=[R.Tensor((10, 10), "int32"), R.Tensor((10, 10), "int32")], + out_ty=[R.Tensor((10, 10), "int32"), R.Tensor((10, 10), "int32")], ) z = y p = t[0] @@ -260,7 +260,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): n = R.const(2, dtype="int32") t = (m, n) a = R.call_pure_packed( - "chaos", t, sinfo_args=R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")) + "chaos", t, ty_args=R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")) ) b = a[0] c = a[1] @@ -496,7 +496,7 @@ def main( cls.add_inplace, (z, y), inplace_indices=[0], - out_sinfo=[ + out_ty=[ R.Tensor((2, 3), dtype="float32"), ], ) @@ -504,7 +504,7 @@ def main( cls.multiply_inplace, (a, y), inplace_indices=[0], - out_sinfo=[ + out_ty=[ R.Tensor((2, 3), dtype="float32"), ], ) @@ -513,7 +513,7 @@ def main( cls.subtract_inplace, (r, r), inplace_indices=[1], - out_sinfo=[ + out_ty=[ R.Tensor((1, 3), dtype="float32"), ], ) @@ -521,7 +521,7 @@ def main( cls.multiply_inplace, (q, s), inplace_indices=[0], - out_sinfo=[ + out_ty=[ R.Tensor((2, 3), dtype="float32"), ], ) @@ -602,13 +602,13 @@ def main( a_1 = R.call_tir_inplace( cls.add_inplace, (z, y), - out_sinfo=R.Tensor((a, b), dtype="float32"), + out_ty=R.Tensor((a, b), dtype="float32"), inplace_indices=[0], ) s = R.call_tir_inplace( cls.subtract_inplace, (a_1, a_1), - out_sinfo=R.Tensor((a, b), dtype="float32"), + out_ty=R.Tensor((a, b), dtype="float32"), inplace_indices=[1], ) R.output(s) @@ -911,24 +911,24 @@ def _concat_axis_for_view_op(op): @classmethod def _build_module(cls, op): if op == "relax.expand_dims": - x_sinfo = relax.TensorStructInfo((4,), "float32") + x_ty = relax.TensorType((4,), "float32") elif op == "relax.squeeze": - x_sinfo = relax.TensorStructInfo((1, 4, 1), "float32") + x_ty = relax.TensorType((1, 4, 1), "float32") elif op == "relax.reshape": - x_sinfo = relax.TensorStructInfo((4,), "float32") + x_ty = relax.TensorType((4,), "float32") elif op == "relax.permute_dims": - x_sinfo = relax.TensorStructInfo((1, 4), "float32") + x_ty = relax.TensorType((1, 4), "float32") elif op == "relax.memory.view": - x_sinfo = relax.TensorStructInfo((4,), "float32") + x_ty = relax.TensorType((4,), "float32") elif op == "relax.memory.ensure_zero_offset": - x_sinfo = relax.TensorStructInfo((4, 1), "float32") + x_ty = relax.TensorType((4, 1), "float32") elif op in ("relax.flatten", "relax.nn.batch_flatten"): - x_sinfo = relax.TensorStructInfo((1, 4), "float32") + x_ty = relax.TensorType((1, 4), "float32") else: raise ValueError(op) bb = relax.BlockBuilder() - x = relax.Var("x", x_sinfo) + x = relax.Var("x", x_ty) concat_axis = cls._concat_axis_for_view_op(op) with bb.function("main", [x]): with bb.dataflow(): diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 303557e81a63..80acc6c65efc 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -155,10 +155,8 @@ def test_function_pattern(): assert isinstance(f.body.args[1], WildcardPattern) x = rx.Var("x", R.Tensor("float32")) y = rx.Var("y", R.Tensor("float32")) - assert f.match(rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32"))) - assert not f.match( - rx.Function([x, y], rx.op.multiply(x, y), ret_struct_info=R.Tensor("float32")) - ) + assert f.match(rx.Function([x, y], rx.op.add(x, y), ret_ty=R.Tensor("float32"))) + assert not f.match(rx.Function([x, y], rx.op.multiply(x, y), ret_ty=R.Tensor("float32"))) def test_tuple_pattern(): @@ -286,7 +284,7 @@ def test_op_attr(): def test_match_call_attr(): x = rx.Var("x", R.Tensor("float32")) y = rx.Var("y", R.Tensor("float32")) - fn = rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32")) + fn = rx.Function([x, y], rx.op.add(x, y), ret_ty=R.Tensor("float32")) annotated_fn = fn.with_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}) xp = is_var("x") yp = is_var("y") @@ -314,7 +312,7 @@ def test_is_call_tir(): def simple_call_packed( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") ) -> R.Tensor: - gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + gv0 = R.call_packed("test.vm.mul", x, w, ty_args=(R.Tensor(ndim=2, dtype="float32"))) return gv0 @@ -1065,7 +1063,7 @@ def qkv_proj_rewriter(matchings, _): Q_weight = matchings[Q_weight_pat] K_weight = matchings[K_weight_pat] V_weight = matchings[V_weight_pat] - width = Q_weight.struct_info.shape[1] + width = Q_weight.ty.shape[1] concat = R.concat([Q_weight, K_weight, V_weight], axis=1) matmul = R.matmul(inp, concat) @@ -1315,7 +1313,7 @@ def rewriter(matchings, _): concat = R.concat([w1, w2], axis=0) matmul = R.matmul(inp, R.permute_dims(concat)) - sections = [w1.struct_info.shape[0]] + sections = [w1.ty.shape[0]] chunks = R.split(matmul, sections, -1) @@ -1479,7 +1477,7 @@ def rewriter(expr, matches): arg = matches[pattern_arg] shape_expr = matches[pattern_shape_expr] - if tvm_ffi.structural_equal(arg.struct_info.shape, shape_expr): + if tvm_ffi.structural_equal(arg.ty.shape, shape_expr): return arg else: return expr @@ -1614,7 +1612,7 @@ def rewriter(expr, matches): strides = matches[pattern_strides] strided_slice = matches[pattern] - if arg.struct_info.shape is None: + if arg.ty.shape is None: return expr if len(axes) != 1: @@ -1628,7 +1626,7 @@ def rewriter(expr, matches): if stride != 1: return expr - size = arg.struct_info.shape[0] + size = arg.ty.shape[0] if ( isinstance(size, tirx.IntImm) and isinstance(begin, tirx.IntImm) @@ -1756,9 +1754,7 @@ def rewriter(expr, matches): if pat_unwrap_concat_split in matches: args = matches[pat_args] - if len(args) == 2 and tvm_ffi.structural_equal( - args[0].struct_info, args[1].struct_info - ): + if len(args) == 2 and tvm_ffi.structural_equal(args[0].ty, args[1].ty): return args elif pat_add_self in matches: @@ -1771,11 +1767,11 @@ def rewriter(expr, matches): tvm.ir.assert_structural_equal(expected, after) -def test_wildcard_with_struct_info_updates_when_matching(): - """A DFPattern may be restricted to a specific StructInfo""" +def test_wildcard_with_ty_updates_when_matching(): + """A DFPattern may be restricted to a specific Type""" - pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3])) - pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3])) + pat_lhs = wildcard().has_ty(R.Tensor([2, 3])) + pat_rhs = wildcard().has_ty(R.Tensor([2, 3])) pat = is_op("relax.add")(pat_lhs, pat_rhs) def rewriter(expr, matches): @@ -1807,15 +1803,15 @@ def expected(): tvm.ir.assert_structural_equal(expected, after) -def test_wildcard_with_struct_info_is_no_op_when_not_matching(): - """StructInfoPattern requires the StructInfo provided +def test_wildcard_with_ty_is_no_op_when_not_matching(): + """TypePattern requires the Type provided Here, the pattern would match, expect that the function has `R.Tensor([16,32])`, and the pattern requires `R.Tensor([2,3])`. """ - pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3])) - pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3])) + pat_lhs = wildcard().has_ty(R.Tensor([2, 3])) + pat_rhs = wildcard().has_ty(R.Tensor([2, 3])) pat = is_op("relax.add")(pat_lhs, pat_rhs) def rewriter(expr, matches): @@ -1841,11 +1837,11 @@ def before(): tvm.ir.assert_structural_equal(expected, after) -def test_wildcard_struct_info_for_unknown_dtype(): - """TensorStructInfo with unknown dtype allows any dtype""" +def test_wildcard_ty_for_unknown_dtype(): + """TensorType with unknown dtype allows any dtype""" - pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3])) - pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3])) + pat_lhs = wildcard().has_ty(R.Tensor([2, 3])) + pat_rhs = wildcard().has_ty(R.Tensor([2, 3])) pat = is_op("relax.add")(pat_lhs, pat_rhs) def rewriter(expr, matches): @@ -1887,8 +1883,8 @@ def expected(): tvm.ir.assert_structural_equal(expected, after) -def test_wildcard_struct_info_with_symbolic_vars(): - """StructInfoPattern may define symbolic vars +def test_wildcard_ty_with_symbolic_vars(): + """TypePattern may define symbolic vars This test finds an elementwise `R.add`, while ignoring a broadcasted `R.add`. @@ -1897,8 +1893,8 @@ def test_wildcard_struct_info_with_symbolic_vars(): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - pat_lhs = wildcard().has_struct_info(R.Tensor([m, n])) - pat_rhs = wildcard().has_struct_info(R.Tensor([m, n])) + pat_lhs = wildcard().has_ty(R.Tensor([m, n])) + pat_rhs = wildcard().has_ty(R.Tensor([m, n])) pat = is_op("relax.add")(pat_lhs, pat_rhs) def rewriter(expr, matches): diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py index 15d270ad8c2c..5264913909ef 100644 --- a/tests/python/relax/test_dataflow_rewriter.py +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -35,9 +35,7 @@ def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): - C = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) + C = R.call_pure_packed("my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32")) return C @R.function @@ -56,7 +54,7 @@ def expected(x: R.Tensor([32], "float32")): lhs = split[0] rhs = split[1] out = R.call_pure_packed( - "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + "my_optimized_add_impl", lhs, rhs, ty_args=R.Tensor([16], "float32") ) return out @@ -155,9 +153,7 @@ def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): - C = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) + C = R.call_pure_packed("my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32")) return C @I.ir_module @@ -183,14 +179,14 @@ def func_a(x: R.Tensor([32], "float32")): lhs = split[0] rhs = split[1] out = R.call_pure_packed( - "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + "my_optimized_add_impl", lhs, rhs, ty_args=R.Tensor([16], "float32") ) return out @R.function def func_b(x: R.Tensor([16], "float32")): out = R.call_pure_packed( - "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + "my_optimized_add_impl", x, x, ty_args=R.Tensor([16], "float32") ) return out @@ -210,9 +206,7 @@ def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): - C = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) + C = R.call_pure_packed("my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32")) return C @I.ir_module @@ -227,7 +221,7 @@ class Expected: @R.function def main(x: R.Tensor([16], "float32")): out = R.call_pure_packed( - "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + "my_optimized_add_impl", x, x, ty_args=R.Tensor([16], "float32") ) return out @@ -247,9 +241,7 @@ def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): - C = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) + C = R.call_pure_packed("my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32")) return C @R.function(private=True) @@ -260,8 +252,8 @@ def before(x: R.Tensor([16], "float32")): @R.function(private=True) def expected(x: R.Tensor([16], "float32")): - y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")) - z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32")) + y = R.call_pure_packed("my_optimized_add_impl", x, x, ty_args=R.Tensor([16], "float32")) + z = R.call_pure_packed("my_optimized_add_impl", y, y, ty_args=R.Tensor([16], "float32")) return z after = Rewriter(before) @@ -280,9 +272,7 @@ def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): - C = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) + C = R.call_pure_packed("my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32")) return C @R.rewriter @@ -294,9 +284,7 @@ def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): - C = R.call_pure_packed( - "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) + C = R.call_pure_packed("my_optimized_mul_impl", A, B, ty_args=R.Tensor([16], "float32")) return C @R.function(private=True) @@ -315,8 +303,8 @@ def expected( B: R.Tensor([16], "float32"), C: R.Tensor([16], "float32"), ): - D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")) - E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32")) + D = R.call_pure_packed("my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32")) + E = R.call_pure_packed("my_optimized_mul_impl", C, D, ty_args=R.Tensor([16], "float32")) return E rewriter = RewriteAdd | RewriteMultiply @@ -354,9 +342,7 @@ def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): @R.function def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): - C = R.call_pure_packed( - "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) + C = R.call_pure_packed("my_optimized_mul_impl", A, B, ty_args=R.Tensor([16], "float32")) return C @R.function(private=True) @@ -370,7 +356,7 @@ def expected(A: R.Tensor([16], "float32")): "my_optimized_mul_impl", A, R.const(2.0, "float32"), - sinfo_args=R.Tensor([16], "float32"), + ty_args=R.Tensor([16], "float32"), ) return B @@ -594,7 +580,7 @@ def pattern(A: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32")): - return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32")) + return R.call_tir(RewriteMul.subroutine_mul, [A], out_ty=R.Tensor([16], "float32")) @T.prim_func(private=True, s_tir=True) def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): @@ -672,7 +658,7 @@ def pattern(A: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32")): - return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32")) + return R.call_tir(RewriteMul.subroutine, [A], out_ty=R.Tensor([16], "float32")) @T.prim_func(private=True, s_tir=True) def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): @@ -692,7 +678,7 @@ class Expected: @R.function def main(A: R.Tensor([16], "float32")): B = Expected.subroutine(A) - C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32")) + C = R.call_tir(Expected.subroutine_1, [B], out_ty=R.Tensor([16], "float32")) return C @R.function(private=True) @@ -929,7 +915,7 @@ def replacement( x, y, z, - sinfo_args=R.Tuple( + ty_args=R.Tuple( [ R.Tensor([16], "float32"), R.Tensor([16], "float32"), @@ -959,7 +945,7 @@ def expected( A, B, C, - sinfo_args=R.Tuple( + ty_args=R.Tuple( [ R.Tensor([16], "float32"), R.Tensor([16], "float32"), @@ -1006,7 +992,7 @@ def replacement( A, B, C, - sinfo_args=R.Tuple( + ty_args=R.Tuple( [ R.Tensor([16], "float32"), R.Tensor([16], "float32"), @@ -1109,10 +1095,10 @@ def pattern(qkv: R.Tensor([12288], "float32")): k = qkv_tuple[1] v = qkv_tuple[2] q_embed = R.call_pure_packed( - "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + "rotary_embedding", [q], ty_args=R.Tensor([4096], "float32") ) k_embed = R.call_pure_packed( - "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + "rotary_embedding", [k], ty_args=R.Tensor([4096], "float32") ) return (q_embed, k_embed, v) @@ -1122,7 +1108,7 @@ def replacement(qkv: R.Tensor([12288], "float32")): return R.call_pure_packed( "split_rotary_embedding", [qkv], - sinfo_args=[ + ty_args=[ R.Tensor([4096], "float32"), R.Tensor([4096], "float32"), R.Tensor([4096], "float32"), @@ -1140,17 +1126,13 @@ def before( q = qkv_tuple[0] k = qkv_tuple[1] v = qkv_tuple[2] - q_embed = R.call_pure_packed( - "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") - ) - k_embed = R.call_pure_packed( - "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") - ) + q_embed = R.call_pure_packed("rotary_embedding", [q], ty_args=R.Tensor([4096], "float32")) + k_embed = R.call_pure_packed("rotary_embedding", [k], ty_args=R.Tensor([4096], "float32")) attention = R.call_pure_packed( "compute_self_attention", [q_embed, k_embed, v, kv_cache], - sinfo_args=R.Tensor([4096]), + ty_args=R.Tensor([4096]), ) return attention @@ -1165,7 +1147,7 @@ def expected( embedded_qkv_tuple = R.call_pure_packed( "split_rotary_embedding", [qkv], - sinfo_args=[ + ty_args=[ R.Tensor([4096], "float32"), R.Tensor([4096], "float32"), R.Tensor([4096], "float32"), @@ -1179,7 +1161,7 @@ def expected( attention = R.call_pure_packed( "compute_self_attention", [q_embed, k_embed, v, kv_cache], - sinfo_args=R.Tensor([4096]), + ty_args=R.Tensor([4096]), ) return attention @@ -1226,7 +1208,7 @@ def replacement( state, weights, bias, - sinfo_args=R.Tensor([16], "float32"), + ty_args=R.Tensor([16], "float32"), ) @R.function(private=True, pure=False) @@ -1286,7 +1268,7 @@ def replacement( state, weights, bias, - sinfo_args=R.Tensor([16], "float32"), + ty_args=R.Tensor([16], "float32"), ) @R.function(private=True, pure=False) @@ -1313,7 +1295,7 @@ def expected( state, weights, bias, - sinfo_args=R.Tensor([16], "float32"), + ty_args=R.Tensor([16], "float32"), ) R.print(format="End of function") return state @@ -1339,7 +1321,7 @@ def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): return R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + "my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32") ) @R.function(private=True) @@ -1355,14 +1337,12 @@ def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.P def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): if cond: out = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + "my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32") ) else: - C = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) + C = R.call_pure_packed("my_optimized_add_impl", A, B, ty_args=R.Tensor([16], "float32")) out = R.call_pure_packed( - "my_optimized_add_impl", C, B, sinfo_args=R.Tensor([16], "float32") + "my_optimized_add_impl", C, B, ty_args=R.Tensor([16], "float32") ) return out @@ -1469,7 +1449,7 @@ def replacement( "my_optimized_square_matmul", A, B, - sinfo_args=R.Tensor([M, N], "float32"), + ty_args=R.Tensor([M, N], "float32"), ) @R.function(private=True) @@ -1496,14 +1476,14 @@ def expected( "my_optimized_square_matmul", A, B, - sinfo_args=R.Tensor([N, N * 2], "float32"), + ty_args=R.Tensor([N, N * 2], "float32"), ) E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) F: R.Tensor([N * 2, N], "float32") = R.call_pure_packed( "my_optimized_square_matmul", E, C, - sinfo_args=R.Tensor([N * 2, N], "float32"), + ty_args=R.Tensor([N * 2, N], "float32"), ) return F diff --git a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py index 0666919799ac..13ca39cef1ef 100644 --- a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py +++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py @@ -121,7 +121,7 @@ def main( out = R.call_tir( AddBefore.add, (a, b), - out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + out_ty=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @@ -223,7 +223,7 @@ def main( out = R.call_tir( AddExpected.add, (a, b), - out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + out_ty=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @@ -320,7 +320,7 @@ def main( out = R.call_tir( SubBefore.sub, (a, b), - out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + out_ty=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @@ -422,7 +422,7 @@ def main( out = R.call_tir( SubExpected.sub, (a, b), - out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + out_ty=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @@ -519,7 +519,7 @@ def main( out = R.call_tir( MulBefore.mul, (a, b), - out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + out_ty=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @@ -621,7 +621,7 @@ def main( out = R.call_tir( MulExpected.mul, (a, b), - out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + out_ty=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 31bd012b449f..855361712e6c 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -45,26 +45,44 @@ def _check_json_roundtrip(x): def test_var() -> None: v0 = rx.Var("v0") assert v0.name_hint == "v0" - assert v0.struct_info_ is None + assert v0.ty is None shape = [54, 96] v1 = rx.Var("v1", R.Tensor(shape, "float32")) assert v1.name_hint == "v1" - for s0, s1 in zip(v1.struct_info.shape, shape): + for s0, s1 in zip(v1.ty.shape, shape): assert s0 == s1 - tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float32")) + tvm.ir.assert_structural_equal(v1.ty, rx.TensorType(shape, "float32")) + + +def test_relax_expr_ty_running_example() -> None: + m = tirx.Var("m", "int64") + x = rx.Var("x", R.Tensor([m, 16], "float32")) + + assert isinstance(x.ty, tvm.ir.Type) + assert x.ty.dtype == "float32" + assert x.ty.ndim == 2 + + call = rx.op.add(x, x) + assert call.ty is None + + bb = rx.BlockBuilder() + normalized = bb.normalize(call) + + assert isinstance(normalized.ty, tvm.ir.Type) + tvm.ir.assert_structural_equal(normalized.ty, x.ty) def test_dataflow_var() -> None: v0 = rx.DataflowVar("v0") assert v0.name_hint == "v0" - assert v0.struct_info_ is None + assert v0.ty is None shape = [54, 96] v1 = rx.DataflowVar("v1", R.Tensor(shape, "float16")) assert v1.name_hint == "v1" assert isinstance(v1, rx.DataflowVar) - tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float16")) + tvm.ir.assert_structural_equal(v1.ty, rx.TensorType(shape, "float16")) def test_tuple() -> None: @@ -86,23 +104,21 @@ def test_tuple() -> None: t[-3] -def test_tuple_sinfo_inferred_on_construction(): - v0 = rx.Var("v0", rx.ObjectStructInfo()) - v1 = rx.Var("v1", rx.ObjectStructInfo()) +def test_tuple_ty_inferred_on_construction(): + v0 = rx.Var("v0", rx.ObjectType()) + v1 = rx.Var("v1", rx.ObjectType()) tup = rx.Tuple((v0, v1)) - assert tup.struct_info_ is not None - tvm.ir.assert_structural_equal( - tup.struct_info, rx.TupleStructInfo([rx.ObjectStructInfo(), rx.ObjectStructInfo()]) - ) + assert tup.ty is not None + tvm.ir.assert_structural_equal(tup.ty, rx.TupleType([rx.ObjectType(), rx.ObjectType()])) -def test_tuple_sinfo_requires_fields_with_known_sinfo(): - v0 = rx.Var("v0", rx.ObjectStructInfo()) +def test_tuple_ty_requires_fields_with_known_ty(): + v0 = rx.Var("v0", rx.ObjectType()) v1 = rx.Var("v1") tup = rx.Tuple((v0, v1)) - assert tup.struct_info_ is None + assert tup.ty is None def test_match_cast() -> None: @@ -133,10 +149,10 @@ def test_match_cast() -> None: m = tirx.Var("m", dtype="int64") n = tirx.Var("n", dtype="int64") ivalue = rx.Var("input_value") - sinfo = rx.TensorStructInfo([n, m], "float32") - b0 = rx.MatchCast(rx.Var("v"), ivalue, sinfo) + ty = rx.TensorType([n, m], "float32") + b0 = rx.MatchCast(rx.Var("v"), ivalue, ty) assert b0.value.same_as(ivalue) - assert b0.struct_info == sinfo + assert b0.ty == ty _check_json_roundtrip(b0) @@ -194,12 +210,12 @@ def test_func(): blocks = [rx.BindingBlock(bindings)] seqe = rx.SeqExpr(blocks, x) - ret_struct_info = R.Tensor(dtype="float32", ndim=-1) - func = rx.Function([x], seqe, ret_struct_info) + ret_ty = R.Tensor(dtype="float32", ndim=-1) + func = rx.Function([x], seqe, ret_ty) func = func.with_attr("global_symbol", "func") assert func.params[0] == x assert func.body == seqe - assert func.ret_struct_info == ret_struct_info + assert func.ret_ty == ret_ty assert func.attrs["global_symbol"] == "func" @@ -221,7 +237,7 @@ def test_shape_expr(): assert s[1] == n assert s[-1] == n assert s[-2] == m - assert isinstance(s.struct_info, rx.ShapeStructInfo) + assert isinstance(s.ty, rx.ShapeType) with pytest.raises(IndexError, match="ShapeExpr index out of range"): s[2] @@ -232,17 +248,15 @@ def test_shape_expr(): shape_expr = rx.ShapeExpr([10, 20]) assert shape_expr.values[0] == 10 assert shape_expr.values[1] == 20 - tvm.ir.assert_structural_equal(shape_expr.struct_info, R.Shape((10, 20))) + tvm.ir.assert_structural_equal(shape_expr.ty, R.Shape((10, 20))) x = rx.Var("v0", R.Tensor((10, 20), "float32")) - assert x.struct_info.shape[0] == 10 - assert x.struct_info.shape[1] == 20 - tvm.ir.assert_structural_equal(x.struct_info.shape.struct_info, R.Shape((10, 20))) + assert x.ty.shape[0] == 10 + assert x.ty.shape[1] == 20 + tvm.ir.assert_structural_equal(x.ty.shape.ty, R.Shape((10, 20))) m = tirx.Var("m", "int32") - with pytest.raises( - RuntimeError, match="the value in ShapeStructInfo can only have dtype of int64" - ): + with pytest.raises(RuntimeError, match="the value in ShapeType can only have dtype of int64"): rx.ShapeExpr([m, 3]) @@ -257,7 +271,7 @@ def test_prim_value_with_var(): n = tirx.Var("n", "int64") pv = rx.PrimValue(n) assert pv.value.same_as(n) - tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n)) + tvm.ir.assert_structural_equal(pv.ty, rx.PrimType(value=n)) _check_equal(pv, rx.PrimValue(n)) _check_json_roundtrip(pv) @@ -265,7 +279,7 @@ def test_prim_value_with_var(): def test_prim_value_with_expr(): n = tirx.Var("n", "int64") pv = rx.PrimValue(n + 1) - tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n + 1)) + tvm.ir.assert_structural_equal(pv.ty, rx.PrimType(value=n + 1)) _check_equal(pv, rx.PrimValue(n + 1)) _check_json_roundtrip(pv) @@ -287,8 +301,8 @@ def test_datatype_imm(): def test_call(): - dtype = rx.PrimStructInfo("int32") - func = rx.Var("func", rx.FuncStructInfo([dtype], dtype)) + dtype = rx.PrimType("int32") + func = rx.Var("func", rx.FuncType([dtype], dtype)) arg = rx.Var("arg", dtype) call = rx.Call(func, [arg]) assert call.op.same_as(func) @@ -297,8 +311,8 @@ def test_call(): def test_call_raises_error_for_invalid_function(): - """relax::Call requires the function to have FuncStructInfo""" - dtype = rx.PrimStructInfo("int32") + """relax::Call requires the function to have FuncType""" + dtype = rx.PrimType("int32") func = rx.Var("func", dtype) arg = rx.Var("arg", dtype) diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index 61b8696ee0bd..d3fc8bc9fb8f 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -302,7 +302,7 @@ def visit_var_binding_(self, binding: VarBinding) -> None: self.builder_.emit_normalized(binding) return - temp = self.with_struct_info(new_var, new_value.struct_info) + temp = self.with_type(new_var, new_value.ty) if not temp.same_as(new_var): new_var = temp self.set_var_remap(binding.var.vid, new_var) @@ -314,13 +314,13 @@ def visit_match_cast_(self, binding: MatchCast) -> None: new_var = self.visit_var_def(binding.var) new_value = self.visit_expr(binding.value) - temp = self.with_struct_info(new_var, binding.struct_info) + temp = self.with_type(new_var, binding.ty) if not temp.same_as(new_var): new_var = temp self.set_var_remap(binding.var.vid, new_var) self.log.add("MatchCast") - self.builder_.emit_normalized(MatchCast(new_var, new_value, binding.struct_info)) + self.builder_.emit_normalized(MatchCast(new_var, new_value, binding.ty)) def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: """Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side.""" @@ -367,7 +367,7 @@ def visit(f, expr): # check no overloading case basic_mutator = BasicMutator() - # skip normalize GlobalVar since it requires context IRModule to get the struct_info_ + # skip normalize GlobalVar since it requires context IRModule to get ty if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): expr = bb.normalize(expr) assert_structural_equal(visit(basic_mutator, expr), expr) @@ -844,8 +844,8 @@ def __init__(self, shape_replacements): def visit_var_def_(self, var): if var.name_hint in self.shape_replacements: new_shape = self.shape_replacements[var.name_hint] - new_sinfo = relax.TensorStructInfo(new_shape, dtype=var.struct_info.dtype) - return relax.Var(f"{var.name_hint}_with_new_shape", new_sinfo) + new_ty = relax.TensorType(new_shape, dtype=var.ty.dtype) + return relax.Var(f"{var.name_hint}_with_new_shape", new_ty) else: return var diff --git a/tests/python/relax/test_frontend_common.py b/tests/python/relax/test_frontend_common.py index 0829a498da17..024eb65d1965 100644 --- a/tests/python/relax/test_frontend_common.py +++ b/tests/python/relax/test_frontend_common.py @@ -45,7 +45,7 @@ class TestAutopad: def _test_autopad(self, pad_type, expected): bb = relax.BlockBuilder() input_shape = (1, 1, 4, 4) - x = relax.Var("x", relax.TensorStructInfo(input_shape, "float32")) + x = relax.Var("x", relax.TensorType(input_shape, "float32")) with bb.function("main", [x]): with bb.dataflow(): @@ -94,9 +94,7 @@ def main(x: R.Tensor((1, 1, 4, 4), dtype="float32")) -> R.Tensor( ): cls = expected with R.dataflow(): - lv = R.call_tir( - cls.pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") - ) + lv = R.call_tir(cls.pad, (x,), out_ty=R.Tensor((1, 1, 5, 5), dtype="float32")) gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv R.output(gv) return gv @@ -156,7 +154,7 @@ def main(x: R.Tensor((1, 1, 4, 4), dtype="float32")) -> R.Tensor( cls = expected with R.dataflow(): lv = R.call_tir( - cls.replicate_pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") + cls.replicate_pad, (x,), out_ty=R.Tensor((1, 1, 5, 5), dtype="float32") ) gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv R.output(gv) @@ -202,7 +200,7 @@ def main(x: R.Tensor((1, 1, 4, 4), dtype="float32")) -> R.Tensor( cls = expected with R.dataflow(): lv = R.call_tir( - cls.mirror_pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") + cls.mirror_pad, (x,), out_ty=R.Tensor((1, 1, 5, 5), dtype="float32") ) gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv R.output(gv) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ee2f4a8f8df6..e4adefc1b1fd 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6637,8 +6637,8 @@ def main( assert len(params) == len(func.params) - 1 for param_var, param_tensor in zip(func.params[1:], params): - assert tuple(x.value for x in param_var.struct_info.shape.values) == param_tensor.shape - assert param_var.struct_info.dtype == param_tensor.dtype + assert tuple(x.value for x in param_var.ty.shape.values) == param_tensor.shape + assert param_var.ty.dtype == param_tensor.dtype tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().detach().numpy()) tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().detach().numpy()) @@ -6897,10 +6897,10 @@ def forward(self, input): exported_program = export(Randn(), args=example_args) mod = from_exported_program(exported_program) func = mod["main"] - ret_sinfo = func.ret_struct_info - assert ret_sinfo.fields[0].shape[0] == 5 - assert ret_sinfo.fields[0].shape[1] == 3 - assert ret_sinfo.fields[0].dtype == "float32" + ret_ty = func.ret_ty + assert ret_ty.fields[0].shape[0] == 5 + assert ret_ty.fields[0].shape[1] == 3 + assert ret_ty.fields[0].dtype == "float32" def test_randn_like(): @@ -6912,10 +6912,10 @@ def forward(self, input): exported_program = export(RandnLike(), args=example_args) mod = from_exported_program(exported_program) func = mod["main"] - ret_sinfo = func.ret_struct_info - assert ret_sinfo.fields[0].shape[0] == 4 - assert ret_sinfo.fields[0].shape[1] == 6 - assert ret_sinfo.fields[0].dtype == "float32" + ret_ty = func.ret_ty + assert ret_ty.fields[0].shape[0] == 4 + assert ret_ty.fields[0].shape[1] == 6 + assert ret_ty.fields[0].dtype == "float32" def test_type_as(): @@ -7413,10 +7413,10 @@ def forward(self, x, buf, idx): exported_program = export(IndexPutTupleOutput(), args=example_args) mod = from_exported_program(exported_program) - ret_sinfo = mod["main"].ret_struct_info - assert isinstance(ret_sinfo, relax.TupleStructInfo) + ret_ty = mod["main"].ret_ty + assert isinstance(ret_ty, relax.TupleType) - tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] + tensor_fields = [f for f in ret_ty.fields if isinstance(f, relax.TensorType)] assert len(tensor_fields) >= 2 assert any( @@ -7445,10 +7445,10 @@ def forward(self, x): # Regression focus: importing this graph should not segfault at Tuple construction. mod = from_exported_program(exported_program) - ret_sinfo = mod["main"].ret_struct_info - assert isinstance(ret_sinfo, relax.TupleStructInfo) + ret_ty = mod["main"].ret_ty + assert isinstance(ret_ty, relax.TupleType) - tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] + tensor_fields = [f for f in ret_ty.fields if isinstance(f, relax.TensorType)] assert len(tensor_fields) >= 2 # x: (2, 3, 5) → x[..., :1]: (2, 3, 1) assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in tensor_fields) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 34da69d5f061..7cd6744d346c 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5036,8 +5036,8 @@ def main( assert len(params) == len(func.params) - 1 for param_var, param_tensor in zip(func.params[1:], params): - assert tuple(x.value for x in param_var.struct_info.shape.values) == param_tensor.shape - assert param_var.struct_info.dtype == param_tensor.dtype + assert tuple(x.value for x in param_var.ty.shape.values) == param_tensor.shape + assert param_var.ty.dtype == param_tensor.dtype tvm.testing.assert_allclose(params[0].numpy(), model.conv.bias.detach().detach().numpy()) tvm.testing.assert_allclose(params[1].numpy(), model.conv.weight.detach().detach().numpy()) diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index b837eb6866ad..f0c0d91e904f 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -95,7 +95,7 @@ def scalar_add( R.func_attr({"num_input": 2}) with R.dataflow(): ext_scalar_add = R.call_dps_packed( - "ext_scalar_add", (a, b), out_sinfo=R.Tensor((), dtype="float32") + "ext_scalar_add", (a, b), out_ty=R.Tensor((), dtype="float32") ) gv: R.Tensor((), dtype="float32") = ext_scalar_add R.output(gv) @@ -111,7 +111,7 @@ def test_sym( R.func_attr({"num_input": 2}) with R.dataflow(): ext_test_sym = R.call_dps_packed( - "ext_test_sym", (a, b), out_sinfo=R.Tensor((x, y, z, 9), dtype="float32") + "ext_test_sym", (a, b), out_ty=R.Tensor((x, y, z, 9), dtype="float32") ) gv1: R.Tensor((x, y, z, 9), dtype="float32") = ext_test_sym R.output(gv1) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 601c4891ead2..014f6d6a9d45 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -514,7 +514,7 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object): lv, R.shape([8, 2, 4]), R.prim_value(0), - sinfo_args=[R.Object()], + ty_args=[R.Object()], ) lv1 = _io, cache gv = lv1 @@ -532,13 +532,13 @@ def forward( cache, x, inplace_indices=[0], - sinfo_args=[R.Object()], + ty_args=[R.Object()], ) lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_pure_packed( "vm.builtin.attention_kv_cache_view", lv2, R.shape([4, 2, 4]), - sinfo_args=(R.Tensor((4, 2, 4), dtype="float32"),), + ty_args=(R.Tensor((4, 2, 4), dtype="float32"),), ) gv1: R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)) = ( lv3, diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 51a0d1e9f0f0..e7a31db0547f 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -618,7 +618,7 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten cls = Expected R.func_attr({"num_input": 2}) with R.dataflow(): - lv1 = R.call_tir(cls.add_one, (x,), out_sinfo=R.Tensor((10, 10), dtype="float32")) + lv1 = R.call_tir(cls.add_one, (x,), out_ty=R.Tensor((10, 10), dtype="float32")) gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = lv1, (_io,) R.output(gv1) return gv1 @@ -700,7 +700,7 @@ def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset: R.Shape(["offse R.func_attr({"num_input": 3}) cls = Expected with R.dataflow(): - lv1 = R.call_tir(cls.llama_fused_rope, (qkv,), out_sinfo=[R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")], tir_vars=R.shape([offset_1])) + lv1 = R.call_tir(cls.llama_fused_rope, (qkv,), out_ty=[R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")], tir_vars=R.shape([offset_1])) llama_fused_rope_0: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[0] llama_fused_rope_1: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[1] llama_fused_rope_2: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[2] @@ -799,7 +799,7 @@ def test( lv1 = R.call_tir_inplace( cls.inplace_take, (embedding_table, input_ids, embedding_dst), - out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), + out_ty=R.Tensor((total_seq_len, hidden_size), dtype), inplace_indices=[2], tir_vars=R.shape([offset_1]), ) @@ -852,7 +852,7 @@ def test(A: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl R.func_attr({"num_input": 1}) cls = Expected with R.dataflow(): - lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv = R.call_tir(cls.tir_func, (A,), out_ty=R.Tensor((16, 16), dtype="float32")) gv: R.Tensor((16, 16), dtype="float32") = lv R.output(gv) return gv @@ -889,7 +889,7 @@ def _initialize_effect() -> R.Tuple(R.Object): def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64, 16, 8), dtype="float32"), v: R.Tensor((64, 16, 8), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): - flashinfer_single_decode = R.call_dps_packed("flashinfer.single_decode", (q, k, v, R.prim_value(0), R.prim_value(0), R.prim_value(T.float64(1)), R.prim_value(T.float64(10000))), out_sinfo=R.Tensor((1, 1, 128), dtype="float16")) + flashinfer_single_decode = R.call_dps_packed("flashinfer.single_decode", (q, k, v, R.prim_value(0), R.prim_value(0), R.prim_value(T.float64(1)), R.prim_value(T.float64(10000))), out_ty=R.Tensor((1, 1, 128), dtype="float16")) gv1: R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)) = flashinfer_single_decode, (_io,) R.output(gv1) return gv1 @@ -1087,8 +1087,8 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), index: R.Tensor((2, 3), dtype=" cls = Expected with R.dataflow(): cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(prob, axis=1, dtype="void", exclusive=None) - lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32")) - lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, index, lv1, uniform_sample, sample_indices), out_sinfo=R.Tensor((3, 1), dtype="int64")) + lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k), out_ty=R.Tensor((2, 1), dtype="float32")) + lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, index, lv1, uniform_sample, sample_indices), out_ty=R.Tensor((3, 1), dtype="int64")) gv1: R.Tuple(R.Tensor((3, 1), dtype="int64"), R.Tuple(R.Object)) = lv2, (_io,) R.output(gv1) return gv1 @@ -1205,8 +1205,8 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: R.Tensor((2, 3), d cls = Expected with R.dataflow(): cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(sorted_prob, axis=1, dtype="void", exclusive=None) - lv1 = R.call_tir(cls.get_renorm_cutoff, (sorted_prob, cumsum, top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32")) - lv2 = R.call_tir(cls.filter_with_top_p_top_k, (prob, lv1), out_sinfo=R.Tensor((2, 3), dtype="float32")) + lv1 = R.call_tir(cls.get_renorm_cutoff, (sorted_prob, cumsum, top_p, top_k), out_ty=R.Tensor((2, 1), dtype="float32")) + lv2 = R.call_tir(cls.filter_with_top_p_top_k, (prob, lv1), out_ty=R.Tensor((2, 3), dtype="float32")) sum: R.Tensor((2, 1), dtype="float32") = R.sum(lv2, axis=[1], keepdims=True) divide: R.Tensor((2, 3), dtype="float32") = R.divide(lv2, sum) gv1: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tuple(R.Object)) = divide, (_io,) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 6e9d4c9d95a1..a2e7bc3b8b13 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -508,7 +508,7 @@ def test_matmulinteger16_ir(): call_ops = collect_relax_call_ops(tvm_model["main"]) assert call_ops.count("relax.astype") == 2 assert "relax.matmul" in call_ops - assert tvm_model["main"].ret_struct_info.dtype == "uint32" + assert tvm_model["main"].ret_ty.dtype == "uint32" def test_matmulinteger16_invalid_dtype_raises(): @@ -4636,7 +4636,7 @@ def test_optional_get_element_tensor_ir(): tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) assert collect_relax_call_ops(tvm_model["main"]) == [] - assert tvm_model["main"].ret_struct_info.dtype == "float32" + assert tvm_model["main"].ret_ty.dtype == "float32" def test_optional_get_element_sequence(): @@ -5698,7 +5698,7 @@ def test_grid_sample_4d_non_square_output_shape(): model = helper.make_model(graph, producer_name="grid_sample_4d_non_square_output_shape_test") tvm_model = from_onnx(model, opset=16, keep_params_in_input=True) - inferred_shape = tuple(dim.value for dim in tvm_model["main"].ret_struct_info.shape.values) + inferred_shape = tuple(dim.value for dim in tvm_model["main"].ret_ty.shape.values) assert inferred_shape == tuple(out_shape) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 5a4460d0b711..de53735a76e2 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -76,8 +76,8 @@ def verify(TestClass, expected=None): tf_inputs = [] tvm_inputs = [] for arg in mod["main"].params: - shape = tuple(shape_val.value for shape_val in arg.struct_info.shape.values) - data = np.random.uniform(0, 1, size=shape).astype(arg.struct_info.dtype) + shape = tuple(shape_val.value for shape_val in arg.ty.shape.values) + data = np.random.uniform(0, 1, size=shape).astype(arg.ty.dtype) tvm_inputs.append(data) tf_inputs.append(tf.constant(data)) @@ -2622,9 +2622,9 @@ def _convert_detection_postprocess_with_options( build_module=True, ): input_num_classes = num_classes if input_num_classes is None else input_num_classes - loc = relax.Var("loc", relax.TensorStructInfo((batch_size, num_anchors, 4), "float32")) + loc = relax.Var("loc", relax.TensorType((batch_size, num_anchors, 4), "float32")) cls = relax.Var( - "cls", relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), "float32") + "cls", relax.TensorType((batch_size, num_anchors, input_num_classes), "float32") ) inputs = [ _make_detection_postprocess_tensor_wrapper(0, (batch_size, num_anchors, 4), "loc"), @@ -3073,13 +3073,13 @@ def test_detection_postprocess_smoke(build_kwargs, expected_topk_count, expected expected_batch = build_kwargs["batch_size"] expected_max_detections = build_kwargs["max_detections"] tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TupleStructInfo( + mod["main"].ret_ty, + relax.TupleType( [ - relax.TensorStructInfo((expected_batch, expected_max_detections, 4), "float32"), - relax.TensorStructInfo((expected_batch, expected_max_detections), "float32"), - relax.TensorStructInfo((expected_batch, expected_max_detections), "float32"), - relax.TensorStructInfo((expected_batch,), "float32"), + relax.TensorType((expected_batch, expected_max_detections, 4), "float32"), + relax.TensorType((expected_batch, expected_max_detections), "float32"), + relax.TensorType((expected_batch, expected_max_detections), "float32"), + relax.TensorType((expected_batch,), "float32"), ] ), ) @@ -3088,7 +3088,7 @@ def test_detection_postprocess_smoke(build_kwargs, expected_topk_count, expected legalized_ir = legalized.script() assert "R.vision.all_class_non_max_suppression(" not in legalized_ir assert "R.call_tir(" in legalized_ir - tvm.ir.assert_structural_equal(legalized["main"].ret_struct_info, mod["main"].ret_struct_info) + tvm.ir.assert_structural_equal(legalized["main"].ret_ty, mod["main"].ret_ty) @pytest.mark.parametrize("build_kwargs", _DETECTION_POSTPROCESS_SHAPE_CASES) @@ -3100,17 +3100,17 @@ def test_detection_postprocess_shape_variations(build_kwargs): max_detections = build_kwargs["max_detections"] tvm.ir.assert_structural_equal( - mod["main"].params[1].struct_info, - relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), "float32"), + mod["main"].params[1].ty, + relax.TensorType((batch_size, num_anchors, input_num_classes), "float32"), ) tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TupleStructInfo( + mod["main"].ret_ty, + relax.TupleType( [ - relax.TensorStructInfo((batch_size, max_detections, 4), "float32"), - relax.TensorStructInfo((batch_size, max_detections), "float32"), - relax.TensorStructInfo((batch_size, max_detections), "float32"), - relax.TensorStructInfo((batch_size,), "float32"), + relax.TensorType((batch_size, max_detections, 4), "float32"), + relax.TensorType((batch_size, max_detections), "float32"), + relax.TensorType((batch_size, max_detections), "float32"), + relax.TensorType((batch_size,), "float32"), ] ), ) @@ -3121,7 +3121,7 @@ def _make_resize_expected( ): """Build an Expected IRModule programmatically to avoid TVMScript variable scope limitations.""" bb = relax.BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo(input_shape, "float32")) + x = relax.Var("x", relax.TensorType(input_shape, "float32")) with bb.function("main", [x]): with bb.dataflow(): gv = bb.emit_output( @@ -3273,7 +3273,7 @@ def _make_reduce_expected(relax_op, input_shape, axes, keepdims, dtype): if axes is None: axes = list(range(len(input_shape))) bb = relax.BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo(input_shape, dtype)) + x = relax.Var("x", relax.TensorType(input_shape, dtype)) with bb.function("main", [x]): with bb.dataflow(): gv = bb.emit_output(relax_op(x, axis=axes, keepdims=keepdims)) @@ -3321,7 +3321,7 @@ def _make_reduce_bool_expected(relax_op, input_shape, axes, keepdims): if axes is None: axes = list(range(len(input_shape))) bb = relax.BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo(input_shape, "bool")) + x = relax.Var("x", relax.TensorType(input_shape, "bool")) with bb.function("main", [x]): with bb.dataflow(): cast_in = bb.emit(relax.op.astype(x, "int8")) @@ -3584,8 +3584,8 @@ def func(self, x): assert "space_to_batch_nd" in ir assert len(mod["main"].params) == 1 tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TensorStructInfo(expected_out_shape, "float32"), + mod["main"].ret_ty, + relax.TensorType(expected_out_shape, "float32"), ) if "CI_ENV_NIGHTLY" in os.environ: @@ -3618,8 +3618,8 @@ def func(self, x): assert "batch_to_space_nd" in ir assert len(mod["main"].params) == 1 tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TensorStructInfo(expected_out_shape, "float32"), + mod["main"].ret_ty, + relax.TensorType(expected_out_shape, "float32"), ) if "CI_ENV_NIGHTLY" in os.environ: @@ -3882,7 +3882,7 @@ def main(diagonal: R.Tensor((3,), dtype="float32")) -> R.Tensor((3, 3), dtype="f R.const(False, "bool"), R.const(False, "bool"), ), - out_sinfo=R.Tensor((3, 3), dtype="float32"), + out_ty=R.Tensor((3, 3), dtype="float32"), ) R.output(gv) return gv @@ -3922,7 +3922,7 @@ def main( R.const(False, "bool"), R.const(False, "bool"), ), - out_sinfo=R.Tensor((3, 3), dtype="float32"), + out_ty=R.Tensor((3, 3), dtype="float32"), ) R.output(gv) return gv @@ -3963,7 +3963,7 @@ def main( gv = R.call_dps_packed( "topi.sparse_to_dense", (indices, R.const([3], "int32"), values, default_value), - out_sinfo=R.Tensor((3,), dtype="float32"), + out_ty=R.Tensor((3,), dtype="float32"), ) R.output(gv) return gv @@ -12340,11 +12340,11 @@ def test_svdf_none_activation(): fn = mod["main"] assert len(fn.params) == 2, f"expected 2 params (input, state), got {len(fn.params)}" - in_shape = fn.params[0].struct_info.shape + in_shape = fn.params[0].ty.shape assert tuple(int(d) for d in in_shape) == (batch, input_size) - state_shape = fn.params[1].struct_info.shape + state_shape = fn.params[1].ty.shape assert tuple(int(d) for d in state_shape) == (batch, num_filters * memory_size) - out_shape = fn.ret_struct_info.shape + out_shape = fn.ret_ty.shape assert tuple(int(d) for d in out_shape) == (batch, num_units) @@ -12798,8 +12798,8 @@ def test_unidirectional_sequence_lstm_time_major(): ) fn = mod["main"] - assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, batch, input_size) - assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, num_units) + assert tuple(int(d) for d in fn.params[0].ty.shape) == (time, batch, input_size) + assert tuple(int(d) for d in fn.ret_ty.shape) == (time, batch, num_units) def test_unidirectional_sequence_lstm_rejects_projection(): @@ -13070,8 +13070,8 @@ def test_bidirectional_sequence_rnn_time_major(): ) fn = mod["main"] - assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, batch, input_size) - assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, num_units * 2) + assert tuple(int(d) for d in fn.params[0].ty.shape) == (time, batch, input_size) + assert tuple(int(d) for d in fn.ret_ty.shape) == (time, batch, num_units * 2) def test_bidirectional_sequence_rnn_rejects_aux_input(): @@ -13368,8 +13368,8 @@ def test_bidirectional_sequence_lstm_time_major(): ) fn = mod["main"] - assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, batch, input_size) - assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, num_units * 2) + assert tuple(int(d) for d in fn.params[0].ty.shape) == (time, batch, input_size) + assert tuple(int(d) for d in fn.ret_ty.shape) == (time, batch, num_units * 2) def test_bidirectional_sequence_lstm_rejects_aux_input(): @@ -13578,9 +13578,9 @@ def test_unidirectional_sequence_rnn_relu_activation(): fn = mod["main"] assert len(fn.params) == 1, "only the sequence input should be a graph input" - in_shape = fn.params[0].struct_info.shape + in_shape = fn.params[0].ty.shape assert tuple(int(d) for d in in_shape) == (batch, time, input_size) - out_shape = fn.ret_struct_info.shape + out_shape = fn.ret_ty.shape assert tuple(int(d) for d in out_shape) == (batch, time, num_units) @@ -13610,10 +13610,10 @@ def test_unidirectional_sequence_rnn_time_major(): fn = mod["main"] # Input to the graph is the raw time-major tensor [time, batch, input_size]. - in_shape = fn.params[0].struct_info.shape + in_shape = fn.params[0].ty.shape assert tuple(int(d) for d in in_shape) == (time, batch, input_size) # Output is always batch-major [batch, time, num_units]. - out_shape = fn.ret_struct_info.shape + out_shape = fn.ret_ty.shape assert tuple(int(d) for d in out_shape) == (batch, time, num_units) diff --git a/tests/python/relax/test_inline_functions.py b/tests/python/relax/test_inline_functions.py index e4efd077d972..b50bfca60994 100644 --- a/tests/python/relax/test_inline_functions.py +++ b/tests/python/relax/test_inline_functions.py @@ -350,7 +350,7 @@ def main(): @R.function(private=True) def subroutine() -> R.Tensor([], "int64"): R.func_attr({"relax.force_pure": True}) - cond = R.call_packed("dummy_function", sinfo_args=R.Tensor([], "bool")) + cond = R.call_packed("dummy_function", ty_args=R.Tensor([], "bool")) if cond: Out = Before.subroutine() else: @@ -375,7 +375,7 @@ def main(): @R.function(private=True) def subroutine_a() -> R.Tensor([], "int64"): R.func_attr({"relax.force_pure": True}) - cond = R.call_packed("dummy_function", sinfo_args=R.Tensor([], "bool")) + cond = R.call_packed("dummy_function", ty_args=R.Tensor([], "bool")) if cond: Out = Before.subroutine_b() else: @@ -386,7 +386,7 @@ def subroutine_a() -> R.Tensor([], "int64"): @R.function(private=True) def subroutine_b() -> R.Tensor([], "int64"): R.func_attr({"relax.force_pure": True}) - cond = R.call_packed("dummy_function", sinfo_args=R.Tensor([], "bool")) + cond = R.call_packed("dummy_function", ty_args=R.Tensor([], "bool")) if cond: Out = Before.subroutine_a() else: diff --git a/tests/python/relax/test_kill_after_last_use.py b/tests/python/relax/test_kill_after_last_use.py index 7da3b5fe53f6..c69263977f79 100644 --- a/tests/python/relax/test_kill_after_last_use.py +++ b/tests/python/relax/test_kill_after_last_use.py @@ -30,7 +30,7 @@ class Before: def main(x: R.Tensor([16, 32], "float32")): storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", "uint8") y = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32") - _dummy = R.call_packed("add_tensors", [x, y], sinfo_args=(R.Tuple,)) + _dummy = R.call_packed("add_tensors", [x, y], ty_args=(R.Tuple,)) z = R.add(x, y) return z @@ -41,7 +41,7 @@ def main(x: R.Tensor([16, 32], "float32")): storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", "uint8") y = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32") _ = R.memory.kill_storage(storage) - _dummy = R.call_packed("add_tensors", [x, y], sinfo_args=(R.Tuple,)) + _dummy = R.call_packed("add_tensors", [x, y], ty_args=(R.Tuple,)) z = R.add(x, y) _ = R.memory.kill_tensor(y) return z diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index 5823044d8595..953e744fb7fd 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -61,9 +61,9 @@ def test_op_correctness(): assert relax.op.logical_xor(x, y).op == Op.get("relax.logical_xor") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) binary_arith_ops = [ @@ -82,7 +82,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_binary_arith_infer_struct_info(binary_arith_op: Callable): +def test_binary_arith_infer_ty(binary_arith_op: Callable): bb = relax.BlockBuilder() vdevice0 = VDevice("llvm") vdevice1 = VDevice("cuda", 0) @@ -101,54 +101,52 @@ def test_binary_arith_infer_struct_info(binary_arith_op: Callable): y4 = relax.Var("y", R.Tensor((2, 3), "float32", vdevice0)) y5 = relax.Var("y", R.Tensor("float32", ndim=2, vdevice=vdevice0)) - _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, binary_arith_op(x1, y1), relax.TensorStructInfo((4, 3, 2, 3), "float32")) - _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, binary_arith_op(x3, y2), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, binary_arith_op(x4, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x4, y1), relax.TensorStructInfo(dtype="float32", ndim=4)) - _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=-1)) - _check_inference(bb, binary_arith_op(x5, y0), relax.TensorStructInfo(dtype="", ndim=-1)) + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x1, y0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x1, y1), relax.TensorType((4, 3, 2, 3), "float32")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x3, y2), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x4, y0), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y1), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x4, y2), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y3), relax.TensorType(dtype="float32", ndim=-1)) + _check_inference(bb, binary_arith_op(x5, y0), relax.TensorType(dtype="", ndim=-1)) _check_inference( bb, binary_arith_op(x6, y5), - relax.TensorStructInfo(dtype="float32", ndim=2, vdevice=vdevice0), + relax.TensorType(dtype="float32", ndim=2, vdevice=vdevice0), ) _check_inference( bb, binary_arith_op(x6, y2), - relax.TensorStructInfo(dtype="float32", ndim=2, vdevice=vdevice0), - ) - _check_inference( - bb, binary_arith_op(x7, y4), relax.TensorStructInfo((2, 3), "float32", vdevice0) + relax.TensorType(dtype="float32", ndim=2, vdevice=vdevice0), ) + _check_inference(bb, binary_arith_op(x7, y4), relax.TensorType((2, 3), "float32", vdevice0)) @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_infer_struct_info_binary_arith_prim_value_with_tensor(binary_arith_op: Callable): +def test_infer_ty_binary_arith_prim_value_with_tensor(binary_arith_op: Callable): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) y = relax.Var("y", R.Prim("float32")) - _check_inference(bb, binary_arith_op(x, y), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x, y), relax.TensorType((2, 3), "float32")) @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_infer_struct_info_binary_arith_prim_value_with_prim_value(binary_arith_op: Callable): +def test_infer_ty_binary_arith_prim_value_with_prim_value(binary_arith_op: Callable): bb = relax.BlockBuilder() x = relax.Var("x", R.Prim("float32")) y = relax.Var("y", R.Prim("float32")) - _check_inference(bb, binary_arith_op(x, y), relax.PrimStructInfo("float32")) + _check_inference(bb, binary_arith_op(x, y), relax.PrimType("float32")) @pytest.mark.parametrize("binary_arith_op,tir_arith_op", binary_arith_ops) @pytest.mark.xfail(reason="Not yet implemented") -def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value( +def test_infer_ty_binary_arith_known_prim_value_with_prim_value( binary_arith_op: Callable, tir_arith_op ): bb = relax.BlockBuilder() @@ -159,8 +157,8 @@ def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value( x = relax.Var("x", R.Prim(value=tir_x)) y = relax.Var("y", R.Prim(value=tir_y)) - _check_inference(bb, binary_arith_op(x, y), relax.PrimStructInfo(value=tir_x + tir_y)) - _check_inference(bb, binary_arith_op(y, x), relax.PrimStructInfo(value=tir_y + tir_x)) + _check_inference(bb, binary_arith_op(x, y), relax.PrimType(value=tir_x + tir_y)) + _check_inference(bb, binary_arith_op(y, x), relax.PrimType(value=tir_y + tir_x)) binary_cmp_ops = [ @@ -174,45 +172,43 @@ def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value( @pytest.mark.parametrize("binary_cmp_op", [row[0] for row in binary_cmp_ops]) -def test_binary_cmp_infer_struct_info(binary_cmp_op: Callable): +def test_binary_cmp_infer_ty(binary_cmp_op: Callable): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x = relax.Var("x", R.Tensor((2, 3), "float32")) y0 = relax.Var("y", R.Tensor((2, 3), "float32")) y1 = relax.Var("y", R.Tensor((2, 3), "int32")) y2 = relax.Var("y", R.Tensor((2, 3), "float32", vdev0)) - _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) - _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) - _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) - _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) - _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) - _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) - _check_inference(bb, binary_cmp_op(x, y2), relax.TensorStructInfo((2, 3), "bool", vdev0)) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorType((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorType((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorType((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorType((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorType((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorType((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y2), relax.TensorType((2, 3), "bool", vdev0)) @pytest.mark.parametrize("binary_cmp_op", [row[0] for row in binary_cmp_ops]) -def test_infer_struct_info_binary_cmp_prim_value_to_tensor(binary_cmp_op: Callable): +def test_infer_ty_binary_cmp_prim_value_to_tensor(binary_cmp_op: Callable): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) y = relax.Var("y", R.Prim("float32")) - _check_inference(bb, binary_cmp_op(x, y), relax.TensorStructInfo((2, 3), "bool")) - _check_inference(bb, binary_cmp_op(y, x), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y), relax.TensorType((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(y, x), relax.TensorType((2, 3), "bool")) @pytest.mark.parametrize("binary_cmp_op", [row[0] for row in binary_cmp_ops]) -def test_infer_struct_info_binary_cmp_prim_value_to_prim_value(binary_cmp_op: Callable): +def test_infer_ty_binary_cmp_prim_value_to_prim_value(binary_cmp_op: Callable): bb = relax.BlockBuilder() x = relax.Var("x", R.Prim("float32")) y = relax.Var("y", R.Prim("float32")) - _check_inference(bb, binary_cmp_op(x, y), relax.PrimStructInfo("bool")) - _check_inference(bb, binary_cmp_op(y, x), relax.PrimStructInfo("bool")) + _check_inference(bb, binary_cmp_op(x, y), relax.PrimType("bool")) + _check_inference(bb, binary_cmp_op(y, x), relax.PrimType("bool")) @pytest.mark.parametrize("binary_cmp_op,tir_cmp_op", binary_cmp_ops) @pytest.mark.xfail(reason="Not yet implemented") -def test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value( - binary_cmp_op: Callable, tir_cmp_op -): +def test_infer_ty_binary_cmp_known_prim_value_to_prim_value(binary_cmp_op: Callable, tir_cmp_op): bb = relax.BlockBuilder() tir_x = tirx.Var("tir_x", "float32") @@ -221,12 +217,12 @@ def test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value( x = relax.Var("x", R.Prim(value=tir_x)) y = relax.Var("y", R.Prim(value=tir_y)) - _check_inference(bb, binary_cmp_op(x, y), relax.PrimStructInfo(value=tir_cmp_op(tir_x, tir_y))) - _check_inference(bb, binary_cmp_op(y, x), relax.PrimStructInfo(value=tir_cmp_op(tir_y, tir_x))) + _check_inference(bb, binary_cmp_op(x, y), relax.PrimType(value=tir_cmp_op(tir_x, tir_y))) + _check_inference(bb, binary_cmp_op(y, x), relax.PrimType(value=tir_cmp_op(tir_y, tir_x))) @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable): +def test_binary_infer_ty_shape_symbolic(binary_arith_op: Callable): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -241,43 +237,43 @@ def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable): y2 = relax.Var("y", R.Tensor((4, k, m, 1), "float32")) y3 = relax.Var("y", R.Tensor("float32", ndim=2)) y4 = relax.Var("y", R.Tensor("float32", ndim=-1)) - _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, binary_arith_op(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, binary_arith_op(x1, y2), relax.TensorStructInfo((4, k, m, n), "float32")) - _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) - _check_inference(bb, binary_arith_op(x2, y3), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, binary_arith_op(x3, y3), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, binary_arith_op(x4, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) - _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x4, y4), relax.TensorStructInfo(dtype="float32", ndim=-1)) + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorType((m, n), "float32")) + _check_inference(bb, binary_arith_op(x0, y1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x1, y0), relax.TensorType((m, n), "float32")) + _check_inference(bb, binary_arith_op(x1, y2), relax.TensorType((4, k, m, n), "float32")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x2, y3), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x3, y3), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x4, y0), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y2), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x4, y3), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y4), relax.TensorType(dtype="float32", ndim=-1)) @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_binary_infer_struct_info_shape_var(binary_arith_op: Callable): +def test_binary_infer_ty_shape_var(binary_arith_op: Callable): bb = relax.BlockBuilder() - s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=4)) - s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) - s4 = relax.Var("s4", relax.ShapeStructInfo()) - x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - y0 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) - y1 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) - y2 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) - y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) - y4 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) - - _check_inference(bb, binary_arith_op(x, y0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, binary_arith_op(x, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) - _check_inference(bb, binary_arith_op(x, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x, y4), relax.TensorStructInfo(dtype="float32")) + s0 = relax.Var("s0", relax.ShapeType(ndim=2)) + s1 = relax.Var("s1", relax.ShapeType(ndim=2)) + s2 = relax.Var("s2", relax.ShapeType(ndim=4)) + s3 = relax.Var("s3", relax.ShapeType(ndim=1)) + s4 = relax.Var("s4", relax.ShapeType()) + x = relax.Var("x", relax.TensorType(s0, "float32")) + y0 = relax.Var("y", relax.TensorType(s0, "float32")) + y1 = relax.Var("y", relax.TensorType(s1, "float32")) + y2 = relax.Var("y", relax.TensorType(s2, "float32")) + y3 = relax.Var("y", relax.TensorType(s3, "float32")) + y4 = relax.Var("y", relax.TensorType(s4, "float32")) + + _check_inference(bb, binary_arith_op(x, y0), relax.TensorType(s0, "float32")) + _check_inference(bb, binary_arith_op(x, y1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y2), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x, y3), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y4), relax.TensorType(dtype="float32")) @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op: Callable): +def test_binary_arith_infer_ty_more_input_dtype(binary_arith_op: Callable): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) y0 = relax.Var("y", R.Tensor((2, 3), "float64")) @@ -286,13 +282,13 @@ def test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op: Callab x2 = relax.Var("x", R.Tensor((2, 3), "int64")) y2 = relax.Var("y", R.Tensor((2, 3), "int64")) - _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float64")) - _check_inference(bb, binary_arith_op(x1, y1), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo((2, 3), "int64")) + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorType((2, 3), "float64")) + _check_inference(bb, binary_arith_op(x1, y1), relax.TensorType((2, 3), "int8")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorType((2, 3), "int64")) @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_binary_infer_struct_info_shape_unequal_const_int(binary_arith_op: Callable): +def test_binary_infer_ty_shape_unequal_const_int(binary_arith_op: Callable): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) y0 = relax.Var("y", R.Tensor((2, 4), "float32")) @@ -301,7 +297,7 @@ def test_binary_infer_struct_info_shape_unequal_const_int(binary_arith_op: Calla @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable): +def test_binary_arith_infer_ty_dtype_mismatch(binary_arith_op: Callable): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) y = relax.Var("y", R.Tensor((2, 3), "int32")) @@ -310,7 +306,7 @@ def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_binary_arith_infer_struct_info_vdevice_mismatch(binary_arith_op: Callable): +def test_binary_arith_infer_ty_vdevice_mismatch(binary_arith_op: Callable): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32", VDevice("llvm"))) y = relax.Var("y", R.Tensor((2, 3), "int32", VDevice("cuda"))) @@ -331,10 +327,10 @@ def test_binary_wrong_input_number(binary_arith_op: Callable): @pytest.mark.parametrize("binary_arith_op", [row[0] for row in binary_arith_ops]) -def test_binary_infer_struct_info_wrong_input_type(binary_arith_op: Callable): +def test_binary_infer_ty_wrong_input_type(binary_arith_op: Callable): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) y = relax.Var("y", R.Tensor((2, 3), "float32")) with pytest.raises(TypeError): diff --git a/tests/python/relax/test_op_ccl.py b/tests/python/relax/test_op_ccl.py index 2e29803a49ae..f9ce60de1bd2 100644 --- a/tests/python/relax/test_op_ccl.py +++ b/tests/python/relax/test_op_ccl.py @@ -31,12 +31,12 @@ def test_op_correctness(): assert relax.op.ccl.allgather(x, 2).op == Op.get("relax.ccl.allgather") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_allreduce_infer_struct_info(): +def test_allreduce_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -45,50 +45,48 @@ def test_allreduce_infer_struct_info(): x4 = relax.Var("x", R.Tensor()) x5 = relax.Var("x", R.Tensor((3, 4))) - _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.ccl.allreduce(x1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference(bb, relax.op.ccl.allreduce(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.ccl.allreduce(x3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.ccl.allreduce(x4), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.ccl.allreduce(x5), relax.TensorStructInfo((3, 4), dtype="")) + _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.ccl.allreduce(x2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.ccl.allreduce(x3), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.ccl.allreduce(x4), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.ccl.allreduce(x5), relax.TensorType((3, 4), dtype="")) -def test_allreduce_infer_struct_info_shape_symbolic(): +def test_allreduce_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) - _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorType((4, n), "float32")) -def test_allreduce_infer_struct_info_shape_var(): +def test_allreduce_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) - _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorType(s1, "float32")) -def test_allreduce_infer_struct_info_more_input_dtype(): +def test_allreduce_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) x2 = relax.Var("x", R.Tensor((2, 3), "int64")) - _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorStructInfo((2, 3), "float64")) - _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.ccl.allreduce(x2), relax.TensorStructInfo((2, 3), "int64")) + _check_inference(bb, relax.op.ccl.allreduce(x0), relax.TensorType((2, 3), "float64")) + _check_inference(bb, relax.op.ccl.allreduce(x1), relax.TensorType((2, 3), "int8")) + _check_inference(bb, relax.op.ccl.allreduce(x2), relax.TensorType((2, 3), "int64")) -def test_allgather_infer_struct_info(): +def test_allgather_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -97,52 +95,48 @@ def test_allgather_infer_struct_info(): x4 = relax.Var("x", R.Tensor()) x5 = relax.Var("x", R.Tensor((3, 4))) - _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((4, 3), "float32")) - _check_inference( - bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference(bb, relax.op.ccl.allgather(x2, 2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.ccl.allgather(x3, 2), relax.TensorStructInfo((4, 3), dtype="")) - _check_inference(bb, relax.op.ccl.allgather(x4, 2), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.ccl.allgather(x5, 2), relax.TensorStructInfo((6, 4), dtype="")) + _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorType((4, 3), "float32")) + _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.ccl.allgather(x2, 2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.ccl.allgather(x3, 2), relax.TensorType((4, 3), dtype="")) + _check_inference(bb, relax.op.ccl.allgather(x4, 2), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.ccl.allgather(x5, 2), relax.TensorType((6, 4), dtype="")) -def test_allgather_infer_struct_info_shape_symbolic(): +def test_allgather_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) - _check_inference( - bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((m * 2, n), "float32") - ) - _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo((8, n), "float32")) + _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorType((m * 2, n), "float32")) + _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorType((8, n), "float32")) -def test_allgather_infer_struct_info_shape_var(): +def test_allgather_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) - _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorType(s1, "float32")) -def test_allgather_infer_struct_info_more_input_dtype(): +def test_allgather_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) x2 = relax.Var("x", R.Tensor((2, 3), "int64")) - _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((4, 3), "float64")) - _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo((4, 3), "int8")) - _check_inference(bb, relax.op.ccl.allgather(x2, 2), relax.TensorStructInfo((4, 3), "int64")) + _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorType((4, 3), "float64")) + _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorType((4, 3), "int8")) + _check_inference(bb, relax.op.ccl.allgather(x2, 2), relax.TensorType((4, 3), "int64")) -def test_broadcast_from_worker0_infer_struct_info(): +def test_broadcast_from_worker0_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -152,24 +146,22 @@ def test_broadcast_from_worker0_infer_struct_info(): x5 = relax.Var("x", R.Tensor((3, 4))) _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x0), relax.TensorStructInfo((2, 3), "float32") - ) - _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.ccl.broadcast_from_worker0(x0), relax.TensorType((2, 3), "float32") ) _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x2), relax.TensorStructInfo(dtype="float32") + bb, relax.op.ccl.broadcast_from_worker0(x1), relax.TensorType(dtype="float32", ndim=3) ) + _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x2), relax.TensorType(dtype="float32")) _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x3), relax.TensorStructInfo((2, 3), dtype="") + bb, relax.op.ccl.broadcast_from_worker0(x3), relax.TensorType((2, 3), dtype="") ) - _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x4), relax.TensorType(dtype="")) _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x5), relax.TensorStructInfo((3, 4), dtype="") + bb, relax.op.ccl.broadcast_from_worker0(x5), relax.TensorType((3, 4), dtype="") ) -def test_broadcast_from_worker0_infer_struct_info_shape_symbolic(): +def test_broadcast_from_worker0_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -177,59 +169,51 @@ def test_broadcast_from_worker0_infer_struct_info_shape_symbolic(): x1 = relax.Var("x", R.Tensor((4, n), "float32")) _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x0), relax.TensorStructInfo((m, n), "float32") + bb, relax.op.ccl.broadcast_from_worker0(x0), relax.TensorType((m, n), "float32") ) _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x1), relax.TensorStructInfo((4, n), "float32") + bb, relax.op.ccl.broadcast_from_worker0(x1), relax.TensorType((4, n), "float32") ) -def test_broadcast_from_worker0_infer_struct_info_shape_var(): +def test_broadcast_from_worker0_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) - _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x0), relax.TensorStructInfo(s0, "float32") - ) - _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x1), relax.TensorStructInfo(s1, "float32") - ) + _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x0), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x1), relax.TensorType(s1, "float32")) -def test_broadcast_from_worker0_infer_struct_info_more_input_dtype(): +def test_broadcast_from_worker0_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) x2 = relax.Var("x", R.Tensor((2, 3), "int64")) _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x0), relax.TensorStructInfo((2, 3), "float64") - ) - _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x1), relax.TensorStructInfo((2, 3), "int8") - ) - _check_inference( - bb, relax.op.ccl.broadcast_from_worker0(x2), relax.TensorStructInfo((2, 3), "int64") + bb, relax.op.ccl.broadcast_from_worker0(x0), relax.TensorType((2, 3), "float64") ) + _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x1), relax.TensorType((2, 3), "int8")) + _check_inference(bb, relax.op.ccl.broadcast_from_worker0(x2), relax.TensorType((2, 3), "int64")) -def test_scatter_from_worker0_infer_struct_info(): +def test_scatter_from_worker0_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor((3, 4, 5))) _check_inference( - bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 3), "float32") + bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorType((1, 3), "float32") ) _check_inference( - bb, relax.op.ccl.scatter_from_worker0(x1, 3), relax.TensorStructInfo((1, 4, 5), dtype="") + bb, relax.op.ccl.scatter_from_worker0(x1, 3), relax.TensorType((1, 4, 5), dtype="") ) -def test_scatter_from_worker0_infer_struct_info_shape_symbolic(): +def test_scatter_from_worker0_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -239,37 +223,35 @@ def test_scatter_from_worker0_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.ccl.scatter_from_worker0(x0, 2), - relax.TensorStructInfo((tirx.div(m, 2), n), "float32"), + relax.TensorType((tirx.div(m, 2), n), "float32"), ) _check_inference( - bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorStructInfo((2, n), "float32") + bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorType((2, n), "float32") ) -def test_scatter_from_worker0_infer_struct_info_shape_var(): +def test_scatter_from_worker0_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 4, 8))) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 4, 8))) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) _check_inference( - bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 4, 8), "float32") + bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorType((1, 4, 8), "float32") ) -def test_scatter_from_worker0_infer_struct_info_more_input_dtype(): +def test_scatter_from_worker0_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) x2 = relax.Var("x", R.Tensor((2, 3), "int64")) _check_inference( - bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 3), "float64") - ) - _check_inference( - bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorStructInfo((1, 3), "int8") + bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorType((1, 3), "float64") ) + _check_inference(bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorType((1, 3), "int8")) _check_inference( - bb, relax.op.ccl.scatter_from_worker0(x2, 2), relax.TensorStructInfo((1, 3), "int64") + bb, relax.op.ccl.scatter_from_worker0(x2, 2), relax.TensorType((1, 3), "int64") ) diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py index b5b75a719888..15a51022d6ee 100644 --- a/tests/python/relax/test_op_create.py +++ b/tests/python/relax/test_op_create.py @@ -38,12 +38,12 @@ def test_op_correctness(): assert relax.op.triu(x).op == Op.get("relax.triu") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_full_infer_struct_info(): +def test_full_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") v0 = relax.Var("v", R.Tensor((), "float32")) @@ -52,144 +52,118 @@ def test_full_infer_struct_info(): v3 = relax.Var("v", R.Tensor(ndim=0)) v4 = relax.Var("v", R.Tensor((), "float32", vdev0)) s0 = relax.ShapeExpr((2, 3)) - s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s3 = relax.Var("s", relax.ShapeStructInfo()) - - _check_inference( - bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.full(s0, v0, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full(s0, v0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full(s0, v4), relax.TensorStructInfo((2, 3), "float32", vdev0)) - _check_inference(bb, relax.op.full(s1, v0, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.full(s1, v0), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full(s2, v0, "float16"), relax.TensorStructInfo(s2, "float16")) - _check_inference(bb, relax.op.full(s2, v0), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full(s3, v0, "float16"), relax.TensorStructInfo(s3, "float16")) - _check_inference(bb, relax.op.full(s3, v0), relax.TensorStructInfo(s3, "float32")) - _check_inference( - bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.full(s0, v1, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full(s0, v1), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full(s1, v1, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.full(s1, v1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full(s2, v1, "float16"), relax.TensorStructInfo(s2, "float16")) - _check_inference(bb, relax.op.full(s2, v1), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full(s3, v1, "float16"), relax.TensorStructInfo(s3, "float16")) - _check_inference(bb, relax.op.full(s3, v1), relax.TensorStructInfo(s3, "float32")) - _check_inference( - bb, relax.op.full((2, 3), v2, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference( - bb, relax.op.full(s0, v2, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full(s0, v2), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full(s1, v2, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.full(s1, v2), relax.TensorStructInfo(s1, dtype="")) - _check_inference(bb, relax.op.full(s2, v2, "float16"), relax.TensorStructInfo(s2, "float16")) - _check_inference(bb, relax.op.full(s2, v2), relax.TensorStructInfo(s2, dtype="")) - _check_inference(bb, relax.op.full(s3, v2, "float16"), relax.TensorStructInfo(s3, "float16")) - _check_inference(bb, relax.op.full(s3, v2), relax.TensorStructInfo(s3, dtype="")) - _check_inference( - bb, relax.op.full((2, 3), v3, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full((2, 3), v3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference( - bb, relax.op.full(s0, v3, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full(s0, v3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full(s1, v3, "float16"), relax.TensorStructInfo(s1, "float16")) + s1 = relax.Var("s", relax.ShapeType((2, 3))) + s2 = relax.Var("s", relax.ShapeType(ndim=2)) + s3 = relax.Var("s", relax.ShapeType()) + + _check_inference(bb, relax.op.full((2, 3), v0, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full(s0, v0, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full(s0, v0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full(s0, v4), relax.TensorType((2, 3), "float32", vdev0)) + _check_inference(bb, relax.op.full(s1, v0, "float16"), relax.TensorType(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v0), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.full(s2, v0, "float16"), relax.TensorType(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v0), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.full(s3, v0, "float16"), relax.TensorType(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v0), relax.TensorType(s3, "float32")) + _check_inference(bb, relax.op.full((2, 3), v1, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full(s0, v1, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full(s0, v1), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v1, "float16"), relax.TensorType(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.full(s2, v1, "float16"), relax.TensorType(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v1), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.full(s3, v1, "float16"), relax.TensorType(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v1), relax.TensorType(s3, "float32")) + _check_inference(bb, relax.op.full((2, 3), v2, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s0, v2, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full(s0, v2), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s1, v2, "float16"), relax.TensorType(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v2), relax.TensorType(s1, dtype="")) + _check_inference(bb, relax.op.full(s2, v2, "float16"), relax.TensorType(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v2), relax.TensorType(s2, dtype="")) + _check_inference(bb, relax.op.full(s3, v2, "float16"), relax.TensorType(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v2), relax.TensorType(s3, dtype="")) + _check_inference(bb, relax.op.full((2, 3), v3, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full((2, 3), v3), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s0, v3, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full(s0, v3), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s1, v3, "float16"), relax.TensorType(s1, "float16")) _check_inference( bb, relax.op.full( s1, v3, ), - relax.TensorStructInfo(s1, dtype=""), + relax.TensorType(s1, dtype=""), ) - _check_inference(bb, relax.op.full(s2, v3, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v3, "float16"), relax.TensorType(s2, "float16")) _check_inference( bb, relax.op.full( s2, v3, ), - relax.TensorStructInfo(s2, dtype=""), + relax.TensorType(s2, dtype=""), ) - _check_inference(bb, relax.op.full(s3, v3, "float16"), relax.TensorStructInfo(s3, "float16")) - _check_inference(bb, relax.op.full(s3, v3), relax.TensorStructInfo(s3, dtype="")) + _check_inference(bb, relax.op.full(s3, v3, "float16"), relax.TensorType(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v3), relax.TensorType(s3, dtype="")) -def test_full_infer_struct_info_shape_symbolic(): +def test_full_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") v = relax.Var("v", R.Tensor((), "float32")) s0 = relax.ShapeExpr((a, 3)) - s1 = relax.Var("s", relax.ShapeStructInfo((a, 3))) + s1 = relax.Var("s", relax.ShapeType((a, 3))) - _check_inference( - bb, relax.op.full((a, 3), v, "float16"), relax.TensorStructInfo((a, 3), "float16") - ) - _check_inference(bb, relax.op.full((a, 3), v), relax.TensorStructInfo((a, 3), "float32")) - _check_inference(bb, relax.op.full(s0, v, "float16"), relax.TensorStructInfo((a, 3), "float16")) - _check_inference(bb, relax.op.full(s0, v), relax.TensorStructInfo((a, 3), "float32")) - _check_inference(bb, relax.op.full(s1, v, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.full(s1, v), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full((a, 3), v, "float16"), relax.TensorType((a, 3), "float16")) + _check_inference(bb, relax.op.full((a, 3), v), relax.TensorType((a, 3), "float32")) + _check_inference(bb, relax.op.full(s0, v, "float16"), relax.TensorType((a, 3), "float16")) + _check_inference(bb, relax.op.full(s0, v), relax.TensorType((a, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v, "float16"), relax.TensorType(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v), relax.TensorType(s1, "float32")) -def test_full_infer_struct_info_shape_var(): +def test_full_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(())) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) - v0 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) - v1 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(())) + s1 = relax.Var("s", relax.ShapeType(ndim=0)) + v0 = relax.Var("v", relax.TensorType(s0, "float32")) + v1 = relax.Var("v", relax.TensorType(s1, "float32")) - _check_inference( - bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference( - bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) + _check_inference(bb, relax.op.full((2, 3), v0, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full((2, 3), v1, "float16"), relax.TensorType((2, 3), "float16")) -def test_full_infer_struct_info_more_input_dtype(): +def test_full_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() v0 = relax.Var("v", R.Tensor((), "float16")) v1 = relax.Var("v", R.Tensor((), "int8")) v2 = relax.Var("v", R.Tensor((), "int32")) - _check_inference( - bb, relax.op.full((2, 3), v0, "float32"), relax.TensorStructInfo((2, 3), "float32") - ) - _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float16")) - _check_inference( - bb, relax.op.full((2, 3), v1, "int32"), relax.TensorStructInfo((2, 3), "int32") - ) - _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.full((2, 3), v2, "int8"), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), "int32")) + _check_inference(bb, relax.op.full((2, 3), v0, "float32"), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full((2, 3), v1, "int32"), relax.TensorType((2, 3), "int32")) + _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorType((2, 3), "int8")) + _check_inference(bb, relax.op.full((2, 3), v2, "int8"), relax.TensorType((2, 3), "int8")) + _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorType((2, 3), "int32")) -def test_full_infer_struct_info_fill_value_not_scalar_tensor(): +def test_full_infer_ty_fill_value_not_scalar_tensor(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((1,))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType((1,))) + s1 = relax.Var("s", relax.ShapeType(ndim=1)) + s2 = relax.Var("s", relax.ShapeType()) v0 = relax.Var("v", R.Tensor((1,), "float32")) v1 = relax.Var("v", R.Tensor("float32", ndim=1)) v2 = relax.Var("v", R.Tensor("float32")) - v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) - v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) - v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + v3 = relax.Var("v", relax.TensorType(s0, "float32")) + v4 = relax.Var("v", relax.TensorType(s1, "float32")) + v5 = relax.Var("v", relax.TensorType(s2, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.full((2, 3), v0)) @@ -215,11 +189,11 @@ def test_full_shape_not_tuple(): relax.op.full(m, v) -def test_full_infer_struct_info_wrong_input_type(): +def test_full_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() v0 = relax.Var("v", R.Tensor((), "float32")) - v1 = relax.Var("v", relax.ShapeStructInfo(())) - v2 = relax.Var("v", relax.FuncStructInfo([], R.Tensor((), "float32"))) + v1 = relax.Var("v", relax.ShapeType(())) + v2 = relax.Var("v", relax.FuncType([], R.Tensor((), "float32"))) s = relax.Var("s", R.Tensor((2, 3))) with pytest.raises(TypeError): @@ -230,7 +204,7 @@ def test_full_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.full((2, 3), v2)) -def test_full_like_infer_struct_info(): +def test_full_like_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -243,53 +217,45 @@ def test_full_like_infer_struct_info(): v2 = relax.Var("v", R.Tensor(())) v3 = relax.Var("v", R.Tensor(ndim=0)) - _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full_like(x0, v3), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(dtype="float32", ndim=2) - ) + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v3), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.full_like(x1, v2), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.full_like(x1, v3), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v3), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v2), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v3), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x4, v0), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v2), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v3), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x5, v0), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v2), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v3), relax.TensorType(dtype="")) _check_inference( - bb, relax.op.full_like(x1, v3), relax.TensorStructInfo(dtype="float32", ndim=2) + bb, relax.op.full_like(x0, v0, dtype="float16"), relax.TensorType((2, 3), "float16") ) - _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.full_like(x2, v3), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full_like(x3, v2), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full_like(x3, v3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full_like(x4, v0), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.full_like(x4, v1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.full_like(x4, v2), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.full_like(x4, v3), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.full_like(x5, v0), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.full_like(x5, v1), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.full_like(x5, v2), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.full_like(x5, v3), relax.TensorStructInfo(dtype="")) _check_inference( - bb, relax.op.full_like(x0, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + bb, relax.op.full_like(x0, v2, dtype="float16"), relax.TensorType((2, 3), "float16") ) _check_inference( - bb, relax.op.full_like(x0, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + bb, relax.op.full_like(x3, v0, dtype="float16"), relax.TensorType((2, 3), "float16") ) _check_inference( - bb, relax.op.full_like(x3, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference( - bb, relax.op.full_like(x3, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + bb, relax.op.full_like(x3, v2, dtype="float16"), relax.TensorType((2, 3), "float16") ) -def test_full_like_infer_struct_info_shape_symbolic(): +def test_full_like_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -297,69 +263,67 @@ def test_full_like_infer_struct_info_shape_symbolic(): x1 = relax.Var("x", R.Tensor((m, n))) v = relax.Var("v", R.Tensor((), "float16")) - _check_inference(bb, relax.op.full_like(x0, v), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.full_like(x1, v), relax.TensorStructInfo((m, n), dtype="")) + _check_inference(bb, relax.op.full_like(x0, v), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.full_like(x1, v), relax.TensorType((m, n), dtype="")) -def test_full_like_infer_struct_info_shape_var(): +def test_full_like_infer_ty_shape_var(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 3))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) x3 = relax.Var("x", R.Tensor((2, 3), "float32")) x4 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0)) - sv0 = relax.Var("sv", relax.ShapeStructInfo(())) - sv1 = relax.Var("sv", relax.ShapeStructInfo(ndim=0)) - v0 = relax.Var("v", relax.TensorStructInfo(sv0, "float16")) - v1 = relax.Var("v", relax.TensorStructInfo(sv1, "float16")) + sv0 = relax.Var("sv", relax.ShapeType(())) + sv1 = relax.Var("sv", relax.ShapeType(ndim=0)) + v0 = relax.Var("v", relax.TensorType(sv0, "float16")) + v1 = relax.Var("v", relax.TensorType(sv1, "float16")) v2 = relax.Var("v", R.Tensor((), "float16")) - v3 = relax.Var("v", relax.TensorStructInfo(sv1, "float16", vdev0)) - - _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.full_like(x4, v3), relax.TensorStructInfo((2, 3), "float32", vdev0) - ) + v3 = relax.Var("v", relax.TensorType(sv1, "float16", vdev0)) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.full_like(x1, v2), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x4, v3), relax.TensorType((2, 3), "float32", vdev0)) -def test_full_like_infer_struct_info_more_input_dtype(): +def test_full_like_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float16")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) v0 = relax.Var("v", R.Tensor((), "int32")) v1 = relax.Var("v", R.Tensor((), "float64")) - _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float16")) - _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float16")) - _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorType((2, 3), "int8")) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorType((2, 3), "int8")) -def test_full_like_infer_struct_info_fill_value_not_scalar_tensor(): +def test_full_like_infer_ty_fill_value_not_scalar_tensor(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) - s0 = relax.Var("s", relax.ShapeStructInfo((1,))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType((1,))) + s1 = relax.Var("s", relax.ShapeType(ndim=1)) + s2 = relax.Var("s", relax.ShapeType()) v0 = relax.Var("v", R.Tensor((1,), "float32")) v1 = relax.Var("v", R.Tensor("float32", ndim=1)) v2 = relax.Var("v", R.Tensor("float32")) - v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) - v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) - v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + v3 = relax.Var("v", relax.TensorType(s0, "float32")) + v4 = relax.Var("v", relax.TensorType(s1, "float32")) + v5 = relax.Var("v", relax.TensorType(s2, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.full_like(x, v0)) @@ -375,13 +339,13 @@ def test_full_like_infer_struct_info_fill_value_not_scalar_tensor(): bb.normalize(relax.op.full_like(x, v5)) -def test_full_like_infer_struct_info_wrong_input_type(): +def test_full_like_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3))) v0 = relax.Var("v", R.Tensor(())) - v1 = relax.Var("v", relax.ShapeStructInfo(())) + v1 = relax.Var("v", relax.ShapeType(())) with pytest.raises(TypeError): bb.normalize(relax.op.full_like(x0, v0)) @@ -391,59 +355,51 @@ def test_full_like_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.full_like(x2, v1)) -def test_ones_zeros_infer_struct_info(): +def test_ones_zeros_infer_ty(): bb = relax.BlockBuilder() s0 = relax.ShapeExpr((2, 3)) - s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s3 = relax.Var("s", relax.ShapeStructInfo()) + s1 = relax.Var("s", relax.ShapeType((2, 3))) + s2 = relax.Var("s", relax.ShapeType(ndim=2)) + s3 = relax.Var("s", relax.ShapeType()) - _check_inference( - bb, relax.op.ones((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") - ) - _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.ones(s2, "float32"), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.ones(s3, "float32"), relax.TensorStructInfo(s3, "float32")) - _check_inference( - bb, relax.op.zeros((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") - ) - _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.zeros(s2, "float32"), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.zeros(s3, "float32"), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.ones((2, 3), "float32"), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.ones(s2, "float32"), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.ones(s3, "float32"), relax.TensorType(s3, "float32")) + _check_inference(bb, relax.op.zeros((2, 3), "float32"), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.zeros(s2, "float32"), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.zeros(s3, "float32"), relax.TensorType(s3, "float32")) -def test_ones_zeros_infer_struct_info_shape_symbolic(): +def test_ones_zeros_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") s0 = relax.ShapeExpr((m, n)) - s1 = relax.Var("s", relax.ShapeStructInfo((m, n))) + s1 = relax.Var("s", relax.ShapeType((m, n))) - _check_inference( - bb, relax.op.ones((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") - ) - _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) - _check_inference( - bb, relax.op.zeros((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") - ) - _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.ones((m, n), "float32"), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.zeros((m, n), "float32"), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorType(s1, "float32")) -def test_ones_zeros_infer_struct_info_more_input_dtype(): +def test_ones_zeros_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() s0 = relax.ShapeExpr((2, 3)) - s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s3 = relax.Var("s", relax.ShapeStructInfo()) + s1 = relax.Var("s", relax.ShapeType((2, 3))) + s2 = relax.Var("s", relax.ShapeType(ndim=2)) + s3 = relax.Var("s", relax.ShapeType()) - _check_inference(bb, relax.op.ones(s0, "float16"), relax.TensorStructInfo((2, 3), "float16")) - _check_inference(bb, relax.op.ones(s1, "int8"), relax.TensorStructInfo(s1, "int8")) - _check_inference(bb, relax.op.zeros(s2, "int32"), relax.TensorStructInfo(s2, "int32")) - _check_inference(bb, relax.op.zeros(s3, "float64"), relax.TensorStructInfo(s3, "float64")) + _check_inference(bb, relax.op.ones(s0, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.ones(s1, "int8"), relax.TensorType(s1, "int8")) + _check_inference(bb, relax.op.zeros(s2, "int32"), relax.TensorType(s2, "int32")) + _check_inference(bb, relax.op.zeros(s3, "float64"), relax.TensorType(s3, "float64")) def test_ones_zeros_shape_not_tuple(): @@ -466,10 +422,10 @@ def test_ones_zeros_wrong_dtype(): relax.op.zeros((2, 3), "") -def test_ones_zeros_infer_struct_info_wrong_input_type(): +def test_ones_zeros_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() s0 = relax.Var("s", R.Tensor((2, 3))) - s1 = relax.Var("s", relax.FuncStructInfo([], R.Tensor((2, 3)))) + s1 = relax.Var("s", relax.FuncType([], R.Tensor((2, 3)))) with pytest.raises(TypeError): bb.normalize(relax.op.ones(s0, "float32")) @@ -477,7 +433,7 @@ def test_ones_zeros_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.zeros(s1, "float32")) -def test_ones_like_zeros_like_infer_struct_info(): +def test_ones_like_zeros_like_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -486,58 +442,58 @@ def test_ones_like_zeros_like_infer_struct_info(): x4 = relax.Var("x", R.Tensor(ndim=2)) x5 = relax.Var("x", R.Tensor()) - _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.ones_like(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.zeros_like(x3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.ones_like(x4), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.zeros_like(x5), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.ones_like(x0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.ones_like(x2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.zeros_like(x3), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.ones_like(x4), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.zeros_like(x5), relax.TensorType(dtype="")) _check_inference( - bb, relax.op.ones_like(x0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + bb, relax.op.ones_like(x0, dtype="float16"), relax.TensorType((2, 3), "float16") ) _check_inference( - bb, relax.op.zeros_like(x3, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + bb, relax.op.zeros_like(x3, dtype="float16"), relax.TensorType((2, 3), "float16") ) -def test_ones_like_zeros_like_infer_struct_info_shape_symbolic(): +def test_ones_like_zeros_like_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) - _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((m, n), dtype="")) + _check_inference(bb, relax.op.ones_like(x0), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorType((m, n), dtype="")) -def test_ones_like_zeros_like_infer_struct_info_shape_var(): +def test_ones_like_zeros_like_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 3))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.zeros_like(x2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.ones_like(x0), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.zeros_like(x2), relax.TensorType(s2, "float32")) -def test_ones_like_zeros_like_infer_struct_info_more_input_dtype(): +def test_ones_like_zeros_like_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) - _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float64")) - _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.ones_like(x0), relax.TensorType((2, 3), "float64")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorType((2, 3), "int8")) -def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): +def test_ones_like_zeros_like_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.ones_like(x0)) @@ -545,57 +501,57 @@ def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.zeros_like(x1)) -def test_eye_infer_struct_info(): +def test_eye_infer_ty(): bb = relax.BlockBuilder() - _check_inference(bb, relax.op.eye(3), relax.TensorStructInfo((3, 3), "float32")) - _check_inference(bb, relax.op.eye(2, 4), relax.TensorStructInfo((2, 4), "float32")) - _check_inference(bb, relax.op.eye(3, dtype="int64"), relax.TensorStructInfo((3, 3), "int64")) - _check_inference(bb, relax.op.eye(3, 5, k=1), relax.TensorStructInfo((3, 5), "float32")) - _check_inference(bb, relax.op.eye(3, 5, k=-2), relax.TensorStructInfo((3, 5), "float32")) + _check_inference(bb, relax.op.eye(3), relax.TensorType((3, 3), "float32")) + _check_inference(bb, relax.op.eye(2, 4), relax.TensorType((2, 4), "float32")) + _check_inference(bb, relax.op.eye(3, dtype="int64"), relax.TensorType((3, 3), "int64")) + _check_inference(bb, relax.op.eye(3, 5, k=1), relax.TensorType((3, 5), "float32")) + _check_inference(bb, relax.op.eye(3, 5, k=-2), relax.TensorType((3, 5), "float32")) -def test_eye_infer_struct_info_symbolic(): +def test_eye_infer_ty_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") m = tirx.Var("m", "int64") k = tirx.Var("k", "int64") - _check_inference(bb, relax.op.eye(n), relax.TensorStructInfo((n, n), "float32")) - _check_inference(bb, relax.op.eye(n, m), relax.TensorStructInfo((n, m), "float32")) - _check_inference(bb, relax.op.eye(n, k=k), relax.TensorStructInfo((n, n), "float32")) + _check_inference(bb, relax.op.eye(n), relax.TensorType((n, n), "float32")) + _check_inference(bb, relax.op.eye(n, m), relax.TensorType((n, m), "float32")) + _check_inference(bb, relax.op.eye(n, k=k), relax.TensorType((n, n), "float32")) -def test_eye_like_infer_struct_info(): +def test_eye_like_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4), "float32")) x1 = relax.Var("x", R.Tensor((2, 5), "int64")) x2 = relax.Var("x", R.Tensor((3, 3))) - _check_inference(bb, relax.op.eye_like(x0), relax.TensorStructInfo((3, 4), "float32")) - _check_inference(bb, relax.op.eye_like(x1), relax.TensorStructInfo((2, 5), "int64")) - _check_inference(bb, relax.op.eye_like(x2), relax.TensorStructInfo((3, 3), dtype="")) - _check_inference(bb, relax.op.eye_like(x0, k=1), relax.TensorStructInfo((3, 4), "float32")) + _check_inference(bb, relax.op.eye_like(x0), relax.TensorType((3, 4), "float32")) + _check_inference(bb, relax.op.eye_like(x1), relax.TensorType((2, 5), "int64")) + _check_inference(bb, relax.op.eye_like(x2), relax.TensorType((3, 3), dtype="")) + _check_inference(bb, relax.op.eye_like(x0, k=1), relax.TensorType((3, 4), "float32")) _check_inference( - bb, relax.op.eye_like(x1, dtype="float32"), relax.TensorStructInfo((2, 5), "float32") + bb, relax.op.eye_like(x1, dtype="float32"), relax.TensorType((2, 5), "float32") ) -def test_eye_like_infer_struct_info_symbolic(): +def test_eye_like_infer_ty_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") m = tirx.Var("m", "int64") x = relax.Var("x", R.Tensor((n, m), "float32")) k = tirx.Var("k", "int64") - _check_inference(bb, relax.op.eye_like(x), relax.TensorStructInfo((n, m), "float32")) - _check_inference(bb, relax.op.eye_like(x, k=k), relax.TensorStructInfo((n, m), "float32")) + _check_inference(bb, relax.op.eye_like(x), relax.TensorType((n, m), "float32")) + _check_inference(bb, relax.op.eye_like(x, k=k), relax.TensorType((n, m), "float32")) -def test_eye_like_infer_struct_info_wrong_input_type(): +def test_eye_like_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.eye_like(x0)) @@ -603,40 +559,38 @@ def test_eye_like_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.eye_like(x1)) -def test_arange_infer_struct_info(): +def test_arange_infer_ty(): bb = relax.BlockBuilder() - _check_inference(bb, relax.op.arange(10), relax.TensorStructInfo((10,), "int64")) - _check_inference(bb, relax.op.arange(1, 10), relax.TensorStructInfo((9,), "int64")) - _check_inference(bb, relax.op.arange(0, 10, 2), relax.TensorStructInfo((5,), "int64")) - _check_inference(bb, relax.op.arange(1, 10, 2), relax.TensorStructInfo((5,), "int64")) + _check_inference(bb, relax.op.arange(10), relax.TensorType((10,), "int64")) + _check_inference(bb, relax.op.arange(1, 10), relax.TensorType((9,), "int64")) + _check_inference(bb, relax.op.arange(0, 10, 2), relax.TensorType((5,), "int64")) + _check_inference(bb, relax.op.arange(1, 10, 2), relax.TensorType((5,), "int64")) - _check_inference(bb, relax.op.arange(10.0), relax.TensorStructInfo((10,), "float32")) - _check_inference(bb, relax.op.arange(1.0, 10), relax.TensorStructInfo((9,), "float32")) - _check_inference(bb, relax.op.arange(0, 20, 2.5), relax.TensorStructInfo((8,), "float32")) - _check_inference(bb, relax.op.arange(1, 10, 2.3), relax.TensorStructInfo((4,), "float32")) + _check_inference(bb, relax.op.arange(10.0), relax.TensorType((10,), "float32")) + _check_inference(bb, relax.op.arange(1.0, 10), relax.TensorType((9,), "float32")) + _check_inference(bb, relax.op.arange(0, 20, 2.5), relax.TensorType((8,), "float32")) + _check_inference(bb, relax.op.arange(1, 10, 2.3), relax.TensorType((4,), "float32")) -def test_arange_infer_struct_info_shape_var(): +def test_arange_infer_ty_shape_var(): bb = relax.BlockBuilder() start = tirx.Var("start", "int64") stop = tirx.Var("stop", "int64") step = tirx.Var("step", "int64") - _check_inference(bb, relax.op.arange(stop), relax.TensorStructInfo((stop,), "int64")) - _check_inference(bb, relax.op.arange(1, stop), relax.TensorStructInfo((stop - 1,), "int64")) - _check_inference( - bb, relax.op.arange(start, stop), relax.TensorStructInfo((stop - start,), "int64") - ) + _check_inference(bb, relax.op.arange(stop), relax.TensorType((stop,), "int64")) + _check_inference(bb, relax.op.arange(1, stop), relax.TensorType((stop - 1,), "int64")) + _check_inference(bb, relax.op.arange(start, stop), relax.TensorType((stop - start,), "int64")) _check_inference( bb, relax.op.arange(start, stop, 2), - relax.TensorStructInfo(((stop + 1 - start) // 2,), "int64"), + relax.TensorType(((stop + 1 - start) // 2,), "int64"), ) _check_inference( bb, relax.op.arange(start, stop, step), - relax.TensorStructInfo(((stop + step - start - 1) // step,), "int64"), + relax.TensorType(((stop + step - start - 1) // step,), "int64"), ) start = tirx.Var("start", "float32") @@ -646,31 +600,31 @@ def test_arange_infer_struct_info_shape_var(): _check_inference( bb, relax.op.arange(stop), - relax.TensorStructInfo((T.cast(T.ceil(stop), "int64"),), "float32"), + relax.TensorType((T.cast(T.ceil(stop), "int64"),), "float32"), ) _check_inference( bb, relax.op.arange(1, stop), - relax.TensorStructInfo((T.cast(T.ceil(stop - 1.0), "int64"),), "float32"), + relax.TensorType((T.cast(T.ceil(stop - 1.0), "int64"),), "float32"), ) _check_inference( bb, relax.op.arange(start, stop), - relax.TensorStructInfo((T.cast(T.ceil(stop - start), "int64"),), "float32"), + relax.TensorType((T.cast(T.ceil(stop - start), "int64"),), "float32"), ) _check_inference( bb, relax.op.arange(start, stop, 2), - relax.TensorStructInfo((T.cast(T.ceil((stop - start) / 2), "int64"),), "float32"), + relax.TensorType((T.cast(T.ceil((stop - start) / 2), "int64"),), "float32"), ) _check_inference( bb, relax.op.arange(start, stop, step), - relax.TensorStructInfo((T.cast(T.ceil((stop - start) / step), "int64"),), "float32"), + relax.TensorType((T.cast(T.ceil((stop - start) / step), "int64"),), "float32"), ) -def test_tril_triu_infer_struct_info(): +def test_tril_triu_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -681,18 +635,18 @@ def test_tril_triu_infer_struct_info(): x5 = relax.Var("x", R.Tensor()) x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) - _check_inference(bb, relax.op.tril(x0, k=1), relax.TensorStructInfo((2, 3, 4), "float32")) - _check_inference(bb, relax.op.triu(x0, k=0), relax.TensorStructInfo((2, 3, 4), "float32")) - _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((2, 3, 4), "float32")) - _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.triu(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) - _check_inference(bb, relax.op.tril(x4), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.triu(x5), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.tril(x6), relax.TensorStructInfo((2, 3, 4), "float32", vdev0)) + _check_inference(bb, relax.op.tril(x0, k=1), relax.TensorType((2, 3, 4), "float32")) + _check_inference(bb, relax.op.triu(x0, k=0), relax.TensorType((2, 3, 4), "float32")) + _check_inference(bb, relax.op.tril(x0), relax.TensorType((2, 3, 4), "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.tril(x2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.triu(x3), relax.TensorType((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.tril(x4), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.triu(x5), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.tril(x6), relax.TensorType((2, 3, 4), "float32", vdev0)) -def test_tril_triu_infer_struct_info_shape_symbolic(): +def test_tril_triu_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") a = tirx.Var("a", "int64") @@ -704,56 +658,56 @@ def test_tril_triu_infer_struct_info_shape_symbolic(): x3 = relax.Var("x", R.Tensor((16, 32, 64))) # Dynamic tensor, static offset - _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c), "float32")) - _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c), dtype="")) - _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo((a, b, c), "float32", vdev0)) + _check_inference(bb, relax.op.tril(x0), relax.TensorType((a, b, c), "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorType((a, b, c), dtype="")) + _check_inference(bb, relax.op.tril(x2), relax.TensorType((a, b, c), "float32", vdev0)) # Static tensor, dynamic offset - _check_inference(bb, relax.op.tril(x3, a), relax.TensorStructInfo((16, 32, 64), dtype="")) + _check_inference(bb, relax.op.tril(x3, a), relax.TensorType((16, 32, 64), dtype="")) # Dynamic tensor, dynamic offset - _check_inference(bb, relax.op.tril(x0, a), relax.TensorStructInfo((a, b, c), "float32")) + _check_inference(bb, relax.op.tril(x0, a), relax.TensorType((a, b, c), "float32")) -def test_tril_triu_infer_struct_info_shape_var(): +def test_tril_triu_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=3)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.tril(x0), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.tril(x2), relax.TensorType(s2, "float32")) -def test_tril_triu_infer_struct_info_more_input_dtype(): +def test_tril_triu_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) - _check_inference(bb, relax.op.triu(x0), relax.TensorStructInfo((2, 3, 4), "float16")) - _check_inference(bb, relax.op.tril(x1), relax.TensorStructInfo((2, 3, 4), "int8")) - _check_inference(bb, relax.op.triu(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + _check_inference(bb, relax.op.triu(x0), relax.TensorType((2, 3, 4), "float16")) + _check_inference(bb, relax.op.tril(x1), relax.TensorType((2, 3, 4), "int8")) + _check_inference(bb, relax.op.triu(x2), relax.TensorType((2, 3, 4), "int32")) -def test_tril_triu_infer_struct_info_less_than_two_ndim(): +def test_tril_triu_infer_ty_less_than_two_ndim(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2,))) - s1 = relax.Var("s", relax.ShapeStructInfo(())) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) - s3 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + s0 = relax.Var("s", relax.ShapeType((2,))) + s1 = relax.Var("s", relax.ShapeType(())) + s2 = relax.Var("s", relax.ShapeType(ndim=1)) + s3 = relax.Var("s", relax.ShapeType(ndim=0)) x0 = relax.Var("x", R.Tensor((2,), "float32")) x1 = relax.Var("x", R.Tensor((), "float32")) x2 = relax.Var("x", R.Tensor("float32", ndim=1)) x3 = relax.Var("x", R.Tensor("float32", ndim=0)) - x4 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x6 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - x7 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x4 = relax.Var("x", relax.TensorType(s0, "float32")) + x5 = relax.Var("x", relax.TensorType(s1, "float32")) + x6 = relax.Var("x", relax.TensorType(s2, "float32")) + x7 = relax.Var("x", relax.TensorType(s3, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.tril(x0)) @@ -773,10 +727,10 @@ def test_tril_triu_infer_struct_info_less_than_two_ndim(): bb.normalize(relax.op.triu(x7)) -def test_tril_triu_infer_struct_info_wrong_input_type(): +def test_tril_triu_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.tril(x0)) diff --git a/tests/python/relax/test_op_datatype.py b/tests/python/relax/test_op_datatype.py index 9a0e809f8e62..553075f7c050 100644 --- a/tests/python/relax/test_op_datatype.py +++ b/tests/python/relax/test_op_datatype.py @@ -31,12 +31,12 @@ def test_op_correctness(): assert relax.op.wrap_param(c, "float32").op == Op.get("relax.wrap_param") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_astype_infer_struct_info(): +def test_astype_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -45,58 +45,54 @@ def test_astype_infer_struct_info(): x4 = relax.Var("x", R.Tensor(ndim=2)) x5 = relax.Var("x", R.Tensor()) - _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((2, 3), "float16")) - _check_inference( - bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2) - ) - _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(dtype="float16")) - _check_inference(bb, relax.op.astype(x3, "float16"), relax.TensorStructInfo((2, 3), "float16")) - _check_inference( - bb, relax.op.astype(x4, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2) - ) - _check_inference(bb, relax.op.astype(x5, "float16"), relax.TensorStructInfo(dtype="float16")) + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorType(dtype="float16", ndim=2)) + _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorType(dtype="float16")) + _check_inference(bb, relax.op.astype(x3, "float16"), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.astype(x4, "float16"), relax.TensorType(dtype="float16", ndim=2)) + _check_inference(bb, relax.op.astype(x5, "float16"), relax.TensorType(dtype="float16")) -def test_astype_infer_struct_info_shape_symbolic(): +def test_astype_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) - _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((m, n), "float16")) - _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo((m, n), "float16")) + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorType((m, n), "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorType((m, n), "float16")) -def test_astype_infer_struct_info_shape_var(): +def test_astype_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 3))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo(s0, "float16")) - _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorType(s0, "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorType(s1, "float16")) + _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorType(s2, "float16")) -def test_astype_infer_struct_info_more_input_dtype(): +def test_astype_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float16")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) x2 = relax.Var("x", R.Tensor((2, 3), "int32")) - _check_inference(bb, relax.op.astype(x0, "float32"), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.astype(x1, "int32"), relax.TensorStructInfo((2, 3), "int32")) - _check_inference(bb, relax.op.astype(x2, "int8"), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.astype(x0, "float32"), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.astype(x1, "int32"), relax.TensorType((2, 3), "int32")) + _check_inference(bb, relax.op.astype(x2, "int8"), relax.TensorType((2, 3), "int8")) -def test_astype_infer_struct_info_wrong_input_type(): +def test_astype_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.astype(x0, "float16")) @@ -104,16 +100,12 @@ def test_astype_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.astype(x1, "float16")) -def test_wrap_param_infer_struct_info(): +def test_wrap_param_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Constant(tvm.runtime.tensor(np.zeros([1, 2, 3], dtype="float16"))) x1 = relax.Constant(tvm.runtime.tensor(np.zeros([1, 2, 3], dtype="int8"))) - _check_inference( - bb, relax.op.wrap_param(x0, "float32"), relax.TensorStructInfo((1, 2, 3), "float32") - ) - _check_inference( - bb, relax.op.wrap_param(x1, "int32"), relax.TensorStructInfo((1, 2, 3), "int32") - ) + _check_inference(bb, relax.op.wrap_param(x0, "float32"), relax.TensorType((1, 2, 3), "float32")) + _check_inference(bb, relax.op.wrap_param(x1, "int32"), relax.TensorType((1, 2, 3), "int32")) if __name__ == "__main__": diff --git a/tests/python/relax/test_op_distributed.py b/tests/python/relax/test_op_distributed.py index 5b5330dfedb9..380659af4984 100644 --- a/tests/python/relax/test_op_distributed.py +++ b/tests/python/relax/test_op_distributed.py @@ -22,9 +22,9 @@ from tvm.script.parser import relax as R -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) def test_redistribute_R_to_S(): diff --git a/tests/python/relax/test_op_grad.py b/tests/python/relax/test_op_grad.py index 22bafa4d8fe5..8db16ff1c05f 100644 --- a/tests/python/relax/test_op_grad.py +++ b/tests/python/relax/test_op_grad.py @@ -52,9 +52,9 @@ def test_op_correctness(): assert relax.op.grad.end_checkpoint(x).args[0] == x -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) def test_start_checkpoint_input_not_var(): @@ -95,7 +95,7 @@ def test_end_checkpoint_input_not_var(): bb.normalize(relax.op.grad.end_checkpoint(relax.const(1, "float32"))) -def test_nll_loss_backward_infer_struct_info(): +def test_nll_loss_backward_infer_ty(): bb = relax.BlockBuilder() g = relax.Var("g", R.Tensor((3, 10, 10))) @@ -103,38 +103,38 @@ def test_nll_loss_backward_infer_struct_info(): y = relax.Var("y", R.Tensor((3, 10, 10), "int64")) w = relax.Var("w", R.Tensor((5,), "float32")) - _check_inference(bb, relax.op.grad.nll_loss_backward(g, x, y), x.struct_info) - _check_inference(bb, relax.op.grad.nll_loss_backward(g, x, y, w), x.struct_info) + _check_inference(bb, relax.op.grad.nll_loss_backward(g, x, y), x.ty) + _check_inference(bb, relax.op.grad.nll_loss_backward(g, x, y, w), x.ty) -def test_max_pool2d_backward_infer_struct_info(): +def test_max_pool2d_backward_infer_ty(): bb = relax.BlockBuilder() g = relax.Var("g", R.Tensor((3, 3, 8, 8), "float32")) x = relax.Var("x", R.Tensor((3, 2, 10, 10), "float32")) - _check_inference(bb, relax.op.grad.max_pool2d_backward(g, x, (2, 2)), x.struct_info) - _check_inference(bb, relax.op.grad.max_pool2d_backward(g, x, (3, 3)), x.struct_info) + _check_inference(bb, relax.op.grad.max_pool2d_backward(g, x, (2, 2)), x.ty) + _check_inference(bb, relax.op.grad.max_pool2d_backward(g, x, (3, 3)), x.ty) -def test_avg_pool2d_backward_infer_struct_info(): +def test_avg_pool2d_backward_infer_ty(): bb = relax.BlockBuilder() g = relax.Var("g", R.Tensor((3, 3, 8, 8), "float32")) x = relax.Var("x", R.Tensor((3, 2, 10, 10), "float32")) - _check_inference(bb, relax.op.grad.avg_pool2d_backward(g, x, (2, 2)), x.struct_info) - _check_inference(bb, relax.op.grad.avg_pool2d_backward(g, x, (3, 3)), x.struct_info) + _check_inference(bb, relax.op.grad.avg_pool2d_backward(g, x, (2, 2)), x.ty) + _check_inference(bb, relax.op.grad.avg_pool2d_backward(g, x, (3, 3)), x.ty) -def test_take_backward_infer_struct_info(): +def test_take_backward_infer_ty(): bb = relax.BlockBuilder() g = relax.Var("g", R.Tensor((3, 2, 5), "float32")) x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) indices = relax.Var("indices", R.Tensor((2,), "float32")) - _check_inference(bb, relax.op.grad.take_backward(g, x, indices, axis=1), x.struct_info) + _check_inference(bb, relax.op.grad.take_backward(g, x, indices, axis=1), x.ty) if __name__ == "__main__": diff --git a/tests/python/relax/test_op_gradient_numeric.py b/tests/python/relax/test_op_gradient_numeric.py index a4e556d1d4c5..1b223ecc996b 100644 --- a/tests/python/relax/test_op_gradient_numeric.py +++ b/tests/python/relax/test_op_gradient_numeric.py @@ -27,8 +27,8 @@ from tvm import relax from tvm.ir.op import Op from tvm.relax.expr import Call -from tvm.relax.struct_info import TensorStructInfo, TupleStructInfo from tvm.relax.transform import LegalizeOps +from tvm.relax.type import TensorType, TupleType from tvm.testing.utils import check_numerical_grads @@ -80,10 +80,10 @@ def relax_check_gradients( func_name = "main" # Helper functions - def _numpy_to_sinfo(data): + def _numpy_to_ty(data): if isinstance(data, list): - return relax.TupleStructInfo([_numpy_to_sinfo(d) for d in data]) - return relax.TensorStructInfo(data.shape, str(data.dtype)) + return relax.TupleType([_numpy_to_ty(d) for d in data]) + return relax.TensorType(data.shape, str(data.dtype)) def _numpy_to_tvm(data): if isinstance(data, list): @@ -97,19 +97,19 @@ def _tvm_to_numpy(data, ignore_idx=[]): return data.numpy() return data - def _gen_weights(out_sinfo): - if isinstance(out_sinfo, TupleStructInfo): - return [_gen_weights(sinfo) for sinfo in out_sinfo.fields] + def _gen_weights(out_ty): + if isinstance(out_ty, TupleType): + return [_gen_weights(ty) for ty in out_ty.fields] else: - assert isinstance(out_sinfo, TensorStructInfo) - return np.random.uniform(size=[int(i) for i in out_sinfo.shape]).astype(out_sinfo.dtype) + assert isinstance(out_ty, TensorType) + return np.random.uniform(size=[int(i) for i in out_ty.shape]).astype(out_ty.dtype) def _is_call_no_grad(expr): return isinstance(expr, Call) and expr.op == Op.get("relax.grad.no_grad") # Generate parameter relax Vars param_vars = [ - relax.Var("x_" + str(i), _numpy_to_sinfo(data)) for i, data in enumerate(inputs_numpy) + relax.Var("x_" + str(i), _numpy_to_ty(data)) for i, data in enumerate(inputs_numpy) ] # Generate the forward call @@ -135,8 +135,8 @@ def _is_call_no_grad(expr): # If the result is a tuple, weights will be a list, and the weighted result will be # sum(i * j for i, j in zip(weights, result)) # In the gradient process, weights is the output gradient, i.e. the gradient w.r.t. the result. - out_sinfo = forward_mod[func_name].body.body.struct_info - weights = _gen_weights(out_sinfo) + out_ty = forward_mod[func_name].body.body.ty + weights = _gen_weights(out_ty) # The inputs of the forward function are inputs_filtered below. def forward(*inputs): @@ -163,7 +163,7 @@ def forward(*inputs): op_grad_func = call.op.get_attr("FPrimalGradient") # The parameter Var for gradient - grad_var = relax.Var("grad", _numpy_to_sinfo(weights)) + grad_var = relax.Var("grad", _numpy_to_ty(weights)) # Gradient mod grad_bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_op_image.py b/tests/python/relax/test_op_image.py index 38c27592fd1e..6fc40a5d5554 100644 --- a/tests/python/relax/test_op_image.py +++ b/tests/python/relax/test_op_image.py @@ -37,12 +37,12 @@ def test_op_correctness(): assert relax.op.image.resize3d(y, (4, 8, 12)).op == Op.get("relax.image.resize3d") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_resize2d_infer_struct_info(): +def test_resize2d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) @@ -56,63 +56,59 @@ def test_resize2d_infer_struct_info(): x8 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0)) _check_inference( - bb, relax.op.image.resize2d(x0, (28, 28)), relax.TensorStructInfo((2, 3, 28, 28), "float32") + bb, relax.op.image.resize2d(x0, (28, 28)), relax.TensorType((2, 3, 28, 28), "float32") ) _check_inference( bb, relax.op.image.resize2d(x8, (28, 28)), - relax.TensorStructInfo((2, 3, 28, 28), "float32", vdev0), + relax.TensorType((2, 3, 28, 28), "float32", vdev0), ) _check_inference( bb, relax.op.image.resize2d(x0, size=28), - relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorType((2, 3, 28, 28), "float32"), ) _check_inference( bb, relax.op.image.resize2d(x0, size=(28, 30)), - relax.TensorStructInfo((2, 3, 28, 30), "float32"), + relax.TensorType((2, 3, 28, 30), "float32"), ) _check_inference( bb, relax.op.image.resize2d(x1, size=28, layout="NHWC"), - relax.TensorStructInfo((2, 28, 28, 3), "float32"), + relax.TensorType((2, 28, 28, 3), "float32"), ) _check_inference( bb, relax.op.image.resize2d(x0, size=28, out_dtype="float16"), - relax.TensorStructInfo((2, 3, 28, 28), "float16"), + relax.TensorType((2, 3, 28, 28), "float16"), ) _check_inference( bb, relax.op.image.resize2d(x2, size=28, layout="NCHW16c"), - relax.TensorStructInfo((2, 4, 28, 28, 16), "float32"), + relax.TensorType((2, 4, 28, 28, 16), "float32"), ) _check_inference( - bb, relax.op.image.resize2d(x3, size=28), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.image.resize2d(x3, size=28), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( bb, relax.op.image.resize2d(x4, size=28, layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( - bb, relax.op.image.resize2d(x5, size=28), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference( - bb, relax.op.image.resize2d(x6, size=28), relax.TensorStructInfo(dtype="", ndim=4) + bb, relax.op.image.resize2d(x5, size=28), relax.TensorType(dtype="float32", ndim=4) ) + _check_inference(bb, relax.op.image.resize2d(x6, size=28), relax.TensorType(dtype="", ndim=4)) _check_inference( bb, relax.op.image.resize2d(x6, size=28, out_dtype="float32"), - relax.TensorStructInfo(dtype="float32", ndim=4), - ) - _check_inference( - bb, relax.op.image.resize2d(x7, size=28), relax.TensorStructInfo(dtype="", ndim=4) + relax.TensorType(dtype="float32", ndim=4), ) + _check_inference(bb, relax.op.image.resize2d(x7, size=28), relax.TensorType(dtype="", ndim=4)) -def test_resize2d_infer_struct_info_shape_symbolic(): +def test_resize2d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -124,77 +120,75 @@ def test_resize2d_infer_struct_info_shape_symbolic(): x1 = relax.Var("x", R.Tensor((n, c, ih, iw, 16), "float32")) _check_inference( - bb, relax.op.image.resize2d(x0, size=oh), relax.TensorStructInfo((n, c, oh, oh), "float32") + bb, relax.op.image.resize2d(x0, size=oh), relax.TensorType((n, c, oh, oh), "float32") ) _check_inference( bb, relax.op.image.resize2d(x0, size=(oh, ow)), - relax.TensorStructInfo((n, c, oh, ow), "float32"), + relax.TensorType((n, c, oh, ow), "float32"), ) _check_inference( bb, relax.op.image.resize2d(x1, size=(oh, ow), layout="NCHW16c"), - relax.TensorStructInfo((n, c, oh, ow, 16), "float32"), + relax.TensorType((n, c, oh, ow, 16), "float32"), ) -def test_resize2d_infer_struct_info_shape_var(): +def test_resize2d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType(ndim=5)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) _check_inference( - bb, relax.op.image.resize2d(x0, size=32), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.image.resize2d(x0, size=32), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( bb, relax.op.image.resize2d(x1, size=32, layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.image.resize2d(x2, size=32, layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) -def test_resize2d_infer_struct_info_pool_size_var(): +def test_resize2d_infer_ty_pool_size_var(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) - s0 = relax.Var("s", relax.ShapeStructInfo((30, 30))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s0 = relax.Var("s", relax.ShapeType((30, 30))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) _check_inference( bb, relax.op.image.resize2d(x0, s0), - relax.TensorStructInfo(dtype="float32", ndim=4), - ) - _check_inference( - bb, relax.op.image.resize2d(x0, s1), relax.TensorStructInfo(dtype="float32", ndim=4) + relax.TensorType(dtype="float32", ndim=4), ) + _check_inference(bb, relax.op.image.resize2d(x0, s1), relax.TensorType(dtype="float32", ndim=4)) -def test_resize2d_infer_struct_info_more_input_dtype(): +def test_resize2d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) _check_inference( - bb, relax.op.image.resize2d(x0, size=28), relax.TensorStructInfo((2, 3, 28, 28), "float16") + bb, relax.op.image.resize2d(x0, size=28), relax.TensorType((2, 3, 28, 28), "float16") ) _check_inference( - bb, relax.op.image.resize2d(x1, size=28), relax.TensorStructInfo((2, 3, 28, 28), "int8") + bb, relax.op.image.resize2d(x1, size=28), relax.TensorType((2, 3, 28, 28), "int8") ) _check_inference( - bb, relax.op.image.resize2d(x2, size=28), relax.TensorStructInfo((2, 3, 28, 28), "int64") + bb, relax.op.image.resize2d(x2, size=28), relax.TensorType((2, 3, 28, 28), "int64") ) -def test_resize3d_infer_struct_info(): +def test_resize3d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32")) @@ -206,39 +200,39 @@ def test_resize3d_infer_struct_info(): _check_inference( bb, relax.op.image.resize3d(x0, (4, 8, 12)), - relax.TensorStructInfo((2, 3, 4, 8, 12), "float32"), + relax.TensorType((2, 3, 4, 8, 12), "float32"), ) _check_inference( bb, relax.op.image.resize3d(x4, (4, 8, 12)), - relax.TensorStructInfo((2, 3, 4, 8, 12), "float32", vdev0), + relax.TensorType((2, 3, 4, 8, 12), "float32", vdev0), ) _check_inference( bb, relax.op.image.resize3d(x0, 7), - relax.TensorStructInfo((2, 3, 7, 7, 7), "float32"), + relax.TensorType((2, 3, 7, 7, 7), "float32"), ) _check_inference( bb, relax.op.image.resize3d(x1, (4, 8, 12), layout="NDHWC"), - relax.TensorStructInfo((2, 4, 8, 12, 3), "float32"), + relax.TensorType((2, 4, 8, 12, 3), "float32"), ) _check_inference( bb, relax.op.image.resize3d(x2, (4, 8, 12), layout="NCDHW8c"), - relax.TensorStructInfo((2, 4, 4, 8, 12, 8), "float32"), + relax.TensorType((2, 4, 4, 8, 12, 8), "float32"), ) _check_inference( bb, relax.op.image.resize3d(x0, (4, 8, 12), out_dtype="float16"), - relax.TensorStructInfo((2, 3, 4, 8, 12), "float16"), + relax.TensorType((2, 3, 4, 8, 12), "float16"), ) _check_inference( - bb, relax.op.image.resize3d(x3, (4, 8, 12)), relax.TensorStructInfo(dtype="float32", ndim=5) + bb, relax.op.image.resize3d(x3, (4, 8, 12)), relax.TensorType(dtype="float32", ndim=5) ) -def test_resize3d_infer_struct_info_wrong_layout_string(): +def test_resize3d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32")) with pytest.raises(ValueError): @@ -262,12 +256,12 @@ def test_resize3d_wrong_size_ndim(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float16")) s0 = relax.ShapeExpr((3, 3)) - s1 = relax.Var("s", relax.ShapeStructInfo((30, 30, 30, 30))) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s3 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s4 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) - s5 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) - s6 = relax.Var("s", relax.ShapeStructInfo()) + s1 = relax.Var("s", relax.ShapeType((30, 30, 30, 30))) + s2 = relax.Var("s", relax.ShapeType(ndim=4)) + s3 = relax.Var("s", relax.ShapeType(ndim=2)) + s4 = relax.Var("s", relax.ShapeType(ndim=1)) + s5 = relax.Var("s", relax.ShapeType(ndim=0)) + s6 = relax.Var("s", relax.ShapeType()) with pytest.raises(ValueError): bb.normalize(relax.op.image.resize3d(x0, (3, 3))) @@ -287,10 +281,10 @@ def test_resize3d_wrong_size_ndim(): bb.normalize(relax.op.image.resize3d(x0, s6)) -def test_resize3d_infer_struct_info_wrong_input_type(): +def test_resize3d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 8, 16, 32))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 8, 16, 32), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 8, 16, 32))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 8, 16, 32), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32")) s0 = relax.Var("s", R.Tensor((3, 3, 3))) @@ -302,7 +296,7 @@ def test_resize3d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.image.resize3d(x2, s0)) -def test_resize2d_infer_struct_info_wrong_layout_string(): +def test_resize2d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) with pytest.raises(ValueError): @@ -326,11 +320,11 @@ def test_resize2d_wrong_pool_size_ndim(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) s0 = relax.ShapeExpr((3,)) - s1 = relax.Var("s", relax.ShapeStructInfo((30, 30, 30))) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) - s4 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) - s5 = relax.Var("s", relax.ShapeStructInfo()) + s1 = relax.Var("s", relax.ShapeType((30, 30, 30))) + s2 = relax.Var("s", relax.ShapeType(ndim=3)) + s3 = relax.Var("s", relax.ShapeType(ndim=1)) + s4 = relax.Var("s", relax.ShapeType(ndim=0)) + s5 = relax.Var("s", relax.ShapeType()) with pytest.raises(ValueError): bb.normalize(relax.op.image.resize2d(x0, (3, 3, 3))) @@ -348,10 +342,10 @@ def test_resize2d_wrong_pool_size_ndim(): bb.normalize(relax.op.image.resize2d(x0, s5)) -def test_resize2d_infer_struct_info_wrong_input_type(): +def test_resize2d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28, 28), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) s0 = relax.Var("s", R.Tensor((3, 3))) @@ -363,7 +357,7 @@ def test_resize2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.image.resize2d(x2, s0)) -def test_affine_grid_infer_struct_info(): +def test_affine_grid_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 2, 3), "float32")) @@ -375,41 +369,41 @@ def test_affine_grid_infer_struct_info(): _check_inference( bb, relax.op.image.affine_grid(x0, (16, 16)), - relax.TensorStructInfo((2, 2, 16, 16), "float32"), + relax.TensorType((2, 2, 16, 16), "float32"), ) _check_inference( bb, relax.op.image.affine_grid(x1, (16, 16)), - relax.TensorStructInfo((2, 2, 16, 16), "float32", vdev0), + relax.TensorType((2, 2, 16, 16), "float32", vdev0), ) _check_inference( bb, relax.op.image.affine_grid(x0, size=16), - relax.TensorStructInfo((2, 2, 16, 16), "float32"), + relax.TensorType((2, 2, 16, 16), "float32"), ) _check_inference( bb, relax.op.image.affine_grid(x0, size=(16, 20)), - relax.TensorStructInfo((2, 2, 16, 20), "float32"), + relax.TensorType((2, 2, 16, 20), "float32"), ) _check_inference( bb, relax.op.image.affine_grid(x2, size=(16, 16)), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.image.affine_grid(x3, size=(16, 16)), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.image.affine_grid(x4, size=(16, 16)), - relax.TensorStructInfo(dtype="", ndim=4), + relax.TensorType(dtype="", ndim=4), ) -def test_affine_grid_infer_struct_info_shape_symbolic(): +def test_affine_grid_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") oh = tirx.Var("oh", "int64") @@ -419,13 +413,13 @@ def test_affine_grid_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.image.affine_grid(x0, size=(oh, ow)), - relax.TensorStructInfo((n, 2, oh, ow), "float32"), + relax.TensorType((n, 2, oh, ow), "float32"), ) -def test_affine_grid_infer_struct_info_wrong_input_type(): +def test_affine_grid_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 2, 3))) + x0 = relax.Var("x", relax.ShapeType((2, 2, 3))) x1 = relax.Var("x", R.Tensor((2, 2, 3), "float32")) s0 = relax.Var("s", R.Tensor((3, 3))) diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py index 997f8e0eb9ab..49478476475d 100644 --- a/tests/python/relax/test_op_index.py +++ b/tests/python/relax/test_op_index.py @@ -37,12 +37,12 @@ def test_op_correctness(): assert relax.op.dynamic_strided_slice(x, x, x, x).op == Op.get("relax.dynamic_strided_slice") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_take_infer_struct_info(): +def test_take_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((4, 10), "float32")) @@ -66,157 +66,123 @@ def test_take_infer_struct_info(): idx7 = relax.Var("idx", R.Tensor(ndim=2)) idx8 = relax.Var("idx", R.Tensor((6,), "int64", vdev0)) - _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float32")) - _check_inference( - bb, relax.op.take(x6, idx8, axis=1), relax.TensorStructInfo((4, 6), "float32", vdev0) - ) - _check_inference( - bb, relax.op.take(x0, idx0, axis=-1), relax.TensorStructInfo((4, 6), "float32") - ) - _check_inference( - bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo((4, 6), dtype="")) - _check_inference(bb, relax.op.take(x4, idx0, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(x5, idx0, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(x4, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(x5, idx1, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float32")) - _check_inference( - bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.take(x3, idx2, axis=1), relax.TensorStructInfo((4, 6), dtype="")) - _check_inference(bb, relax.op.take(x4, idx2, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(x5, idx2, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.take(x0, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.take(x1, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.take(x2, idx3, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.take(x3, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(x4, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(x5, idx3, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.take(x0, idx4, axis=0), relax.TensorStructInfo((6, 4, 10), dtype="float32") - ) - _check_inference( - bb, relax.op.take(x0, idx4, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="float32") - ) - _check_inference( - bb, relax.op.take(x1, idx4, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference(bb, relax.op.take(x2, idx4, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.take(x3, idx4, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="") - ) - _check_inference(bb, relax.op.take(x4, idx4, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.take(x5, idx4, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.take(x0, idx5, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.take(x0, idx5, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.take(x1, idx5, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference(bb, relax.op.take(x2, idx5, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.take(x3, idx5, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.take(x4, idx5, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.take(x5, idx5, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.take(x0, idx6, axis=0), relax.TensorStructInfo((6, 4, 10), dtype="float32") - ) - _check_inference( - bb, relax.op.take(x0, idx6, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="float32") - ) - _check_inference( - bb, relax.op.take(x1, idx6, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference(bb, relax.op.take(x2, idx6, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.take(x3, idx6, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="") - ) - _check_inference(bb, relax.op.take(x4, idx6, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.take(x5, idx6, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.take(x0, idx7, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.take(x0, idx7, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.take(x1, idx7, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference(bb, relax.op.take(x2, idx7, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.take(x3, idx7, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.take(x4, idx7, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.take(x5, idx7, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((6,), "float32")) - _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.take(y2, idx0), relax.TensorStructInfo((6,), dtype="")) - _check_inference(bb, relax.op.take(y3, idx0), relax.TensorStructInfo(dtype="", ndim=1)) - _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.take(y2, idx1), relax.TensorStructInfo(dtype="", ndim=1)) - _check_inference(bb, relax.op.take(y3, idx1), relax.TensorStructInfo(dtype="", ndim=1)) - _check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((6,), "float32")) - _check_inference(bb, relax.op.take(y1, idx2), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.take(y2, idx2), relax.TensorStructInfo((6,), dtype="")) - _check_inference(bb, relax.op.take(y3, idx2), relax.TensorStructInfo(dtype="", ndim=1)) - _check_inference(bb, relax.op.take(y0, idx3), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.take(y1, idx3), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.take(y2, idx3), relax.TensorStructInfo(dtype="", ndim=1)) - _check_inference(bb, relax.op.take(y3, idx3), relax.TensorStructInfo(dtype="", ndim=1)) - _check_inference(bb, relax.op.take(y0, idx4), relax.TensorStructInfo((6, 4), "float32")) - _check_inference(bb, relax.op.take(y1, idx4), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.take(y2, idx4), relax.TensorStructInfo((6, 4), dtype="")) - _check_inference(bb, relax.op.take(y3, idx4), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(y0, idx5), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.take(y1, idx5), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.take(y2, idx5), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(y3, idx5), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(y0, idx6), relax.TensorStructInfo((6, 4), "float32")) - _check_inference(bb, relax.op.take(y1, idx6), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.take(y2, idx6), relax.TensorStructInfo((6, 4), dtype="")) - _check_inference(bb, relax.op.take(y3, idx6), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(y0, idx7), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.take(y1, idx7), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.take(y2, idx7), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.take(y3, idx7), relax.TensorStructInfo(dtype="", ndim=2)) - - -def test_take_infer_struct_info_scalar_tensor_index(): + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorType((4, 6), "float32")) + _check_inference( + bb, relax.op.take(x6, idx8, axis=1), relax.TensorType((4, 6), "float32", vdev0) + ) + _check_inference(bb, relax.op.take(x0, idx0, axis=-1), relax.TensorType((4, 6), "float32")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx0, axis=1), relax.TensorType((4, 6), dtype="")) + _check_inference(bb, relax.op.take(x4, idx0, axis=1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx0, axis=1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx1, axis=1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x4, idx1, axis=1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx1, axis=1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorType((4, 6), "float32")) + _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx2, axis=1), relax.TensorType((4, 6), dtype="")) + _check_inference(bb, relax.op.take(x4, idx2, axis=1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx2, axis=1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.take(x0, idx3, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x1, idx3, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x2, idx3, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx3, axis=1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x4, idx3, axis=1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx3, axis=1), relax.TensorType(dtype="")) + _check_inference( + bb, relax.op.take(x0, idx4, axis=0), relax.TensorType((6, 4, 10), dtype="float32") + ) + _check_inference( + bb, relax.op.take(x0, idx4, axis=1), relax.TensorType((4, 6, 4), dtype="float32") + ) + _check_inference(bb, relax.op.take(x1, idx4, axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.take(x2, idx4, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx4, axis=1), relax.TensorType((4, 6, 4), dtype="")) + _check_inference(bb, relax.op.take(x4, idx4, axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.take(x5, idx4, axis=1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.take(x0, idx5, axis=0), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.take(x0, idx5, axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.take(x1, idx5, axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.take(x2, idx5, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx5, axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.take(x4, idx5, axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.take(x5, idx5, axis=1), relax.TensorType(dtype="")) + _check_inference( + bb, relax.op.take(x0, idx6, axis=0), relax.TensorType((6, 4, 10), dtype="float32") + ) + _check_inference( + bb, relax.op.take(x0, idx6, axis=1), relax.TensorType((4, 6, 4), dtype="float32") + ) + _check_inference(bb, relax.op.take(x1, idx6, axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.take(x2, idx6, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx6, axis=1), relax.TensorType((4, 6, 4), dtype="")) + _check_inference(bb, relax.op.take(x4, idx6, axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.take(x5, idx6, axis=1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.take(x0, idx7, axis=0), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.take(x0, idx7, axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.take(x1, idx7, axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.take(x2, idx7, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx7, axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.take(x4, idx7, axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.take(x5, idx7, axis=1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.take(y0, idx0), relax.TensorType((6,), "float32")) + _check_inference(bb, relax.op.take(y1, idx0), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx0), relax.TensorType((6,), dtype="")) + _check_inference(bb, relax.op.take(y3, idx0), relax.TensorType(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx1), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y1, idx1), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx1), relax.TensorType(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y3, idx1), relax.TensorType(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx2), relax.TensorType((6,), "float32")) + _check_inference(bb, relax.op.take(y1, idx2), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx2), relax.TensorType((6,), dtype="")) + _check_inference(bb, relax.op.take(y3, idx2), relax.TensorType(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx3), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y1, idx3), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx3), relax.TensorType(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y3, idx3), relax.TensorType(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx4), relax.TensorType((6, 4), "float32")) + _check_inference(bb, relax.op.take(y1, idx4), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(y2, idx4), relax.TensorType((6, 4), dtype="")) + _check_inference(bb, relax.op.take(y3, idx4), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(y0, idx5), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(y1, idx5), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(y2, idx5), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(y3, idx5), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(y0, idx6), relax.TensorType((6, 4), "float32")) + _check_inference(bb, relax.op.take(y1, idx6), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(y2, idx6), relax.TensorType((6, 4), dtype="")) + _check_inference(bb, relax.op.take(y3, idx6), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(y0, idx7), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(y1, idx7), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(y2, idx7), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(y3, idx7), relax.TensorType(dtype="", ndim=2)) + + +def test_take_infer_ty_scalar_tensor_index(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((4, 10), "float32")) idx = relax.Var("idx", R.Tensor([], "int64")) - _check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorStructInfo([10], "float32")) - _check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorStructInfo([4], "float32")) + _check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorType([10], "float32")) + _check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorType([4], "float32")) -def test_take_infer_struct_info_prim_value_index(): +def test_take_infer_ty_prim_value_index(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((4, 10), "float32")) idx = relax.Var("idx", R.Prim("int64")) - _check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorStructInfo([10], "float32")) - _check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorStructInfo([4], "float32")) + _check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorType([10], "float32")) + _check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorType([4], "float32")) -def test_take_infer_struct_info_shape_symbolic(): +def test_take_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -241,69 +207,49 @@ def test_take_infer_struct_info_shape_symbolic(): ), ) - _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((m, i), "float32")) - _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((m, i), dtype="")) - _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((m, i), "float32")) - _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((m, i), dtype="")) - _check_inference( - bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((m, i, j, k), dtype="") - ) - _check_inference( - bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((m, i, j, k), dtype="") - ) - _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((i,), "float32")) - _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo((i,), dtype="")) - _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo((i,), "float32")) - _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo((i,), dtype="")) - _check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((i, j, k), "float32")) - _check_inference(bb, relax.op.take(y1, idx2), relax.TensorStructInfo((i, j, k), dtype="")) + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorType((m, i), "float32")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorType((m, i), dtype="")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorType((m, i), "float32")) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorType((m, i), dtype="")) + _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorType((m, i, j, k), dtype="")) + _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorType((m, i, j, k), dtype="")) + _check_inference(bb, relax.op.take(y0, idx0), relax.TensorType((i,), "float32")) + _check_inference(bb, relax.op.take(y1, idx0), relax.TensorType((i,), dtype="")) + _check_inference(bb, relax.op.take(y0, idx1), relax.TensorType((i,), "float32")) + _check_inference(bb, relax.op.take(y1, idx1), relax.TensorType((i,), dtype="")) + _check_inference(bb, relax.op.take(y0, idx2), relax.TensorType((i, j, k), "float32")) + _check_inference(bb, relax.op.take(y1, idx2), relax.TensorType((i, j, k), dtype="")) -def test_take_infer_struct_info_shape_var(): +def test_take_infer_ty_shape_var(): bb = relax.BlockBuilder() - sx0 = relax.Var("sx", relax.ShapeStructInfo((4, 10))) - sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=2)) - sx2 = relax.Var("sx", relax.ShapeStructInfo()) - sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6,))) - sidx1 = relax.Var("sidx", relax.ShapeStructInfo(ndim=1)) - x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + sx0 = relax.Var("sx", relax.ShapeType((4, 10))) + sx1 = relax.Var("sx", relax.ShapeType(ndim=2)) + sx2 = relax.Var("sx", relax.ShapeType()) + sidx0 = relax.Var("sidx", relax.ShapeType((6,))) + sidx1 = relax.Var("sidx", relax.ShapeType(ndim=1)) + x0 = relax.Var("x", relax.TensorType(sx0, "float32")) + x1 = relax.Var("x", relax.TensorType(sx1, "float32")) + x2 = relax.Var("x", relax.TensorType(sx2, "float32")) x3 = relax.Var("x", R.Tensor((4, 10), "float32")) - idx0 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64")) - idx1 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64")) + idx0 = relax.Var("idx", relax.TensorType(sidx0, "int64")) + idx1 = relax.Var("idx", relax.TensorType(sidx1, "int64")) idx2 = relax.Var("idx", R.Tensor((6,), "int64")) - _check_inference( - bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx0, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.take(x3, idx1, axis=1), relax.TensorType(dtype="float32", ndim=2)) -def test_take_infer_struct_info_more_input_dtype(): +def test_take_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((4, 10), "float16")) x1 = relax.Var("x", R.Tensor((4, 10), "int16")) @@ -312,18 +258,18 @@ def test_take_infer_struct_info_more_input_dtype(): idx1 = relax.Var("idx", R.Tensor((6,), "int8")) idx2 = relax.Var("idx", R.Tensor((6,), "uint32")) - _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float16")) - _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((4, 6), "int16")) - _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo((4, 6), "int32")) - _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((4, 6), "float16")) - _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((4, 6), "int16")) - _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo((4, 6), "int32")) - _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float16")) - _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((4, 6), "int16")) - _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorType((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorType((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorType((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorType((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorType((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorType((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorType((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorType((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorType((4, 6), "int32")) -def test_take_infer_struct_info_indices_not_integer_dtype(): +def test_take_infer_ty_indices_not_integer_dtype(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((4, 10), "float32")) idx0 = relax.Var("idx", R.Tensor((6, 6), "float32")) @@ -335,7 +281,7 @@ def test_take_infer_struct_info_indices_not_integer_dtype(): bb.normalize(relax.op.take(x, idx1, axis=1)) -def test_take_infer_struct_info_multi_dimensional_without_axis(): +def test_take_infer_ty_multi_dimensional_without_axis(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((4, 10), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -357,7 +303,7 @@ def test_take_infer_struct_info_multi_dimensional_without_axis(): bb.normalize(relax.op.take(x2, idx1)) -def test_take_infer_struct_info_axis_out_of_range(): +def test_take_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((4, 10), "float32")) idx = relax.Var("idx", R.Tensor((6,), "int64")) @@ -368,11 +314,11 @@ def test_take_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.take(x, idx, axis=2)) -def test_take_infer_struct_info_wrong_input_type(): +def test_take_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((4, 10))) + x0 = relax.Var("x", relax.ShapeType((4, 10))) x1 = relax.Var("x", R.Tensor((4, 10), "float32")) - idx0 = relax.Var("idx", relax.ShapeStructInfo((6,))) + idx0 = relax.Var("idx", relax.ShapeType((6,))) idx1 = relax.Var("idx", R.Tensor((6,), "int64")) with pytest.raises(TypeError): @@ -381,7 +327,7 @@ def test_take_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.take(x1, idx0, axis=1)) -def test_strided_slice_infer_struct_info(): +def test_strided_slice_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) @@ -397,65 +343,65 @@ def test_strided_slice_infer_struct_info(): relax.op.strided_slice( x0, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] ), - relax.TensorStructInfo((4, 9, 10, 3), "float32"), + relax.TensorType((4, 9, 10, 3), "float32"), ) _check_inference( bb, relax.op.strided_slice( x6, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] ), - relax.TensorStructInfo((4, 9, 10, 3), "float32", vdev0), + relax.TensorType((4, 9, 10, 3), "float32", vdev0), ) _check_inference( bb, relax.op.strided_slice( x1, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] ), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.strided_slice( x2, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] ), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.strided_slice( x3, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] ), - relax.TensorStructInfo((4, 9, 10, 3), dtype=""), + relax.TensorType((4, 9, 10, 3), dtype=""), ) _check_inference( bb, relax.op.strided_slice( x4, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] ), - relax.TensorStructInfo(dtype="", ndim=4), + relax.TensorType(dtype="", ndim=4), ) _check_inference( bb, relax.op.strided_slice( x5, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] ), - relax.TensorStructInfo(dtype=""), + relax.TensorType(dtype=""), ) _check_inference( bb, relax.op.strided_slice( x0, axes=[-1, -3, -4], begin=[8, 0, 1], end=[0, 9, 8], strides=[-3, 1, 2] ), - relax.TensorStructInfo((4, 9, 10, 3), "float32"), + relax.TensorType((4, 9, 10, 3), "float32"), ) _check_inference( bb, relax.op.strided_slice(x0, axes=[1, 2], begin=[1, 0], end=[8, 9]), - relax.TensorStructInfo((8, 7, 9, 10), "float32"), + relax.TensorType((8, 7, 9, 10), "float32"), ) -def test_strided_slice_infer_struct_info_shape_out_of_range(): +def test_strided_slice_infer_ty_shape_out_of_range(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((20, 10, 5), "float32")) _check_inference( @@ -463,32 +409,32 @@ def test_strided_slice_infer_struct_info_shape_out_of_range(): relax.op.strided_slice( x0, axes=[0, 1, 2], begin=[20, 10, 4], end=[0, 0, 1], strides=[-1, -3, -2] ), - relax.TensorStructInfo((19, 3, 2), "float32"), + relax.TensorType((19, 3, 2), "float32"), ) _check_inference( bb, relax.op.strided_slice( x0, axes=[0, 1, 2], begin=[200, 10, 4], end=[0, 0, 1], strides=[-1, -3, -2] ), - relax.TensorStructInfo((19, 3, 2), "float32"), + relax.TensorType((19, 3, 2), "float32"), ) _check_inference( bb, relax.op.strided_slice( x0, axes=[0, 1, 2], begin=[200, 10, 100], end=[0, 0, 1], strides=[-1, -3, -5] ), - relax.TensorStructInfo((19, 3, 1), "float32"), + relax.TensorType((19, 3, 1), "float32"), ) _check_inference( bb, relax.op.strided_slice( x0, axes=[0, 1, 2], begin=[-21, -11, -6], end=[1, 1, 1], strides=[1000, 1000, 1000] ), - relax.TensorStructInfo((1, 1, 1), "float32"), + relax.TensorType((1, 1, 1), "float32"), ) -def test_strided_slice_infer_struct_info_shape_symbolic(): +def test_strided_slice_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -498,70 +444,70 @@ def test_strided_slice_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]), - relax.TensorStructInfo((tirx.min(3, m) - tirx.min(1, m), n), "float32"), + relax.TensorType((tirx.min(3, m) - tirx.min(1, m), n), "float32"), ) _check_inference( bb, relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]), - relax.TensorStructInfo(((tirx.min(8, m) + 2 - tirx.min(1, m)) // 3, n), "float32"), + relax.TensorType(((tirx.min(8, m) + 2 - tirx.min(1, m)) // 3, n), "float32"), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]), - relax.TensorStructInfo((tirx.min(3, m) - tirx.min(1, m), n), dtype=""), + relax.TensorType((tirx.min(3, m) - tirx.min(1, m), n), dtype=""), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]), - relax.TensorStructInfo(((tirx.min(8, m) + 2 - tirx.min(1, m)) // 3, n), dtype=""), + relax.TensorType(((tirx.min(8, m) + 2 - tirx.min(1, m)) // 3, n), dtype=""), ) -def test_strided_slice_infer_struct_info_shape_var(): +def test_strided_slice_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((8, 10))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s0, dtype="")) - x4 = relax.Var("x", relax.TensorStructInfo(s1, dtype="")) - x5 = relax.Var("x", relax.TensorStructInfo(s2, dtype="")) + s0 = relax.Var("s", relax.ShapeType((8, 10))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + x3 = relax.Var("x", relax.TensorType(s0, dtype="")) + x4 = relax.Var("x", relax.TensorType(s1, dtype="")) + x5 = relax.Var("x", relax.TensorType(s2, dtype="")) _check_inference( bb, relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo(shape=[8, 10], dtype="float32"), + relax.TensorType(shape=[8, 10], dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x3, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo(shape=[8, 10], dtype=""), + relax.TensorType(shape=[8, 10], dtype=""), ) _check_inference( bb, relax.op.strided_slice(x4, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo(dtype="", ndim=2), + relax.TensorType(dtype="", ndim=2), ) _check_inference( bb, relax.op.strided_slice(x5, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo(dtype=""), + relax.TensorType(dtype=""), ) -def test_strided_slice_infer_struct_info_more_input_dtype(): +def test_strided_slice_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((8, 9), "float16")) x1 = relax.Var("x", R.Tensor((8, 9), "int32")) @@ -570,21 +516,21 @@ def test_strided_slice_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo((8, 9), "float16"), + relax.TensorType((8, 9), "float16"), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo((8, 9), "int32"), + relax.TensorType((8, 9), "int32"), ) _check_inference( bb, relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]), - relax.TensorStructInfo((8, 9), "int64"), + relax.TensorType((8, 9), "int64"), ) -def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): +def test_strided_slice_infer_ty_symbolic_begin_end_strides(): bb = relax.BlockBuilder() var = tirx.Var("var", "int64") size_var = tirx.SizeVar("size_var", "int64") @@ -593,7 +539,7 @@ def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[var], end=[8]), - relax.TensorStructInfo( + relax.TensorType( (tirx.max(8 - tirx.max(tirx.if_then_else(var < 0, var + 8, var), 0), 0), 9), dtype="float32", ), @@ -601,24 +547,24 @@ def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8]), - relax.TensorStructInfo((tirx.max(8 - size_var, 0), 9), dtype="float32"), + relax.TensorType((tirx.max(8 - size_var, 0), 9), dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[var]), - relax.TensorStructInfo( + relax.TensorType( (tirx.min(tirx.max(tirx.if_then_else(var < 0, var + 8, var), 0), 8), 9), dtype="float32" ), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var]), - relax.TensorStructInfo((tirx.min(size_var, 8), 9), dtype="float32"), + relax.TensorType((tirx.min(size_var, 8), 9), dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var]), - relax.TensorStructInfo( + relax.TensorType( [tirx.if_then_else(var < 0, -8 // (0 - var) + 1, (var + 7) // var), 9], dtype="float32", ), @@ -626,11 +572,11 @@ def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[size_var]), - relax.TensorStructInfo([7 // size_var + 1, 9], dtype="float32"), + relax.TensorType([7 // size_var + 1, 9], dtype="float32"), ) -def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound(): +def test_strided_slice_infer_ty_symbolic_begin_end_strides_inbound(): bb = relax.BlockBuilder() var = tirx.Var("var", "int64") size_var = tirx.SizeVar("size_var", "int64") @@ -639,7 +585,7 @@ def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound(): _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[var], end=[8], assume_inbound=True), - relax.TensorStructInfo( + relax.TensorType( (8 - var, 9), dtype="float32", ), @@ -647,73 +593,73 @@ def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound(): _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8], assume_inbound=True), - relax.TensorStructInfo((8 - size_var, 9), dtype="float32"), + relax.TensorType((8 - size_var, 9), dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[var], assume_inbound=True), - relax.TensorStructInfo((var, 9), dtype="float32"), + relax.TensorType((var, 9), dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var], assume_inbound=True), - relax.TensorStructInfo((size_var, 9), dtype="float32"), + relax.TensorType((size_var, 9), dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var], assume_inbound=True), - relax.TensorStructInfo([(var + 7) // var, 9], dtype="float32"), + relax.TensorType([(var + 7) // var, 9], dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var], assume_inbound=True), - relax.TensorStructInfo([(var + 7) // var, 9], dtype="float32"), + relax.TensorType([(var + 7) // var, 9], dtype="float32"), ) -def test_strided_slice_infer_struct_info_no_axis(): +def test_strided_slice_infer_ty_no_axis(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((m, n))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType((m, n))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + s2 = relax.Var("s", relax.ShapeType()) x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor(dtype="float32", ndim=2)) x2 = relax.Var("x", R.Tensor(dtype="float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorType(s0, "float32")) + x4 = relax.Var("x", relax.TensorType(s1, "float32")) + x5 = relax.Var("x", relax.TensorType(s2, "float32")) _check_inference( bb, relax.op.strided_slice(x0, axes=[], begin=[], end=[]), - relax.TensorStructInfo((m, n), "float32"), + relax.TensorType((m, n), "float32"), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[], begin=[], end=[]), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.strided_slice(x2, axes=[], begin=[], end=[]), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x3, axes=[], begin=[], end=[]), - relax.TensorStructInfo([m, n], "float32"), + relax.TensorType([m, n], "float32"), ) _check_inference( bb, relax.op.strided_slice(x4, axes=[], begin=[], end=[]), - relax.TensorStructInfo(s1, "float32"), + relax.TensorType(s1, "float32"), ) _check_inference( bb, relax.op.strided_slice(x5, axes=[], begin=[], end=[]), - relax.TensorStructInfo(s2, "float32"), + relax.TensorType(s2, "float32"), ) @@ -727,15 +673,15 @@ def test_strided_slice_begin_end_strides_int64(): ends = strided_slice.args[2] strides = strided_slice.args[3] - assert begins[0].struct_info.dtype == "int64" - assert begins[1].struct_info.dtype == "int64" - assert begins[2].struct_info.dtype == "int64" - assert ends[0].struct_info.dtype == "int64" - assert ends[1].struct_info.dtype == "int64" - assert ends[2].struct_info.dtype == "int64" - assert strides[0].struct_info.dtype == "int64" - assert strides[1].struct_info.dtype == "int64" - assert strides[2].struct_info.dtype == "int64" + assert begins[0].ty.dtype == "int64" + assert begins[1].ty.dtype == "int64" + assert begins[2].ty.dtype == "int64" + assert ends[0].ty.dtype == "int64" + assert ends[1].ty.dtype == "int64" + assert ends[2].ty.dtype == "int64" + assert strides[0].ty.dtype == "int64" + assert strides[1].ty.dtype == "int64" + assert strides[2].ty.dtype == "int64" def test_strided_slice_inconsistent_axes_begin_end_strides_length(): @@ -749,7 +695,7 @@ def test_strided_slice_inconsistent_axes_begin_end_strides_length(): relax.op.strided_slice(x, axes=[1], begin=[0], end=[9], strides=[]) -def test_strided_slice_infer_struct_info_repetitive_axes(): +def test_strided_slice_infer_ty_repetitive_axes(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((8, 9), "float32")) @@ -759,7 +705,7 @@ def test_strided_slice_infer_struct_info_repetitive_axes(): bb.normalize(relax.op.strided_slice(x, axes=[0, -2], begin=[0, 0], end=[8, 8])) -def test_strided_slice_infer_struct_info_axis_out_of_range(): +def test_strided_slice_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((8, 9), "float32")) @@ -769,10 +715,10 @@ def test_strided_slice_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.strided_slice(x, axes=[-3], begin=[0], end=[8])) -def test_strided_slice_infer_struct_info_wrong_input_type(): +def test_strided_slice_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((8, 9))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((8, 9), "float32"))) + x0 = relax.Var("x", relax.ShapeType((8, 9))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((8, 9), "float32"))) with pytest.raises(tvm.error.InternalError): bb.normalize(relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8])) @@ -780,7 +726,7 @@ def test_strided_slice_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8])) -def test_dynamic_strided_slice_infer_struct_info(): +def test_dynamic_strided_slice_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -859,7 +805,7 @@ def test_dynamic_strided_slice_infer_struct_info(): ) -def test_dynamic_strided_slice_infer_struct_info_symbolic(): +def test_dynamic_strided_slice_infer_ty_symbolic(): bb = relax.BlockBuilder() i = tirx.Var("i", "int64") j = tirx.Var("j", "int64") @@ -942,7 +888,7 @@ def test_dynamic_strided_slice_infer_struct_info_symbolic(): ) -def test_dynamic_strided_slice_infer_struct_info_arg_wrong_dtype(): +def test_dynamic_strided_slice_infer_ty_arg_wrong_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) b0 = relax.Var("begin", R.Tensor((4,), "float32")) @@ -953,7 +899,7 @@ def test_dynamic_strided_slice_infer_struct_info_arg_wrong_dtype(): bb.normalize(relax.op.strided_slice(x0, b0, e0, s0)) -def test_dynamic_strided_slice_infer_struct_info_arg_wrong_shape_info(): +def test_dynamic_strided_slice_infer_ty_arg_wrong_shape_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) m = tirx.Var("m", "int64") @@ -994,7 +940,7 @@ def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) -> R.Tensor((1 return R.call_tir( expected.strided_slice, (A,), - out_sinfo=R.Tensor((1, 16), "float32"), + out_ty=R.Tensor((1, 16), "float32"), tir_vars=R.shape([index]), ) @@ -1045,7 +991,7 @@ def strided_slice(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), var_T_dyna def main(A: R.Tensor((16, 16), dtype="float32"), B: R.Shape(["index"])) -> R.Tensor(("T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0)", 16), dtype="float32"): index = T.int64() cls = expected - gv = R.call_tir(cls.strided_slice, (A,), out_sinfo=R.Tensor((T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0), 16), dtype="float32"), tir_vars=R.shape([index])) + gv = R.call_tir(cls.strided_slice, (A,), out_ty=R.Tensor((T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0), 16), dtype="float32"), tir_vars=R.shape([index])) return gv # fmt: on diff --git a/tests/python/relax/test_op_linear_algebra.py b/tests/python/relax/test_op_linear_algebra.py index 63a51dd9aaf1..4b7b50854d03 100644 --- a/tests/python/relax/test_op_linear_algebra.py +++ b/tests/python/relax/test_op_linear_algebra.py @@ -30,12 +30,12 @@ def test_op_correctness(): assert relax.op.matmul(x, y).op == Op.get("relax.matmul") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_matmul_infer_struct_info(): +def test_matmul_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((3, 4), "float32")) @@ -54,31 +54,29 @@ def test_matmul_infer_struct_info(): y5 = relax.Var("y", R.Tensor()) y6 = relax.Var("y", R.Tensor((4, 5), "float32", vdev0)) - _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float32")) - _check_inference(bb, relax.op.matmul(x7, y6), relax.TensorStructInfo((3, 5), "float32", vdev0)) - _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32")) - _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((2, 3, 5), "float32")) - _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((2, 3, 5), "float32")) - _check_inference( - bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), "float32") - ) - _check_inference(bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), "")) - _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5)) - _check_inference(bb, relax.op.matmul(x5, y3), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.matmul(x3, y5), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorType((3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x7, y6), relax.TensorType((3, 5), "float32", vdev0)) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorType((), "float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorType((2, 3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorType((2, 3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x3, y3), relax.TensorType((6, 2, 3, 4, 7), "float32")) + _check_inference(bb, relax.op.matmul(x4, y3), relax.TensorType((6, 2, 3, 4, 7), "")) + _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.matmul(x5, y3), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.matmul(x3, y5), relax.TensorType(dtype="")) _check_inference( bb, relax.op.matmul(x3, y3, out_dtype="float16"), - relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"), + relax.TensorType((6, 2, 3, 4, 7), "float16"), ) _check_inference( bb, relax.op.matmul(x6, y3, out_dtype="float16"), - relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"), + relax.TensorType((6, 2, 3, 4, 7), "float16"), ) -def test_matmul_infer_struct_info_shape_symbolic(): +def test_matmul_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -99,43 +97,39 @@ def test_matmul_infer_struct_info_shape_symbolic(): y3 = relax.Var("y", R.Tensor((a, 1, c, k0, n), "float32")) y4 = relax.Var("y", R.Tensor((a, b1, c, k0, n), "float32")) - _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32")) - _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((a, b, n), "float32")) - _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((a, b, m), "float32")) - _check_inference( - bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((a, b, c, m, n), "float32") - ) - _check_inference( - bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((a, b, c, m, n), "float32") - ) - _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorType((), "float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorType((a, b, n), "float32")) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorType((a, b, m), "float32")) + _check_inference(bb, relax.op.matmul(x3, y3), relax.TensorType((a, b, c, m, n), "float32")) + _check_inference(bb, relax.op.matmul(x4, y3), relax.TensorType((a, b, c, m, n), "float32")) + _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorType(dtype="float32", ndim=5)) -def test_matmul_infer_struct_info_shape_var(): +def test_matmul_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) - s3 = relax.Var("s4", relax.ShapeStructInfo(ndim=1)) - s5 = relax.Var("s5", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) - y0 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) - y1 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) - y2 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) - - _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo(dtype="float32", ndim=4)) - _check_inference(bb, relax.op.matmul(x1, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.matmul(x2, y0), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.matmul(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=0)) - _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo(dtype="float32", ndim=0)) - - -def test_matmul_infer_struct_info_more_input_dtype(): + s0 = relax.Var("s0", relax.ShapeType(ndim=4)) + s1 = relax.Var("s1", relax.ShapeType(ndim=3)) + s2 = relax.Var("s3", relax.ShapeType(ndim=1)) + s3 = relax.Var("s4", relax.ShapeType(ndim=1)) + s5 = relax.Var("s5", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s2, "float32")) + x2 = relax.Var("x", relax.TensorType(s5, "float32")) + y0 = relax.Var("y", relax.TensorType(s1, "float32")) + y1 = relax.Var("y", relax.TensorType(s2, "float32")) + y2 = relax.Var("y", relax.TensorType(s3, "float32")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.matmul(x1, y0), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.matmul(x2, y0), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.matmul(x0, y1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorType(dtype="float32", ndim=0)) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorType(dtype="float32", ndim=0)) + + +def test_matmul_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4), "float16")) y0 = relax.Var("y", R.Tensor((4, 5), "float16")) @@ -144,12 +138,12 @@ def test_matmul_infer_struct_info_more_input_dtype(): x2 = relax.Var("x", R.Tensor((3, 4), "int64")) y2 = relax.Var("y", R.Tensor((4, 5), "int64")) - _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float16")) - _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((3, 5), "int8")) - _check_inference(bb, relax.op.matmul(x2, y2), relax.TensorStructInfo((3, 5), "int64")) + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorType((3, 5), "float16")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorType((3, 5), "int8")) + _check_inference(bb, relax.op.matmul(x2, y2), relax.TensorType((3, 5), "int64")) -def test_matmul_infer_struct_info_mixed_precision(): +def test_matmul_infer_ty_mixed_precision(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4), "float16")) y0 = relax.Var("y", R.Tensor((4, 5), "float16")) @@ -161,19 +155,19 @@ def test_matmul_infer_struct_info_mixed_precision(): _check_inference( bb, relax.op.matmul(x0, y0, out_dtype="float32"), - relax.TensorStructInfo((3, 5), "float32"), + relax.TensorType((3, 5), "float32"), ) _check_inference( - bb, relax.op.matmul(x1, y1, out_dtype="int32"), relax.TensorStructInfo((3, 5), "int32") + bb, relax.op.matmul(x1, y1, out_dtype="int32"), relax.TensorType((3, 5), "int32") ) _check_inference( bb, relax.op.matmul(x2, y2, out_dtype="float32"), - relax.TensorStructInfo((3, 5), "float32"), + relax.TensorType((3, 5), "float32"), ) -def test_matmul_infer_struct_info_zero_rank_input(): +def test_matmul_infer_ty_zero_rank_input(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4), "float32")) x1 = relax.Var("x", R.Tensor((), "float32")) @@ -186,7 +180,7 @@ def test_matmul_infer_struct_info_zero_rank_input(): bb.normalize(relax.op.matmul(x1, y0)) -def test_matmul_infer_struct_info_not_broadcastable(): +def test_matmul_infer_ty_not_broadcastable(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) y = relax.Var("y", R.Tensor((2, 8, 3, 5, 6), "float32")) @@ -195,7 +189,7 @@ def test_matmul_infer_struct_info_not_broadcastable(): bb.normalize(relax.op.matmul(x, y)) -def test_matmul_infer_struct_info_unequal_reduction_length(): +def test_matmul_infer_ty_unequal_reduction_length(): bb = relax.BlockBuilder() k = tirx.Var("k", "int64") x0 = relax.Var("x", R.Tensor((3, 4), "float32")) @@ -227,32 +221,28 @@ def test_linear(): # Need a scope to normalize non-leaf nodes with bb.function("func", [x1]): + _check_inference(bb, relax.op.linear(x1, w1, b1), relax.TensorType((2, 3, 5), "float32")) _check_inference( - bb, relax.op.linear(x1, w1, b1), relax.TensorStructInfo((2, 3, 5), "float32") - ) - _check_inference( - bb, relax.op.linear(x3, w4, b3), relax.TensorStructInfo((2, 3, 5), "float32", vdev0) - ) - _check_inference( - bb, relax.op.linear(x1, w1, b2), relax.TensorStructInfo((2, 3, 5), "float32") + bb, relax.op.linear(x3, w4, b3), relax.TensorType((2, 3, 5), "float32", vdev0) ) + _check_inference(bb, relax.op.linear(x1, w1, b2), relax.TensorType((2, 3, 5), "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.linear(x1, w2, b1)) # error on Add with shape (2, 3, 5) and (4,) - _check_inference(bb, relax.op.linear(x1, w2, b2), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.linear(x1, w3, b1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.linear(x1, w3, b2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.linear(x2, w1, b1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.linear(x2, w1, b2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.linear(x2, w2, b1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.linear(x2, w2, b2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.linear(x2, w3, b1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.linear(x2, w3, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x1, w2, b2), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.linear(x1, w3, b1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.linear(x1, w3, b2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w1, b1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w1, b2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w2, b1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w2, b2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w3, b1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w3, b2), relax.TensorType(dtype="float32")) # Fake output gv = bb.emit_func_output(relax.Tuple([])) -def test_einsum_infer_struct_info(): +def test_einsum_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x0", R.Tensor((), "float32")) @@ -273,46 +263,42 @@ def test_einsum_infer_struct_info(): x15 = relax.Var("x15", R.Tensor((2, 5, 3, 6, 4), "float32")) x16 = relax.Var("x16", R.Tensor((5, 5), "int32", vdev0)) - _check_inference(bb, relax.op.einsum((x2,), "ii"), relax.TensorStructInfo((), "int32")) - _check_inference(bb, relax.op.einsum((x16,), "ii"), relax.TensorStructInfo((), "int32", vdev0)) - _check_inference(bb, relax.op.einsum((x2,), "ii->i"), relax.TensorStructInfo((5,), "int32")) - _check_inference(bb, relax.op.einsum([x2], "...j->..."), relax.TensorStructInfo((5,), "int32")) - _check_inference( - bb, relax.op.einsum((x2, x1), "...j, j"), relax.TensorStructInfo((5,), "int32") - ) + _check_inference(bb, relax.op.einsum((x2,), "ii"), relax.TensorType((), "int32")) + _check_inference(bb, relax.op.einsum((x16,), "ii"), relax.TensorType((), "int32", vdev0)) + _check_inference(bb, relax.op.einsum((x2,), "ii->i"), relax.TensorType((5,), "int32")) + _check_inference(bb, relax.op.einsum([x2], "...j->..."), relax.TensorType((5,), "int32")) + _check_inference(bb, relax.op.einsum((x2, x1), "...j, j"), relax.TensorType((5,), "int32")) + _check_inference(bb, relax.op.einsum((x0, x5), "..., ..."), relax.TensorType((2, 3), "float32")) _check_inference( - bb, relax.op.einsum((x0, x5), "..., ..."), relax.TensorStructInfo((2, 3), "float32") + bb, relax.op.einsum((x5, x6), "ij,jk->ik"), relax.TensorType((2, 4), "float32") ) _check_inference( - bb, relax.op.einsum((x5, x6), "ij,jk->ik"), relax.TensorStructInfo((2, 4), "float32") + bb, relax.op.einsum((x5, x6, x8), "ij,jk,km->im"), relax.TensorType((2, 5), "float32") ) _check_inference( - bb, relax.op.einsum((x5, x6, x8), "ij,jk,km->im"), relax.TensorStructInfo((2, 5), "float32") + bb, relax.op.einsum((x9, x10), "ijk, jil->kl"), relax.TensorType((5, 2), "float32") ) _check_inference( - bb, relax.op.einsum((x9, x10), "ijk, jil->kl"), relax.TensorStructInfo((5, 2), "float32") - ) - _check_inference( - bb, relax.op.einsum((x3, x4), "ij, ij -> i"), relax.TensorStructInfo((2,), "float32") + bb, relax.op.einsum((x3, x4), "ij, ij -> i"), relax.TensorType((2,), "float32") ) _check_inference( bb, relax.op.einsum((x3, x7), "...ij, ...jk -> ...ik"), - relax.TensorStructInfo((1, 2), "float32"), + relax.TensorType((1, 2), "float32"), ) _check_inference( bb, relax.op.einsum((x12, x13), "...ij, ...ik -> ...jk"), - relax.TensorStructInfo((1, 1, 4, 3), "float16"), + relax.TensorType((1, 1, 4, 3), "float16"), ) _check_inference( bb, relax.op.einsum((x11, x14, x15), "...ik, ...jk, ...hk -> i...jh"), - relax.TensorStructInfo((4, 2, 5, 3, 8, 6), "float32"), + relax.TensorType((4, 2, 5, 3, 8, 6), "float32"), ) -def test_einsum_infer_struct_info_shape_symbolic(): +def test_einsum_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -321,15 +307,13 @@ def test_einsum_infer_struct_info_shape_symbolic(): y = relax.Var("y", R.Tensor((b, c), "float32")) z = relax.Var("z", R.Tensor((a, a), "float32")) - _check_inference(bb, relax.op.einsum((z,), "ii->i"), relax.TensorStructInfo((a,), "float32")) - _check_inference( - bb, relax.op.einsum((x, y), "ij,jk->ik"), relax.TensorStructInfo((a, c), "float32") - ) + _check_inference(bb, relax.op.einsum((z,), "ii->i"), relax.TensorType((a,), "float32")) + _check_inference(bb, relax.op.einsum((x, y), "ij,jk->ik"), relax.TensorType((a, c), "float32")) -def test_einsum_infer_struct_info_wrong_inputs(): +def test_einsum_infer_ty_wrong_inputs(): bb = relax.BlockBuilder() - x0 = relax.Var("x0", relax.ShapeStructInfo((2, 3, 4, 5))) + x0 = relax.Var("x0", relax.ShapeType((2, 3, 4, 5))) x1 = relax.Var("x1", R.Tensor((5, 5), "int32")) with pytest.raises(TypeError): diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 4cecd3ed3e5c..c09a04893f2e 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -50,12 +50,12 @@ def test_op_correctness(): assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_reshape_infer_struct_info(): +def test_reshape_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -70,65 +70,49 @@ def test_reshape_infer_struct_info(): s2 = relax.Var("s", R.Shape()) s3 = relax.ShapeExpr((3, 8, 5)) - _check_inference( - bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") - ) - _check_inference( - bb, relax.op.reshape(x6, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32", vdev0) - ) - _check_inference( - bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") - ) - _check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorStructInfo((120,), "float32")) - _check_inference( - bb, relax.op.reshape(x0, relax.ShapeExpr([-1])), relax.TensorStructInfo((120,), "float32") - ) - _check_inference( - bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") - ) - _check_inference( - bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") - ) - _check_inference( - bb, relax.op.reshape(x3, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") - ) - _check_inference( - bb, relax.op.reshape(x3, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") - ) - _check_inference( - bb, relax.op.reshape(x4, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") - ) - _check_inference( - bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") - ) - # Remove Var from StructInfo when we can - _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo((3, 8, 5), "float32")) - _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo((3, 8, 5), "float32")) - _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo((3, 8, 5), "float32")) - _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) - _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) - _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) - _check_inference(bb, relax.op.reshape(x0, s1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.reshape(x1, s1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.reshape(x2, s1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.reshape(x3, s1), relax.TensorStructInfo(s1, dtype="")) - _check_inference(bb, relax.op.reshape(x4, s1), relax.TensorStructInfo(s1, dtype="")) - _check_inference(bb, relax.op.reshape(x5, s1), relax.TensorStructInfo(s1, dtype="")) - _check_inference(bb, relax.op.reshape(x0, s2), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.reshape(x1, s2), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.reshape(x2, s2), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2, dtype="")) - _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2, dtype="")) - _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2, dtype="")) - _check_inference(bb, relax.op.reshape(x0, s3), relax.TensorStructInfo(s3, "float32")) - _check_inference(bb, relax.op.reshape(x1, s3), relax.TensorStructInfo(s3, "float32")) - _check_inference(bb, relax.op.reshape(x2, s3), relax.TensorStructInfo(s3, "float32")) - _check_inference(bb, relax.op.reshape(x3, s3), relax.TensorStructInfo(s3, dtype="")) - _check_inference(bb, relax.op.reshape(x4, s3), relax.TensorStructInfo(s3, dtype="")) - _check_inference(bb, relax.op.reshape(x5, s3), relax.TensorStructInfo(s3, dtype="")) - - -def test_reshape_infer_struct_info_shape_symbolic(): + _check_inference(bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorType((3, 8, 5), "float32")) + _check_inference( + bb, relax.op.reshape(x6, (3, 8, 5)), relax.TensorType((3, 8, 5), "float32", vdev0) + ) + _check_inference(bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorType((120,), "float32")) + _check_inference( + bb, relax.op.reshape(x0, relax.ShapeExpr([-1])), relax.TensorType((120,), "float32") + ) + _check_inference(bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x3, (3, 8, 5)), relax.TensorType((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x3, (3, -1, 5)), relax.TensorType((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x4, (3, 8, 5)), relax.TensorType((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorType((3, 8, 5), dtype="")) + # Remove Var from Type when we can + _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorType((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorType((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorType((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x0, s1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.reshape(x1, s1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.reshape(x2, s1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.reshape(x3, s1), relax.TensorType(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s1), relax.TensorType(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s1), relax.TensorType(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s2), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.reshape(x1, s2), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.reshape(x2, s2), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorType(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorType(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorType(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s3), relax.TensorType(s3, "float32")) + _check_inference(bb, relax.op.reshape(x1, s3), relax.TensorType(s3, "float32")) + _check_inference(bb, relax.op.reshape(x2, s3), relax.TensorType(s3, "float32")) + _check_inference(bb, relax.op.reshape(x3, s3), relax.TensorType(s3, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s3), relax.TensorType(s3, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s3), relax.TensorType(s3, dtype="")) + + +def test_reshape_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -140,99 +124,91 @@ def test_reshape_infer_struct_info_shape_symbolic(): s2 = relax.ShapeExpr((c, a, d, b)) _check_inference( - bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a, d, b), "float32") + bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorType((c, a, d, b), "float32") ) _check_inference( bb, relax.op.reshape(x, (d, c, b, -1)), - relax.TensorStructInfo((d, c, b, a), "float32"), + relax.TensorType((d, c, b, a), "float32"), ) _check_inference( bb, relax.op.reshape(x, (1, -1, 1)), - relax.TensorStructInfo((1, a * b * c * d, 1), "float32"), + relax.TensorType((1, a * b * c * d, 1), "float32"), ) _check_inference( bb, relax.op.reshape(x, (2, -1, a)), - relax.TensorStructInfo((2, tirx.floordiv(b * c * d, 2), a), "float32"), + relax.TensorType((2, tirx.floordiv(b * c * d, 2), a), "float32"), ) _check_inference( bb, relax.op.reshape(x, (c, -1, d, b)), - relax.TensorStructInfo((c, a, d, b), "float32"), + relax.TensorType((c, a, d, b), "float32"), ) _check_inference( bb, relax.op.reshape(x, (c, a * d, b)), - relax.TensorStructInfo((c, a * d, b), "float32"), + relax.TensorType((c, a * d, b), "float32"), ) _check_inference( bb, relax.op.reshape(x, (c, a * b * d, -1)), - relax.TensorStructInfo((c, a * b * d, 1), "float32"), + relax.TensorType((c, a * b * d, 1), "float32"), ) - # Remove Var from StructInfo when we can - _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo((c, a, d, b), "float32")) - _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2, "float32")) + # Remove Var from Type when we can + _check_inference(bb, relax.op.reshape(x, s0), relax.TensorType((c, a, d, b), "float32")) + _check_inference(bb, relax.op.reshape(x, s1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.reshape(x, s2), relax.TensorType(s2, "float32")) -def test_reshape_infer_struct_info_shape_var(): +def test_reshape_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - ns0 = relax.Var("ns", relax.ShapeStructInfo((3, 8, 5))) - ns1 = relax.Var("ns", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType((2, 3, 4, 5))) + s1 = relax.Var("s", relax.ShapeType(ndim=4)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + ns0 = relax.Var("ns", relax.ShapeType((3, 8, 5))) + ns1 = relax.Var("ns", relax.ShapeType()) + _check_inference(bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorType((3, 8, 5), "float32")) _check_inference( - bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") - ) - _check_inference( - bb, relax.op.reshape(x0, (2, 3, 0, 5)), relax.TensorStructInfo((2, 3, 4, 5), "float32") - ) - _check_inference( - bb, relax.op.reshape(x0, (1, 3, 0, -1)), relax.TensorStructInfo((1, 3, 4, 10), "float32") - ) - _check_inference( - bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + bb, relax.op.reshape(x0, (2, 3, 0, 5)), relax.TensorType((2, 3, 4, 5), "float32") ) - # Remove Var from StructInfo when we can - _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) - _check_inference(bb, relax.op.reshape(x0, ns1), relax.TensorStructInfo(ns1, "float32")) _check_inference( - bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + bb, relax.op.reshape(x0, (1, 3, 0, -1)), relax.TensorType((1, 3, 4, 10), "float32") ) - # Remove Var from StructInfo when we can - _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) - _check_inference(bb, relax.op.reshape(x1, ns1), relax.TensorStructInfo(ns1, "float32")) - _check_inference( - bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") - ) - # Remove Var from StructInfo when we can - _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) - _check_inference(bb, relax.op.reshape(x2, ns1), relax.TensorStructInfo(ns1, "float32")) + _check_inference(bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorType((3, 8, 5), "float32")) + # Remove Var from Type when we can + _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x0, ns1), relax.TensorType(ns1, "float32")) + _check_inference(bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorType((3, 8, 5), "float32")) + # Remove Var from Type when we can + _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x1, ns1), relax.TensorType(ns1, "float32")) + _check_inference(bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorType((3, 8, 5), "float32")) + # Remove Var from Type when we can + _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorType((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x2, ns1), relax.TensorType(ns1, "float32")) -def test_reshape_infer_struct_info_more_input_dtype(): +def test_reshape_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) - _check_inference(bb, relax.op.reshape(x0, (120,)), relax.TensorStructInfo((120,), "float16")) - _check_inference(bb, relax.op.reshape(x1, (120,)), relax.TensorStructInfo((120,), "int8")) + _check_inference(bb, relax.op.reshape(x0, (120,)), relax.TensorType((120,), "float16")) + _check_inference(bb, relax.op.reshape(x1, (120,)), relax.TensorType((120,), "int8")) -def test_reshape_infer_struct_info_unequal_shape_prod(): +def test_reshape_infer_ty_unequal_shape_prod(): bb = relax.BlockBuilder() - s = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) + s = relax.Var("s", relax.ShapeType((2, 3, 4, 5))) x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) - ns = relax.Var("ns", relax.ShapeStructInfo((4, 4, 1, 5))) + x1 = relax.Var("x", relax.TensorType(s, "float32")) + ns = relax.Var("ns", relax.ShapeType((4, 4, 1, 5))) with pytest.raises(ValueError): bb.normalize(relax.op.reshape(x0, (4, 4, 1, 5))) @@ -248,14 +224,14 @@ def test_reshape_infer_struct_info_unequal_shape_prod(): bb.normalize(relax.op.reshape(x1, ns)) -def test_reshape_infer_struct_info_inference_not_deducible(): +def test_reshape_infer_ty_inference_not_deducible(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType()) x0 = relax.Var("x", R.Tensor("float32", ndim=4)) x1 = relax.Var("x", R.Tensor("float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) with pytest.raises(tvm.error.InternalError): bb.normalize(relax.op.reshape(x0, (2, 3, -1))) @@ -277,7 +253,7 @@ def test_reshape_new_shape_not_tuple(): relax.op.reshape(x, m) -def test_reshape_infer_struct_info_new_shape_not_integer(): +def test_reshape_infer_ty_new_shape_not_integer(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -289,7 +265,7 @@ def test_reshape_infer_struct_info_new_shape_not_integer(): bb.normalize(relax.op.reshape(x, (2, 3, 4.0, -1))) -def test_reshape_infer_struct_info_multiple_dim_inference(): +def test_reshape_infer_ty_multiple_dim_inference(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -299,7 +275,7 @@ def test_reshape_infer_struct_info_multiple_dim_inference(): bb.normalize(relax.op.reshape(x, (-1, -1, -1, -1))) -def test_reshape_infer_struct_info_non_positive_new_shape(): +def test_reshape_infer_ty_non_positive_new_shape(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -307,13 +283,13 @@ def test_reshape_infer_struct_info_non_positive_new_shape(): bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5))) -def test_reshape_infer_struct_info_wrong_input_type(): +def test_reshape_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) - ns = relax.Var("ns", relax.TensorStructInfo((120,), "float32")) - pv = relax.Var("pv", relax.PrimStructInfo("int64")) + ns = relax.Var("ns", relax.TensorType((120,), "float32")) + pv = relax.Var("pv", relax.PrimType("int64")) with pytest.raises(TypeError): bb.normalize(relax.op.reshape(x0, (2, 3, 4, 5))) @@ -325,7 +301,7 @@ def test_reshape_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.reshape(x2, [pv])) -def test_permute_dims_infer_struct_info(): +def test_permute_dims_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) @@ -339,57 +315,49 @@ def test_permute_dims_infer_struct_info(): x8 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32", vdev0)) _check_inference( - bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float32") + bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorType((3, 4, 2, 1), "float32") ) _check_inference( bb, relax.op.permute_dims(x8, [2, 3, 1, 0]), - relax.TensorStructInfo((3, 4, 2, 1), "float32", vdev0), + relax.TensorType((3, 4, 2, 1), "float32", vdev0), ) _check_inference( - bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo((4, 3, 2, 1), "float32") + bb, relax.op.permute_dims(x0, axes=None), relax.TensorType((4, 3, 2, 1), "float32") ) _check_inference( bb, relax.op.permute_dims(x0, [-2, -3, 3, -4]), - relax.TensorStructInfo((3, 2, 4, 1), "float32"), - ) - _check_inference( - bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="float32", ndim=4) + relax.TensorType((3, 2, 4, 1), "float32"), ) _check_inference( - bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( - bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") + bb, relax.op.permute_dims(x1, axes=None), relax.TensorType(dtype="float32", ndim=4) ) + _check_inference(bb, relax.op.permute_dims(x2, axes=None), relax.TensorType(dtype="float32")) _check_inference( - bb, relax.op.permute_dims(x3, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), dtype="") + bb, relax.op.permute_dims(x3, [2, 3, 1, 0]), relax.TensorType((3, 4, 2, 1), dtype="") ) _check_inference( - bb, relax.op.permute_dims(x3, axes=None), relax.TensorStructInfo((4, 3, 2, 1), dtype="") + bb, relax.op.permute_dims(x3, axes=None), relax.TensorType((4, 3, 2, 1), dtype="") ) _check_inference( bb, relax.op.permute_dims(x3, [-2, -3, 3, -4]), - relax.TensorStructInfo((3, 2, 4, 1), dtype=""), - ) - _check_inference( - bb, relax.op.permute_dims(x4, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="", ndim=4) - ) - _check_inference( - bb, relax.op.permute_dims(x4, axes=None), relax.TensorStructInfo(dtype="", ndim=4) - ) - _check_inference(bb, relax.op.permute_dims(x5, axes=None), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.permute_dims(x6, axes=None), relax.TensorStructInfo((1,), "float32") + relax.TensorType((3, 2, 4, 1), dtype=""), ) _check_inference( - bb, relax.op.permute_dims(x7, axes=None), relax.TensorStructInfo((), "float32") + bb, relax.op.permute_dims(x4, [2, 3, 1, 0]), relax.TensorType(dtype="", ndim=4) ) + _check_inference(bb, relax.op.permute_dims(x4, axes=None), relax.TensorType(dtype="", ndim=4)) + _check_inference(bb, relax.op.permute_dims(x5, axes=None), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.permute_dims(x6, axes=None), relax.TensorType((1,), "float32")) + _check_inference(bb, relax.op.permute_dims(x7, axes=None), relax.TensorType((), "float32")) -def test_permute_dims_infer_struct_info_shape_symbolic(): +def test_permute_dims_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -398,75 +366,69 @@ def test_permute_dims_infer_struct_info_shape_symbolic(): x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) _check_inference( - bb, relax.op.permute_dims(x, [2, 3, 1, 0]), relax.TensorStructInfo((c, d, b, a), "float32") + bb, relax.op.permute_dims(x, [2, 3, 1, 0]), relax.TensorType((c, d, b, a), "float32") ) _check_inference( - bb, relax.op.permute_dims(x, axes=None), relax.TensorStructInfo((d, c, b, a), "float32") + bb, relax.op.permute_dims(x, axes=None), relax.TensorType((d, c, b, a), "float32") ) _check_inference( bb, relax.op.permute_dims(x, [-2, -3, 3, -4]), - relax.TensorStructInfo((c, b, d, a), "float32"), + relax.TensorType((c, b, d, a), "float32"), ) -def test_permute_dims_infer_struct_info_shape_var(): +def test_permute_dims_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((1, 2, 3, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=4)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.permute_dims(x0, [0, 1, 2, 3]), relax.TensorType(s0, "float32")) _check_inference( - bb, relax.op.permute_dims(x0, [0, 1, 2, 3]), relax.TensorStructInfo(s0, "float32") - ) - _check_inference( - bb, relax.op.permute_dims(x0, [-4, -3, -2, -1]), relax.TensorStructInfo(s0, "float32") - ) - _check_inference( - bb, relax.op.permute_dims(x0, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.permute_dims(x0, [-4, -3, -2, -1]), relax.TensorType(s0, "float32") ) _check_inference( - bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.permute_dims(x0, [2, 3, 0, 1]), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( - bb, relax.op.permute_dims(x1, [0, 1, 2, 3]), relax.TensorStructInfo(s1, "float32") + bb, relax.op.permute_dims(x0, axes=None), relax.TensorType(dtype="float32", ndim=4) ) + _check_inference(bb, relax.op.permute_dims(x1, [0, 1, 2, 3]), relax.TensorType(s1, "float32")) _check_inference( - bb, relax.op.permute_dims(x1, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.permute_dims(x1, [2, 3, 0, 1]), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( - bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference( - bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") + bb, relax.op.permute_dims(x1, axes=None), relax.TensorType(dtype="float32", ndim=4) ) + _check_inference(bb, relax.op.permute_dims(x2, axes=None), relax.TensorType(dtype="float32")) -def test_permute_dims_infer_struct_info_more_input_dtype(): +def test_permute_dims_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int8")) x2 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int32")) _check_inference( - bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float16") + bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorType((3, 4, 2, 1), "float16") ) _check_inference( - bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int8") + bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorType((3, 4, 2, 1), "int8") ) _check_inference( - bb, relax.op.permute_dims(x2, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int32") + bb, relax.op.permute_dims(x2, [2, 3, 1, 0]), relax.TensorType((3, 4, 2, 1), "int32") ) -def test_permute_dims_infer_struct_info_unknown_ndim_with_axes(): +def test_permute_dims_infer_ty_unknown_ndim_with_axes(): bb = relax.BlockBuilder() - s = relax.Var("s", relax.ShapeStructInfo()) + s = relax.Var("s", relax.ShapeType()) x0 = relax.Var("x", R.Tensor("float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + x1 = relax.Var("x", relax.TensorType(s, "float32")) with pytest.raises(tvm.error.InternalError): bb.normalize(relax.op.permute_dims(x0, [2, 3, 1, 0])) @@ -474,14 +436,14 @@ def test_permute_dims_infer_struct_info_unknown_ndim_with_axes(): bb.normalize(relax.op.permute_dims(x1, [2, 3, 1, 0])) -def test_permute_dims_infer_struct_info_wrong_number_axes(): +def test_permute_dims_infer_ty_wrong_number_axes(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s0 = relax.Var("s", relax.ShapeType((1, 2, 3, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=4)) x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.permute_dims(x0, [0, 2, 1])) @@ -501,7 +463,7 @@ def test_permute_dims_infer_struct_info_wrong_number_axes(): bb.normalize(relax.op.permute_dims(x3, [1, 2, 4, 0, 3])) -def test_permute_dims_infer_struct_info_axis_out_of_range(): +def test_permute_dims_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -516,7 +478,7 @@ def test_permute_dims_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.permute_dims(x1, [0, -5, 1, 3])) -def test_permute_dims_infer_struct_info_repetitive_axes(): +def test_permute_dims_infer_ty_repetitive_axes(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -531,10 +493,10 @@ def test_permute_dims_infer_struct_info_repetitive_axes(): bb.normalize(relax.op.permute_dims(x1, [0, 2, -2, 1])) -def test_permute_dims_infer_struct_info_wrong_input_type(): +def test_permute_dims_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((1, 2, 3, 4))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((1, 2, 3, 4), "float32"))) + x0 = relax.Var("x", relax.ShapeType((1, 2, 3, 4))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((1, 2, 3, 4), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.permute_dims(x0)) @@ -542,7 +504,7 @@ def test_permute_dims_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.permute_dims(x1)) -def test_expand_dims_infer_struct_info(): +def test_expand_dims_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -554,105 +516,103 @@ def test_expand_dims_infer_struct_info(): x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) _check_inference( - bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float32") + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorType((2, 1, 3, 1, 4), "float32") ) _check_inference( bb, relax.op.expand_dims(x6, [1, 3]), - relax.TensorStructInfo((2, 1, 3, 1, 4), "float32", vdev0), + relax.TensorType((2, 1, 3, 1, 4), "float32", vdev0), ) _check_inference( bb, relax.op.expand_dims(x0, [-1, 1, -6, 3, 5]), - relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), "float32"), - ) - _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo((2, 3, 4), "float32")) - _check_inference( - bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + relax.TensorType((2, 1, 1, 1, 3, 1, 4, 1), "float32"), ) + _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorType((2, 3, 4), "float32")) _check_inference( - bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorType(dtype="float32", ndim=5) ) - _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x1, []), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorType(dtype="float32")) _check_inference( - bb, relax.op.expand_dims(x3, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), dtype="") + bb, relax.op.expand_dims(x3, [1, 3]), relax.TensorType((2, 1, 3, 1, 4), dtype="") ) _check_inference( bb, relax.op.expand_dims(x3, [-1, 1, -6, 3, 5]), - relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), dtype=""), + relax.TensorType((2, 1, 1, 1, 3, 1, 4, 1), dtype=""), ) - _check_inference(bb, relax.op.expand_dims(x3, []), relax.TensorStructInfo((2, 3, 4), dtype="")) - _check_inference(bb, relax.op.expand_dims(x4, [1, 3]), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference(bb, relax.op.expand_dims(x4, []), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.expand_dims(x5, [1, 3]), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.expand_dims(x5, []), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.expand_dims(x3, []), relax.TensorType((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.expand_dims(x4, [1, 3]), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.expand_dims(x4, []), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.expand_dims(x5, [1, 3]), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.expand_dims(x5, []), relax.TensorType(dtype="")) -def test_expand_dims_infer_struct_info_shape_symbolic(): +def test_expand_dims_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") x = relax.Var("x", R.Tensor((a, 4, b), "float32")) _check_inference( - bb, relax.op.expand_dims(x, [1, 3]), relax.TensorStructInfo((a, 1, 4, 1, b), "float32") + bb, relax.op.expand_dims(x, [1, 3]), relax.TensorType((a, 1, 4, 1, b), "float32") ) _check_inference( bb, relax.op.expand_dims(x, [-1, 1, -6, 3, 5]), - relax.TensorStructInfo((a, 1, 1, 1, 4, 1, b, 1), "float32"), + relax.TensorType((a, 1, 1, 1, 4, 1, b, 1), "float32"), ) - _check_inference(bb, relax.op.expand_dims(x, []), relax.TensorStructInfo((a, 4, b), "float32")) + _check_inference(bb, relax.op.expand_dims(x, []), relax.TensorType((a, 4, b), "float32")) -def test_expand_dims_infer_struct_info_shape_var(): +def test_expand_dims_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=3)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) _check_inference( - bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorType(dtype="float32", ndim=5) ) - _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorType(s0, "float32")) _check_inference( - bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorType(dtype="float32", ndim=5) ) - _check_inference(bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.expand_dims(x1, []), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorType(s2, "float32")) -def test_expand_dims_infer_struct_info_more_input_dtype(): +def test_expand_dims_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) _check_inference( - bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float16") + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorType((2, 1, 3, 1, 4), "float16") ) _check_inference( - bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int8") + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorType((2, 1, 3, 1, 4), "int8") ) _check_inference( - bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int32") + bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorType((2, 1, 3, 1, 4), "int32") ) -def test_expand_dims_infer_struct_info_axis_out_of_range(): +def test_expand_dims_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s0 = relax.Var("s", relax.ShapeType((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=3)) x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) - x2 = relax.Var("x", relax.TensorStructInfo(s0)) - x3 = relax.Var("x", relax.TensorStructInfo(s1)) + x2 = relax.Var("x", relax.TensorType(s0)) + x3 = relax.Var("x", relax.TensorType(s1)) with pytest.raises(ValueError): bb.normalize(relax.op.expand_dims(x0, [1, 5])) @@ -672,14 +632,14 @@ def test_expand_dims_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.expand_dims(x3, [-6, 1])) -def test_expand_dims_infer_struct_info_repetitive_axes(): +def test_expand_dims_infer_ty_repetitive_axes(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s0 = relax.Var("s", relax.ShapeType((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=3)) x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) - x2 = relax.Var("x", relax.TensorStructInfo(s0)) - x3 = relax.Var("x", relax.TensorStructInfo(s1)) + x2 = relax.Var("x", relax.TensorType(s0)) + x3 = relax.Var("x", relax.TensorType(s1)) with pytest.raises(ValueError): bb.normalize(relax.op.expand_dims(x0, [1, 1])) @@ -699,10 +659,10 @@ def test_expand_dims_infer_struct_info_repetitive_axes(): bb.normalize(relax.op.expand_dims(x3, [1, -4])) -def test_expand_dims_infer_struct_info_wrong_input_type(): +def test_expand_dims_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.expand_dims(x0, axis=[])) @@ -710,7 +670,7 @@ def test_expand_dims_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.expand_dims(x1, axis=[])) -def test_layout_transform_infer_struct_info(): +def test_layout_transform_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) @@ -720,37 +680,37 @@ def test_layout_transform_infer_struct_info(): _check_inference( bb, relax.op.layout_transform(x, index_map=transpose_transform), - relax.TensorStructInfo((10, 30, 20), "float32"), + relax.TensorType((10, 30, 20), "float32"), ) _check_inference( bb, relax.op.layout_transform(x1, index_map=transpose_transform), - relax.TensorStructInfo((10, 30, 20), "float32", vdev0), + relax.TensorType((10, 30, 20), "float32", vdev0), ) tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2) _check_inference( bb, relax.op.layout_transform(x, index_map=tiling_transform), - relax.TensorStructInfo((10, 10, 30, 2), "float32"), + relax.TensorType((10, 10, 30, 2), "float32"), ) implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3) _check_inference( bb, relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2), - relax.TensorStructInfo((10, 30, 7, 3), "float32"), + relax.TensorType((10, 30, 7, 3), "float32"), ) flatten_transform = lambda a, b, c: a * 600 + b * 30 + c _check_inference( bb, relax.op.layout_transform(x, index_map=flatten_transform), - relax.TensorStructInfo((6000,), "float32"), + relax.TensorType((6000,), "float32"), ) -def test_layout_transform_infer_struct_info_mismatch_dtype(): +def test_layout_transform_infer_ty_mismatch_dtype(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((10, 20, 30), "int32")) @@ -759,7 +719,7 @@ def test_layout_transform_infer_struct_info_mismatch_dtype(): bb.normalize(relax.op.layout_transform(x, index_map=transpose_transform, pad_value=2.2)) -def test_layout_transform_infer_struct_info_unknown_shape(): +def test_layout_transform_infer_ty_unknown_shape(): bb = relax.BlockBuilder() tiling_transform = lambda a, b: (a, b // 2, b % 2) @@ -767,18 +727,18 @@ def test_layout_transform_infer_struct_info_unknown_shape(): _check_inference( bb, relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) x_unknown_rank_dtype = relax.Var("x", R.Tensor()) _check_inference( bb, relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform), - relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), ) -def test_layout_transform_infer_struct_info_symbolic_shape(): +def test_layout_transform_infer_ty_symbolic_shape(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -788,50 +748,50 @@ def test_layout_transform_infer_struct_info_symbolic_shape(): _check_inference( bb, relax.op.layout_transform(x0, index_map=tiling_transform), - relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + relax.TensorType((a, (b - b % (-3)) // 3, 3), "float32"), ) -def test_layout_transform_infer_struct_info_shape_var(): +def test_layout_transform_infer_ty_shape_var(): bb = relax.BlockBuilder() - s = relax.Var("s", relax.ShapeStructInfo((30, 20))) - x = relax.Var("x", relax.TensorStructInfo(s, "float32")) + s = relax.Var("s", relax.ShapeType((30, 20))) + x = relax.Var("x", relax.TensorType(s, "float32")) tiling_padding_transform = lambda a, b: (a, b // 3, b % 3) _check_inference( bb, relax.op.layout_transform(x, index_map=tiling_padding_transform), - relax.TensorStructInfo((30, 7, 3), "float32"), + relax.TensorType((30, 7, 3), "float32"), ) - s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32")) + s_unknown_shape = relax.Var("s", relax.ShapeType(ndim=2)) + x_unknown_shape = relax.Var("x", relax.TensorType(s_unknown_shape, "float32")) _check_inference( bb, relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) - s_unknown_rank = relax.Var("s", relax.ShapeStructInfo()) - x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32")) + s_unknown_rank = relax.Var("s", relax.ShapeType()) + x_unknown_rank = relax.Var("x", relax.TensorType(s_unknown_rank, "float32")) _check_inference( bb, relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") - s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b))) - x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32")) + s_symbolic_shape = relax.Var("s", relax.ShapeType((a, b))) + x_symbolic_shape = relax.Var("x", relax.TensorType(s_symbolic_shape, "float32")) _check_inference( bb, relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform), - relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + relax.TensorType((a, (b - b % (-3)) // 3, 3), "float32"), ) -def test_layout_transform_infer_struct_info_invalid_index_map(): +def test_layout_transform_infer_ty_invalid_index_map(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) @@ -839,7 +799,7 @@ def test_layout_transform_infer_struct_info_invalid_index_map(): bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a))) -def test_squeeze_infer_struct_info(): +def test_squeeze_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) @@ -850,96 +810,86 @@ def test_squeeze_infer_struct_info(): x5 = relax.Var("x", R.Tensor()) x6 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32", vdev0)) + _check_inference(bb, relax.op.squeeze(x0, [1, 4]), relax.TensorType((2, 3, 1, 4), "float32")) _check_inference( - bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32") + bb, relax.op.squeeze(x6, [1, 4]), relax.TensorType((2, 3, 1, 4), "float32", vdev0) ) - _check_inference( - bb, relax.op.squeeze(x6, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32", vdev0) - ) - _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float32")) - _check_inference( - bb, relax.op.squeeze(x1, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.squeeze(x2, [1, 4]), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), dtype="") - ) - _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) - _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="", ndim=4)) - _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.squeeze(x5, [1, 4]), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.squeeze(x5), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorType((2, 3, 4), "float32")) + _check_inference(bb, relax.op.squeeze(x1, [1, 4]), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2, [1, 4]), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x3, [1, 4]), relax.TensorType((2, 3, 1, 4), dtype="")) + _check_inference(bb, relax.op.squeeze(x3), relax.TensorType((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorType(dtype="", ndim=4)) + _check_inference(bb, relax.op.squeeze(x4), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.squeeze(x5, [1, 4]), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.squeeze(x5), relax.TensorType(dtype="")) -def test_squeeze_infer_struct_info_shape_symbolic(): +def test_squeeze_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") x0 = relax.Var("x", R.Tensor((a, 1, b), "float32")) x1 = relax.Var("x", R.Tensor((a, 1, b))) - _check_inference(bb, relax.op.squeeze(x0, [1]), relax.TensorStructInfo((a, b), "float32")) - _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.squeeze(x1, [1]), relax.TensorStructInfo((a, b), dtype="")) - _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.squeeze(x0, [1]), relax.TensorType((a, b), "float32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x1, [1]), relax.TensorType((a, b), dtype="")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorType(dtype="")) -def test_squeeze_infer_struct_info_shape_var(): +def test_squeeze_infer_ty_shape_var(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s2 = relax.Var("s", relax.ShapeStructInfo((a, 1, b))) - s3 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) - s4 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) - - _check_inference( - bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference(bb, relax.op.squeeze(x0, []), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.squeeze(x1, []), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(s1, dtype="float32")) - _check_inference(bb, relax.op.squeeze(x2, [1]), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.squeeze(x2, []), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference(bb, relax.op.squeeze(x3, []), relax.TensorStructInfo(s3, "float32")) - _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.squeeze(x4, []), relax.TensorStructInfo(s4, "float32")) - _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="float32")) - - -def test_squeeze_infer_struct_info_more_input_dtype(): + s0 = relax.Var("s", relax.ShapeType((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeType((2, 3, 4))) + s2 = relax.Var("s", relax.ShapeType((a, 1, b))) + s3 = relax.Var("s", relax.ShapeType(ndim=6)) + s4 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + x3 = relax.Var("x", relax.TensorType(s3, "float32")) + x4 = relax.Var("x", relax.TensorType(s4, "float32")) + + _check_inference(bb, relax.op.squeeze(x0, [1, 4]), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.squeeze(x0, []), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x1, []), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorType(s1, dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2, [1]), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.squeeze(x2, []), relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x3, [1, 4]), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.squeeze(x3, []), relax.TensorType(s3, "float32")) + _check_inference(bb, relax.op.squeeze(x3), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x4, []), relax.TensorType(s4, "float32")) + _check_inference(bb, relax.op.squeeze(x4), relax.TensorType(dtype="float32")) + + +def test_squeeze_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int8")) x2 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int32")) - _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float16")) - _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo((2, 3, 4), "int8")) - _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorType((2, 3, 4), "float16")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorType((2, 3, 4), "int8")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorType((2, 3, 4), "int32")) -def test_squeeze_infer_struct_info_axis_out_of_range(): +def test_squeeze_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s0 = relax.Var("s", relax.ShapeType((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=6)) x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=6)) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.squeeze(x0, [6])) @@ -959,14 +909,14 @@ def test_squeeze_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.squeeze(x3, [-7])) -def test_squeeze_infer_struct_info_repetitive_axes(): +def test_squeeze_infer_ty_repetitive_axes(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s0 = relax.Var("s", relax.ShapeType((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=6)) x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=6)) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.squeeze(x0, [3, -3])) @@ -986,36 +936,32 @@ def test_squeeze_infer_struct_info_repetitive_axes(): bb.normalize(relax.op.squeeze(x3, [1, 1])) -def test_squeeze_infer_struct_info_axis_length_not_one(): +def test_squeeze_infer_ty_axis_length_not_one(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo((a, 3, 4))) + s0 = relax.Var("s", relax.ShapeType((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeType((a, 3, 4))) x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor((a, 3, 4), "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) # Squeeze concrete shape (2,3,4) at axis=0, but axis length 2 != 1, squeeze is no-op. _check_inference( - bb, relax.op.squeeze(x0, [0]), relax.TensorStructInfo(shape=(2, 3, 4), dtype="float32") + bb, relax.op.squeeze(x0, [0]), relax.TensorType(shape=(2, 3, 4), dtype="float32") ) # Squeeze symbolic shape (a,3,4) at axis=0, assuming a can achieve successful squeeze. - _check_inference( - bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo(shape=(3, 4), dtype="float32") - ) + _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorType(shape=(3, 4), dtype="float32")) # Squeeze shape variable s0 (corresponding to (2,3,4)) at axis=0. - _check_inference( - bb, relax.op.squeeze(x2, [0]), relax.TensorStructInfo(shape=s0, dtype="float32") - ) + _check_inference(bb, relax.op.squeeze(x2, [0]), relax.TensorType(shape=s0, dtype="float32")) # Squeeze shape variable s1 (a,3,4) at axis=0, assuming a can achieve successful squeeze. - _check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorType(dtype="float32", ndim=2)) -def test_squeeze_infer_struct_info_wrong_input_type(): +def test_squeeze_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.squeeze(x0)) @@ -1023,7 +969,7 @@ def test_squeeze_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.squeeze(x1)) -def test_flatten_infer_struct_info(): +def test_flatten_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) @@ -1042,75 +988,75 @@ def test_flatten_infer_struct_info(): x13 = relax.Var("x", R.Tensor()) x14 = relax.Var("x", R.Tensor((3, 4, 5), "float32", vdev0)) - _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float32")) - _check_inference(bb, relax.op.flatten(x14), relax.TensorStructInfo((60,), "float32", vdev0)) - _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((3,), "float32")) - _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) - _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) - _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.flatten(x7), relax.TensorStructInfo((60,), dtype="")) - _check_inference(bb, relax.op.flatten(x8), relax.TensorStructInfo((3,), dtype="")) - _check_inference(bb, relax.op.flatten(x9), relax.TensorStructInfo((1,), dtype="")) - _check_inference(bb, relax.op.flatten(x10), relax.TensorStructInfo(dtype="", ndim=1)) - _check_inference(bb, relax.op.flatten(x11), relax.TensorStructInfo(dtype="", ndim=1)) - _check_inference(bb, relax.op.flatten(x12), relax.TensorStructInfo((1,), dtype="")) - _check_inference(bb, relax.op.flatten(x13), relax.TensorStructInfo(dtype="", ndim=1)) - - -def test_flatten_infer_struct_info_shape_symbolic(): + _check_inference(bb, relax.op.flatten(x0), relax.TensorType((60,), "float32")) + _check_inference(bb, relax.op.flatten(x14), relax.TensorType((60,), "float32", vdev0)) + _check_inference(bb, relax.op.flatten(x1), relax.TensorType((3,), "float32")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorType((1,), "float32")) + _check_inference(bb, relax.op.flatten(x3), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x4), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x5), relax.TensorType((1,), "float32")) + _check_inference(bb, relax.op.flatten(x6), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x7), relax.TensorType((60,), dtype="")) + _check_inference(bb, relax.op.flatten(x8), relax.TensorType((3,), dtype="")) + _check_inference(bb, relax.op.flatten(x9), relax.TensorType((1,), dtype="")) + _check_inference(bb, relax.op.flatten(x10), relax.TensorType(dtype="", ndim=1)) + _check_inference(bb, relax.op.flatten(x11), relax.TensorType(dtype="", ndim=1)) + _check_inference(bb, relax.op.flatten(x12), relax.TensorType((1,), dtype="")) + _check_inference(bb, relax.op.flatten(x13), relax.TensorType(dtype="", ndim=1)) + + +def test_flatten_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") x0 = relax.Var("x", R.Tensor((a, b), "float32")) x1 = relax.Var("x", R.Tensor((a, b))) - _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((a * b,), "float32")) - _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((a * b,), dtype="")) + _check_inference(bb, relax.op.flatten(x0), relax.TensorType((a * b,), "float32")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorType((a * b,), dtype="")) -def test_flatten_infer_struct_info_shape_var(): +def test_flatten_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) - s1 = relax.Var("s", relax.ShapeStructInfo((3,))) - s2 = relax.Var("s", relax.ShapeStructInfo(())) - s3 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s4 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) - s5 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) - s6 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) - x6 = relax.Var("x", relax.TensorStructInfo(s6, "float32")) + s0 = relax.Var("s", relax.ShapeType((3, 4, 5))) + s1 = relax.Var("s", relax.ShapeType((3,))) + s2 = relax.Var("s", relax.ShapeType(())) + s3 = relax.Var("s", relax.ShapeType(ndim=3)) + s4 = relax.Var("s", relax.ShapeType(ndim=1)) + s5 = relax.Var("s", relax.ShapeType(ndim=0)) + s6 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + x3 = relax.Var("x", relax.TensorType(s3, "float32")) + x4 = relax.Var("x", relax.TensorType(s4, "float32")) + x5 = relax.Var("x", relax.TensorType(s5, "float32")) + x6 = relax.Var("x", relax.TensorType(s6, "float32")) - _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) - _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) - _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(s4, "float32")) - _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) - _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x0), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorType((1,), "float32")) + _check_inference(bb, relax.op.flatten(x3), relax.TensorType(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x4), relax.TensorType(s4, "float32")) + _check_inference(bb, relax.op.flatten(x5), relax.TensorType((1,), "float32")) + _check_inference(bb, relax.op.flatten(x6), relax.TensorType(dtype="float32", ndim=1)) -def test_flatten_infer_struct_info_more_input_dtype(): +def test_flatten_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float16")) x1 = relax.Var("x", R.Tensor((3, 4, 5), "int8")) x2 = relax.Var("x", R.Tensor((3, 4, 5), "int32")) - _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float16")) - _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((60,), "int8")) - _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((60,), "int32")) + _check_inference(bb, relax.op.flatten(x0), relax.TensorType((60,), "float16")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorType((60,), "int8")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorType((60,), "int32")) -def test_flatten_infer_struct_info_wrong_input_type(): +def test_flatten_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((3, 4, 5), "float32"))) + x0 = relax.Var("x", relax.ShapeType((3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((3, 4, 5), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.flatten(x0)) @@ -1126,7 +1072,7 @@ def test_flatten_wrong_input_number(): relax.op.flatten(x, y) -def test_concat_infer_struct_info_with_axis(): +def test_concat_infer_ty_with_axis(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -1152,108 +1098,80 @@ def test_concat_infer_struct_info_with_axis(): z6 = relax.Var("z", R.Tensor((2, 5, 4), "float32", vdev0)) _check_inference( - bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), "float32") + bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorType((2, 12, 4), "float32") ) _check_inference( bb, relax.op.concat([x6, y6, z6], axis=1), - relax.TensorStructInfo((2, 12, 4), "float32", vdev0), + relax.TensorType((2, 12, 4), "float32", vdev0), ) _check_inference( bb, relax.op.concat([x6, y0, z0], axis=1), - relax.TensorStructInfo((2, 12, 4), "float32", vdev0), - ) - _check_inference( - bb, relax.op.concat([x0, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), "float32") - ) - _check_inference( - bb, relax.op.concat([x1, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x2, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") - ) - _check_inference( - bb, relax.op.concat([x3, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") - ) - _check_inference( - bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x5, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x1, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x2, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x3, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x5, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x2, y2, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x3, y2, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + relax.TensorType((2, 12, 4), "float32", vdev0), ) _check_inference( - bb, relax.op.concat([x5, y5, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + bb, relax.op.concat([x0, y0, z0], axis=-2), relax.TensorType((2, 12, 4), "float32") ) _check_inference( - bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.concat([x1, y0, z0], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.concat([x2, y2, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.concat([x2, y0, z0], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.concat([x3, y1, z1], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorType((2, 12, 4), dtype="") ) _check_inference( - bb, relax.op.concat([x2, y2, z2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=-1) + bb, relax.op.concat([x3, y0, z0], axis=-2), relax.TensorType((2, 12, 4), dtype="") ) + _check_inference(bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x5, y0, z0], axis=1), relax.TensorType(dtype="", ndim=3)) _check_inference( - bb, relax.op.concat([x3, y2, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + bb, relax.op.concat([x1, y1, z0], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.concat([x4, y4, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + bb, relax.op.concat([x2, y1, z0], axis=1), relax.TensorType(dtype="float32", ndim=3) ) + _check_inference(bb, relax.op.concat([x3, y1, z0], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x5, y1, z0], axis=1), relax.TensorType(dtype="", ndim=3)) _check_inference( - bb, relax.op.concat([x5, y5, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=-1) + bb, relax.op.concat([x2, y2, z0], axis=1), relax.TensorType(dtype="float32", ndim=3) ) + _check_inference(bb, relax.op.concat([x3, y2, z0], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x5, y5, z0], axis=1), relax.TensorType(dtype="", ndim=3)) _check_inference( - bb, relax.op.concat([x3, y3, z3], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") + bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.concat([x3, y3, z3], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") + bb, relax.op.concat([x2, y2, z1], axis=1), relax.TensorType(dtype="float32", ndim=3) ) + _check_inference(bb, relax.op.concat([x3, y1, z1], axis=1), relax.TensorType(dtype="", ndim=3)) _check_inference( - bb, relax.op.concat([x4, y3, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + bb, relax.op.concat([x2, y2, z2], axis=1), relax.TensorType(dtype="float32", ndim=-1) ) + _check_inference(bb, relax.op.concat([x3, y2, z2], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x4, y4, z2], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x5, y5, z2], axis=1), relax.TensorType(dtype="", ndim=-1)) _check_inference( - bb, relax.op.concat([x5, y5, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + bb, relax.op.concat([x3, y3, z3], axis=1), relax.TensorType((2, 12, 4), dtype="") ) _check_inference( - bb, relax.op.concat([x4, y4, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + bb, relax.op.concat([x3, y3, z3], axis=-2), relax.TensorType((2, 12, 4), dtype="") ) - _check_inference( - bb, relax.op.concat([x5, y5, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) - ) - _check_inference(bb, relax.op.concat([x5, y5, z5], axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.concat([x4, y3, z3], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x5, y5, z3], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x4, y4, z4], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x5, y5, z4], axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.concat([x5, y5, z5], axis=1), relax.TensorType(dtype="")) _check_inference( bb, relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), - relax.TensorStructInfo((2, 12, 4), "float32"), + relax.TensorType((2, 12, 4), "float32"), ) -def test_concat_infer_struct_info_with_axis_shape_symbolic(): +def test_concat_infer_ty_with_axis_shape_symbolic(): bb = relax.BlockBuilder() a0 = tirx.Var("a0", "int64") a1 = tirx.Var("a1", "int64") @@ -1270,29 +1188,29 @@ def test_concat_infer_struct_info_with_axis_shape_symbolic(): _check_inference( bb, relax.op.concat([x0, y, z], axis=1), - relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + relax.TensorType((a0, b0 + b1 + b2, c), "float32"), ) _check_inference( bb, relax.op.concat([x0, y, z], axis=-2), - relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + relax.TensorType((a0, b0 + b1 + b2, c), "float32"), ) _check_inference( - bb, relax.op.concat([x1, y, z], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.concat([x1, y, z], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( bb, relax.op.concat(relax.Tuple([x0, y, z]), axis=1), - relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + relax.TensorType((a0, b0 + b1 + b2, c), "float32"), ) _check_inference( bb, relax.op.concat(relax.Tuple([x0, x2]), axis=1), - relax.TensorStructInfo((a0, b0 * 2, c), "float32"), + relax.TensorType((a0, b0 * 2, c), "float32"), ) -def test_concat_infer_struct_info_with_axis_shape_var(): +def test_concat_infer_ty_with_axis_shape_var(): bb = relax.BlockBuilder() a0 = tirx.Var("a0", "int64") a1 = tirx.Var("a1", "int64") @@ -1300,44 +1218,44 @@ def test_concat_infer_struct_info_with_axis_shape_var(): b1 = tirx.Var("b1", "int64") b2 = tirx.Var("b2", "int64") c = tirx.Var("c", "int64") - sx0 = relax.Var("sx", relax.ShapeStructInfo((2, 3, 4))) - sx1 = relax.Var("sx", relax.ShapeStructInfo((a0, b0, c))) - sx2 = relax.Var("sx", relax.ShapeStructInfo((a1, b0, c))) - sx3 = relax.Var("sx", relax.ShapeStructInfo(ndim=3)) - sx4 = relax.Var("sx", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(sx3, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(sx4, "float32")) + sx0 = relax.Var("sx", relax.ShapeType((2, 3, 4))) + sx1 = relax.Var("sx", relax.ShapeType((a0, b0, c))) + sx2 = relax.Var("sx", relax.ShapeType((a1, b0, c))) + sx3 = relax.Var("sx", relax.ShapeType(ndim=3)) + sx4 = relax.Var("sx", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(sx0, "float32")) + x1 = relax.Var("x", relax.TensorType(sx1, "float32")) + x2 = relax.Var("x", relax.TensorType(sx2, "float32")) + x3 = relax.Var("x", relax.TensorType(sx3, "float32")) + x4 = relax.Var("x", relax.TensorType(sx4, "float32")) y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) y1 = relax.Var("y", R.Tensor((a0, b1, c), "float32")) z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) z1 = relax.Var("z", R.Tensor((a0, b2, c), "float32")) _check_inference( - bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.concat([x2, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.concat([x2, y1, z1], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( bb, relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) -def test_concat_infer_struct_info_without_axis(): +def test_concat_infer_ty_without_axis(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3,), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=1)) @@ -1349,40 +1267,40 @@ def test_concat_infer_struct_info_without_axis(): z1 = relax.Var("z", R.Tensor("float32", ndim=1)) _check_inference( - bb, relax.op.concat([x0, y0, z0], axis=None), relax.TensorStructInfo((12,), "float32") + bb, relax.op.concat([x0, y0, z0], axis=None), relax.TensorType((12,), "float32") ) _check_inference( bb, relax.op.concat([x1, y0, z0], axis=None), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( - bb, relax.op.concat([x2, y0, z0], axis=None), relax.TensorStructInfo((12,), dtype="") + bb, relax.op.concat([x2, y0, z0], axis=None), relax.TensorType((12,), dtype="") ) _check_inference( - bb, relax.op.concat([x3, y0, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) + bb, relax.op.concat([x3, y0, z0], axis=None), relax.TensorType(dtype="", ndim=1) ) _check_inference( bb, relax.op.concat([x1, y1, z0], axis=None), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( - bb, relax.op.concat([x2, y1, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) + bb, relax.op.concat([x2, y1, z0], axis=None), relax.TensorType(dtype="", ndim=1) ) _check_inference( bb, relax.op.concat([x1, y1, z1], axis=None), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( bb, relax.op.concat(relax.Tuple([x0, y0, z0]), axis=None), - relax.TensorStructInfo((12,), "float32"), + relax.TensorType((12,), "float32"), ) -def test_concat_infer_struct_info_without_axis_shape_symbolic(): +def test_concat_infer_ty_without_axis_shape_symbolic(): bb = relax.BlockBuilder() a0 = tirx.Var("a0", "int64") a1 = tirx.Var("a1", "int64") @@ -1392,47 +1310,47 @@ def test_concat_infer_struct_info_without_axis_shape_symbolic(): y1 = relax.Var("y", R.Tensor((a1,), "")) _check_inference( - bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((a0 + a1,), "float32") + bb, relax.op.concat([x0, y0], axis=None), relax.TensorType((a0 + a1,), "float32") ) _check_inference( - bb, relax.op.concat([x0, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + bb, relax.op.concat([x0, y1], axis=None), relax.TensorType((a0 + a1,), dtype="") ) _check_inference( - bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + bb, relax.op.concat([x1, y0], axis=None), relax.TensorType((a0 + a1,), dtype="") ) _check_inference( - bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + bb, relax.op.concat([x1, y1], axis=None), relax.TensorType((a0 + a1,), dtype="") ) _check_inference( bb, relax.op.concat(relax.Tuple([x0, y0]), axis=None), - relax.TensorStructInfo((a0 + a1,), "float32"), + relax.TensorType((a0 + a1,), "float32"), ) -def test_concat_infer_struct_info_without_axis_shape_var(): +def test_concat_infer_ty_without_axis_shape_var(): bb = relax.BlockBuilder() - sx0 = relax.Var("sx", relax.ShapeStructInfo((3,))) - sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=1)) - sy0 = relax.Var("sy", relax.ShapeStructInfo((4,))) - x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) - y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) + sx0 = relax.Var("sx", relax.ShapeType((3,))) + sx1 = relax.Var("sx", relax.ShapeType(ndim=1)) + sy0 = relax.Var("sy", relax.ShapeType((4,))) + x0 = relax.Var("x", relax.TensorType(sx0, "float32")) + x1 = relax.Var("x", relax.TensorType(sx1, "float32")) + y0 = relax.Var("y", relax.TensorType(sy0, "float32")) _check_inference( - bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + bb, relax.op.concat([x0, y0], axis=None), relax.TensorType(dtype="float32", ndim=1) ) _check_inference( - bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + bb, relax.op.concat([x1, y0], axis=None), relax.TensorType(dtype="float32", ndim=1) ) _check_inference( bb, relax.op.concat(relax.Tuple([x0, y0]), axis=None), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) -def test_concat_infer_struct_info_more_input_dtype(): +def test_concat_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3,), "float16")) y0 = relax.Var("y", R.Tensor((4,), "float16")) @@ -1441,146 +1359,116 @@ def test_concat_infer_struct_info_more_input_dtype(): x2 = relax.Var("x", R.Tensor((3,), "int32")) y2 = relax.Var("y", R.Tensor((4,), "int32")) - _check_inference( - bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((7,), "float16") - ) - _check_inference(bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((7,), "int8")) - _check_inference( - bb, relax.op.concat([x2, y2], axis=None), relax.TensorStructInfo((7,), "int32") - ) + _check_inference(bb, relax.op.concat([x0, y0], axis=None), relax.TensorType((7,), "float16")) + _check_inference(bb, relax.op.concat([x1, y1], axis=None), relax.TensorType((7,), "int8")) + _check_inference(bb, relax.op.concat([x2, y2], axis=None), relax.TensorType((7,), "int32")) -def test_concat_infer_struct_info_tuple_var(): +def test_concat_infer_ty_tuple_var(): bb = relax.BlockBuilder() a = tirx.Var("a0", "int64") b0 = tirx.Var("b0", "int64") b1 = tirx.Var("b1", "int64") t0 = relax.Var( "t", - relax.TupleStructInfo( - [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1), "float32")] + relax.TupleType( + [relax.TensorType((a, b0), "float32"), relax.TensorType((a, b1), "float32")] ), ) t1 = relax.Var( "t", - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((a, b0), "float32"), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType((a, b0), "float32"), + relax.TensorType(dtype="float32", ndim=2), ] ), ) t2 = relax.Var( "t", - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="float32", ndim=2), ] ), ) t3 = relax.Var( "t", - relax.TupleStructInfo( - [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] - ), + relax.TupleType([relax.TensorType(dtype="float32"), relax.TensorType(dtype="float32")]), ) t4 = relax.Var( "t", - relax.TupleStructInfo( - [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1))] - ), + relax.TupleType([relax.TensorType((a, b0), "float32"), relax.TensorType((a, b1))]), ) t5 = relax.Var( "t", - relax.TupleStructInfo( - [relax.TensorStructInfo((a, b0), dtype=""), relax.TensorStructInfo((a, b1), dtype="")] - ), + relax.TupleType([relax.TensorType((a, b0), dtype=""), relax.TensorType((a, b1), dtype="")]), ) t6 = relax.Var( "t", - relax.TupleStructInfo( - [relax.TensorStructInfo(dtype="", ndim=2), relax.TensorStructInfo(dtype="")] - ), + relax.TupleType([relax.TensorType(dtype="", ndim=2), relax.TensorType(dtype="")]), ) t7 = relax.Var( "t", - relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), + relax.TupleType([relax.TensorType(dtype=""), relax.TensorType(dtype="")]), ) - _check_inference( - bb, relax.op.concat(t0, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") - ) - _check_inference( - bb, relax.op.concat(t1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.concat(t2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.concat(t3, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.concat(t4, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") - ) - _check_inference( - bb, relax.op.concat(t5, axis=1), relax.TensorStructInfo((a, b0 + b1), dtype="") - ) - _check_inference(bb, relax.op.concat(t6, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.concat(t7, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.concat(t0, axis=1), relax.TensorType((a, b0 + b1), "float32")) + _check_inference(bb, relax.op.concat(t1, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.concat(t2, axis=1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.concat(t3, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.concat(t4, axis=1), relax.TensorType((a, b0 + b1), "float32")) + _check_inference(bb, relax.op.concat(t5, axis=1), relax.TensorType((a, b0 + b1), dtype="")) + _check_inference(bb, relax.op.concat(t6, axis=1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.concat(t7, axis=1), relax.TensorType(dtype="")) -def test_concat_infer_struct_info_single_input_tensor(): +def test_concat_infer_ty_single_input_tensor(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((3, a))) - s1 = relax.Var("s", relax.ShapeStructInfo((a,))) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) - s4 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType((3, a))) + s1 = relax.Var("s", relax.ShapeType((a,))) + s2 = relax.Var("s", relax.ShapeType(ndim=3)) + s3 = relax.Var("s", relax.ShapeType(ndim=1)) + s4 = relax.Var("s", relax.ShapeType()) x0 = relax.Var("x", R.Tensor((3, a), "float32")) x1 = relax.Var("x", R.Tensor((a,), "float32")) x2 = relax.Var("x", R.Tensor("float32", ndim=3)) x3 = relax.Var("x", R.Tensor("float32", ndim=1)) x4 = relax.Var("x", R.Tensor("float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x6 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x7 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - x8 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) - x9 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) - - _check_inference(bb, relax.op.concat([x0], axis=1), relax.TensorStructInfo((3, a), "float32")) - _check_inference(bb, relax.op.concat([x1], axis=0), relax.TensorStructInfo((a,), "float32")) - _check_inference(bb, relax.op.concat([x1], axis=None), relax.TensorStructInfo((a,), "float32")) - _check_inference( - bb, relax.op.concat([x2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.concat([x3], axis=0), relax.TensorStructInfo(dtype="float32", ndim=1) - ) - _check_inference( - bb, relax.op.concat([x3], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) - ) - _check_inference(bb, relax.op.concat([x4], axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.concat([x5], axis=1), relax.TensorStructInfo(s0, dtype="float32")) - _check_inference(bb, relax.op.concat([x6], axis=0), relax.TensorStructInfo(s1, dtype="float32")) - _check_inference( - bb, relax.op.concat([x6], axis=None), relax.TensorStructInfo(s1, dtype="float32") - ) - _check_inference(bb, relax.op.concat([x7], axis=1), relax.TensorStructInfo(s2, dtype="float32")) - _check_inference(bb, relax.op.concat([x8], axis=0), relax.TensorStructInfo(s3, dtype="float32")) - _check_inference( - bb, relax.op.concat([x8], axis=None), relax.TensorStructInfo(s3, dtype="float32") - ) - _check_inference(bb, relax.op.concat([x9], axis=1), relax.TensorStructInfo(s4, dtype="float32")) - - -def test_concat_infer_struct_info_zero_rank_input_tensor(): - bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(())) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x5 = relax.Var("x", relax.TensorType(s0, "float32")) + x6 = relax.Var("x", relax.TensorType(s1, "float32")) + x7 = relax.Var("x", relax.TensorType(s2, "float32")) + x8 = relax.Var("x", relax.TensorType(s3, "float32")) + x9 = relax.Var("x", relax.TensorType(s4, "float32")) + + _check_inference(bb, relax.op.concat([x0], axis=1), relax.TensorType((3, a), "float32")) + _check_inference(bb, relax.op.concat([x1], axis=0), relax.TensorType((a,), "float32")) + _check_inference(bb, relax.op.concat([x1], axis=None), relax.TensorType((a,), "float32")) + _check_inference(bb, relax.op.concat([x2], axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.concat([x3], axis=0), relax.TensorType(dtype="float32", ndim=1)) + _check_inference( + bb, relax.op.concat([x3], axis=None), relax.TensorType(dtype="float32", ndim=1) + ) + _check_inference(bb, relax.op.concat([x4], axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.concat([x5], axis=1), relax.TensorType(s0, dtype="float32")) + _check_inference(bb, relax.op.concat([x6], axis=0), relax.TensorType(s1, dtype="float32")) + _check_inference(bb, relax.op.concat([x6], axis=None), relax.TensorType(s1, dtype="float32")) + _check_inference(bb, relax.op.concat([x7], axis=1), relax.TensorType(s2, dtype="float32")) + _check_inference(bb, relax.op.concat([x8], axis=0), relax.TensorType(s3, dtype="float32")) + _check_inference(bb, relax.op.concat([x8], axis=None), relax.TensorType(s3, dtype="float32")) + _check_inference(bb, relax.op.concat([x9], axis=1), relax.TensorType(s4, dtype="float32")) + + +def test_concat_infer_ty_zero_rank_input_tensor(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeType(())) + s1 = relax.Var("s", relax.ShapeType(ndim=0)) x0 = relax.Var("x", R.Tensor((), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=0)) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.concat([x0], axis=0)) @@ -1592,7 +1480,7 @@ def test_concat_infer_struct_info_zero_rank_input_tensor(): bb.normalize(relax.op.concat([x3], axis=None)) -def test_concat_infer_struct_info_no_input_tensor(): +def test_concat_infer_ty_no_input_tensor(): bb = relax.BlockBuilder() with pytest.raises(ValueError): bb.normalize(relax.op.concat([], axis=1)) @@ -1600,31 +1488,31 @@ def test_concat_infer_struct_info_no_input_tensor(): bb.normalize(relax.op.concat([], axis=None)) -def test_concat_infer_struct_info_without_axis_but_tensor_not_one_dimensional(): +def test_concat_infer_ty_without_axis_but_tensor_not_one_dimensional(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType((3, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + s2 = relax.Var("s", relax.ShapeType()) x0 = relax.Var("x", R.Tensor((3, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) x2 = relax.Var("x", R.Tensor("float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorType(s0, "float32")) + x4 = relax.Var("x", relax.TensorType(s1, "float32")) + x5 = relax.Var("x", relax.TensorType(s2, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.concat([x0], axis=None)) with pytest.raises(ValueError): bb.normalize(relax.op.concat([x1], axis=None)) - _check_inference(bb, relax.op.concat([x2], axis=None), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.concat([x2], axis=None), relax.TensorType(dtype="float32")) with pytest.raises(ValueError): bb.normalize(relax.op.concat([x3], axis=None)) with pytest.raises(ValueError): bb.normalize(relax.op.concat([x4], axis=None)) - _check_inference(bb, relax.op.concat([x5], axis=None), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.concat([x5], axis=None), relax.TensorType(s2, "float32")) -def test_concat_infer_struct_info_inconsistent_dtype(): +def test_concat_infer_ty_inconsistent_dtype(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((3,))) y = relax.Var("y", R.Tensor((4,), "float32")) @@ -1634,15 +1522,15 @@ def test_concat_infer_struct_info_inconsistent_dtype(): bb.normalize(relax.op.concat([x, y, z], axis=0)) -def test_concat_infer_struct_info_inconsistent_ndim(): +def test_concat_infer_ty_inconsistent_ndim(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((4, 5))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s0 = relax.Var("s", relax.ShapeType((4, 5))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) x = relax.Var("x", R.Tensor((3,), "float32")) y0 = relax.Var("y", R.Tensor((4, 5), "float32")) y1 = relax.Var("y", R.Tensor("float32", ndim=2)) - y2 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) - y3 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y2 = relax.Var("y", relax.TensorType(s0, "float32")) + y3 = relax.Var("y", relax.TensorType(s1, "float32")) z = relax.Var("z", R.Tensor((5,), "float32")) with pytest.raises(ValueError): @@ -1655,14 +1543,14 @@ def test_concat_infer_struct_info_inconsistent_ndim(): bb.normalize(relax.op.concat([x, y3, z], axis=0)) -def test_concat_infer_struct_info_axis_out_of_range(): +def test_concat_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((3,))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s0 = relax.Var("s", relax.ShapeType((3,))) + s1 = relax.Var("s", relax.ShapeType(ndim=1)) x0 = relax.Var("x", R.Tensor((3,), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=1)) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.concat([x0], axis=1)) @@ -1674,15 +1562,15 @@ def test_concat_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.concat([x3], axis=1)) -def test_concat_infer_struct_info_unequal_shape(): +def test_concat_infer_ty_unequal_shape(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo((3, a + 2))) + s0 = relax.Var("s", relax.ShapeType((3, 4))) + s1 = relax.Var("s", relax.ShapeType((3, a + 2))) x0 = relax.Var("x", R.Tensor((3, 4), "float32")) x1 = relax.Var("x", R.Tensor((3, a + 2), "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) y0 = relax.Var("y", R.Tensor((3, 3), "float32")) y1 = relax.Var("y", R.Tensor((3, a), "float32")) @@ -1696,10 +1584,10 @@ def test_concat_infer_struct_info_unequal_shape(): bb.normalize(relax.op.concat([x3, y1])) -def test_concat_infer_struct_info_input_not_tuple(): +def test_concat_infer_ty_input_not_tuple(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((3,), "float32")) - s = relax.Var("s", relax.ShapeStructInfo((3,))) + s = relax.Var("s", relax.ShapeType((3,))) with pytest.raises(TypeError): bb.normalize(relax.op.concat(x)) @@ -1707,15 +1595,15 @@ def test_concat_infer_struct_info_input_not_tuple(): bb.normalize(relax.op.concat(s)) -def test_concat_infer_struct_info_input_tuple_field_not_tensor(): +def test_concat_infer_ty_input_tuple_field_not_tensor(): bb = relax.BlockBuilder() - s = relax.Var("s", relax.ShapeStructInfo((3,))) + s = relax.Var("s", relax.ShapeType((3,))) with pytest.raises(TypeError): bb.normalize(relax.op.concat([s])) -def test_split_infer_struct_info_by_indices(): +def test_split_infer_ty_by_indices(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -1729,110 +1617,110 @@ def test_split_infer_struct_info_by_indices(): _check_inference( bb, relax.op.split(x0, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 4), "float32"), - relax.TensorStructInfo((2, 4, 4), "float32"), - relax.TensorStructInfo((2, 3, 4), "float32"), + relax.TensorType((2, 3, 4), "float32"), + relax.TensorType((2, 4, 4), "float32"), + relax.TensorType((2, 3, 4), "float32"), ] ), ) _check_inference( bb, relax.op.split(x6, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 4), "float32", vdev0), - relax.TensorStructInfo((2, 4, 4), "float32", vdev0), - relax.TensorStructInfo((2, 3, 4), "float32", vdev0), + relax.TensorType((2, 3, 4), "float32", vdev0), + relax.TensorType((2, 4, 4), "float32", vdev0), + relax.TensorType((2, 3, 4), "float32", vdev0), ] ), ) _check_inference( bb, relax.op.split(x0, [3, 7], axis=-2), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 4), "float32"), - relax.TensorStructInfo((2, 4, 4), "float32"), - relax.TensorStructInfo((2, 3, 4), "float32"), + relax.TensorType((2, 3, 4), "float32"), + relax.TensorType((2, 4, 4), "float32"), + relax.TensorType((2, 3, 4), "float32"), ] ), ) _check_inference( bb, relax.op.split(x1, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ] ), ) _check_inference( bb, relax.op.split(x2, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="float32"), ] ), ) _check_inference( bb, relax.op.split(x3, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 4), dtype=""), - relax.TensorStructInfo((2, 4, 4), dtype=""), - relax.TensorStructInfo((2, 3, 4), dtype=""), + relax.TensorType((2, 3, 4), dtype=""), + relax.TensorType((2, 4, 4), dtype=""), + relax.TensorType((2, 3, 4), dtype=""), ] ), ) _check_inference( bb, relax.op.split(x4, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="", ndim=3), - relax.TensorStructInfo(dtype="", ndim=3), - relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), ] ), ) _check_inference( bb, relax.op.split(x5, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype=""), - relax.TensorStructInfo(dtype=""), - relax.TensorStructInfo(dtype=""), + relax.TensorType(dtype=""), + relax.TensorType(dtype=""), + relax.TensorType(dtype=""), ] ), ) _check_inference( bb, relax.op.split(x0, [-2, 2, 6, 4, 8, 12, 9], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 0, 4), "float32"), - relax.TensorStructInfo((2, 2, 4), "float32"), - relax.TensorStructInfo((2, 4, 4), "float32"), - relax.TensorStructInfo((2, 0, 4), "float32"), - relax.TensorStructInfo((2, 4, 4), "float32"), - relax.TensorStructInfo((2, 2, 4), "float32"), - relax.TensorStructInfo((2, 0, 4), "float32"), - relax.TensorStructInfo((2, 1, 4), "float32"), + relax.TensorType((2, 0, 4), "float32"), + relax.TensorType((2, 2, 4), "float32"), + relax.TensorType((2, 4, 4), "float32"), + relax.TensorType((2, 0, 4), "float32"), + relax.TensorType((2, 4, 4), "float32"), + relax.TensorType((2, 2, 4), "float32"), + relax.TensorType((2, 0, 4), "float32"), + relax.TensorType((2, 1, 4), "float32"), ] ), ) -def test_split_infer_struct_info_by_indices_shape_symbolic(): +def test_split_infer_ty_by_indices_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -1841,55 +1729,53 @@ def test_split_infer_struct_info_by_indices_shape_symbolic(): _check_inference( bb, relax.op.split(x, [10, 20], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo([a, T.max(T.min(10, b) - T.min(0, b), 0)], dtype="float32"), - relax.TensorStructInfo([a, T.max(T.min(20, b) - T.min(10, b), 0)], dtype="float32"), - relax.TensorStructInfo([a, T.max(b - 20, 0)], dtype="float32"), + relax.TensorType([a, T.max(T.min(10, b) - T.min(0, b), 0)], dtype="float32"), + relax.TensorType([a, T.max(T.min(20, b) - T.min(10, b), 0)], dtype="float32"), + relax.TensorType([a, T.max(b - 20, 0)], dtype="float32"), ] ), ) -def test_split_infer_struct_info_by_indices_shape_var(): +def test_split_infer_ty_by_indices_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 10, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=3)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) _check_inference( bb, relax.op.split(x0, [3], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ] ), ) _check_inference( bb, relax.op.split(x1, [3], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ] ), ) _check_inference( bb, relax.op.split(x2, [3], axis=1), - relax.TupleStructInfo( - [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] - ), + relax.TupleType([relax.TensorType(dtype="float32"), relax.TensorType(dtype="float32")]), ) -def test_split_infer_struct_info_by_n_section(): +def test_split_infer_ty_by_n_section(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -1901,93 +1787,93 @@ def test_split_infer_struct_info_by_n_section(): _check_inference( bb, relax.op.split(x0, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 4, 4), "float32"), - relax.TensorStructInfo((2, 4, 4), "float32"), - relax.TensorStructInfo((2, 2, 4), "float32"), + relax.TensorType((2, 4, 4), "float32"), + relax.TensorType((2, 4, 4), "float32"), + relax.TensorType((2, 2, 4), "float32"), ] ), ) _check_inference( bb, relax.op.split(x0, 2, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 5, 4), "float32"), - relax.TensorStructInfo((2, 5, 4), "float32"), + relax.TensorType((2, 5, 4), "float32"), + relax.TensorType((2, 5, 4), "float32"), ] ), ) _check_inference( bb, relax.op.split(x0, 3, axis=-2), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 4, 4), "float32"), - relax.TensorStructInfo((2, 4, 4), "float32"), - relax.TensorStructInfo((2, 2, 4), "float32"), + relax.TensorType((2, 4, 4), "float32"), + relax.TensorType((2, 4, 4), "float32"), + relax.TensorType((2, 2, 4), "float32"), ] ), ) _check_inference( bb, relax.op.split(x1, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ] ), ) _check_inference( bb, relax.op.split(x2, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="float32"), ] ), ) _check_inference( bb, relax.op.split(x3, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 4, 4), dtype=""), - relax.TensorStructInfo((2, 4, 4), dtype=""), - relax.TensorStructInfo((2, 2, 4), dtype=""), + relax.TensorType((2, 4, 4), dtype=""), + relax.TensorType((2, 4, 4), dtype=""), + relax.TensorType((2, 2, 4), dtype=""), ] ), ) _check_inference( bb, relax.op.split(x4, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="", ndim=3), - relax.TensorStructInfo(dtype="", ndim=3), - relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), ] ), ) _check_inference( bb, relax.op.split(x5, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype=""), - relax.TensorStructInfo(dtype=""), - relax.TensorStructInfo(dtype=""), + relax.TensorType(dtype=""), + relax.TensorType(dtype=""), + relax.TensorType(dtype=""), ] ), ) -def test_split_infer_struct_info_by_n_section_shape_symbolic(): +def test_split_infer_ty_by_n_section_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -1996,61 +1882,61 @@ def test_split_infer_struct_info_by_n_section_shape_symbolic(): _check_inference( bb, relax.op.split(x, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((a, (b + 2) // 3), "float32"), - relax.TensorStructInfo((a, (b + 2) // 3), "float32"), - relax.TensorStructInfo((a, b - (b + 2) // 3 * 2), "float32"), + relax.TensorType((a, (b + 2) // 3), "float32"), + relax.TensorType((a, (b + 2) // 3), "float32"), + relax.TensorType((a, b - (b + 2) // 3 * 2), "float32"), ] ), ) -def test_split_infer_struct_info_by_n_section_shape_var(): +def test_split_infer_ty_by_n_section_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 10, 4))) + s1 = relax.Var("s", relax.ShapeType(ndim=3)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) _check_inference( bb, relax.op.split(x0, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ] ), ) _check_inference( bb, relax.op.split(x1, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ] ), ) _check_inference( bb, relax.op.split(x2, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="float32"), ] ), ) -def test_split_infer_struct_info_more_input_dtype(): +def test_split_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 10, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 10, 4), "int8")) @@ -2058,122 +1944,122 @@ def test_split_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.split(x0, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 4), "float16"), - relax.TensorStructInfo((2, 4, 4), "float16"), - relax.TensorStructInfo((2, 3, 4), "float16"), + relax.TensorType((2, 3, 4), "float16"), + relax.TensorType((2, 4, 4), "float16"), + relax.TensorType((2, 3, 4), "float16"), ] ), ) _check_inference( bb, relax.op.split(x1, [3, 7], axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 4), "int8"), - relax.TensorStructInfo((2, 4, 4), "int8"), - relax.TensorStructInfo((2, 3, 4), "int8"), + relax.TensorType((2, 3, 4), "int8"), + relax.TensorType((2, 4, 4), "int8"), + relax.TensorType((2, 3, 4), "int8"), ] ), ) _check_inference( bb, relax.op.split(x0, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 4, 4), "float16"), - relax.TensorStructInfo((2, 4, 4), "float16"), - relax.TensorStructInfo((2, 2, 4), "float16"), + relax.TensorType((2, 4, 4), "float16"), + relax.TensorType((2, 4, 4), "float16"), + relax.TensorType((2, 2, 4), "float16"), ] ), ) _check_inference( bb, relax.op.split(x1, 3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 4, 4), "int8"), - relax.TensorStructInfo((2, 4, 4), "int8"), - relax.TensorStructInfo((2, 2, 4), "int8"), + relax.TensorType((2, 4, 4), "int8"), + relax.TensorType((2, 4, 4), "int8"), + relax.TensorType((2, 2, 4), "int8"), ] ), ) -def test_split_infer_struct_info_single_output(): +def test_split_infer_ty_single_output(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((a, b))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType((a, b))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + s2 = relax.Var("s", relax.ShapeType()) x0 = relax.Var("x", R.Tensor((a, b), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) x2 = relax.Var("x", R.Tensor("float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorType(s0, "float32")) + x4 = relax.Var("x", relax.TensorType(s1, "float32")) + x5 = relax.Var("x", relax.TensorType(s2, "float32")) _check_inference( bb, relax.op.split(x0, [], axis=1), - relax.TensorStructInfo((a, b), "float32"), + relax.TensorType((a, b), "float32"), ) _check_inference( bb, relax.op.split(x1, [], axis=1), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.split(x2, [], axis=1), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.split(x3, [], axis=1), - relax.TensorStructInfo(s0, "float32"), + relax.TensorType(s0, "float32"), ) _check_inference( bb, relax.op.split(x4, [], axis=1), - relax.TensorStructInfo(s1, "float32"), + relax.TensorType(s1, "float32"), ) _check_inference( bb, relax.op.split(x5, [], axis=1), - relax.TensorStructInfo(s2, "float32"), + relax.TensorType(s2, "float32"), ) _check_inference( bb, relax.op.split(x0, 1, axis=1), - relax.TensorStructInfo((a, b), "float32"), + relax.TensorType((a, b), "float32"), ) _check_inference( bb, relax.op.split(x1, 1, axis=1), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.split(x2, 1, axis=1), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.split(x3, 1, axis=1), - relax.TensorStructInfo(s0, "float32"), + relax.TensorType(s0, "float32"), ) _check_inference( bb, relax.op.split(x4, 1, axis=1), - relax.TensorStructInfo(s1, "float32"), + relax.TensorType(s1, "float32"), ) _check_inference( bb, relax.op.split(x5, 1, axis=1), - relax.TensorStructInfo(s2, "float32"), + relax.TensorType(s2, "float32"), ) @@ -2187,7 +2073,7 @@ def test_split_indices_or_sections_int64(): assert split1.attrs.indices_or_sections.dtype == "int64" -def test_split_infer_struct_info(): +def test_split_infer_ty(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((16, 4))) @@ -2296,7 +2182,7 @@ def test_split_infer_struct_info(): ) -def test_split_infer_struct_info_non_integer_indices(): +def test_split_infer_ty_non_integer_indices(): bb = relax.BlockBuilder() a = tirx.Var("c", "int64") b = tirx.Var("d", "int64") @@ -2318,7 +2204,7 @@ def test_split_invalid_n_section(): relax.op.split(x, n, axis=1) -def test_split_infer_struct_info_axis_out_of_range(): +def test_split_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -2333,10 +2219,10 @@ def test_split_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.split(x1, 1, axis=-3)) -def test_split_infer_invalid_struct_info_indices(): +def test_split_infer_invalid_ty_indices(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) - v = relax.Var("v", relax.PrimStructInfo("int64")) + v = relax.Var("v", relax.PrimType("int64")) with pytest.raises(TypeError): bb.normalize(relax.op.split(x0, [v], axis=1)) @@ -2344,10 +2230,10 @@ def test_split_infer_invalid_struct_info_indices(): bb.normalize(relax.op.split(x0, v, axis=1)) -def test_split_infer_struct_info_wrong_input_type(): +def test_split_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.split(x0, 1, axis=1)) @@ -2355,7 +2241,7 @@ def test_split_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.split(x1, 1, axis=1)) -def test_broadcast_to_infer_struct_info(): +def test_broadcast_to_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) @@ -2367,31 +2253,31 @@ def test_broadcast_to_infer_struct_info(): x6 = relax.Var("x", R.Tensor((2, 1, 3), "float32", vdev0)) _check_inference( - bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "float32") ) _check_inference( bb, relax.op.broadcast_to(x6, (4, 2, 5, 3)), - relax.TensorStructInfo((4, 2, 5, 3), "float32", vdev0), + relax.TensorType((4, 2, 5, 3), "float32", vdev0), ) _check_inference( - bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "float32") ) _check_inference( - bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "float32") ) _check_inference( - bb, relax.op.broadcast_to(x3, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + bb, relax.op.broadcast_to(x3, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), dtype="") ) _check_inference( - bb, relax.op.broadcast_to(x4, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + bb, relax.op.broadcast_to(x4, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), dtype="") ) _check_inference( - bb, relax.op.broadcast_to(x5, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + bb, relax.op.broadcast_to(x5, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), dtype="") ) -def test_broadcast_to_infer_struct_info_shape_symbolic(): +def test_broadcast_to_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -2403,101 +2289,101 @@ def test_broadcast_to_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.broadcast_to(x0, (a, b, 1, c, d)), - relax.TensorStructInfo((a, b, 1, c, d), "float32"), + relax.TensorType((a, b, 1, c, d), "float32"), ) _check_inference( bb, relax.op.broadcast_to(x1, (a, b, 1, c, d)), - relax.TensorStructInfo((a, b, 1, c, d), dtype=""), + relax.TensorType((a, b, 1, c, d), dtype=""), ) -def test_broadcast_to_infer_struct_info_shape_var(): +def test_broadcast_to_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 1, 3))) + s1 = relax.Var("s", relax.ShapeType(ndim=3)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) _check_inference( - bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "float32") ) _check_inference( - bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "float32") ) _check_inference( - bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "float32") ) -def test_broadcast_to_infer_struct_info_tgt_shape_var(): +def test_broadcast_to_infer_ty_tgt_shape_var(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") c = tirx.Var("c", "int64") d = tirx.Var("d", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((b, 1, 1, d))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType((b, 1, 1, d))) + s1 = relax.Var("s", relax.ShapeType(ndim=4)) + s2 = relax.Var("s", relax.ShapeType()) x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) x2 = relax.Var("x", R.Tensor("float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - stgt0 = relax.Var("stgt", relax.ShapeStructInfo((a, b, 1, c, d))) - stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=5)) - stgt2 = relax.Var("stgt", relax.ShapeStructInfo()) - - _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) - _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) - _check_inference(bb, relax.op.broadcast_to(x2, stgt0), relax.TensorStructInfo(stgt0, "float32")) - _check_inference(bb, relax.op.broadcast_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32")) - _check_inference(bb, relax.op.broadcast_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32")) - _check_inference(bb, relax.op.broadcast_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32")) - _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) - _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) - _check_inference(bb, relax.op.broadcast_to(x2, stgt1), relax.TensorStructInfo(stgt1, "float32")) - _check_inference(bb, relax.op.broadcast_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32")) - _check_inference(bb, relax.op.broadcast_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32")) - _check_inference(bb, relax.op.broadcast_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32")) - _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) - _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) - _check_inference(bb, relax.op.broadcast_to(x2, stgt2), relax.TensorStructInfo(stgt2, "float32")) - _check_inference(bb, relax.op.broadcast_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32")) - _check_inference(bb, relax.op.broadcast_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32")) - _check_inference(bb, relax.op.broadcast_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32")) - - -def test_broadcast_to_infer_struct_info_more_input_dtype(): + x3 = relax.Var("x", relax.TensorType(s0, "float32")) + x4 = relax.Var("x", relax.TensorType(s1, "float32")) + x5 = relax.Var("x", relax.TensorType(s2, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeType((a, b, 1, c, d))) + stgt1 = relax.Var("stgt", relax.ShapeType(ndim=5)) + stgt2 = relax.Var("stgt", relax.ShapeType()) + + _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt2), relax.TensorType(stgt2, "float32")) + + +def test_broadcast_to_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 1, 3), "float16")) x1 = relax.Var("x", R.Tensor((2, 1, 3), "int8")) x2 = relax.Var("x", R.Tensor((2, 1, 3), "int32")) _check_inference( - bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float16") + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "float16") ) _check_inference( - bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int8") + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "int8") ) _check_inference( - bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int32") + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorType((4, 2, 5, 3), "int32") ) -def test_broadcast_to_infer_struct_info_tgt_ndim_less_than_old_ndim(): +def test_broadcast_to_infer_ty_tgt_ndim_less_than_old_ndim(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 1))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s0 = relax.Var("s", relax.ShapeType((2, 1))) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) x0 = relax.Var("x", R.Tensor((2, 1), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2,))) - stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=1)) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeType((2,))) + stgt1 = relax.Var("stgt", relax.ShapeType(ndim=1)) with pytest.raises(ValueError): bb.normalize(relax.op.broadcast_to(x0, (2,))) @@ -2525,12 +2411,12 @@ def test_broadcast_to_infer_struct_info_tgt_ndim_less_than_old_ndim(): bb.normalize(relax.op.broadcast_to(x3, stgt1)) -def test_broadcast_to_infer_struct_info_not_broadcastable_static(): +def test_broadcast_to_infer_ty_not_broadcastable_static(): bb = relax.BlockBuilder() - s = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) + s = relax.Var("s", relax.ShapeType((2, 1, 3))) x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) - stgt = relax.Var("stgt", relax.ShapeStructInfo((2, 1, 6))) + x1 = relax.Var("x", relax.TensorType(s, "float32")) + stgt = relax.Var("stgt", relax.ShapeType((2, 1, 6))) with pytest.raises(ValueError): bb.normalize(relax.op.broadcast_to(x0, (2, 1, 6))) @@ -2542,48 +2428,36 @@ def test_broadcast_to_infer_struct_info_not_broadcastable_static(): bb.normalize(relax.op.broadcast_to(x1, stgt)) -def test_broadcast_to_infer_struct_info_not_broadcastable_symbolic(): +def test_broadcast_to_infer_ty_not_broadcastable_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") - s = relax.Var("s", relax.ShapeStructInfo((2, a))) + s = relax.Var("s", relax.ShapeType((2, a))) x0 = relax.Var("x", R.Tensor((2, a), "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) - stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2, b))) - stgt1 = relax.Var("stgt", relax.ShapeStructInfo((2, 1))) - stgt2 = relax.Var("stgt", relax.ShapeStructInfo((b, a))) - - _check_inference( - bb, relax.op.broadcast_to(x0, (2, b)), relax.TensorStructInfo((2, b), "float32") - ) - _check_inference( - bb, relax.op.broadcast_to(x0, (2, 1)), relax.TensorStructInfo((2, 1), "float32") - ) - _check_inference( - bb, relax.op.broadcast_to(x0, (b, a)), relax.TensorStructInfo((b, a), "float32") - ) - _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) - _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) - _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) - _check_inference( - bb, relax.op.broadcast_to(x1, (2, b)), relax.TensorStructInfo((2, b), "float32") - ) - _check_inference( - bb, relax.op.broadcast_to(x1, (2, 1)), relax.TensorStructInfo((2, 1), "float32") - ) - _check_inference( - bb, relax.op.broadcast_to(x1, (b, a)), relax.TensorStructInfo((b, a), "float32") - ) - _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) - _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) - _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) - - -def test_broadcast_to_infer_struct_info_wrong_input_type(): - bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 1, 3))) + x1 = relax.Var("x", relax.TensorType(s, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeType((2, b))) + stgt1 = relax.Var("stgt", relax.ShapeType((2, 1))) + stgt2 = relax.Var("stgt", relax.ShapeType((b, a))) + + _check_inference(bb, relax.op.broadcast_to(x0, (2, b)), relax.TensorType((2, b), "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, (2, 1)), relax.TensorType((2, 1), "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, (b, a)), relax.TensorType((b, a), "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, (2, b)), relax.TensorType((2, b), "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, (2, 1)), relax.TensorType((2, 1), "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, (b, a)), relax.TensorType((b, a), "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorType(stgt2, "float32")) + + +def test_broadcast_to_infer_ty_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeType((2, 1, 3))) x1 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) - stgt = relax.Var("stgt", relax.TensorStructInfo((4, 2, 5, 3), dtype="")) + stgt = relax.Var("stgt", relax.TensorType((4, 2, 5, 3), dtype="")) with pytest.raises(TypeError): bb.normalize(relax.op.broadcast_to(x0, (4, 2, 5, 3))) @@ -2591,7 +2465,7 @@ def test_broadcast_to_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.broadcast_to(x1, stgt)) -def test_collapse_sum_like_infer_struct_info(): +def test_collapse_sum_like_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -2608,42 +2482,30 @@ def test_collapse_sum_like_infer_struct_info(): y5 = relax.Var("y", R.Tensor((1, 4))) y6 = relax.Var("y", R.Tensor((3, 4), "float32", vdev0)) + _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorType((3, 4), "float32")) _check_inference( - bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_like(x3, y6), relax.TensorStructInfo((3, 4), "float32", vdev0) - ) - _check_inference( - bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=2) + bb, relax.op.collapse_sum_like(x3, y6), relax.TensorType((3, 4), "float32", vdev0) ) _check_inference( - bb, relax.op.collapse_sum_like(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2) + bb, relax.op.collapse_sum_like(x1, y1), relax.TensorType(dtype="float32", ndim=2) ) _check_inference( - bb, relax.op.collapse_sum_like(x0, y2), relax.TensorStructInfo(dtype="float32", ndim=-1) + bb, relax.op.collapse_sum_like(x0, y1), relax.TensorType(dtype="float32", ndim=2) ) _check_inference( - bb, relax.op.collapse_sum_like(x0, y3), relax.TensorStructInfo((3, 4), "float32") + bb, relax.op.collapse_sum_like(x0, y2), relax.TensorType(dtype="float32", ndim=-1) ) + _check_inference(bb, relax.op.collapse_sum_like(x0, y3), relax.TensorType((3, 4), "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x2, y0), relax.TensorType((3, 4), "float32")) _check_inference( - bb, relax.op.collapse_sum_like(x2, y0), relax.TensorStructInfo((3, 4), "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_like(x2, y4), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.collapse_sum_like(x4, y1), relax.TensorStructInfo(dtype="", ndim=2) - ) - _check_inference( - bb, relax.op.collapse_sum_like(x5, y3), relax.TensorStructInfo((3, 4), dtype="") - ) - _check_inference( - bb, relax.op.collapse_sum_like(x0, y5), relax.TensorStructInfo((1, 4), "float32") + bb, relax.op.collapse_sum_like(x2, y4), relax.TensorType(dtype="float32", ndim=2) ) + _check_inference(bb, relax.op.collapse_sum_like(x4, y1), relax.TensorType(dtype="", ndim=2)) + _check_inference(bb, relax.op.collapse_sum_like(x5, y3), relax.TensorType((3, 4), dtype="")) + _check_inference(bb, relax.op.collapse_sum_like(x0, y5), relax.TensorType((1, 4), "float32")) -def test_collapse_sum_like_infer_struct_info_shape_symbolic(): +def test_collapse_sum_like_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -2652,52 +2514,48 @@ def test_collapse_sum_like_infer_struct_info_shape_symbolic(): x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) y1 = relax.Var("x", R.Tensor((1, a + b), "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorType((4, a), "float32")) _check_inference( - bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((4, a), "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((1, a + b), "float32") + bb, relax.op.collapse_sum_like(x1, y1), relax.TensorType((1, a + b), "float32") ) -def test_collapse_sum_like_infer_struct_info_shape_var(): +def test_collapse_sum_like_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s2", relax.ShapeStructInfo()) - s3 = relax.Var("s3", relax.ShapeStructInfo((3, 4))) - s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=2)) - s5 = relax.Var("s5", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - y0 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) - y1 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) - y2 = relax.Var("y", relax.TensorStructInfo(s5, "float32")) + s0 = relax.Var("s0", relax.ShapeType((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeType(ndim=3)) + s2 = relax.Var("s2", relax.ShapeType()) + s3 = relax.Var("s3", relax.ShapeType((3, 4))) + s4 = relax.Var("s4", relax.ShapeType(ndim=2)) + s5 = relax.Var("s5", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + y0 = relax.Var("y", relax.TensorType(s3, "float32")) + y1 = relax.Var("y", relax.TensorType(s4, "float32")) + y2 = relax.Var("y", relax.TensorType(s5, "float32")) - _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo(s3, "float32")) - _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(s4, "float32")) - _check_inference(bb, relax.op.collapse_sum_like(x2, y2), relax.TensorStructInfo(s5, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorType(s3, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorType(s4, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x2, y2), relax.TensorType(s5, "float32")) -def test_collapse_sum_like_infer_struct_info_more_input_dtype(): +def test_collapse_sum_like_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) y0 = relax.Var("y", R.Tensor((3, 4), "float16")) y1 = relax.Var("y", R.Tensor((3, 4), "int8")) - _check_inference( - bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float16") - ) - _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((3, 4), "int8")) + _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorType((3, 4), "float16")) + _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorType((3, 4), "int8")) -def test_collapse_sum_like_infer_struct_info_wrong_input_type(): +def test_collapse_sum_like_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) - x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + x1 = relax.Var("x", relax.ShapeType((4, 5))) + x2 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.collapse_sum_like(x0, x1)) @@ -2706,7 +2564,7 @@ def test_collapse_sum_like_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.collapse_sum_like(x2, x0)) -def test_collapse_sum_like_infer_struct_info_shape_mismatch(): +def test_collapse_sum_like_infer_ty_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32")) @@ -2715,15 +2573,15 @@ def test_collapse_sum_like_infer_struct_info_shape_mismatch(): x1 = relax.Var("z", R.Tensor((3, a, 5), "float32")) y1 = relax.Var("w", R.Tensor((3, b, 5), "float32")) - s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) - s1 = relax.Var("s1", relax.ShapeStructInfo((3, 6, 5))) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s0", relax.ShapeType((3, 4, 5))) + s1 = relax.Var("s1", relax.ShapeType((3, 6, 5))) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + y2 = relax.Var("y", relax.TensorType(s1, "float32")) - s2 = relax.Var("s2", relax.ShapeStructInfo((3, a, 5))) - s3 = relax.Var("s3", relax.ShapeStructInfo((3, b, 5))) - x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + s2 = relax.Var("s2", relax.ShapeType((3, a, 5))) + s3 = relax.Var("s3", relax.ShapeType((3, b, 5))) + x3 = relax.Var("x", relax.TensorType(s2, "float32")) + y3 = relax.Var("y", relax.TensorType(s3, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.collapse_sum_like(x0, y0)) @@ -2738,7 +2596,7 @@ def test_collapse_sum_like_infer_struct_info_shape_mismatch(): bb.normalize(relax.op.collapse_sum_like(x3, y3)) -def test_collapse_sum_to_infer_struct_info(): +def test_collapse_sum_to_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -2747,72 +2605,54 @@ def test_collapse_sum_to_infer_struct_info(): x4 = relax.Var("x", R.Tensor(ndim=3)) x5 = relax.Var("x", R.Tensor()) - _check_inference( - bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x2, (3, 4)), relax.TensorStructInfo((3, 4), "float32") - ) - _check_inference(bb, relax.op.collapse_sum_to(x3, (3, 4)), relax.TensorStructInfo((3, 4), "")) - _check_inference(bb, relax.op.collapse_sum_to(x4, (3, 4)), relax.TensorStructInfo((3, 4), "")) - _check_inference(bb, relax.op.collapse_sum_to(x5, (3, 4)), relax.TensorStructInfo((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorType((3, 4), "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorType((3, 4), "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x2, (3, 4)), relax.TensorType((3, 4), "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x3, (3, 4)), relax.TensorType((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x4, (3, 4)), relax.TensorType((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x5, (3, 4)), relax.TensorType((3, 4), "")) -def test_collapse_sum_to_infer_struct_info_shape_symbolic(): +def test_collapse_sum_to_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x0, (4, a)), relax.TensorType((4, a), "float32")) _check_inference( - bb, relax.op.collapse_sum_to(x0, (4, a)), relax.TensorStructInfo((4, a), "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x1, (1, a + b)), relax.TensorStructInfo((1, a + b), "float32") + bb, relax.op.collapse_sum_to(x1, (1, a + b)), relax.TensorType((1, a + b), "float32") ) -def test_collapse_sum_to_infer_struct_info_shape_var(): +def test_collapse_sum_to_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s2", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - _check_inference( - bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") - ) + s0 = relax.Var("s0", relax.ShapeType((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeType(ndim=3)) + s2 = relax.Var("s2", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorType((3, 4), "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorType((3, 4), "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorType((3, 4), "float32")) -def test_collapse_sum_to_infer_struct_info_more_input_dtype(): +def test_collapse_sum_to_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) - _check_inference( - bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float16") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "int8") - ) + _check_inference(bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorType((3, 4), "float16")) + _check_inference(bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorType((3, 4), "int8")) -def test_collapse_sum_to_infer_struct_info_wrong_input_type(): +def test_collapse_sum_to_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) - x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + x1 = relax.Var("x", relax.ShapeType((4, 5))) + x2 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.collapse_sum_to(x0, x0)) @@ -2824,18 +2664,18 @@ def test_collapse_sum_to_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.collapse_sum_to(x1, x1)) -def test_collapse_sum_to_infer_struct_info_shape_mismatch(): +def test_collapse_sum_to_infer_ty_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") x1 = relax.Var("x", R.Tensor((3, a, 5), "float32")) - s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + s0 = relax.Var("s0", relax.ShapeType((3, 4, 5))) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) - s1 = relax.Var("s1", relax.ShapeStructInfo((3, a, 5))) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s1 = relax.Var("s1", relax.ShapeType((3, a, 5))) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) with pytest.raises(ValueError): bb.normalize(relax.op.collapse_sum_to(x0, (4, 4, 5))) @@ -2850,76 +2690,46 @@ def test_collapse_sum_to_infer_struct_info_shape_mismatch(): bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5))) -def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var(): +def test_collapse_sum_to_infer_type_tgt_shape_var(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") c = tirx.Var("c", "int64") d = tirx.Var("d", "int64") - s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b))) - s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s2", relax.ShapeStructInfo()) + s0 = relax.Var("s0", relax.ShapeType((3, a, b))) + s1 = relax.Var("s1", relax.ShapeType(ndim=3)) + s2 = relax.Var("s2", relax.ShapeType()) x0 = relax.Var("x", R.Tensor((3, a, b), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) x2 = relax.Var("x", R.Tensor("")) - x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - stgt0 = relax.Var("stgt0", relax.ShapeStructInfo((a, b))) - stgt1 = relax.Var("stgt1", relax.ShapeStructInfo(ndim=2)) - stgt2 = relax.Var("stgt2", relax.ShapeStructInfo()) - - _check_inference( - bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32") - ) - _check_inference(bb, relax.op.collapse_sum_to(x2, stgt0), relax.TensorStructInfo(stgt0, "")) - _check_inference( - bb, relax.op.collapse_sum_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32") - ) - _check_inference(bb, relax.op.collapse_sum_to(x2, stgt1), relax.TensorStructInfo(stgt1, "")) - _check_inference( - bb, relax.op.collapse_sum_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32") - ) - _check_inference(bb, relax.op.collapse_sum_to(x2, stgt2), relax.TensorStructInfo(stgt2, "")) - _check_inference( - bb, relax.op.collapse_sum_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32") - ) - _check_inference( - bb, relax.op.collapse_sum_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32") - ) - - -def test_repeat_infer_struct_info(): + x3 = relax.Var("x", relax.TensorType(s0, "float32")) + x4 = relax.Var("x", relax.TensorType(s1, "float32")) + x5 = relax.Var("x", relax.TensorType(s2, "float32")) + stgt0 = relax.Var("stgt0", relax.ShapeType((a, b))) + stgt1 = relax.Var("stgt1", relax.ShapeType(ndim=2)) + stgt2 = relax.Var("stgt2", relax.ShapeType()) + + _check_inference(bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x1, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt0), relax.TensorType(stgt0, "")) + _check_inference(bb, relax.op.collapse_sum_to(x3, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x4, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x5, stgt0), relax.TensorType(stgt0, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x0, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x1, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt1), relax.TensorType(stgt1, "")) + _check_inference(bb, relax.op.collapse_sum_to(x3, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x4, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x5, stgt1), relax.TensorType(stgt1, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x0, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x1, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt2), relax.TensorType(stgt2, "")) + _check_inference(bb, relax.op.collapse_sum_to(x3, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x4, stgt2), relax.TensorType(stgt2, "float32")) + _check_inference(bb, relax.op.collapse_sum_to(x5, stgt2), relax.TensorType(stgt2, "float32")) + + +def test_repeat_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -2933,74 +2743,74 @@ def test_repeat_infer_struct_info(): _check_inference( bb, relax.op.repeat(x0, 2, axis=0), - relax.TensorStructInfo((4, 10, 4), "float32"), + relax.TensorType((4, 10, 4), "float32"), ) _check_inference( bb, relax.op.repeat(x6, 2, axis=0), - relax.TensorStructInfo((4, 10, 4), "float32", vdev0), + relax.TensorType((4, 10, 4), "float32", vdev0), ) _check_inference( bb, relax.op.repeat(x0, 2, axis=-2), - relax.TensorStructInfo((2, 20, 4), "float32"), + relax.TensorType((2, 20, 4), "float32"), ) _check_inference( bb, relax.op.repeat(x0, 2), - relax.TensorStructInfo((160,), "float32"), + relax.TensorType((160,), "float32"), ) _check_inference( bb, relax.op.repeat(x1, 2, axis=0), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.repeat(x1, 2), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) - _check_inference(bb, relax.op.repeat(x2, 2, axis=0), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.repeat(x2, 2), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.repeat(x2, 2, axis=0), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.repeat(x2, 2), relax.TensorType(dtype="float32", ndim=1)) _check_inference( bb, relax.op.repeat(x3, 2, axis=0), - relax.TensorStructInfo((4, 10, 4), dtype=""), + relax.TensorType((4, 10, 4), dtype=""), ) - _check_inference(bb, relax.op.repeat(x4, 2, axis=0), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.repeat(x5, 2, axis=0), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.repeat(x4, 2, axis=0), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.repeat(x5, 2, axis=0), relax.TensorType(dtype="")) -def test_repeat_infer_struct_info_shape_symbolic(): +def test_repeat_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) - _check_inference(bb, relax.op.repeat(x, 2, 0), relax.TensorStructInfo((a * 2, b, c), "float32")) + _check_inference(bb, relax.op.repeat(x, 2, 0), relax.TensorType((a * 2, b, c), "float32")) _check_inference( bb, relax.op.repeat(x, 2, -1), - relax.TensorStructInfo((a, b, c * 2), "float32"), + relax.TensorType((a, b, c * 2), "float32"), ) _check_inference( bb, relax.op.repeat(x, 2), - relax.TensorStructInfo((a * b * c * 2,), "float32"), + relax.TensorType((a * b * c * 2,), "float32"), ) -def test_repeat_infer_struct_info_more_input_dtype(): +def test_repeat_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) - _check_inference(bb, relax.op.repeat(x0, 2, 0), relax.TensorStructInfo((4, 3, 4), "float16")) - _check_inference(bb, relax.op.repeat(x1, 2, 0), relax.TensorStructInfo((4, 3, 4), "int8")) + _check_inference(bb, relax.op.repeat(x0, 2, 0), relax.TensorType((4, 3, 4), "float16")) + _check_inference(bb, relax.op.repeat(x1, 2, 0), relax.TensorType((4, 3, 4), "int8")) -def test_repeat_infer_struct_info_axis_out_of_range(): +def test_repeat_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -3019,22 +2829,22 @@ def test_repeat_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.repeat(x2, 2, -4)) -def test_repeat_return_data_sinfo(): +def test_repeat_return_data_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) x2 = relax.Var("x", R.Tensor("float32")) - _check_inference(bb, relax.op.repeat(x0, 1, 0), x0.struct_info) - _check_inference(bb, relax.op.repeat(x0, 1, -1), x0.struct_info) - _check_inference(bb, relax.op.repeat(x1, 1, 0), x1.struct_info) - _check_inference(bb, relax.op.repeat(x2, 1, 0), x2.struct_info) + _check_inference(bb, relax.op.repeat(x0, 1, 0), x0.ty) + _check_inference(bb, relax.op.repeat(x0, 1, -1), x0.ty) + _check_inference(bb, relax.op.repeat(x1, 1, 0), x1.ty) + _check_inference(bb, relax.op.repeat(x2, 1, 0), x2.ty) -def test_repeat_infer_struct_info_wrong_input_type(): +def test_repeat_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) r1 = tirx.Var("r", "float32") r2 = tirx.StringImm("abc") @@ -3051,7 +2861,7 @@ def test_repeat_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.repeat(x2, r2)) -def test_tile_infer_struct_info(): +def test_tile_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -3065,101 +2875,99 @@ def test_tile_infer_struct_info(): _check_inference( bb, relax.op.tile(x0, 2), - relax.TensorStructInfo((2, 10, 8), "float32"), + relax.TensorType((2, 10, 8), "float32"), ) _check_inference( bb, relax.op.tile(x6, 2), - relax.TensorStructInfo((2, 10, 8), "float32", vdev0), + relax.TensorType((2, 10, 8), "float32", vdev0), ) _check_inference( bb, relax.op.tile(x0, (3, 2)), - relax.TensorStructInfo((2, 30, 8), "float32"), + relax.TensorType((2, 30, 8), "float32"), ) _check_inference( bb, relax.op.tile(x0, (4, 3, 2)), - relax.TensorStructInfo((8, 30, 8), "float32"), + relax.TensorType((8, 30, 8), "float32"), ) _check_inference( bb, relax.op.tile(x0, (5, 4, 3, 2)), - relax.TensorStructInfo((5, 8, 30, 8), "float32"), + relax.TensorType((5, 8, 30, 8), "float32"), ) _check_inference( bb, relax.op.tile(x1, 2), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.tile(x1, (5, 4, 3, 2)), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) - _check_inference(bb, relax.op.tile(x2, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.tile(x2, (5, 4, 3, 2)), relax.TensorType(dtype="float32")) _check_inference( bb, relax.op.tile(x3, 2), - relax.TensorStructInfo((2, 10, 8), dtype=""), + relax.TensorType((2, 10, 8), dtype=""), ) _check_inference( bb, relax.op.tile(x3, (5, 4, 3, 2)), - relax.TensorStructInfo((5, 8, 30, 8), dtype=""), + relax.TensorType((5, 8, 30, 8), dtype=""), ) - _check_inference(bb, relax.op.tile(x4, 2), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.tile(x4, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="", ndim=4)) - _check_inference(bb, relax.op.tile(x5, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.tile(x4, 2), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.tile(x4, (5, 4, 3, 2)), relax.TensorType(dtype="", ndim=4)) + _check_inference(bb, relax.op.tile(x5, (5, 4, 3, 2)), relax.TensorType(dtype="")) -def test_tile_infer_struct_info_shape_symbolic(): +def test_tile_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) - _check_inference(bb, relax.op.tile(x, 2), relax.TensorStructInfo((a, b, c * 2), "float32")) - _check_inference( - bb, relax.op.tile(x, (3, 2)), relax.TensorStructInfo((a, b * 3, c * 2), "float32") - ) + _check_inference(bb, relax.op.tile(x, 2), relax.TensorType((a, b, c * 2), "float32")) + _check_inference(bb, relax.op.tile(x, (3, 2)), relax.TensorType((a, b * 3, c * 2), "float32")) _check_inference( - bb, relax.op.tile(x, (4, 3, 2)), relax.TensorStructInfo((a * 4, b * 3, c * 2), "float32") + bb, relax.op.tile(x, (4, 3, 2)), relax.TensorType((a * 4, b * 3, c * 2), "float32") ) _check_inference( bb, relax.op.tile(x, (5, 4, 3, 2)), - relax.TensorStructInfo((5, a * 4, b * 3, c * 2), "float32"), + relax.TensorType((5, a * 4, b * 3, c * 2), "float32"), ) -def test_tile_infer_struct_info_more_input_dtype(): +def test_tile_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) - _check_inference(bb, relax.op.tile(x0, (3, 2)), relax.TensorStructInfo((2, 9, 8), "float16")) - _check_inference(bb, relax.op.tile(x1, (3, 2)), relax.TensorStructInfo((2, 9, 8), "int8")) + _check_inference(bb, relax.op.tile(x0, (3, 2)), relax.TensorType((2, 9, 8), "float16")) + _check_inference(bb, relax.op.tile(x1, (3, 2)), relax.TensorType((2, 9, 8), "int8")) -def test_tile_return_data_sinfo(): +def test_tile_return_data_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) x2 = relax.Var("x", R.Tensor("float32")) - _check_inference(bb, relax.op.tile(x0, 1), x0.struct_info) - _check_inference(bb, relax.op.tile(x0, (1, 1)), x0.struct_info) - _check_inference(bb, relax.op.tile(x0, (1, 1, 1)), x0.struct_info) - _check_inference(bb, relax.op.tile(x1, 1), x1.struct_info) - _check_inference(bb, relax.op.tile(x2, 1), x2.struct_info) + _check_inference(bb, relax.op.tile(x0, 1), x0.ty) + _check_inference(bb, relax.op.tile(x0, (1, 1)), x0.ty) + _check_inference(bb, relax.op.tile(x0, (1, 1, 1)), x0.ty) + _check_inference(bb, relax.op.tile(x1, 1), x1.ty) + _check_inference(bb, relax.op.tile(x2, 1), x2.ty) -def test_tile_infer_struct_info_wrong_input_type(): +def test_tile_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) r1 = tirx.Var("a", "float32") r2 = tirx.StringImm("abc") @@ -3176,7 +2984,7 @@ def test_tile_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.tile(x2, r2)) -def test_flip_infer_struct_info(): +def test_flip_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -3186,26 +2994,24 @@ def test_flip_infer_struct_info(): x4 = relax.Var("x", R.Tensor(ndim=3)) x5 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0)) - _check_inference(bb, relax.op.flip(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32")) - _check_inference( - bb, relax.op.flip(x5, axis=1), relax.TensorStructInfo((2, 10, 4), "float32", vdev0) - ) + _check_inference(bb, relax.op.flip(x0, axis=1), relax.TensorType((2, 10, 4), "float32")) + _check_inference(bb, relax.op.flip(x5, axis=1), relax.TensorType((2, 10, 4), "float32", vdev0)) _check_inference(bb, relax.op.flip(x1, axis=0), R.Tensor("float16", ndim=3)) _check_inference(bb, relax.op.flip(x2, axis=0), R.Tensor("int32")) _check_inference(bb, relax.op.flip(x3, axis=2), R.Tensor((2, 10, 4))) _check_inference(bb, relax.op.flip(x4, axis=2), R.Tensor(ndim=3)) -def test_flip_infer_struct_info_shape_symbolic(): +def test_flip_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") x = relax.Var("x", R.Tensor((a, b), "float32")) - _check_inference(bb, relax.op.flip(x, axis=0), relax.TensorStructInfo((a, b), "float32")) + _check_inference(bb, relax.op.flip(x, axis=0), relax.TensorType((a, b), "float32")) -def test_flip_infer_struct_info_wrong_inputs(): +def test_flip_infer_ty_wrong_inputs(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -3213,7 +3019,7 @@ def test_flip_infer_struct_info_wrong_inputs(): bb.normalize(relax.op.flip(x0, axis=3)) -def test_gather_elements_infer_struct_info(): +def test_gather_elements_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -3227,39 +3033,39 @@ def test_gather_elements_infer_struct_info(): i4 = relax.Var("i", R.Tensor((2, 3, 4), "int64", vdev0)) _check_inference( - bb, relax.op.gather_elements(x0, i0, axis=1), relax.TensorStructInfo((2, 3, 4), "float32") + bb, relax.op.gather_elements(x0, i0, axis=1), relax.TensorType((2, 3, 4), "float32") ) _check_inference( bb, relax.op.gather_elements(x3, i4, axis=1), - relax.TensorStructInfo((2, 3, 4), "float32", vdev0), + relax.TensorType((2, 3, 4), "float32", vdev0), ) _check_inference( bb, relax.op.gather_elements(x1, i0, axis=1), - relax.TensorStructInfo((2, 3, 4), dtype="float32"), + relax.TensorType((2, 3, 4), dtype="float32"), ) _check_inference( bb, relax.op.gather_elements(x2, i0, axis=0), - relax.TensorStructInfo(dtype="float32", ndim=-1), + relax.TensorType(dtype="float32", ndim=-1), ) _check_inference( - bb, relax.op.gather_elements(x0, i1, axis=1), relax.TensorStructInfo((2, 3, 4), "float32") + bb, relax.op.gather_elements(x0, i1, axis=1), relax.TensorType((2, 3, 4), "float32") ) _check_inference( bb, relax.op.gather_elements(x1, i2, axis=1), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.gather_elements(x2, i3, axis=0), - relax.TensorStructInfo(dtype="float32", ndim=-1), + relax.TensorType(dtype="float32", ndim=-1), ) -def test_gather_elements_infer_struct_info_shape_symbolic(): +def test_gather_elements_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -3267,11 +3073,11 @@ def test_gather_elements_infer_struct_info_shape_symbolic(): i = relax.Var("i", R.Tensor((a, b), "int64")) _check_inference( - bb, relax.op.gather_elements(x, i, axis=1), relax.TensorStructInfo((a, b), "float32") + bb, relax.op.gather_elements(x, i, axis=1), relax.TensorType((a, b), "float32") ) -def test_gather_elements_infer_struct_info_wrong_inputs(): +def test_gather_elements_infer_ty_wrong_inputs(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -3289,7 +3095,7 @@ def test_gather_elements_infer_struct_info_wrong_inputs(): bb.normalize(relax.op.gather_elements(x0, i2)) -def test_gather_nd_infer_struct_info(): +def test_gather_nd_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -3302,22 +3108,16 @@ def test_gather_nd_infer_struct_info(): i3 = relax.Var("i", R.Tensor(ndim=2)) i4 = relax.Var("i", R.Tensor((2, 2), "int64", vdev0)) - _check_inference(bb, relax.op.gather_nd(x0, i0), relax.TensorStructInfo((2, 4), "float32")) - _check_inference( - bb, relax.op.gather_nd(x3, i4), relax.TensorStructInfo((2, 4), "float32", vdev0) - ) - _check_inference( - bb, relax.op.gather_nd(x1, i0), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.gather_nd(x2, i0), relax.TensorStructInfo(dtype="float32", ndim=-1) - ) - _check_inference(bb, relax.op.gather_nd(x0, i1), relax.TensorStructInfo((2, 4), "float32")) - _check_inference(bb, relax.op.gather_nd(x1, i2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.gather_nd(x2, i3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.gather_nd(x0, i0), relax.TensorType((2, 4), "float32")) + _check_inference(bb, relax.op.gather_nd(x3, i4), relax.TensorType((2, 4), "float32", vdev0)) + _check_inference(bb, relax.op.gather_nd(x1, i0), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.gather_nd(x2, i0), relax.TensorType(dtype="float32", ndim=-1)) + _check_inference(bb, relax.op.gather_nd(x0, i1), relax.TensorType((2, 4), "float32")) + _check_inference(bb, relax.op.gather_nd(x1, i2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.gather_nd(x2, i3), relax.TensorType(dtype="float32")) -def test_gather_nd_infer_struct_info_shape_symbolic(): +def test_gather_nd_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -3325,10 +3125,10 @@ def test_gather_nd_infer_struct_info_shape_symbolic(): x = relax.Var("x", R.Tensor((a, b, c), "float32")) i = relax.Var("i", R.Tensor((2, 2), "int64")) - _check_inference(bb, relax.op.gather_nd(x, i), relax.TensorStructInfo((2, c), "float32")) + _check_inference(bb, relax.op.gather_nd(x, i), relax.TensorType((2, c), "float32")) -def test_gather_nd_infer_struct_info_wrong_inputs(): +def test_gather_nd_infer_ty_wrong_inputs(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) i0 = relax.Var("i", R.Tensor((2, 4), "int64")) # indices too long @@ -3340,7 +3140,7 @@ def test_gather_nd_infer_struct_info_wrong_inputs(): bb.normalize(relax.op.gather_nd(x0, i1)) -def test_scatter_elements_infer_struct_info(): +def test_scatter_elements_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") d0 = relax.Var("data", R.Tensor((4, 4), "float32")) @@ -3357,85 +3157,85 @@ def test_scatter_elements_infer_struct_info(): _check_inference( bb, relax.op.scatter_elements(d0, i0, u0, 0, "updates"), - relax.TensorStructInfo((4, 4), dtype="float32"), + relax.TensorType((4, 4), dtype="float32"), ) _check_inference( bb, relax.op.scatter_elements(d3, i4, u1, 0, "updates"), - relax.TensorStructInfo((4, 4), dtype="float32", vdevice=vdev0), + relax.TensorType((4, 4), dtype="float32", vdevice=vdev0), ) _check_inference( bb, relax.op.scatter_elements(d1, i0, u0, 0, "updates"), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.scatter_elements(d2, i0, u0, 0, "updates"), - relax.TensorStructInfo(dtype="float32", ndim=-1), + relax.TensorType(dtype="float32", ndim=-1), ) _check_inference( bb, relax.op.scatter_elements(d0, i1, u0, 0, "updates"), - relax.TensorStructInfo((4, 4), dtype="float32"), + relax.TensorType((4, 4), dtype="float32"), ) _check_inference( bb, relax.op.scatter_elements(d1, i1, u0, 0, "updates"), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.scatter_elements(d2, i1, u0, 0, "updates"), - relax.TensorStructInfo(dtype="float32", ndim=-1), + relax.TensorType(dtype="float32", ndim=-1), ) _check_inference( bb, relax.op.scatter_elements(d0, i2, u0, 0, "updates"), - relax.TensorStructInfo((4, 4), dtype="float32"), + relax.TensorType((4, 4), dtype="float32"), ) _check_inference( bb, relax.op.scatter_elements(d1, i2, u0, 0, "updates"), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.scatter_elements(d2, i2, u0, 0, "updates"), - relax.TensorStructInfo(dtype="float32", ndim=-1), + relax.TensorType(dtype="float32", ndim=-1), ) _check_inference( bb, relax.op.scatter_elements(d0, i3, u0, 0, "updates"), - relax.TensorStructInfo((4, 4), dtype="float32"), + relax.TensorType((4, 4), dtype="float32"), ) _check_inference( bb, relax.op.scatter_elements(d1, i3, u0, 0, "updates"), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.scatter_elements(d2, i3, u0, 0, "updates"), - relax.TensorStructInfo(dtype="float32", ndim=-1), + relax.TensorType(dtype="float32", ndim=-1), ) # Test with unknown dtype for data d_unknown = relax.Var("data", R.Tensor((4, 4))) _check_inference( bb, relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"), - relax.TensorStructInfo((4, 4), dtype=""), + relax.TensorType((4, 4), dtype=""), ) # Test with unknown dtype for updates u_unknown = relax.Var("updates", R.Tensor((2, 2))) _check_inference( bb, relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"), - relax.TensorStructInfo((4, 4), dtype="float32"), + relax.TensorType((4, 4), dtype="float32"), ) -def test_scatter_elements_infer_struct_info_symbolic_shape(): +def test_scatter_elements_infer_ty_symbolic_shape(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -3452,16 +3252,16 @@ def test_scatter_elements_infer_struct_info_symbolic_shape(): _check_inference( bb, relax.op.scatter_elements(d0, i0, u0, 0, "updates"), - relax.TensorStructInfo((a, b), dtype="float32"), + relax.TensorType((a, b), dtype="float32"), ) _check_inference( bb, relax.op.scatter_elements(d0, i0, u1, 0, "updates"), - relax.TensorStructInfo((a, b), dtype="float32"), + relax.TensorType((a, b), dtype="float32"), ) -def test_scatter_elements_infer_struct_info_wrong_indices_type(): +def test_scatter_elements_infer_ty_wrong_indices_type(): bb = relax.BlockBuilder() d0 = relax.Var("data", R.Tensor((4, 4), "float32")) i0 = relax.Var("indices", R.Tensor((2, 2), "float32")) @@ -3471,7 +3271,7 @@ def test_scatter_elements_infer_struct_info_wrong_indices_type(): bb.normalize(relax.op.scatter_elements(d0, i0, u0)) -def test_scatter_elements_infer_struct_info_rank_shape_mismatch(): +def test_scatter_elements_infer_ty_rank_shape_mismatch(): a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -3502,7 +3302,7 @@ def test_scatter_elements_infer_struct_info_rank_shape_mismatch(): bb.normalize(relax.op.scatter_elements(d0, i0, u4)) -def test_scatter_nd_infer_struct_info(): +def test_scatter_nd_infer_ty(): bb = relax.BlockBuilder() d0 = relax.Var("data", R.Tensor((8,), "float32")) @@ -3512,7 +3312,7 @@ def test_scatter_nd_infer_struct_info(): _check_inference( bb, relax.op.scatter_nd(d0, i0, u0, "update"), - relax.TensorStructInfo((8,), dtype="float32"), + relax.TensorType((8,), dtype="float32"), ) d1 = relax.Var("data", R.Tensor((4, 4, 4), "float32")) @@ -3522,11 +3322,11 @@ def test_scatter_nd_infer_struct_info(): _check_inference( bb, relax.op.scatter_nd(d1, i1, u1, "update"), - relax.TensorStructInfo((4, 4, 4), dtype="float32"), + relax.TensorType((4, 4, 4), dtype="float32"), ) -def test_meshgrid_infer_struct_info(): +def test_meshgrid_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") t0 = relax.Var("t0", R.Tensor((3,), "float32")) @@ -3537,18 +3337,16 @@ def test_meshgrid_infer_struct_info(): _check_inference( bb, relax.op.meshgrid((t0, t1), indexing="ij"), - relax.TupleStructInfo( - [relax.TensorStructInfo((3, 4), "float32"), relax.TensorStructInfo((3, 4), "float32")] - ), + relax.TupleType([relax.TensorType((3, 4), "float32"), relax.TensorType((3, 4), "float32")]), ) _check_inference( bb, relax.op.meshgrid((t3, t1), indexing="ij"), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((5, 4), "float32", vdev0), - relax.TensorStructInfo((5, 4), "float32", vdev0), + relax.TensorType((5, 4), "float32", vdev0), + relax.TensorType((5, 4), "float32", vdev0), ] ), ) @@ -3556,10 +3354,10 @@ def test_meshgrid_infer_struct_info(): _check_inference( bb, relax.op.meshgrid((t2, t1), indexing="xy"), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=2), - relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), + relax.TensorType(dtype="float32", ndim=2), ] ), ) @@ -3567,11 +3365,11 @@ def test_meshgrid_infer_struct_info(): _check_inference( bb, relax.op.meshgrid((t0,), indexing="ij"), - relax.TupleStructInfo([relax.TensorStructInfo((3,), "float32")]), + relax.TupleType([relax.TensorType((3,), "float32")]), ) -def test_one_hot_infer_struct_info(): +def test_one_hot_infer_ty(): bb = relax.BlockBuilder() # Test case 1: Basic usage @@ -3579,7 +3377,7 @@ def test_one_hot_infer_struct_info(): _check_inference( bb, relax.op.one_hot(i0, relax.PrimValue(1.0), relax.PrimValue(0.0), 5), - relax.TensorStructInfo((3, 5), "float32"), + relax.TensorType((3, 5), "float32"), ) # Test case 2: With specified axis @@ -3587,7 +3385,7 @@ def test_one_hot_infer_struct_info(): _check_inference( bb, relax.op.one_hot(i1, relax.PrimValue(1), relax.PrimValue(0), 3, axis=1), - relax.TensorStructInfo((2, 3, 2), "int64"), + relax.TensorType((2, 3, 2), "int64"), ) # Test case 3: With symbolic shape @@ -3596,7 +3394,7 @@ def test_one_hot_infer_struct_info(): _check_inference( bb, relax.op.one_hot(i2, relax.PrimValue(1.0), relax.PrimValue(0.0), 4), - relax.TensorStructInfo((n, 4), "float32"), + relax.TensorType((n, 4), "float32"), ) # Test case 4: With unknown shape @@ -3604,7 +3402,7 @@ def test_one_hot_infer_struct_info(): _check_inference( bb, relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0.0), 6), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) # Test case 5: With different on_value and off_value dtypes diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index baa63797481c..db2cf3a66249 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -129,35 +129,35 @@ def _check_call(expr, op_name: str): def test_vm_alloc_tensor(): bb = rx.BlockBuilder() - storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32")) + storage = rx.Var("storage", rx.TensorType(dtype="float32")) alloc = rx.op.vm.alloc_tensor(storage, offset=0, shape=rx.ShapeExpr([4, 5]), dtype="float32") alloc = bb.normalize(alloc) - tvm.ir.assert_structural_equal(alloc.struct_info, R.Tensor([4, 5], "float32")) + tvm.ir.assert_structural_equal(alloc.ty, R.Tensor([4, 5], "float32")) -def test_vm_alloc_tensor_infer_struct_info(): +def test_vm_alloc_tensor_infer_ty(): bb = rx.BlockBuilder() s1 = rx.Var("s", R.Shape(ndim=3)) - storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32")) + storage = rx.Var("storage", rx.TensorType(dtype="float32")) alloc = rx.op.vm.alloc_tensor(storage, offset=0, shape=s1, dtype="float32") ret = bb.normalize(alloc) - tvm.ir.assert_structural_equal(ret.struct_info, R.Tensor(dtype="float32", ndim=3)) + tvm.ir.assert_structural_equal(ret.ty, R.Tensor(dtype="float32", ndim=3)) def test_vm_kill_object(): bb = rx.BlockBuilder() - storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32")) + storage = rx.Var("storage", rx.TensorType(dtype="float32")) kill = rx.op.vm.kill_object(storage) ret = bb.normalize(kill) - tvm.ir.assert_structural_equal(ret.struct_info, R.Tuple([])) + tvm.ir.assert_structural_equal(ret.ty, R.Tuple([])) def test_builtin_stop_lift_params(): bb = rx.BlockBuilder() - x = rx.Var("x", rx.TensorStructInfo(shape=[4, 5], dtype="float32")) + x = rx.Var("x", rx.TensorType(shape=[4, 5], dtype="float32")) x1 = rx.op.builtin.stop_lift_params(x) x1 = bb.normalize(x1) - tvm.ir.assert_structural_equal(x1.struct_info, R.Tensor([4, 5], "float32")) + tvm.ir.assert_structural_equal(x1.ty, R.Tensor([4, 5], "float32")) if __name__ == "__main__": diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index f362c6bb5f77..af4d2c25fc5e 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -57,12 +57,12 @@ def test_op_correctness(): ) -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_linear_unit_infer_struct_info(): +def test_linear_unit_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -73,61 +73,61 @@ def test_linear_unit_infer_struct_info(): x5 = relax.Var("x", R.Tensor((3, 4))) x6 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0)) - _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.nn.relu(x6), relax.TensorStructInfo((2, 3), "float32", vdev0)) - _check_inference(bb, relax.op.nn.relu6(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.nn.relu6(x6), relax.TensorStructInfo((2, 3), "float32", vdev0)) - _check_inference(bb, relax.op.nn.silu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, relax.op.nn.gelu(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.nn.relu6(x3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.nn.gelu(x4), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3, 4), dtype="")) - _check_inference(bb, relax.op.nn.softplus(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.nn.softplus(x5), relax.TensorStructInfo((3, 4), dtype="")) + _check_inference(bb, relax.op.nn.relu(x0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.nn.relu(x6), relax.TensorType((2, 3), "float32", vdev0)) + _check_inference(bb, relax.op.nn.relu6(x0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.nn.relu6(x6), relax.TensorType((2, 3), "float32", vdev0)) + _check_inference(bb, relax.op.nn.silu(x1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.gelu(x2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.nn.relu(x3), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.relu6(x3), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.gelu(x4), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorType((3, 4), dtype="")) + _check_inference(bb, relax.op.nn.softplus(x0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.nn.softplus(x5), relax.TensorType((3, 4), dtype="")) -def test_linear_unit_infer_struct_info_shape_symbolic(): +def test_linear_unit_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) - _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), "float32")) - _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorStructInfo((4, n), "float32")) - _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4, n), "float32")) - _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.nn.silu(x0), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorType((4, n), "float32")) + _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorType((4, n), "float32")) + _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorType((4, n), "float32")) + _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorType((4, n), "float32")) -def test_linear_unit_infer_struct_info_shape_var(): +def test_linear_unit_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) - _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorType(s1, "float32")) -def test_linear_unit_infer_struct_info_more_input_dtype(): +def test_linear_unit_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) x2 = relax.Var("x", R.Tensor((2, 3), "int64")) - _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float64")) - _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.nn.relu(x2), relax.TensorStructInfo((2, 3), "int64")) + _check_inference(bb, relax.op.nn.relu(x0), relax.TensorType((2, 3), "float64")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorType((2, 3), "int8")) + _check_inference(bb, relax.op.nn.relu(x2), relax.TensorType((2, 3), "int64")) -def test_linear_unit_infer_struct_info_invalid_input_dtype(): +def test_linear_unit_infer_ty_invalid_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "int8")) x1 = relax.Var("x", R.Tensor((2, 3), "int64")) @@ -138,10 +138,10 @@ def test_linear_unit_infer_struct_info_invalid_input_dtype(): bb.normalize(relax.op.nn.silu(x1)) -def test_linear_unit_infer_struct_info_wrong_input_type(): +def test_linear_unit_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.gelu(x0)) @@ -149,7 +149,7 @@ def test_linear_unit_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.silu(x1)) -def test_softmax_log_softmax_infer_struct_info(): +def test_softmax_log_softmax_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -160,75 +160,65 @@ def test_softmax_log_softmax_infer_struct_info(): x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0)) x6 = relax.Var("x", R.Tensor((2, 3), "bfloat16")) - _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.nn.softmax(x5), relax.TensorStructInfo((2, 3), "float32", vdev0)) - _check_inference( - bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference(bb, relax.op.nn.softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.nn.softmax(x5), relax.TensorType((2, 3), "float32", vdev0)) + _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.softmax(x2, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), relax.TensorType(dtype="")) - _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.nn.log_softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32") - ) - _check_inference( - bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="") - ) - _check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.nn.softmax(x6), relax.TensorStructInfo((2, 3), dtype="bfloat16")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorType((2, 3), "float32")) _check_inference( - bb, relax.op.nn.log_softmax(x6), relax.TensorStructInfo((2, 3), dtype="bfloat16") + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorType(dtype="float32", ndim=3) ) + _check_inference(bb, relax.op.nn.log_softmax(x2, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.nn.softmax(x6), relax.TensorType((2, 3), dtype="bfloat16")) + _check_inference(bb, relax.op.nn.log_softmax(x6), relax.TensorType((2, 3), dtype="bfloat16")) -def test_softmax_log_softmax_infer_struct_info_shape_symbolic(): +def test_softmax_log_softmax_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) - _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorType((4, n), "float32")) - _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((m, n), "float32")) - _check_inference( - bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32") - ) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorType((4, n), "float32")) -def test_softmax_log_softmax_infer_struct_info_shape_var(): +def test_softmax_log_softmax_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) - _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorType(s1, "float32")) - _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorType(s1, "float32")) -def test_softmax_log_softmax_infer_struct_info_more_input_dtype(): +def test_softmax_log_softmax_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float16")) x1 = relax.Var("x", R.Tensor((2, 3), "float64")) - _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float16")) - _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorType((2, 3), "float64")) - _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float16")) - _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorType((2, 3), "float16")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorType((2, 3), "float64")) -def test_softmax_log_softmax_infer_struct_info_invalid_input_dtype(): +def test_softmax_log_softmax_infer_ty_invalid_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "int8")) x1 = relax.Var("x", R.Tensor((2, 3), "int64")) @@ -243,7 +233,7 @@ def test_softmax_log_softmax_infer_struct_info_invalid_input_dtype(): bb.normalize(relax.op.nn.log_softmax(x1)) -def test_softmax_log_softmax_infer_struct_info_axis_out_of_range(): +def test_softmax_log_softmax_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -270,10 +260,10 @@ def test_softmax_log_softmax_wrong_with_multiple_axes(): relax.op.nn.log_softmax(x, axis=[-1, -2, -3]) -def test_softmax_log_softmax_infer_struct_info_wrong_input_type(): +def test_softmax_log_softmax_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.softmax(x0)) @@ -285,7 +275,7 @@ def test_softmax_log_softmax_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.log_softmax(x1)) -def test_batch_norm_infer_struct_info(): +def test_batch_norm_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -306,105 +296,105 @@ def test_batch_norm_infer_struct_info(): _check_inference( bb, relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 28, 28), "float32"), - relax.TensorStructInfo((3,), "float32"), - relax.TensorStructInfo((3,), "float32"), + relax.TensorType((2, 3, 28, 28), "float32"), + relax.TensorType((3,), "float32"), + relax.TensorType((3,), "float32"), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=-3), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 28, 28), "float32"), - relax.TensorStructInfo((3,), "float32"), - relax.TensorStructInfo((3,), "float32"), + relax.TensorType((2, 3, 28, 28), "float32"), + relax.TensorType((3,), "float32"), + relax.TensorType((3,), "float32"), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x1, gamma0, beta0, moving_mean0, moving_var0, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=4), - relax.TensorStructInfo((3,), "float32"), - relax.TensorStructInfo((3,), "float32"), + relax.TensorType(dtype="float32", ndim=4), + relax.TensorType((3,), "float32"), + relax.TensorType((3,), "float32"), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, moving_var0, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 28, 28), "float32"), - relax.TensorStructInfo((3,), "float32"), - relax.TensorStructInfo((3,), "float32"), + relax.TensorType((2, 3, 28, 28), "float32"), + relax.TensorType((3,), "float32"), + relax.TensorType((3,), "float32"), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var1, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 28, 28), "float32"), - relax.TensorStructInfo((3,), "float32"), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType((2, 3, 28, 28), "float32"), + relax.TensorType((3,), "float32"), + relax.TensorType(dtype="float32", ndim=1), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x1, gamma1, beta0, moving_mean0, moving_var1, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=4), - relax.TensorStructInfo((3,), "float32"), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=4), + relax.TensorType((3,), "float32"), + relax.TensorType(dtype="float32", ndim=1), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x2, gamma1, beta0, moving_mean0, moving_var1, axis=1, momentum=0.1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo((3,), "float32"), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32"), + relax.TensorType((3,), "float32"), + relax.TensorType(dtype="float32", ndim=1), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x3, gamma2, beta1, moving_mean1, moving_var2, axis=1, momentum=0.1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(ndim=4, dtype=""), - relax.TensorStructInfo((3,), dtype=""), - relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorType(ndim=4, dtype=""), + relax.TensorType((3,), dtype=""), + relax.TensorType(dtype="", ndim=1), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x4, gamma2, beta1, moving_mean1, moving_var2, axis=1, momentum=0.1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype=""), - relax.TensorStructInfo((3,), dtype=""), - relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorType(dtype=""), + relax.TensorType((3,), dtype=""), + relax.TensorType(dtype="", ndim=1), ] ), ) -def test_batch_norm_infer_struct_info_shape_symbolic(): +def test_batch_norm_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c0 = tirx.Var("c", "int64") @@ -426,120 +416,120 @@ def test_batch_norm_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var0, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((n, c0, h, w), "float32"), - relax.TensorStructInfo((c0,), "float32"), - relax.TensorStructInfo((c0,), "float32"), + relax.TensorType((n, c0, h, w), "float32"), + relax.TensorType((c0,), "float32"), + relax.TensorType((c0,), "float32"), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var0, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=4), - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x2, gamma0, beta, moving_mean, moving_var0, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=4), - relax.TensorStructInfo((c0,), "float32"), - relax.TensorStructInfo((c0,), "float32"), + relax.TensorType(dtype="float32", ndim=4), + relax.TensorType((c0,), "float32"), + relax.TensorType((c0,), "float32"), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var0, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=4), - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=4), - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x0, gamma2, beta, moving_mean, moving_var0, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((n, c0, h, w), "float32"), - relax.TensorStructInfo((c0,), "float32"), - relax.TensorStructInfo((c0,), "float32"), + relax.TensorType((n, c0, h, w), "float32"), + relax.TensorType((c0,), "float32"), + relax.TensorType((c0,), "float32"), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var2, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((n, c0, h, w), "float32"), - relax.TensorStructInfo((c0,), "float32"), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType((n, c0, h, w), "float32"), + relax.TensorType((c0,), "float32"), + relax.TensorType(dtype="float32", ndim=1), ] ), ) -def test_batch_norm_infer_struct_info_shape_var(): +def test_batch_norm_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s1", relax.ShapeStructInfo()) - s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) - s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) - beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) - moving_mean = relax.Var("moving_mean", relax.TensorStructInfo(s2, "float32")) - moving_var = relax.Var("moving_var", relax.TensorStructInfo(s3, "float32")) + s0 = relax.Var("s0", relax.ShapeType(ndim=4)) + s1 = relax.Var("s1", relax.ShapeType()) + s2 = relax.Var("s2", relax.ShapeType(ndim=1)) + s3 = relax.Var("s3", relax.ShapeType(ndim=1)) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorType(s2, "float32")) + beta = relax.Var("beta", relax.TensorType(s3, "float32")) + moving_mean = relax.Var("moving_mean", relax.TensorType(s2, "float32")) + moving_var = relax.Var("moving_var", relax.TensorType(s3, "float32")) _check_inference( bb, relax.op.nn.batch_norm(x0, gamma, beta, moving_mean, moving_var, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(s0, "float32"), - relax.TensorStructInfo(s2, "float32"), - relax.TensorStructInfo(s3, "float32"), + relax.TensorType(s0, "float32"), + relax.TensorType(s2, "float32"), + relax.TensorType(s3, "float32"), ] ), ) _check_inference( bb, relax.op.nn.batch_norm(x1, gamma, beta, moving_mean, moving_var, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(s1, "float32"), - relax.TensorStructInfo(s2, "float32"), - relax.TensorStructInfo(s3, "float32"), + relax.TensorType(s1, "float32"), + relax.TensorType(s2, "float32"), + relax.TensorType(s3, "float32"), ] ), ) -def test_batch_norm_infer_struct_info_more_input_dtype(): +def test_batch_norm_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) gamma = relax.Var("gamma", R.Tensor((3,), "float16")) @@ -550,17 +540,17 @@ def test_batch_norm_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3, 28, 28), "float16"), - relax.TensorStructInfo((3,), "float16"), - relax.TensorStructInfo((3,), "float16"), + relax.TensorType((2, 3, 28, 28), "float16"), + relax.TensorType((3,), "float16"), + relax.TensorType((3,), "float16"), ] ), ) -def test_batch_norm_infer_struct_info_invalid_input_dtype(): +def test_batch_norm_infer_ty_invalid_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) gamma0 = relax.Var("gamma", R.Tensor((3,), "int8")) @@ -579,7 +569,7 @@ def test_batch_norm_infer_struct_info_invalid_input_dtype(): bb.normalize(relax.op.nn.batch_norm(x1, gamma1, beta1, moving_mean1, moving_var1, axis=1)) -def test_batch_norm_infer_struct_info_axis_out_of_range(): +def test_batch_norm_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) gamma = relax.Var("gamma", R.Tensor((3,), "float32")) @@ -593,7 +583,7 @@ def test_batch_norm_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=-5)) -def test_batch_norm_infer_struct_info_dtype_mismatch(): +def test_batch_norm_infer_ty_dtype_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) @@ -612,7 +602,7 @@ def test_batch_norm_infer_struct_info_dtype_mismatch(): bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, axis=1)) -def test_batch_norm_infer_struct_info_ndim_mismatch(): +def test_batch_norm_infer_ty_ndim_mismatch(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) @@ -628,7 +618,7 @@ def test_batch_norm_infer_struct_info_ndim_mismatch(): bb.normalize(relax.op.nn.batch_norm(x, gamma0, beta, moving_mean, moving_var1, axis=1)) -def test_batch_norm_infer_struct_info_shape_mismatch(): +def test_batch_norm_infer_ty_shape_mismatch(): bb = relax.BlockBuilder() c = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) @@ -652,12 +642,12 @@ def test_batch_norm_infer_struct_info_shape_mismatch(): bb.normalize(relax.op.nn.batch_norm(x1, gamma2, beta1, moving_mean1, moving_var2, axis=1)) -def test_batch_norm_infer_struct_info_wrong_input_type(): +def test_batch_norm_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.ShapeType((2, 3, 28, 28))) gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) - gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((3,), "float32"))) + gamma1 = relax.Var("gamma", relax.FuncType([], R.Tensor((3,), "float32"))) beta = relax.Var("beta", R.Tensor((3,), "float32")) moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) @@ -668,7 +658,7 @@ def test_batch_norm_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var, axis=1)) -def test_layer_norm_infer_struct_info(): +def test_layer_norm_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -683,36 +673,36 @@ def test_layer_norm_infer_struct_info(): _check_inference( bb, relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]), - relax.TensorStructInfo((2, 3, 4, 5), "float32"), + relax.TensorType((2, 3, 4, 5), "float32"), ) _check_inference( bb, relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, 3]), - relax.TensorStructInfo((2, 3, 4, 5), "float32"), + relax.TensorType((2, 3, 4, 5), "float32"), ) _check_inference( bb, relax.op.nn.layer_norm(x1, gamma0, beta0, axes=[-2, -1]), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.layer_norm(x2, gamma0, beta0, axes=[-2, -1]), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.nn.layer_norm(x0, gamma1, beta0, axes=[-2, -1]), - relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + relax.TensorType((2, 3, 4, 5), dtype="float32"), ) _check_inference( bb, relax.op.nn.layer_norm(x3, gamma2, beta1, axes=[-2, -1]), - relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + relax.TensorType((2, 3, 4, 5), dtype=""), ) -def test_layer_norm_infer_struct_info_shape_symbolic(): +def test_layer_norm_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") a = tirx.Var("a", "int64") @@ -729,54 +719,54 @@ def test_layer_norm_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.layer_norm(x0, gamma0, beta, axes=[-2, -1]), - relax.TensorStructInfo((n, a, b, c0), "float32"), + relax.TensorType((n, a, b, c0), "float32"), ) _check_inference( bb, relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1]), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1]), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.layer_norm(x2, gamma0, beta, axes=[-2, -1]), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.layer_norm(x2, gamma1, beta, axes=[-2, -1]), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) -def test_layer_norm_infer_struct_info_shape_var(): +def test_layer_norm_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s1", relax.ShapeStructInfo()) - s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=2)) - s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=2)) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) - beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + s0 = relax.Var("s0", relax.ShapeType(ndim=4)) + s1 = relax.Var("s1", relax.ShapeType()) + s2 = relax.Var("s2", relax.ShapeType(ndim=2)) + s3 = relax.Var("s3", relax.ShapeType(ndim=2)) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorType(s2, "float32")) + beta = relax.Var("beta", relax.TensorType(s3, "float32")) _check_inference( bb, relax.op.nn.layer_norm(x0, gamma, beta, axes=[2, 3]), - relax.TensorStructInfo(s0, "float32"), + relax.TensorType(s0, "float32"), ) _check_inference( bb, relax.op.nn.layer_norm(x1, gamma, beta, axes=[2, 3]), - relax.TensorStructInfo(s1, "float32"), + relax.TensorType(s1, "float32"), ) -def test_layer_norm_infer_struct_info_more_input_dtype(): +def test_layer_norm_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float16")) @@ -788,16 +778,16 @@ def test_layer_norm_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]), - relax.TensorStructInfo((2, 3, 4, 5), "float16"), + relax.TensorType((2, 3, 4, 5), "float16"), ) _check_inference( bb, relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1]), - relax.TensorStructInfo((2, 3, 4, 5), "float64"), + relax.TensorType((2, 3, 4, 5), "float64"), ) -def test_layer_norm_infer_struct_info_invalid_input_dtype(): +def test_layer_norm_infer_ty_invalid_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) gamma0 = relax.Var("gamma", R.Tensor((4, 5), "int8")) @@ -812,7 +802,7 @@ def test_layer_norm_infer_struct_info_invalid_input_dtype(): bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1])) -def test_layer_norm_infer_struct_info_axis_out_of_range_and_repetitive(): +def test_layer_norm_infer_ty_axis_out_of_range_and_repetitive(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) gamma = relax.Var("gamma", R.Tensor((4, 5), "float32")) @@ -824,7 +814,7 @@ def test_layer_norm_infer_struct_info_axis_out_of_range_and_repetitive(): bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, -1])) -def test_layer_norm_infer_struct_info_dtype_mismatch(): +def test_layer_norm_infer_ty_dtype_mismatch(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) @@ -838,7 +828,7 @@ def test_layer_norm_infer_struct_info_dtype_mismatch(): bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1])) -def test_layer_norm_infer_struct_info_ndim_mismatch(): +def test_layer_norm_infer_ty_ndim_mismatch(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) @@ -852,7 +842,7 @@ def test_layer_norm_infer_struct_info_ndim_mismatch(): bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1])) -def test_layer_norm_infer_struct_info_shape_mismatch(): +def test_layer_norm_infer_ty_shape_mismatch(): bb = relax.BlockBuilder() c0 = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -868,12 +858,12 @@ def test_layer_norm_infer_struct_info_shape_mismatch(): bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1])) -def test_layer_norm_infer_struct_info_wrong_input_type(): +def test_layer_norm_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) - gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + gamma1 = relax.Var("gamma", relax.FuncType([], R.Tensor((4, 5), "float32"))) beta = relax.Var("beta", R.Tensor((4, 5), "float32")) with pytest.raises(TypeError): @@ -882,7 +872,7 @@ def test_layer_norm_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1])) -def test_group_norm_infer_struct_info(): +def test_group_norm_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -897,36 +887,36 @@ def test_group_norm_infer_struct_info(): _check_inference( bb, relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), - relax.TensorStructInfo((2, 3, 4, 5), "float32"), + relax.TensorType((2, 3, 4, 5), "float32"), ) _check_inference( bb, relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), - relax.TensorStructInfo((2, 3, 4, 5), "float32"), + relax.TensorType((2, 3, 4, 5), "float32"), ) _check_inference( bb, relax.op.nn.group_norm(x1, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.group_norm(x2, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.nn.group_norm(x0, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-1]), - relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + relax.TensorType((2, 3, 4, 5), dtype="float32"), ) _check_inference( bb, relax.op.nn.group_norm(x3, gamma2, beta1, num_groups=2, channel_axis=-2, axes=[-1]), - relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + relax.TensorType((2, 3, 4, 5), dtype=""), ) -def test_group_norm_infer_struct_info_shape_symbolic(): +def test_group_norm_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") a = tirx.Var("a", "int64") @@ -943,54 +933,54 @@ def test_group_norm_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.group_norm(x0, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), - relax.TensorStructInfo((n, a, b, c0), "float32"), + relax.TensorType((n, a, b, c0), "float32"), ) _check_inference( bb, relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), - relax.TensorStructInfo((n, a, b, c1), "float32"), + relax.TensorType((n, a, b, c1), "float32"), ) _check_inference( bb, relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), - relax.TensorStructInfo((n, a, b, c0), "float32"), + relax.TensorType((n, a, b, c0), "float32"), ) _check_inference( bb, relax.op.nn.group_norm(x2, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.group_norm(x2, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) -def test_group_norm_infer_struct_info_shape_var(): +def test_group_norm_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s1", relax.ShapeStructInfo()) - s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) - s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) - beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + s0 = relax.Var("s0", relax.ShapeType(ndim=4)) + s1 = relax.Var("s1", relax.ShapeType()) + s2 = relax.Var("s2", relax.ShapeType(ndim=1)) + s3 = relax.Var("s3", relax.ShapeType(ndim=1)) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorType(s2, "float32")) + beta = relax.Var("beta", relax.TensorType(s3, "float32")) _check_inference( bb, relax.op.nn.group_norm(x0, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), - relax.TensorStructInfo(s0, "float32"), + relax.TensorType(s0, "float32"), ) _check_inference( bb, relax.op.nn.group_norm(x1, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), - relax.TensorStructInfo(s1, "float32"), + relax.TensorType(s1, "float32"), ) -def test_group_norm_infer_struct_info_more_input_dtype(): +def test_group_norm_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) gamma0 = relax.Var("gamma", R.Tensor((3,), "float16")) @@ -1002,16 +992,16 @@ def test_group_norm_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=3, channel_axis=1, axes=[-2, -1]), - relax.TensorStructInfo((2, 3, 4, 5), "float16"), + relax.TensorType((2, 3, 4, 5), "float16"), ) _check_inference( bb, relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=3, channel_axis=1, axes=[-2, -1]), - relax.TensorStructInfo((2, 3, 4, 5), "float64"), + relax.TensorType((2, 3, 4, 5), "float64"), ) -def test_group_norm_infer_struct_info_invalid_input_dtype(): +def test_group_norm_infer_ty_invalid_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) gamma0 = relax.Var("gamma", R.Tensor((4,), "int8")) @@ -1030,7 +1020,7 @@ def test_group_norm_infer_struct_info_invalid_input_dtype(): ) -def test_group_norm_infer_struct_info_axis_out_of_range_and_repetitive(): +def test_group_norm_infer_ty_axis_out_of_range_and_repetitive(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) gamma = relax.Var("gamma", R.Tensor((4,), "float32")) @@ -1046,7 +1036,7 @@ def test_group_norm_infer_struct_info_axis_out_of_range_and_repetitive(): ) -def test_group_norm_infer_struct_info_dtype_mismatch(): +def test_group_norm_infer_ty_dtype_mismatch(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) @@ -1064,7 +1054,7 @@ def test_group_norm_infer_struct_info_dtype_mismatch(): ) -def test_group_norm_infer_struct_info_ndim_mismatch(): +def test_group_norm_infer_ty_ndim_mismatch(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) @@ -1082,7 +1072,7 @@ def test_group_norm_infer_struct_info_ndim_mismatch(): ) -def test_group_norm_infer_struct_info_shape_mismatch(): +def test_group_norm_infer_ty_shape_mismatch(): bb = relax.BlockBuilder() c0 = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -1102,12 +1092,12 @@ def test_group_norm_infer_struct_info_shape_mismatch(): ) -def test_group_norm_infer_struct_info_wrong_input_type(): +def test_group_norm_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) - gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + gamma1 = relax.Var("gamma", relax.FuncType([], R.Tensor((4, 5), "float32"))) beta = relax.Var("beta", R.Tensor((4, 5), "float32")) with pytest.raises(TypeError): @@ -1120,7 +1110,7 @@ def test_group_norm_infer_struct_info_wrong_input_type(): ) -def test_dropout_infer_struct_info(): +def test_dropout_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -1133,52 +1123,46 @@ def test_dropout_infer_struct_info(): _check_inference( bb, relax.op.nn.dropout(x0), - relax.TupleStructInfo( - [relax.TensorStructInfo((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32")] - ), + relax.TupleType([relax.TensorType((2, 3), "float32"), relax.TensorType((2, 3), "float32")]), ) _check_inference( bb, relax.op.nn.dropout(x5), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 3), "float32", vdev0), - relax.TensorStructInfo((2, 3), "float32", vdev0), + relax.TensorType((2, 3), "float32", vdev0), + relax.TensorType((2, 3), "float32", vdev0), ] ), ) _check_inference( bb, relax.op.nn.dropout(x1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ] ), ) _check_inference( bb, relax.op.nn.dropout(x2), - relax.TupleStructInfo( - [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] - ), + relax.TupleType([relax.TensorType(dtype="float32"), relax.TensorType(dtype="float32")]), ) _check_inference( bb, relax.op.nn.dropout(x3), - relax.TupleStructInfo( - [relax.TensorStructInfo((2, 3), dtype=""), relax.TensorStructInfo((2, 3), dtype="")] - ), + relax.TupleType([relax.TensorType((2, 3), dtype=""), relax.TensorType((2, 3), dtype="")]), ) _check_inference( bb, relax.op.nn.dropout(x4), - relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), + relax.TupleType([relax.TensorType(dtype=""), relax.TensorType(dtype="")]), ) -def test_dropout_infer_struct_info_shape_symbolic(): +def test_dropout_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -1187,36 +1171,30 @@ def test_dropout_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.dropout(x), - relax.TupleStructInfo( - [relax.TensorStructInfo((m, n), "float32"), relax.TensorStructInfo((m, n), "float32")] - ), + relax.TupleType([relax.TensorType((m, n), "float32"), relax.TensorType((m, n), "float32")]), ) -def test_dropout_infer_struct_info_shape_var(): +def test_dropout_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) _check_inference( bb, relax.op.nn.dropout(x0), - relax.TupleStructInfo( - [relax.TensorStructInfo(s0, "float32"), relax.TensorStructInfo(s0, "float32")] - ), + relax.TupleType([relax.TensorType(s0, "float32"), relax.TensorType(s0, "float32")]), ) _check_inference( bb, relax.op.nn.dropout(x1), - relax.TupleStructInfo( - [relax.TensorStructInfo(s1, "float32"), relax.TensorStructInfo(s1, "float32")] - ), + relax.TupleType([relax.TensorType(s1, "float32"), relax.TensorType(s1, "float32")]), ) -def test_dropout_infer_struct_info_more_input_dtype(): +def test_dropout_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) @@ -1225,30 +1203,24 @@ def test_dropout_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.nn.dropout(x0), - relax.TupleStructInfo( - [relax.TensorStructInfo((2, 3), "float64"), relax.TensorStructInfo((2, 3), "float64")] - ), + relax.TupleType([relax.TensorType((2, 3), "float64"), relax.TensorType((2, 3), "float64")]), ) _check_inference( bb, relax.op.nn.dropout(x1), - relax.TupleStructInfo( - [relax.TensorStructInfo((2, 3), "int8"), relax.TensorStructInfo((2, 3), "int8")] - ), + relax.TupleType([relax.TensorType((2, 3), "int8"), relax.TensorType((2, 3), "int8")]), ) _check_inference( bb, relax.op.nn.dropout(x2), - relax.TupleStructInfo( - [relax.TensorStructInfo((2, 3), "int64"), relax.TensorStructInfo((2, 3), "int64")] - ), + relax.TupleType([relax.TensorType((2, 3), "int64"), relax.TensorType((2, 3), "int64")]), ) -def test_dropout_infer_struct_info_wrong_input_type(): +def test_dropout_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.dropout(x0)) @@ -1256,7 +1228,7 @@ def test_dropout_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.dropout(x1)) -def test_cross_entropy_infer_struct_info(): +def test_cross_entropy_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -1266,22 +1238,22 @@ def test_cross_entropy_infer_struct_info(): y3 = relax.Var("y", R.Tensor(ndim=2)) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorType((), "float32") ) _check_inference( bb, relax.op.nn.cross_entropy_with_logits(x, y1), - relax.TensorStructInfo((), dtype="float32"), + relax.TensorType((), dtype="float32"), ) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x, y2), relax.TensorStructInfo((), dtype="") + bb, relax.op.nn.cross_entropy_with_logits(x, y2), relax.TensorType((), dtype="") ) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x, y3), relax.TensorStructInfo((), dtype="") + bb, relax.op.nn.cross_entropy_with_logits(x, y3), relax.TensorType((), dtype="") ) -def test_cross_entropy_infer_struct_info_shape_symbolic(): +def test_cross_entropy_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m0 = tirx.Var("m", "int64") m1 = tirx.Var("m", "int64") @@ -1291,30 +1263,30 @@ def test_cross_entropy_infer_struct_info_shape_symbolic(): y = relax.Var("y", R.Tensor((m0, n), "float32")) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x0, y), relax.TensorStructInfo((), "float32") + bb, relax.op.nn.cross_entropy_with_logits(x0, y), relax.TensorType((), "float32") ) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x1, y), relax.TensorStructInfo((), "float32") + bb, relax.op.nn.cross_entropy_with_logits(x1, y), relax.TensorType((), "float32") ) -def test_cross_entropy_infer_struct_info_shape_var(): +def test_cross_entropy_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - y0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - y1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + x = relax.Var("x", relax.TensorType(s0, "float32")) + y0 = relax.Var("x", relax.TensorType(s0, "float32")) + y1 = relax.Var("x", relax.TensorType(s1, "float32")) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorType((), "float32") ) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x, y1), relax.TensorStructInfo((), "float32") + bb, relax.op.nn.cross_entropy_with_logits(x, y1), relax.TensorType((), "float32") ) -def test_cross_entropy_infer_struct_info_more_input_dtype(): +def test_cross_entropy_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float16")) y0 = relax.Var("y", R.Tensor((2, 3), "float16")) @@ -1324,14 +1296,14 @@ def test_cross_entropy_infer_struct_info_more_input_dtype(): y2 = relax.Var("y", R.Tensor((2, 3), "int32")) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x0, y0), relax.TensorStructInfo((), "float16") + bb, relax.op.nn.cross_entropy_with_logits(x0, y0), relax.TensorType((), "float16") ) _check_inference( - bb, relax.op.nn.cross_entropy_with_logits(x1, y1), relax.TensorStructInfo((), "int8") + bb, relax.op.nn.cross_entropy_with_logits(x1, y1), relax.TensorType((), "int8") ) -def test_cross_entropy_infer_struct_info_wrong_ndim(): +def test_cross_entropy_infer_ty_wrong_ndim(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -1344,7 +1316,7 @@ def test_cross_entropy_infer_struct_info_wrong_ndim(): bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y1)) -def test_cross_entropy_infer_struct_info_shape_mismatch(): +def test_cross_entropy_infer_ty_shape_mismatch(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") x0 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -1354,10 +1326,10 @@ def test_cross_entropy_infer_struct_info_shape_mismatch(): bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y0)) -def test_cross_entropy_infer_struct_info_wrong_input_type(): +def test_cross_entropy_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) y = relax.Var("y", R.Tensor((2, 3), "float32")) with pytest.raises(TypeError): @@ -1366,7 +1338,7 @@ def test_cross_entropy_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y)) -def test_nll_loss_infer_struct_info(): +def test_nll_loss_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) @@ -1392,133 +1364,133 @@ def test_nll_loss_infer_struct_info(): _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x1, y0, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x2, y0, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x3, y0, w0, reduction="mean"), - relax.TensorStructInfo((), ""), + relax.TensorType((), ""), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y1, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y2, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y3, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w1, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w2, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w3, reduction="mean"), - relax.TensorStructInfo((), ""), + relax.TensorType((), ""), ) _check_inference( bb, relax.op.nn.nll_loss(x4, y4, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x5, y5, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) # reduction=sum is totally the same as mean. Just need one test to ensure they behave the same _check_inference( - bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="sum"), relax.TensorStructInfo((), "float32") + bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="sum"), relax.TensorType((), "float32") ) # reduction=none _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), - relax.TensorStructInfo((3, 10, 10), "float32"), + relax.TensorType((3, 10, 10), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x1, y0, w0, reduction="none"), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.nn.nll_loss(x2, y0, w0, reduction="none"), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x3, y0, w0, reduction="none"), - relax.TensorStructInfo((3, 10, 10), ""), + relax.TensorType((3, 10, 10), ""), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y1, w0, reduction="none"), - relax.TensorStructInfo((3, 10, 10), "float32"), + relax.TensorType((3, 10, 10), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y2, w0, reduction="none"), - relax.TensorStructInfo((3, 10, 10), "float32"), + relax.TensorType((3, 10, 10), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y3, w0, reduction="none"), - relax.TensorStructInfo((3, 10, 10), "float32"), + relax.TensorType((3, 10, 10), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w1, reduction="none"), - relax.TensorStructInfo((3, 10, 10), "float32"), + relax.TensorType((3, 10, 10), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w2, reduction="none"), - relax.TensorStructInfo((3, 10, 10), "float32"), + relax.TensorType((3, 10, 10), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w3, reduction="none"), - relax.TensorStructInfo((3, 10, 10), ""), + relax.TensorType((3, 10, 10), ""), ) _check_inference( bb, relax.op.nn.nll_loss(x4, y4, w0, reduction="none"), - relax.TensorStructInfo((3,), "float32"), # (N,) + relax.TensorType((3,), "float32"), # (N,) ) _check_inference( bb, relax.op.nn.nll_loss(x5, y5, w0, reduction="none"), - relax.TensorStructInfo((), "float32"), # () + relax.TensorType((), "float32"), # () ) -def test_nll_loss_infer_struct_info_shape_symbolic(): +def test_nll_loss_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() N = tirx.Var("N", "int64") C = tirx.Var("C", "int64") @@ -1538,87 +1510,87 @@ def test_nll_loss_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), - relax.TensorStructInfo((N, d1, d2), "float32"), + relax.TensorType((N, d1, d2), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x1, y1, w0, reduction="none"), - relax.TensorStructInfo((N,), "float32"), + relax.TensorType((N,), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x2, y2, w0, reduction="none"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x3, y3, w0, reduction="none"), - relax.TensorStructInfo((3, d1, 2), "float32"), + relax.TensorType((3, d1, 2), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x3, y3, w1, reduction="none"), - relax.TensorStructInfo((3, d1, 2), "float32"), + relax.TensorType((3, d1, 2), "float32"), ) -def test_nll_loss_infer_struct_info_shape_var(): +def test_nll_loss_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s0", relax.ShapeStructInfo((3, 5, 10, 10))) - s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=4)) - s2 = relax.Var("s2", relax.ShapeStructInfo()) - s3 = relax.Var("s3", relax.ShapeStructInfo((3, 10, 10))) - s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=3)) - s5 = relax.Var("s5", relax.ShapeStructInfo((5,))) - s6 = relax.Var("s6", relax.ShapeStructInfo(ndim=1)) + s0 = relax.Var("s0", relax.ShapeType((3, 5, 10, 10))) + s1 = relax.Var("s1", relax.ShapeType(ndim=4)) + s2 = relax.Var("s2", relax.ShapeType()) + s3 = relax.Var("s3", relax.ShapeType((3, 10, 10))) + s4 = relax.Var("s4", relax.ShapeType(ndim=3)) + s5 = relax.Var("s5", relax.ShapeType((5,))) + s6 = relax.Var("s6", relax.ShapeType(ndim=1)) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - y0 = relax.Var("y", relax.TensorStructInfo(s3, "int64")) - y1 = relax.Var("y", relax.TensorStructInfo(s4, "int64")) - w0 = relax.Var("w", relax.TensorStructInfo(s5, "float32")) - w1 = relax.Var("w", relax.TensorStructInfo(s6, "float32")) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + y0 = relax.Var("y", relax.TensorType(s3, "int64")) + y1 = relax.Var("y", relax.TensorType(s4, "int64")) + w0 = relax.Var("w", relax.TensorType(s5, "float32")) + w1 = relax.Var("w", relax.TensorType(s6, "float32")) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.nn.nll_loss(x1, y0, w0, reduction="none"), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.nn.nll_loss(x2, y0, w0, reduction="none"), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y1, w0, reduction="none"), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w1, reduction="none"), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) -def test_nll_loss_infer_struct_info_no_weights(): +def test_nll_loss_infer_ty_no_weights(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) y = relax.Var("x", R.Tensor((3, 10, 10), "int64")) @@ -1626,16 +1598,16 @@ def test_nll_loss_infer_struct_info_no_weights(): _check_inference( bb, relax.op.nn.nll_loss(x, y, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x, y, reduction="none"), - relax.TensorStructInfo((3, 10, 10), "float32"), + relax.TensorType((3, 10, 10), "float32"), ) -def test_nll_loss_infer_struct_info_no_weights_symbolic(): +def test_nll_loss_infer_ty_no_weights_symbolic(): N = tirx.Var("N", "int64") C = tirx.Var("C", "int64") d1 = tirx.Var("d", "int64") @@ -1647,26 +1619,26 @@ def test_nll_loss_infer_struct_info_no_weights_symbolic(): _check_inference( bb, relax.op.nn.nll_loss(x, y, reduction="mean"), - relax.TensorStructInfo((), "float32"), + relax.TensorType((), "float32"), ) _check_inference( bb, relax.op.nn.nll_loss(x, y, reduction="none"), - relax.TensorStructInfo((N, d1, d2), "float32"), + relax.TensorType((N, d1, d2), "float32"), ) -def test_nll_loss_infer_struct_info_wrong_input_type(): +def test_nll_loss_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x1 = relax.Var("x", relax.ShapeType((2, 3))) + x2 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) y0 = relax.Var("y", R.Tensor((3, 10, 10), "int64")) - y1 = relax.Var("y", relax.ShapeStructInfo((2, 3))) - y2 = relax.Var("y", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y1 = relax.Var("y", relax.ShapeType((2, 3))) + y2 = relax.Var("y", relax.FuncType([], R.Tensor((2, 3), "float32"))) w0 = relax.Var("w", R.Tensor((5,), "float32")) - w1 = relax.Var("w", relax.ShapeStructInfo((2, 3))) - w2 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + w1 = relax.Var("w", relax.ShapeType((2, 3))) + w2 = relax.Var("w", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.nll_loss(x1, y0, w0)) @@ -1682,7 +1654,7 @@ def test_nll_loss_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.nll_loss(x0, y0, w2)) -def test_nll_loss_infer_struct_info_more_input_dtype(): +def test_nll_loss_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float16")) x1 = relax.Var("x", R.Tensor((3, 5, 10, 10), "int8")) @@ -1697,26 +1669,26 @@ def test_nll_loss_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), - relax.TensorStructInfo((), "float16"), + relax.TensorType((), "float16"), ) _check_inference( bb, relax.op.nn.nll_loss(x1, y0, w1, reduction="mean"), - relax.TensorStructInfo((), "int8"), + relax.TensorType((), "int8"), ) _check_inference( bb, relax.op.nn.nll_loss(x2, y0, w2, reduction="mean"), - relax.TensorStructInfo((), "int32"), + relax.TensorType((), "int32"), ) _check_inference( bb, relax.op.nn.nll_loss(x3, y0, w3, reduction="mean"), - relax.TensorStructInfo((), "float64"), + relax.TensorType((), "float64"), ) -def test_nll_loss_infer_struct_info_targets_dtype(): +def test_nll_loss_infer_ty_targets_dtype(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) w = relax.Var("w", R.Tensor((5,), "float32")) @@ -1739,7 +1711,7 @@ def test_nll_loss_infer_struct_info_targets_dtype(): bb.normalize(relax.op.nn.nll_loss(x, targets6, w)) # unknwon dtype -def test_nll_loss_infer_struct_info_ndim_mismatch(): +def test_nll_loss_infer_ty_ndim_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) x1 = relax.Var("x", R.Tensor((3, 5, 10, 10, 10), "float32")) @@ -1765,7 +1737,7 @@ def test_nll_loss_infer_struct_info_ndim_mismatch(): bb.normalize(relax.op.nn.nll_loss(x0, y0, w2)) -def test_nll_loss_infer_struct_info_shape_mismatch(): +def test_nll_loss_infer_ty_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) x1 = relax.Var("x", R.Tensor((3, 6, 10, 10), "float32")) @@ -1791,7 +1763,7 @@ def test_nll_loss_infer_struct_info_shape_mismatch(): bb.normalize(relax.op.nn.nll_loss(x0, y0, w1)) -def test_nll_loss_infer_struct_info_wrong_reduction(): +def test_nll_loss_infer_ty_wrong_reduction(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) y = relax.Var("x", R.Tensor((3, 10, 10), "int64")) @@ -1801,7 +1773,7 @@ def test_nll_loss_infer_struct_info_wrong_reduction(): bb.normalize(relax.op.nn.nll_loss(x, y, w, reduction="foo")) -def test_pad_infer_struct_info(): +def test_pad_infer_ty(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -1810,23 +1782,21 @@ def test_pad_infer_struct_info(): pad_width1 = (1, 1, 1, 1) pad_width2 = (0, 1, 1, 0) - _check_inference(bb, relax.op.nn.pad(x, pad_width0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.nn.pad(x, pad_width0), relax.TensorType((2, 3), "float32")) _check_inference( bb, relax.op.nn.pad(x, pad_width1), - relax.TensorStructInfo((4, 5), dtype="float32"), + relax.TensorType((4, 5), dtype="float32"), ) _check_inference( bb, relax.op.nn.pad(x, pad_width2), - relax.TensorStructInfo((3, 4), dtype="float32"), - ) - _check_inference( - bb, relax.op.nn.pad(x1, pad_width1), relax.TensorStructInfo(dtype="float32", ndim=2) + relax.TensorType((3, 4), dtype="float32"), ) + _check_inference(bb, relax.op.nn.pad(x1, pad_width1), relax.TensorType(dtype="float32", ndim=2)) -def test_pixel_shuffle_infer_struct_info(): +def test_pixel_shuffle_infer_ty(): bb = relax.BlockBuilder() x1 = relax.Var("x1", R.Tensor((1, 8, 10, 15), "float32")) x2 = relax.Var("x2", R.Tensor((2, 6, 18, 5, 4), "float32")) @@ -1835,14 +1805,14 @@ def test_pixel_shuffle_infer_struct_info(): _check_inference( bb, relax.op.nn.pixel_shuffle(x1, upscale_factor1), - relax.TensorStructInfo((1, 2, 20, 30), dtype="float32"), + relax.TensorType((1, 2, 20, 30), dtype="float32"), ) upscale_factor2 = 3 _check_inference( bb, relax.op.nn.pixel_shuffle(x2, upscale_factor2), - relax.TensorStructInfo((2, 6, 2, 15, 12), dtype="float32"), + relax.TensorType((2, 6, 2, 15, 12), dtype="float32"), ) @@ -1851,7 +1821,7 @@ def test_batch_flatten_op_correctness(): assert relax.op.nn.batch_flatten(x).op == Op.get("relax.nn.batch_flatten") -def test_batch_flatten_infer_struct_info(): +def test_batch_flatten_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -1861,21 +1831,15 @@ def test_batch_flatten_infer_struct_info(): x4 = relax.Var("x", R.Tensor((10, 20), "float32")) x5 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0)) - _check_inference(bb, relax.op.nn.batch_flatten(x0), relax.TensorStructInfo((2, 60), "float32")) - _check_inference( - bb, relax.op.nn.batch_flatten(x5), relax.TensorStructInfo((2, 60), "float32", vdev0) - ) - _check_inference( - bb, relax.op.nn.batch_flatten(x1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.nn.batch_flatten(x2), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.nn.batch_flatten(x3), relax.TensorStructInfo((2, 60), dtype="")) - _check_inference(bb, relax.op.nn.batch_flatten(x4), relax.TensorStructInfo((10, 20), "float32")) + _check_inference(bb, relax.op.nn.batch_flatten(x0), relax.TensorType((2, 60), "float32")) + _check_inference(bb, relax.op.nn.batch_flatten(x5), relax.TensorType((2, 60), "float32", vdev0)) + _check_inference(bb, relax.op.nn.batch_flatten(x1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.nn.batch_flatten(x2), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.nn.batch_flatten(x3), relax.TensorType((2, 60), dtype="")) + _check_inference(bb, relax.op.nn.batch_flatten(x4), relax.TensorType((10, 20), "float32")) -def test_batch_flatten_infer_struct_info_shape_symbolic(): +def test_batch_flatten_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -1884,15 +1848,11 @@ def test_batch_flatten_infer_struct_info_shape_symbolic(): x0 = relax.Var("x", R.Tensor((m, n, h, w), "float32")) x1 = relax.Var("x", R.Tensor((4, n, 8, 8), "float32")) - _check_inference( - bb, relax.op.nn.batch_flatten(x0), relax.TensorStructInfo((m, n * h * w), "float32") - ) - _check_inference( - bb, relax.op.nn.batch_flatten(x1), relax.TensorStructInfo((4, n * 8 * 8), "float32") - ) + _check_inference(bb, relax.op.nn.batch_flatten(x0), relax.TensorType((m, n * h * w), "float32")) + _check_inference(bb, relax.op.nn.batch_flatten(x1), relax.TensorType((4, n * 8 * 8), "float32")) -def test_batch_flatten_infer_struct_info_wrong_ndim(): +def test_batch_flatten_infer_ty_wrong_ndim(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3,), "float32")) diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py index 2729b9d2f802..31f5a15059bd 100644 --- a/tests/python/relax/test_op_nn_convolution.py +++ b/tests/python/relax/test_op_nn_convolution.py @@ -45,12 +45,12 @@ def test_conv3d_op_correctness(): assert relax.op.nn.conv3d_transpose(x, wt).op == Op.get("relax.nn.conv3d_transpose") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_conv1d_infer_struct_info(): +def test_conv1d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) @@ -67,81 +67,71 @@ def test_conv1d_infer_struct_info(): w4 = relax.Var("w", R.Tensor((48, 4, 3, 16), "float32")) w5 = relax.Var("w", R.Tensor((4, 3, 3), "float32", vdev0)) - _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float32")) - _check_inference( - bb, relax.op.nn.conv1d(x6, w5), relax.TensorStructInfo((2, 4, 26), "float32", vdev0) - ) + _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorType((2, 4, 26), "float32")) + _check_inference(bb, relax.op.nn.conv1d(x6, w5), relax.TensorType((2, 4, 26), "float32", vdev0)) _check_inference( bb, relax.op.nn.conv1d(x0, w0, out_dtype="float16"), - relax.TensorStructInfo((2, 4, 26), "float16"), + relax.TensorType((2, 4, 26), "float16"), ) _check_inference( - bb, relax.op.nn.conv1d(x0, w0, padding=1), relax.TensorStructInfo((2, 4, 28), "float32") + bb, relax.op.nn.conv1d(x0, w0, padding=1), relax.TensorType((2, 4, 28), "float32") ) _check_inference( bb, relax.op.nn.conv1d(x0, w0, padding=[1, 3]), - relax.TensorStructInfo((2, 4, 30), "float32"), + relax.TensorType((2, 4, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x0, w0, strides=2), - relax.TensorStructInfo((2, 4, 13), "float32"), + relax.TensorType((2, 4, 13), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x0, w0, strides=(2,)), - relax.TensorStructInfo((2, 4, 13), "float32"), + relax.TensorType((2, 4, 13), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x0, w0, dilation=2), - relax.TensorStructInfo((2, 4, 24), "float32"), + relax.TensorType((2, 4, 24), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x0, w0, dilation=(2,)), - relax.TensorStructInfo((2, 4, 24), "float32"), + relax.TensorType((2, 4, 24), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x1, w0, data_layout="NWC"), - relax.TensorStructInfo((2, 26, 4), "float32"), + relax.TensorType((2, 26, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x0, w0, out_layout="NWC"), - relax.TensorStructInfo((2, 26, 4), "float32"), + relax.TensorType((2, 26, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x0, w1, kernel_layout="IOW"), - relax.TensorStructInfo((2, 4, 26), "float32"), + relax.TensorType((2, 4, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv1d( x5, w4, data_layout="NCW16c", kernel_layout="OIW16i", out_layout="NWC16c" ), - relax.TensorStructInfo((2, 26, 3, 16), "float32"), - ) - _check_inference( - bb, relax.op.nn.conv1d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.nn.conv1d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.nn.conv1d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.nn.conv1d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=3) + relax.TensorType((2, 26, 3, 16), "float32"), ) - _check_inference(bb, relax.op.nn.conv1d(x4, w0), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.nn.conv1d(x2, w0), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.conv1d(x3, w0), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.conv1d(x0, w2), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.conv1d(x0, w3), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.conv1d(x4, w0), relax.TensorType(dtype="", ndim=3)) -def test_conv1d_infer_struct_info_shape_symbolic(): +def test_conv1d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -159,58 +149,58 @@ def test_conv1d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.conv1d(x0, w0), - relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x0, w1), - relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x1, w2, data_layout="NCW16c", kernel_layout="OIW16i", out_layout="NCW"), - relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x0, w0, strides=2, padding=1, dilation=2), - relax.TensorStructInfo( + relax.TensorType( (n, ko, tvm.tirx.floordiv(iw + 3, 2) + 1 - kw), "float32", ), ) -def test_conv1d_infer_struct_info_shape_var(): +def test_conv1d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s3 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) - w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=3)) + s1 = relax.Var("s", relax.ShapeType(ndim=4)) + s2 = relax.Var("s", relax.ShapeType(ndim=3)) + s3 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s3, "float32")) + w = relax.Var("w", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.nn.conv1d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.conv1d(x0, w), relax.TensorType(dtype="float32", ndim=3)) _check_inference( bb, relax.op.nn.conv1d(x1, w, data_layout="NCW16c"), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.conv1d(x0, w, out_layout="NCW16c"), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.conv1d(x2, w), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) -def test_conv1d_infer_struct_info_groups(): +def test_conv1d_infer_ty_groups(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32")) x1 = relax.Var("x", R.Tensor((2, 8, 28, 16), "float32")) @@ -218,21 +208,21 @@ def test_conv1d_infer_struct_info_groups(): w1 = relax.Var("w", R.Tensor((48, 2, 3, 8), "float32")) _check_inference( - bb, relax.op.nn.conv1d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26), "float32") + bb, relax.op.nn.conv1d(x0, w0, groups=8), relax.TensorType((2, 48, 26), "float32") ) _check_inference( bb, relax.op.nn.conv1d(x0, w1, kernel_layout="OIW8i", groups=8), - relax.TensorStructInfo((2, 48, 26), "float32"), + relax.TensorType((2, 48, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x1, w0, data_layout="NCW16c", groups=8), - relax.TensorStructInfo((2, 3, 26, 16), "float32"), + relax.TensorType((2, 3, 26, 16), "float32"), ) -def test_conv1d_infer_struct_info_symbolic_groups(): +def test_conv1d_infer_ty_symbolic_groups(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -244,14 +234,14 @@ def test_conv1d_infer_struct_info_symbolic_groups(): _check_inference( bb, relax.op.nn.conv1d(x, w0, groups=4), - relax.TensorStructInfo((n, oc * 4, 26), "float32"), + relax.TensorType((n, oc * 4, 26), "float32"), ) _check_inference( - bb, relax.op.nn.conv1d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26), "float32") + bb, relax.op.nn.conv1d(x, w1, groups=4), relax.TensorType((n, oc, 26), "float32") ) -def test_conv1d_infer_struct_info_input_channel_group_incompatible(): +def test_conv1d_infer_ty_input_channel_group_incompatible(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -267,7 +257,7 @@ def test_conv1d_infer_struct_info_input_channel_group_incompatible(): bb.normalize(relax.op.nn.conv1d(x1, w1, groups=6)) -def test_conv1d_infer_struct_info_output_channel_group_incompatible(): +def test_conv1d_infer_ty_output_channel_group_incompatible(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -293,7 +283,7 @@ def test_conv1d_non_positive_group(): relax.op.nn.conv1d(x, w, groups=-2) -def test_conv1d_infer_struct_info_more_input_dtype(): +def test_conv1d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) w0 = relax.Var("w", R.Tensor((4, 3, 3), "float16")) @@ -304,13 +294,13 @@ def test_conv1d_infer_struct_info_more_input_dtype(): x3 = relax.Var("x", R.Tensor((2, 3, 28), "int32")) w3 = relax.Var("w", R.Tensor((4, 3, 3), "int32")) - _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float16")) - _check_inference(bb, relax.op.nn.conv1d(x1, w1), relax.TensorStructInfo((2, 4, 26), "float64")) - _check_inference(bb, relax.op.nn.conv1d(x2, w2), relax.TensorStructInfo((2, 4, 26), "int8")) - _check_inference(bb, relax.op.nn.conv1d(x3, w3), relax.TensorStructInfo((2, 4, 26), "int32")) + _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorType((2, 4, 26), "float16")) + _check_inference(bb, relax.op.nn.conv1d(x1, w1), relax.TensorType((2, 4, 26), "float64")) + _check_inference(bb, relax.op.nn.conv1d(x2, w2), relax.TensorType((2, 4, 26), "int8")) + _check_inference(bb, relax.op.nn.conv1d(x3, w3), relax.TensorType((2, 4, 26), "int32")) -def test_conv1d_infer_struct_info_mixed_precision(): +def test_conv1d_infer_ty_mixed_precision(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) w0 = relax.Var("w", R.Tensor((4, 3, 3), "float16")) @@ -322,17 +312,17 @@ def test_conv1d_infer_struct_info_mixed_precision(): _check_inference( bb, relax.op.nn.conv1d(x0, w0, out_dtype="float32"), - relax.TensorStructInfo((2, 4, 26), "float32"), + relax.TensorType((2, 4, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv1d(x1, w1, out_dtype="int32"), - relax.TensorStructInfo((2, 4, 26), "int32"), + relax.TensorType((2, 4, 26), "int32"), ) _check_inference( bb, relax.op.nn.conv1d(x2, w2, out_dtype="float32"), - relax.TensorStructInfo((2, 4, 26), "float32"), + relax.TensorType((2, 4, 26), "float32"), ) @@ -371,7 +361,7 @@ def test_conv1d_wrong_strides_padding_dilation_length(): relax.op.nn.conv1d(x, w, dilation=(1, 2)) -def test_conv1d_infer_struct_info_wrong_layout_string(): +def test_conv1d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) @@ -413,12 +403,12 @@ def test_conv1d_wrong_input_ndim(): bb.normalize(relax.op.nn.conv1d(x2, w0)) -def test_conv1d_infer_struct_info_wrong_input_type(): +def test_conv1d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) + x1 = relax.Var("x", relax.ShapeType((2, 3, 28))) w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32")) - w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3), "float32"))) + w1 = relax.Var("w", relax.FuncType([], R.Tensor((4, 3, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.conv1d(x0, w1)) @@ -426,7 +416,7 @@ def test_conv1d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.conv1d(x1, w0)) -def test_conv1d_transpose_infer_struct_info(): +def test_conv1d_transpose_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) @@ -444,88 +434,86 @@ def test_conv1d_transpose_infer_struct_info(): w5 = relax.Var("w", R.Tensor((3, 4, 3), "float32", vdev0)) _check_inference( - bb, relax.op.nn.conv1d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30), "float32") + bb, relax.op.nn.conv1d_transpose(x0, w0), relax.TensorType((2, 4, 30), "float32") ) _check_inference( bb, relax.op.nn.conv1d_transpose(x6, w5), - relax.TensorStructInfo((2, 4, 30), "float32", vdev0), + relax.TensorType((2, 4, 30), "float32", vdev0), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float16"), - relax.TensorStructInfo((2, 4, 30), "float16"), + relax.TensorType((2, 4, 30), "float16"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, padding=1), - relax.TensorStructInfo((2, 4, 28), "float32"), + relax.TensorType((2, 4, 28), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, padding=[1, 3]), - relax.TensorStructInfo((2, 4, 26), "float32"), + relax.TensorType((2, 4, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, strides=3, output_padding=1), - relax.TensorStructInfo((2, 4, 85), "float32"), + relax.TensorType((2, 4, 85), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, strides=2), - relax.TensorStructInfo((2, 4, 57), "float32"), + relax.TensorType((2, 4, 57), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, dilation=2), - relax.TensorStructInfo((2, 4, 32), "float32"), + relax.TensorType((2, 4, 32), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, dilation=(2,)), - relax.TensorStructInfo((2, 4, 32), "float32"), + relax.TensorType((2, 4, 32), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x1, w0, data_layout="NWC"), - relax.TensorStructInfo((2, 30, 4), "float32"), + relax.TensorType((2, 30, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, out_layout="NWC"), - relax.TensorStructInfo((2, 30, 4), "float32"), + relax.TensorType((2, 30, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w1, kernel_layout="OIW"), - relax.TensorStructInfo((2, 4, 30), "float32"), + relax.TensorType((2, 4, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose( x5, w4, data_layout="NCW16c", kernel_layout="IOW16i", out_layout="NWC16c" ), - relax.TensorStructInfo((2, 30, 3, 16), "float32"), + relax.TensorType((2, 30, 3, 16), "float32"), ) _check_inference( - bb, relax.op.nn.conv1d_transpose(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.nn.conv1d_transpose(x2, w0), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.nn.conv1d_transpose(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.nn.conv1d_transpose(x3, w0), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.nn.conv1d_transpose(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=3) + bb, relax.op.nn.conv1d_transpose(x0, w2), relax.TensorType(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.nn.conv1d_transpose(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.nn.conv1d_transpose(x4, w0), relax.TensorStructInfo(dtype="", ndim=3) + bb, relax.op.nn.conv1d_transpose(x0, w3), relax.TensorType(dtype="float32", ndim=3) ) + _check_inference(bb, relax.op.nn.conv1d_transpose(x4, w0), relax.TensorType(dtype="", ndim=3)) -def test_conv1d_transpose_infer_struct_info_shape_symbolic(): +def test_conv1d_transpose_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -543,60 +531,60 @@ def test_conv1d_transpose_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0), - relax.TensorStructInfo((n, ko, iw + kw - 1), "float32"), + relax.TensorType((n, ko, iw + kw - 1), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w1), - relax.TensorStructInfo((n, ko, iw + kw - 1), "float32"), + relax.TensorType((n, ko, iw + kw - 1), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose( x1, w2, data_layout="NCW16c", kernel_layout="IOW16i", out_layout="NCW" ), - relax.TensorStructInfo((n, ko, iw + kw - 1), "float32"), + relax.TensorType((n, ko, iw + kw - 1), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, strides=2, padding=1, dilation=2, output_padding=1), - relax.TensorStructInfo( + relax.TensorType( (n, ko, iw * 2 + kw * 2 - 4), "float32", ), ) -def test_conv1d_transpose_infer_struct_info_shape_var(): +def test_conv1d_transpose_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s3 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) - w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=3)) + s1 = relax.Var("s", relax.ShapeType(ndim=4)) + s2 = relax.Var("s", relax.ShapeType(ndim=3)) + s3 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s3, "float32")) + w = relax.Var("w", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.nn.conv1d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.conv1d(x0, w), relax.TensorType(dtype="float32", ndim=3)) _check_inference( bb, relax.op.nn.conv1d(x1, w, data_layout="NCW16c"), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.conv1d(x0, w, out_layout="NCW16c"), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.conv1d(x2, w), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) -def test_conv1d_transpose_infer_struct_info_groups(): +def test_conv1d_transpose_infer_ty_groups(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32")) x1 = relax.Var("x", R.Tensor((2, 8, 28, 16), "float32")) @@ -606,21 +594,21 @@ def test_conv1d_transpose_infer_struct_info_groups(): _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, groups=8), - relax.TensorStructInfo((2, 48, 30), "float32"), + relax.TensorType((2, 48, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w1, kernel_layout="IOW8i", groups=8), - relax.TensorStructInfo((2, 48, 30), "float32"), + relax.TensorType((2, 48, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x1, w0, data_layout="NCW16c", groups=8), - relax.TensorStructInfo((2, 3, 30, 16), "float32"), + relax.TensorType((2, 3, 30, 16), "float32"), ) -def test_conv1d_transpose_infer_struct_info_symbolic_groups(): +def test_conv1d_transpose_infer_ty_symbolic_groups(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -631,11 +619,11 @@ def test_conv1d_transpose_infer_struct_info_symbolic_groups(): _check_inference( bb, relax.op.nn.conv1d_transpose(x, w0, groups=4), - relax.TensorStructInfo((n, oc * 4, 30), "float32"), + relax.TensorType((n, oc * 4, 30), "float32"), ) -def test_conv1d_transpose_infer_struct_info_input_channel_group_incompatible(): +def test_conv1d_transpose_infer_ty_input_channel_group_incompatible(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -661,7 +649,7 @@ def test_conv1d_transpose_non_positive_group(): relax.op.nn.conv1d_transpose(x, w, groups=-2) -def test_conv1d_transpose_infer_struct_info_more_input_dtype(): +def test_conv1d_transpose_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16")) @@ -673,16 +661,14 @@ def test_conv1d_transpose_infer_struct_info_more_input_dtype(): w3 = relax.Var("w", R.Tensor((3, 4, 3), "int32")) _check_inference( - bb, relax.op.nn.conv1d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30), "float16") + bb, relax.op.nn.conv1d_transpose(x0, w0), relax.TensorType((2, 4, 30), "float16") ) _check_inference( - bb, relax.op.nn.conv1d_transpose(x1, w1), relax.TensorStructInfo((2, 4, 30), "float64") + bb, relax.op.nn.conv1d_transpose(x1, w1), relax.TensorType((2, 4, 30), "float64") ) + _check_inference(bb, relax.op.nn.conv1d_transpose(x2, w2), relax.TensorType((2, 4, 30), "int8")) _check_inference( - bb, relax.op.nn.conv1d_transpose(x2, w2), relax.TensorStructInfo((2, 4, 30), "int8") - ) - _check_inference( - bb, relax.op.nn.conv1d_transpose(x3, w3), relax.TensorStructInfo((2, 4, 30), "int32") + bb, relax.op.nn.conv1d_transpose(x3, w3), relax.TensorType((2, 4, 30), "int32") ) @@ -729,7 +715,7 @@ def test_conv1d_transpose_wrong_strides_padding_dilation_length(): relax.op.nn.conv1d_transpose(x, w, dilation=(1, 2)) -def test_conv1d_transpose_infer_struct_info_wrong_layout_string(): +def test_conv1d_transpose_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) w = relax.Var("w", R.Tensor((3, 4, 3), "float32")) @@ -771,12 +757,12 @@ def test_conv1d_transpose_wrong_input_ndim(): bb.normalize(relax.op.nn.conv1d_transpose(x2, w0)) -def test_conv1d_transpose_infer_struct_info_wrong_input_type(): +def test_conv1d_transpose_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) + x1 = relax.Var("x", relax.ShapeType((2, 3, 28))) w0 = relax.Var("w", R.Tensor((3, 4, 3), "float32")) - w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((3, 4, 3), "float32"))) + w1 = relax.Var("w", relax.FuncType([], R.Tensor((3, 4, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.conv1d_transpose(x0, w1)) @@ -784,7 +770,7 @@ def test_conv1d_transpose_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.conv1d_transpose(x1, w0)) -def test_conv1d_transpose_infer_struct_info_mixed_precision(): +def test_conv1d_transpose_infer_ty_mixed_precision(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16")) @@ -794,16 +780,16 @@ def test_conv1d_transpose_infer_struct_info_mixed_precision(): _check_inference( bb, relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float32"), - relax.TensorStructInfo((2, 4, 30), "float32"), + relax.TensorType((2, 4, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv1d_transpose(x1, w1, out_dtype="int32"), - relax.TensorStructInfo((2, 4, 30), "int32"), + relax.TensorType((2, 4, 30), "int32"), ) -def test_conv2d_infer_struct_info(): +def test_conv2d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) @@ -820,88 +806,78 @@ def test_conv2d_infer_struct_info(): w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 16), "float32")) w5 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32", vdev0)) + _check_inference(bb, relax.op.nn.conv2d(x0, w0), relax.TensorType((2, 4, 26, 26), "float32")) _check_inference( - bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float32") - ) - _check_inference( - bb, relax.op.nn.conv2d(x6, w5), relax.TensorStructInfo((2, 4, 26, 26), "float32", vdev0) + bb, relax.op.nn.conv2d(x6, w5), relax.TensorType((2, 4, 26, 26), "float32", vdev0) ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, out_dtype="float16"), - relax.TensorStructInfo((2, 4, 26, 26), "float16"), + relax.TensorType((2, 4, 26, 26), "float16"), ) _check_inference( - bb, relax.op.nn.conv2d(x0, w0, padding=1), relax.TensorStructInfo((2, 4, 28, 28), "float32") + bb, relax.op.nn.conv2d(x0, w0, padding=1), relax.TensorType((2, 4, 28, 28), "float32") ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, padding=[1, 2]), - relax.TensorStructInfo((2, 4, 28, 30), "float32"), + relax.TensorType((2, 4, 28, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, padding=[1, 2, 3, 4]), - relax.TensorStructInfo((2, 4, 30, 32), "float32"), + relax.TensorType((2, 4, 30, 32), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, strides=2), - relax.TensorStructInfo((2, 4, 13, 13), "float32"), + relax.TensorType((2, 4, 13, 13), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, strides=(2, 3)), - relax.TensorStructInfo((2, 4, 13, 9), "float32"), + relax.TensorType((2, 4, 13, 9), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, dilation=2), - relax.TensorStructInfo((2, 4, 24, 24), "float32"), + relax.TensorType((2, 4, 24, 24), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, dilation=(2, 1)), - relax.TensorStructInfo((2, 4, 24, 26), "float32"), + relax.TensorType((2, 4, 24, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x1, w0, data_layout="NHWC"), - relax.TensorStructInfo((2, 26, 26, 4), "float32"), + relax.TensorType((2, 26, 26, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, out_layout="NHWC"), - relax.TensorStructInfo((2, 26, 26, 4), "float32"), + relax.TensorType((2, 26, 26, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w1, kernel_layout="IOHW"), - relax.TensorStructInfo((2, 4, 26, 26), "float32"), + relax.TensorType((2, 4, 26, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv2d( x5, w4, data_layout="NCHW16c", kernel_layout="OIHW16i", out_layout="NHWC16c" ), - relax.TensorStructInfo((2, 26, 26, 3, 16), "float32"), - ) - _check_inference( - bb, relax.op.nn.conv2d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference( - bb, relax.op.nn.conv2d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + relax.TensorType((2, 26, 26, 3, 16), "float32"), ) - _check_inference( - bb, relax.op.nn.conv2d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference( - bb, relax.op.nn.conv2d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference(bb, relax.op.nn.conv2d(x4, w0), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.conv2d(x2, w0), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.conv2d(x3, w0), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.conv2d(x0, w2), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.conv2d(x0, w3), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.conv2d(x4, w0), relax.TensorType(dtype="", ndim=4)) -def test_conv2d_infer_struct_info_shape_symbolic(): +def test_conv2d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -921,60 +897,60 @@ def test_conv2d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.conv2d(x0, w0), - relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w1), - relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv2d( x1, w2, data_layout="NCHW16c", kernel_layout="OIHW16i", out_layout="NCHW" ), - relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x0, w0, strides=(2, 2), padding=(1, 1), dilation=(2, 2)), - relax.TensorStructInfo( + relax.TensorType( (n, ko, tvm.tirx.floordiv(ih + 3, 2) + 1 - kh, tvm.tirx.floordiv(iw + 3, 2) + 1 - kw), "float32", ), ) -def test_conv2d_infer_struct_info_shape_var(): +def test_conv2d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s3 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) - w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType(ndim=5)) + s2 = relax.Var("s", relax.ShapeType(ndim=4)) + s3 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s3, "float32")) + w = relax.Var("w", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.nn.conv2d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.conv2d(x0, w), relax.TensorType(dtype="float32", ndim=4)) _check_inference( bb, relax.op.nn.conv2d(x1, w, data_layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.nn.conv2d(x0, w, out_layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.nn.conv2d(x2, w), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) -def test_conv2d_infer_struct_info_groups(): +def test_conv2d_infer_ty_groups(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32")) @@ -982,21 +958,21 @@ def test_conv2d_infer_struct_info_groups(): w1 = relax.Var("w", R.Tensor((48, 2, 3, 3, 8), "float32")) _check_inference( - bb, relax.op.nn.conv2d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26, 26), "float32") + bb, relax.op.nn.conv2d(x0, w0, groups=8), relax.TensorType((2, 48, 26, 26), "float32") ) _check_inference( bb, relax.op.nn.conv2d(x0, w1, kernel_layout="OIHW8i", groups=8), - relax.TensorStructInfo((2, 48, 26, 26), "float32"), + relax.TensorType((2, 48, 26, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x1, w0, data_layout="NCHW16c", groups=8), - relax.TensorStructInfo((2, 3, 26, 26, 16), "float32"), + relax.TensorType((2, 3, 26, 26, 16), "float32"), ) -def test_conv2d_infer_struct_info_symbolic_groups(): +def test_conv2d_infer_ty_symbolic_groups(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -1008,14 +984,14 @@ def test_conv2d_infer_struct_info_symbolic_groups(): _check_inference( bb, relax.op.nn.conv2d(x, w0, groups=4), - relax.TensorStructInfo((n, oc * 4, 26, 26), "float32"), + relax.TensorType((n, oc * 4, 26, 26), "float32"), ) _check_inference( - bb, relax.op.nn.conv2d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26, 26), "float32") + bb, relax.op.nn.conv2d(x, w1, groups=4), relax.TensorType((n, oc, 26, 26), "float32") ) -def test_conv2d_infer_struct_info_input_channel_group_incompatible(): +def test_conv2d_infer_ty_input_channel_group_incompatible(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -1031,7 +1007,7 @@ def test_conv2d_infer_struct_info_input_channel_group_incompatible(): bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) -def test_conv2d_infer_struct_info_output_channel_group_incompatible(): +def test_conv2d_infer_ty_output_channel_group_incompatible(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -1057,7 +1033,7 @@ def test_conv2d_non_positive_group(): relax.op.nn.conv2d(x, w, groups=-2) -def test_conv2d_infer_struct_info_more_input_dtype(): +def test_conv2d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16")) @@ -1068,19 +1044,13 @@ def test_conv2d_infer_struct_info_more_input_dtype(): x3 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32")) w3 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int32")) - _check_inference( - bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float16") - ) - _check_inference( - bb, relax.op.nn.conv2d(x1, w1), relax.TensorStructInfo((2, 4, 26, 26), "float64") - ) - _check_inference(bb, relax.op.nn.conv2d(x2, w2), relax.TensorStructInfo((2, 4, 26, 26), "int8")) - _check_inference( - bb, relax.op.nn.conv2d(x3, w3), relax.TensorStructInfo((2, 4, 26, 26), "int32") - ) + _check_inference(bb, relax.op.nn.conv2d(x0, w0), relax.TensorType((2, 4, 26, 26), "float16")) + _check_inference(bb, relax.op.nn.conv2d(x1, w1), relax.TensorType((2, 4, 26, 26), "float64")) + _check_inference(bb, relax.op.nn.conv2d(x2, w2), relax.TensorType((2, 4, 26, 26), "int8")) + _check_inference(bb, relax.op.nn.conv2d(x3, w3), relax.TensorType((2, 4, 26, 26), "int32")) -def test_conv2d_infer_struct_info_mixed_precision(): +def test_conv2d_infer_ty_mixed_precision(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16")) @@ -1092,17 +1062,17 @@ def test_conv2d_infer_struct_info_mixed_precision(): _check_inference( bb, relax.op.nn.conv2d(x0, w0, out_dtype="float32"), - relax.TensorStructInfo((2, 4, 26, 26), "float32"), + relax.TensorType((2, 4, 26, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv2d(x1, w1, out_dtype="int32"), - relax.TensorStructInfo((2, 4, 26, 26), "int32"), + relax.TensorType((2, 4, 26, 26), "int32"), ) _check_inference( bb, relax.op.nn.conv2d(x2, w2, out_dtype="float32"), - relax.TensorStructInfo((2, 4, 26, 26), "float32"), + relax.TensorType((2, 4, 26, 26), "float32"), ) @@ -1145,7 +1115,7 @@ def test_conv2d_wrong_strides_padding_dilation_length(): relax.op.nn.conv2d(x, w, dilation=(1, 2, 3)) -def test_conv2d_infer_struct_info_wrong_layout_string(): +def test_conv2d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) @@ -1187,12 +1157,12 @@ def test_conv2d_wrong_input_ndim(): bb.normalize(relax.op.nn.conv2d(x2, w0)) -def test_conv2d_infer_struct_info_wrong_input_type(): +def test_conv2d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.ShapeType((2, 3, 28, 28))) w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) - w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3, 3), "float32"))) + w1 = relax.Var("w", relax.FuncType([], R.Tensor((4, 3, 3, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.conv2d(x0, w1)) @@ -1200,7 +1170,7 @@ def test_conv2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.conv2d(x1, w0)) -def test_conv2d_transpose_infer_struct_info(): +def test_conv2d_transpose_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) @@ -1218,103 +1188,101 @@ def test_conv2d_transpose_infer_struct_info(): w5 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32", vdev0)) _check_inference( - bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30, 30), "float32") + bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorType((2, 4, 30, 30), "float32") ) _check_inference( bb, relax.op.nn.conv2d_transpose(x6, w5), - relax.TensorStructInfo((2, 4, 30, 30), "float32", vdev0), + relax.TensorType((2, 4, 30, 30), "float32", vdev0), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float16"), - relax.TensorStructInfo((2, 4, 30, 30), "float16"), + relax.TensorType((2, 4, 30, 30), "float16"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, padding=1), - relax.TensorStructInfo((2, 4, 28, 28), "float32"), + relax.TensorType((2, 4, 28, 28), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, padding=[1, 2]), - relax.TensorStructInfo((2, 4, 28, 26), "float32"), + relax.TensorType((2, 4, 28, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, padding=[1, 2, 3, 4]), - relax.TensorStructInfo((2, 4, 26, 24), "float32"), + relax.TensorType((2, 4, 26, 24), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, strides=3, output_padding=1), - relax.TensorStructInfo((2, 4, 85, 85), "float32"), + relax.TensorType((2, 4, 85, 85), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, strides=3, output_padding=[2, 1]), - relax.TensorStructInfo((2, 4, 86, 85), "float32"), + relax.TensorType((2, 4, 86, 85), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, strides=2), - relax.TensorStructInfo((2, 4, 57, 57), "float32"), + relax.TensorType((2, 4, 57, 57), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, strides=(2, 3)), - relax.TensorStructInfo((2, 4, 57, 84), "float32"), + relax.TensorType((2, 4, 57, 84), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, dilation=2), - relax.TensorStructInfo((2, 4, 32, 32), "float32"), + relax.TensorType((2, 4, 32, 32), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, dilation=(2, 1)), - relax.TensorStructInfo((2, 4, 32, 30), "float32"), + relax.TensorType((2, 4, 32, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x1, w0, data_layout="NHWC"), - relax.TensorStructInfo((2, 30, 30, 4), "float32"), + relax.TensorType((2, 30, 30, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, out_layout="NHWC"), - relax.TensorStructInfo((2, 30, 30, 4), "float32"), + relax.TensorType((2, 30, 30, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w1, kernel_layout="OIHW"), - relax.TensorStructInfo((2, 4, 30, 30), "float32"), + relax.TensorType((2, 4, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose( x5, w4, data_layout="NCHW16c", kernel_layout="IOHW16i", out_layout="NHWC16c" ), - relax.TensorStructInfo((2, 30, 30, 3, 16), "float32"), + relax.TensorType((2, 30, 30, 3, 16), "float32"), ) _check_inference( - bb, relax.op.nn.conv2d_transpose(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.nn.conv2d_transpose(x2, w0), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( - bb, relax.op.nn.conv2d_transpose(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.nn.conv2d_transpose(x3, w0), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( - bb, relax.op.nn.conv2d_transpose(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.nn.conv2d_transpose(x0, w2), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( - bb, relax.op.nn.conv2d_transpose(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference( - bb, relax.op.nn.conv2d_transpose(x4, w0), relax.TensorStructInfo(dtype="", ndim=4) + bb, relax.op.nn.conv2d_transpose(x0, w3), relax.TensorType(dtype="float32", ndim=4) ) + _check_inference(bb, relax.op.nn.conv2d_transpose(x4, w0), relax.TensorType(dtype="", ndim=4)) -def test_conv2d_transpose_infer_struct_info_shape_symbolic(): +def test_conv2d_transpose_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1334,64 +1302,64 @@ def test_conv2d_transpose_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0), - relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"), + relax.TensorType((n, ko, ih + kh - 1, iw + kw - 1), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w1), - relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"), + relax.TensorType((n, ko, ih + kh - 1, iw + kw - 1), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose( x1, w2, data_layout="NCHW16c", kernel_layout="IOHW16i", out_layout="NCHW" ), - relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"), + relax.TensorType((n, ko, ih + kh - 1, iw + kw - 1), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose( x0, w0, strides=(2, 2), padding=(1, 1), output_padding=(1, 0), dilation=(2, 2) ), - relax.TensorStructInfo( + relax.TensorType( (n, ko, ih * 2 + kh * 2 - 4, iw * 2 + kw * 2 - 5), "float32", ), ) -def test_conv2d_transpose_infer_struct_info_shape_var(): +def test_conv2d_transpose_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s3 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) - w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType(ndim=5)) + s2 = relax.Var("s", relax.ShapeType(ndim=4)) + s3 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s3, "float32")) + w = relax.Var("w", relax.TensorType(s2, "float32")) _check_inference( - bb, relax.op.nn.conv2d_transpose(x0, w), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.nn.conv2d_transpose(x0, w), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( bb, relax.op.nn.conv2d_transpose(x1, w, data_layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w, out_layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x2, w), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) -def test_conv2d_transpose_infer_struct_info_groups(): +def test_conv2d_transpose_infer_ty_groups(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32")) @@ -1401,21 +1369,21 @@ def test_conv2d_transpose_infer_struct_info_groups(): _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, groups=8), - relax.TensorStructInfo((2, 48, 30, 30), "float32"), + relax.TensorType((2, 48, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w1, kernel_layout="IOHW8i", groups=8), - relax.TensorStructInfo((2, 48, 30, 30), "float32"), + relax.TensorType((2, 48, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x1, w0, data_layout="NCHW16c", groups=8), - relax.TensorStructInfo((2, 3, 30, 30, 16), "float32"), + relax.TensorType((2, 3, 30, 30, 16), "float32"), ) -def test_conv2d_transpose_infer_struct_info_symbolic_groups(): +def test_conv2d_transpose_infer_ty_symbolic_groups(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -1426,11 +1394,11 @@ def test_conv2d_transpose_infer_struct_info_symbolic_groups(): _check_inference( bb, relax.op.nn.conv2d_transpose(x, w0, groups=4), - relax.TensorStructInfo((n, oc * 4, 30, 30), "float32"), + relax.TensorType((n, oc * 4, 30, 30), "float32"), ) -def test_conv2d_transpose_infer_struct_info_input_channel_group_incompatible(): +def test_conv2d_transpose_infer_ty_input_channel_group_incompatible(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") ic = tirx.Var("c", "int64") @@ -1456,7 +1424,7 @@ def test_conv2d_transpose_non_positive_group(): relax.op.nn.conv2d_transpose(x, w, groups=-2) -def test_conv2d_transpose_infer_struct_info_more_input_dtype(): +def test_conv2d_transpose_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16")) @@ -1468,16 +1436,16 @@ def test_conv2d_transpose_infer_struct_info_more_input_dtype(): w3 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int32")) _check_inference( - bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30, 30), "float16") + bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorType((2, 4, 30, 30), "float16") ) _check_inference( - bb, relax.op.nn.conv2d_transpose(x1, w1), relax.TensorStructInfo((2, 4, 30, 30), "float64") + bb, relax.op.nn.conv2d_transpose(x1, w1), relax.TensorType((2, 4, 30, 30), "float64") ) _check_inference( - bb, relax.op.nn.conv2d_transpose(x2, w2), relax.TensorStructInfo((2, 4, 30, 30), "int8") + bb, relax.op.nn.conv2d_transpose(x2, w2), relax.TensorType((2, 4, 30, 30), "int8") ) _check_inference( - bb, relax.op.nn.conv2d_transpose(x3, w3), relax.TensorStructInfo((2, 4, 30, 30), "int32") + bb, relax.op.nn.conv2d_transpose(x3, w3), relax.TensorType((2, 4, 30, 30), "int32") ) @@ -1537,7 +1505,7 @@ def test_conv2d_transpose_wrong_strides_padding_dilation_length(): relax.op.nn.conv2d_transpose(x, w, dilation=(1, 2, 3)) -def test_conv2d_transpose_infer_struct_info_wrong_layout_string(): +def test_conv2d_transpose_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) @@ -1579,12 +1547,12 @@ def test_conv2d_transpose_wrong_input_ndim(): bb.normalize(relax.op.nn.conv2d_transpose(x2, w0)) -def test_conv2d_transpose_infer_struct_info_wrong_input_type(): +def test_conv2d_transpose_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) - x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.ShapeType((2, 3, 28, 28))) w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) - w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((3, 4, 3, 3), "float32"))) + w1 = relax.Var("w", relax.FuncType([], R.Tensor((3, 4, 3, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.conv2d_transpose(x0, w1)) @@ -1592,7 +1560,7 @@ def test_conv2d_transpose_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.conv2d_transpose(x1, w0)) -def test_conv2d_transpose_infer_struct_info_mixed_precision(): +def test_conv2d_transpose_infer_ty_mixed_precision(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16")) @@ -1602,37 +1570,37 @@ def test_conv2d_transpose_infer_struct_info_mixed_precision(): _check_inference( bb, relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float32"), - relax.TensorStructInfo((2, 4, 30, 30), "float32"), + relax.TensorType((2, 4, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv2d_transpose(x1, w1, out_dtype="int32"), - relax.TensorStructInfo((2, 4, 30, 30), "int32"), + relax.TensorType((2, 4, 30, 30), "int32"), ) -def test_conv3d_transpose_infer_struct_info(): +def test_conv3d_transpose_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) w0 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32")) _check_inference( bb, relax.op.nn.conv3d_transpose(x0, w0), - relax.TensorStructInfo((2, 4, 30, 30, 30), "float32"), + relax.TensorType((2, 4, 30, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.conv3d_transpose(x0, w0, padding=1), - relax.TensorStructInfo((2, 4, 28, 28, 28), "float32"), + relax.TensorType((2, 4, 28, 28, 28), "float32"), ) _check_inference( bb, relax.op.nn.conv3d_transpose(x0, w0, strides=2, output_padding=1), - relax.TensorStructInfo((2, 4, 58, 58, 58), "float32"), + relax.TensorType((2, 4, 58, 58, 58), "float32"), ) -def test_conv3d_transpose_infer_struct_info_ndhwc_out_layout(): +def test_conv3d_transpose_infer_ty_ndhwc_out_layout(): bb = relax.BlockBuilder() x_ndhwc = relax.Var("x_nd", R.Tensor((2, 28, 28, 28, 3), "float32")) x_ncdhw = relax.Var("x_nc", R.Tensor((2, 3, 28, 28, 28), "float32")) @@ -1640,24 +1608,24 @@ def test_conv3d_transpose_infer_struct_info_ndhwc_out_layout(): _check_inference( bb, relax.op.nn.conv3d_transpose(x_ndhwc, w0, data_layout="NDHWC"), - relax.TensorStructInfo((2, 30, 30, 30, 4), "float32"), + relax.TensorType((2, 30, 30, 30, 4), "float32"), ) # Default data_layout is NCDHW; use NCDHW-shaped input when only out_layout is NDHWC. _check_inference( bb, relax.op.nn.conv3d_transpose(x_ncdhw, w0, out_layout="NDHWC"), - relax.TensorStructInfo((2, 30, 30, 30, 4), "float32"), + relax.TensorType((2, 30, 30, 30, 4), "float32"), ) -def test_conv3d_transpose_infer_struct_info_groups(): +def test_conv3d_transpose_infer_ty_groups(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 128, 28, 28, 28), "float32")) w0 = relax.Var("w", R.Tensor((128, 16, 3, 3, 3), "float32")) _check_inference( bb, relax.op.nn.conv3d_transpose(x0, w0, groups=8), - relax.TensorStructInfo((2, 128, 30, 30, 30), "float32"), + relax.TensorType((2, 128, 30, 30, 30), "float32"), ) @@ -1681,7 +1649,7 @@ def test_conv3d_transpose_unequal_input_channel(): bb.normalize(relax.op.nn.conv3d_transpose(x0, w0)) -def test_conv3d_infer_struct_info(): +def test_conv3d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) @@ -1699,89 +1667,81 @@ def test_conv3d_infer_struct_info(): w5 = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32", vdev0)) _check_inference( - bb, relax.op.nn.conv3d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26, 26), "float32") + bb, relax.op.nn.conv3d(x0, w0), relax.TensorType((2, 4, 26, 26, 26), "float32") ) _check_inference( - bb, relax.op.nn.conv3d(x6, w5), relax.TensorStructInfo((2, 4, 26, 26, 26), "float32", vdev0) + bb, relax.op.nn.conv3d(x6, w5), relax.TensorType((2, 4, 26, 26, 26), "float32", vdev0) ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, out_dtype="float16"), - relax.TensorStructInfo((2, 4, 26, 26, 26), "float16"), + relax.TensorType((2, 4, 26, 26, 26), "float16"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, padding=1), - relax.TensorStructInfo((2, 4, 28, 28, 28), "float32"), + relax.TensorType((2, 4, 28, 28, 28), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, padding=[1, 2, 3]), - relax.TensorStructInfo((2, 4, 28, 30, 32), "float32"), + relax.TensorType((2, 4, 28, 30, 32), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, padding=[1, 2, 3, 4, 5, 6]), - relax.TensorStructInfo((2, 4, 31, 33, 35), "float32"), + relax.TensorType((2, 4, 31, 33, 35), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, strides=2), - relax.TensorStructInfo((2, 4, 13, 13, 13), "float32"), + relax.TensorType((2, 4, 13, 13, 13), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, strides=(2, 3, 4)), - relax.TensorStructInfo((2, 4, 13, 9, 7), "float32"), + relax.TensorType((2, 4, 13, 9, 7), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, dilation=2), - relax.TensorStructInfo((2, 4, 24, 24, 24), "float32"), + relax.TensorType((2, 4, 24, 24, 24), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, dilation=(3, 2, 1)), - relax.TensorStructInfo((2, 4, 22, 24, 26), "float32"), + relax.TensorType((2, 4, 22, 24, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x1, w0, data_layout="NDHWC"), - relax.TensorStructInfo((2, 26, 26, 26, 4), "float32"), + relax.TensorType((2, 26, 26, 26, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, out_layout="NDHWC"), - relax.TensorStructInfo((2, 26, 26, 26, 4), "float32"), + relax.TensorType((2, 26, 26, 26, 4), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w1, kernel_layout="IODHW"), - relax.TensorStructInfo((2, 4, 26, 26, 26), "float32"), + relax.TensorType((2, 4, 26, 26, 26), "float32"), ) _check_inference( bb, relax.op.nn.conv3d( x5, w4, data_layout="NCDHW16c", kernel_layout="OIDHW16i", out_layout="NDHWC16c" ), - relax.TensorStructInfo((2, 26, 26, 26, 3, 16), "float32"), - ) - _check_inference( - bb, relax.op.nn.conv3d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.nn.conv3d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.nn.conv3d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.nn.conv3d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=5) + relax.TensorType((2, 26, 26, 26, 3, 16), "float32"), ) - _check_inference(bb, relax.op.nn.conv3d(x4, w0), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.conv3d(x2, w0), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.conv3d(x3, w0), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.conv3d(x0, w2), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.conv3d(x0, w3), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.conv3d(x4, w0), relax.TensorType(dtype="", ndim=5)) -def test_conv3d_infer_struct_info_shape_symbolic(): +def test_conv3d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1803,24 +1763,24 @@ def test_conv3d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.conv3d(x0, w0), - relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w1), - relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv3d( x1, w2, data_layout="NCDHW16c", kernel_layout="OIDHW16i", out_layout="NCDHW" ), - relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), "float32"), + relax.TensorType((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), "float32"), ) _check_inference( bb, relax.op.nn.conv3d(x0, w0, strides=(2, 2, 2), padding=(1, 1, 1), dilation=(2, 2, 2)), - relax.TensorStructInfo( + relax.TensorType( ( n, ko, @@ -1833,32 +1793,32 @@ def test_conv3d_infer_struct_info_shape_symbolic(): ) -def test_conv3d_infer_struct_info_shape_var(): +def test_conv3d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) - s2 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s3 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) - w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=5)) + s1 = relax.Var("s", relax.ShapeType(ndim=6)) + s2 = relax.Var("s", relax.ShapeType(ndim=5)) + s3 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s3, "float32")) + w = relax.Var("w", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.nn.conv3d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.conv3d(x0, w), relax.TensorType(dtype="float32", ndim=5)) _check_inference( bb, relax.op.nn.conv3d(x1, w, data_layout="NCDHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=6), + relax.TensorType(dtype="float32", ndim=6), ) _check_inference( bb, relax.op.nn.conv3d(x0, w, out_layout="NCDHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=6), + relax.TensorType(dtype="float32", ndim=6), ) _check_inference( bb, relax.op.nn.conv3d(x2, w), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index bfa04b634a2d..159a2802834e 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -39,12 +39,12 @@ def test_op_correctness(): assert relax.op.nn.adaptive_avg_pool3d(x).op == Op.get("relax.nn.adaptive_avg_pool3d") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_max_pool1d_infer_struct_info(): +def test_max_pool1d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) @@ -54,38 +54,32 @@ def test_max_pool1d_infer_struct_info(): x4 = relax.Var("x", R.Tensor()) x5 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) - _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float32")) + _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorType((2, 3, 32), "float32")) + _check_inference(bb, relax.op.nn.max_pool1d(x5), relax.TensorType((2, 3, 32), "float32", vdev0)) _check_inference( - bb, relax.op.nn.max_pool1d(x5), relax.TensorStructInfo((2, 3, 32), "float32", vdev0) + bb, relax.op.nn.max_pool1d(x0, pool_size=3), relax.TensorType((2, 3, 30), "float32") ) _check_inference( - bb, relax.op.nn.max_pool1d(x0, pool_size=3), relax.TensorStructInfo((2, 3, 30), "float32") + bb, relax.op.nn.max_pool1d(x0, strides=2), relax.TensorType((2, 3, 16), "float32") ) _check_inference( - bb, relax.op.nn.max_pool1d(x0, strides=2), relax.TensorStructInfo((2, 3, 16), "float32") + bb, relax.op.nn.max_pool1d(x0, padding=1), relax.TensorType((2, 3, 34), "float32") ) _check_inference( - bb, relax.op.nn.max_pool1d(x0, padding=1), relax.TensorStructInfo((2, 3, 34), "float32") - ) - _check_inference( - bb, relax.op.nn.max_pool1d(x0, dilation=2), relax.TensorStructInfo((2, 3, 32), "float32") + bb, relax.op.nn.max_pool1d(x0, dilation=2), relax.TensorType((2, 3, 32), "float32") ) _check_inference( bb, relax.op.nn.max_pool1d(x0, layout="NCW", out_layout="NWC"), - relax.TensorStructInfo((2, 32, 3), "float32"), - ) - _check_inference( - bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference( - bb, relax.op.nn.max_pool1d(x3), relax.TensorStructInfo(dtype="float32", ndim=3) + relax.TensorType((2, 32, 3), "float32"), ) - _check_inference(bb, relax.op.nn.max_pool1d(x4), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.nn.max_pool1d(x1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.nn.max_pool1d(x3), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.max_pool1d(x4), relax.TensorType(dtype="", ndim=3)) -def test_max_pool1d_infer_struct_info_shape_symbolic(): +def test_max_pool1d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -98,7 +92,7 @@ def test_max_pool1d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.max_pool1d(x0, pool_size=3, strides=3, padding=2, dilation=2), - relax.TensorStructInfo( + relax.TensorType( ( n, c, @@ -110,52 +104,50 @@ def test_max_pool1d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.max_pool1d(x1, layout="NCW16c", out_layout="NWC"), - relax.TensorStructInfo((n, w, c * 16), "float32"), + relax.TensorType((n, w, c * 16), "float32"), ) -def test_max_pool1d_infer_struct_info_shape_var(): +def test_max_pool1d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType(ndim=3)) + s1 = relax.Var("s", relax.ShapeType(ndim=4)) + s2 = relax.Var("s", relax.ShapeType()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference( - bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo(dtype="float32", ndim=3) - ) + _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorType(dtype="float32", ndim=3)) _check_inference( bb, relax.op.nn.max_pool1d(x1, layout="NCW16c"), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.max_pool1d(x2), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) -def test_max_pool1d_infer_struct_info_ceil_mode(): +def test_max_pool1d_infer_ty_ceil_mode(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 32), "float32")) _check_inference( bb, relax.op.nn.max_pool1d(x, pool_size=3, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 16), "float32"), + relax.TensorType((2, 3, 16), "float32"), ) _check_inference( bb, relax.op.nn.max_pool1d(x, pool_size=5, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 15), "float32"), + relax.TensorType((2, 3, 15), "float32"), ) -def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): +def test_max_pool1d_infer_ty_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -165,19 +157,19 @@ def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): _check_inference( bb, relax.op.nn.max_pool1d(x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True), - relax.TensorStructInfo((n, c, tvm.tirx.floordiv(w, 2)), "float32"), + relax.TensorType((n, c, tvm.tirx.floordiv(w, 2)), "float32"), ) -def test_max_pool1d_infer_struct_info_more_input_dtype(): +def test_max_pool1d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64")) - _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float16")) - _check_inference(bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo((2, 3, 32), "int8")) - _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo((2, 3, 32), "int64")) + _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorType((2, 3, 32), "float16")) + _check_inference(bb, relax.op.nn.max_pool1d(x1), relax.TensorType((2, 3, 32), "int8")) + _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorType((2, 3, 32), "int64")) def test_max_pool1d_stride_padding_dilation_int64(): @@ -202,7 +194,7 @@ def test_max_pool1d_wrong_pool_size_strides_padding_dilation_length(): relax.op.nn.max_pool1d(x, dilation=(1, 2)) -def test_max_pool1d_infer_struct_info_wrong_layout_string(): +def test_max_pool1d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) with pytest.raises(ValueError): @@ -223,10 +215,10 @@ def test_max_pool1d_wrong_input_ndim(): bb.normalize(relax.op.nn.max_pool1d(x1)) -def test_max_pool1d_infer_struct_info_wrong_input_type(): +def test_max_pool1d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.max_pool1d(x0)) @@ -235,7 +227,7 @@ def test_max_pool1d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.max_pool1d(x1)) -def test_max_pool2d_infer_struct_info(): +def test_max_pool2d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) @@ -247,66 +239,60 @@ def test_max_pool2d_infer_struct_info(): x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0)) + _check_inference(bb, relax.op.nn.max_pool2d(x0), relax.TensorType((2, 3, 32, 32), "float32")) _check_inference( - bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") - ) - _check_inference( - bb, relax.op.nn.max_pool2d(x7), relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0) + bb, relax.op.nn.max_pool2d(x7), relax.TensorType((2, 3, 32, 32), "float32", vdev0) ) _check_inference( bb, relax.op.nn.max_pool2d(x0, pool_size=3), - relax.TensorStructInfo((2, 3, 30, 30), "float32"), + relax.TensorType((2, 3, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.max_pool2d(x0, pool_size=(5, 3)), - relax.TensorStructInfo((2, 3, 28, 30), "float32"), + relax.TensorType((2, 3, 28, 30), "float32"), ) _check_inference( - bb, relax.op.nn.max_pool2d(x0, padding=1), relax.TensorStructInfo((2, 3, 34, 34), "float32") + bb, relax.op.nn.max_pool2d(x0, padding=1), relax.TensorType((2, 3, 34, 34), "float32") ) _check_inference( bb, relax.op.nn.max_pool2d(x0, padding=[1, 2]), - relax.TensorStructInfo((2, 3, 34, 36), "float32"), + relax.TensorType((2, 3, 34, 36), "float32"), ) _check_inference( bb, relax.op.nn.max_pool2d(x0, strides=2), - relax.TensorStructInfo((2, 3, 16, 16), "float32"), + relax.TensorType((2, 3, 16, 16), "float32"), ) _check_inference( bb, relax.op.nn.max_pool2d(x0, dilation=2), - relax.TensorStructInfo((2, 3, 32, 32), "float32"), + relax.TensorType((2, 3, 32, 32), "float32"), ) _check_inference( bb, relax.op.nn.max_pool2d(x1, layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.TensorType((2, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.max_pool2d(x0, out_layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.TensorType((2, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.max_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), - relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + relax.TensorType((2, 32, 32, 4, 16), "float32"), ) - _check_inference( - bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference( - bb, relax.op.nn.max_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference(bb, relax.op.nn.max_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4)) - _check_inference(bb, relax.op.nn.max_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.max_pool2d(x2), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.max_pool2d(x3), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.max_pool2d(x4), relax.TensorType(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.max_pool2d(x5), relax.TensorType(dtype="", ndim=4)) -def test_max_pool2d_infer_struct_info_shape_symbolic(): +def test_max_pool2d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -321,7 +307,7 @@ def test_max_pool2d_infer_struct_info_shape_symbolic(): relax.op.nn.max_pool2d( x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 2) ), - relax.TensorStructInfo( + relax.TensorType( ( n, c, @@ -334,51 +320,49 @@ def test_max_pool2d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.max_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), - relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + relax.TensorType((n, ih, iw, c * 16), "float32"), ) -def test_max_pool2d_infer_struct_info_shape_var(): +def test_max_pool2d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType(ndim=5)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference( - bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo(dtype="float32", ndim=4) - ) + _check_inference(bb, relax.op.nn.max_pool2d(x0), relax.TensorType(dtype="float32", ndim=4)) _check_inference( bb, relax.op.nn.max_pool2d(x1, layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.nn.max_pool2d(x2), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) -def test_max_pool2d_infer_struct_info_ceil_mode(): +def test_max_pool2d_infer_ty_ceil_mode(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) _check_inference( bb, relax.op.nn.max_pool2d(x, pool_size=3, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 16, 16), "float32"), + relax.TensorType((2, 3, 16, 16), "float32"), ) _check_inference( bb, relax.op.nn.max_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 15, 16), "float32"), + relax.TensorType((2, 3, 15, 16), "float32"), ) -def test_max_pool2d_infer_struct_info_ceil_mode_symbolic(): +def test_max_pool2d_infer_ty_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -391,24 +375,18 @@ def test_max_pool2d_infer_struct_info_ceil_mode_symbolic(): relax.op.nn.max_pool2d( x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True ), - relax.TensorStructInfo( - (n, c, tvm.tirx.floordiv(ih, 2), tvm.tirx.floordiv(iw, 2)), "float32" - ), + relax.TensorType((n, c, tvm.tirx.floordiv(ih, 2), tvm.tirx.floordiv(iw, 2)), "float32"), ) -def test_max_pool2d_infer_struct_info_more_input_dtype(): +def test_max_pool2d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) - _check_inference( - bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") - ) - _check_inference(bb, relax.op.nn.max_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8")) - _check_inference( - bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") - ) + _check_inference(bb, relax.op.nn.max_pool2d(x0), relax.TensorType((2, 3, 32, 32), "float16")) + _check_inference(bb, relax.op.nn.max_pool2d(x1), relax.TensorType((2, 3, 32, 32), "int8")) + _check_inference(bb, relax.op.nn.max_pool2d(x2), relax.TensorType((2, 3, 32, 32), "int64")) def test_max_pool2d_stride_padding_dilation_int64(): @@ -437,7 +415,7 @@ def test_max_pool2d_wrong_pool_size_strides_padding_dilation_length(): relax.op.nn.max_pool2d(x, dilation=(1, 2, 3)) -def test_max_pool2d_infer_struct_info_wrong_layout_string(): +def test_max_pool2d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) with pytest.raises(ValueError): @@ -456,10 +434,10 @@ def test_max_pool2d_wrong_input_ndim(): bb.normalize(relax.op.nn.max_pool2d(x1)) -def test_max_pool2d_infer_struct_info_wrong_input_type(): +def test_max_pool2d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28, 28), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.max_pool2d(x0)) @@ -467,7 +445,7 @@ def test_max_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.max_pool2d(x1)) -def test_max_pool3d_infer_struct_info(): +def test_max_pool3d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32")) @@ -480,67 +458,63 @@ def test_max_pool3d_infer_struct_info(): x7 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32", vdev0)) _check_inference( - bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 16, 32, 32), "float32") + bb, relax.op.nn.max_pool3d(x0), relax.TensorType((2, 3, 16, 32, 32), "float32") ) _check_inference( - bb, relax.op.nn.max_pool3d(x7), relax.TensorStructInfo((2, 3, 16, 32, 32), "float32", vdev0) + bb, relax.op.nn.max_pool3d(x7), relax.TensorType((2, 3, 16, 32, 32), "float32", vdev0) ) _check_inference( bb, relax.op.nn.max_pool3d(x0, pool_size=3), - relax.TensorStructInfo((2, 3, 14, 30, 30), "float32"), + relax.TensorType((2, 3, 14, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x0, pool_size=(3, 5, 3)), - relax.TensorStructInfo((2, 3, 14, 28, 30), "float32"), + relax.TensorType((2, 3, 14, 28, 30), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x0, padding=1), - relax.TensorStructInfo((2, 3, 18, 34, 34), "float32"), + relax.TensorType((2, 3, 18, 34, 34), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x0, padding=[1, 2, 3]), - relax.TensorStructInfo((2, 3, 18, 36, 38), "float32"), + relax.TensorType((2, 3, 18, 36, 38), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x0, strides=2), - relax.TensorStructInfo((2, 3, 8, 16, 16), "float32"), + relax.TensorType((2, 3, 8, 16, 16), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x0, dilation=2), - relax.TensorStructInfo((2, 3, 16, 32, 32), "float32"), + relax.TensorType((2, 3, 16, 32, 32), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x1, layout="NDHWC"), - relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"), + relax.TensorType((2, 16, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x0, out_layout="NDHWC"), - relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"), + relax.TensorType((2, 16, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), - relax.TensorStructInfo((2, 16, 32, 32, 4, 16), "float32"), + relax.TensorType((2, 16, 32, 32, 4, 16), "float32"), ) - _check_inference( - bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.nn.max_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference(bb, relax.op.nn.max_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference(bb, relax.op.nn.max_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.max_pool3d(x2), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.max_pool3d(x3), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.max_pool3d(x4), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.max_pool3d(x5), relax.TensorType(dtype="", ndim=5)) -def test_max_pool3d_infer_struct_info_shape_symbolic(): +def test_max_pool3d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -556,7 +530,7 @@ def test_max_pool3d_infer_struct_info_shape_symbolic(): relax.op.nn.max_pool3d( x0, pool_size=(3, 3, 3), strides=(3, 3, 3), padding=(2, 2, 2), dilation=(2, 2, 2) ), - relax.TensorStructInfo( + relax.TensorType( ( n, c, @@ -571,51 +545,49 @@ def test_max_pool3d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.max_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), - relax.TensorStructInfo((n, id, ih, iw, c * 16), "float32"), + relax.TensorType((n, id, ih, iw, c * 16), "float32"), ) -def test_max_pool3d_infer_struct_info_shape_var(): +def test_max_pool3d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=5)) + s1 = relax.Var("s", relax.ShapeType(ndim=6)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference( - bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo(dtype="float32", ndim=5) - ) + _check_inference(bb, relax.op.nn.max_pool3d(x0), relax.TensorType(dtype="float32", ndim=5)) _check_inference( bb, relax.op.nn.max_pool3d(x1, layout="NCDHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=6), + relax.TensorType(dtype="float32", ndim=6), ) _check_inference( bb, relax.op.nn.max_pool3d(x2), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) -def test_max_pool3d_infer_struct_info_ceil_mode(): +def test_max_pool3d_infer_ty_ceil_mode(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) _check_inference( bb, relax.op.nn.max_pool3d(x, pool_size=3, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), + relax.TensorType((2, 3, 16, 16, 16), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x, pool_size=(5, 3, 3), strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 15, 16, 16), "float32"), + relax.TensorType((2, 3, 15, 16, 16), "float32"), ) -def test_max_pool3d_infer_struct_info_ceil_mode_symbolic(): +def test_max_pool3d_infer_ty_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -634,27 +606,23 @@ def test_max_pool3d_infer_struct_info_ceil_mode_symbolic(): dilation=(2, 2, 2), ceil_mode=True, ), - relax.TensorStructInfo( + relax.TensorType( (n, c, tvm.tirx.floordiv(id_, 2), tvm.tirx.floordiv(ih, 2), tvm.tirx.floordiv(iw, 2)), "float32", ), ) -def test_max_pool3d_infer_struct_info_more_input_dtype(): +def test_max_pool3d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64")) _check_inference( - bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float16") - ) - _check_inference( - bb, relax.op.nn.max_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") - ) - _check_inference( - bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") + bb, relax.op.nn.max_pool3d(x0), relax.TensorType((2, 3, 32, 32, 32), "float16") ) + _check_inference(bb, relax.op.nn.max_pool3d(x1), relax.TensorType((2, 3, 32, 32, 32), "int8")) + _check_inference(bb, relax.op.nn.max_pool3d(x2), relax.TensorType((2, 3, 32, 32, 32), "int64")) def test_max_pool3d_stride_padding_dilation_int64(): @@ -688,7 +656,7 @@ def test_max_pool3d_wrong_pool_size_strides_padding_dilation_length(): relax.op.nn.max_pool3d(x, dilation=(1, 2, 3, 4)) -def test_max_pool3d_infer_struct_info_wrong_layout_string(): +def test_max_pool3d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) with pytest.raises(ValueError): @@ -707,10 +675,10 @@ def test_max_pool3d_wrong_input_ndim(): bb.normalize(relax.op.nn.max_pool3d(x1)) -def test_max_pool3d_infer_struct_info_wrong_input_type(): +def test_max_pool3d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28, 28, 28), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.max_pool3d(x0)) @@ -718,7 +686,7 @@ def test_max_pool3d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.max_pool3d(x1)) -def test_avg_pool1d_infer_struct_info(): +def test_avg_pool1d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) @@ -730,56 +698,50 @@ def test_avg_pool1d_infer_struct_info(): x6 = relax.Var("x", R.Tensor((2, 4, 32, 16), "float32")) x7 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) - _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float32")) - _check_inference( - bb, relax.op.nn.avg_pool1d(x7), relax.TensorStructInfo((2, 3, 32), "float32", vdev0) - ) + _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorType((2, 3, 32), "float32")) + _check_inference(bb, relax.op.nn.avg_pool1d(x7), relax.TensorType((2, 3, 32), "float32", vdev0)) _check_inference( bb, relax.op.nn.avg_pool1d(x0, pool_size=3), - relax.TensorStructInfo((2, 3, 30), "float32"), + relax.TensorType((2, 3, 30), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool1d(x0, padding=1), - relax.TensorStructInfo((2, 3, 34), "float32"), + relax.TensorType((2, 3, 34), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool1d(x0, padding=[1, 2]), - relax.TensorStructInfo((2, 3, 35), "float32"), + relax.TensorType((2, 3, 35), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool1d(x0, strides=2), - relax.TensorStructInfo((2, 3, 16), "float32"), + relax.TensorType((2, 3, 16), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool1d(x0, dilation=2), - relax.TensorStructInfo((2, 3, 32), "float32"), + relax.TensorType((2, 3, 32), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool1d(x1, layout="NWC"), - relax.TensorStructInfo((2, 32, 3), "float32"), + relax.TensorType((2, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool1d(x0, out_layout="NWC"), - relax.TensorStructInfo((2, 32, 3), "float32"), - ) - _check_inference( - bb, relax.op.nn.avg_pool1d(x2), relax.TensorStructInfo(dtype="float32", ndim=3) - ) - _check_inference( - bb, relax.op.nn.avg_pool1d(x3), relax.TensorStructInfo(dtype="float32", ndim=3) + relax.TensorType((2, 32, 3), "float32"), ) - _check_inference(bb, relax.op.nn.avg_pool1d(x4), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.nn.avg_pool1d(x5), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.nn.avg_pool1d(x2), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.avg_pool1d(x3), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.avg_pool1d(x4), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.nn.avg_pool1d(x5), relax.TensorType(dtype="", ndim=3)) -def test_avg_pool1d_infer_struct_info_shape_symbolic(): +def test_avg_pool1d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -791,7 +753,7 @@ def test_avg_pool1d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.avg_pool1d(x0, pool_size=3, strides=3, padding=2, dilation=2), - relax.TensorStructInfo( + relax.TensorType( ( n, c, @@ -803,51 +765,49 @@ def test_avg_pool1d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.avg_pool1d(x1, layout="NCW16c", out_layout="NWC"), - relax.TensorStructInfo((n, iw, c * 16), "float32"), + relax.TensorType((n, iw, c * 16), "float32"), ) -def test_avg_pool1d_infer_struct_info_shape_var(): +def test_avg_pool1d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=3)) + s1 = relax.Var("s", relax.ShapeType(ndim=4)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference( - bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo(dtype="float32", ndim=3) - ) + _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorType(dtype="float32", ndim=3)) _check_inference( bb, relax.op.nn.avg_pool1d(x1, layout="NCW16c"), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.avg_pool1d(x2), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) -def test_avg_pool1d_infer_struct_info_ceil_mode(): +def test_avg_pool1d_infer_ty_ceil_mode(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 32), "float32")) _check_inference( bb, relax.op.nn.avg_pool1d(x, pool_size=3, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 16), "float32"), + relax.TensorType((2, 3, 16), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool1d(x, pool_size=5, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 15), "float32"), + relax.TensorType((2, 3, 15), "float32"), ) -def test_avg_pool1d_infer_struct_info_ceil_mode_symbolic(): +def test_avg_pool1d_infer_ty_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -857,21 +817,21 @@ def test_avg_pool1d_infer_struct_info_ceil_mode_symbolic(): _check_inference( bb, relax.op.nn.avg_pool1d(x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True), - relax.TensorStructInfo( + relax.TensorType( (n, c, tvm.tirx.floordiv(iw, 2)), "float32", ), ) -def test_avg_pool1d_infer_struct_info_more_input_dtype(): +def test_avg_pool1d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64")) - _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float16")) - _check_inference(bb, relax.op.nn.avg_pool1d(x1), relax.TensorStructInfo((2, 3, 32), "int8")) - _check_inference(bb, relax.op.nn.avg_pool1d(x2), relax.TensorStructInfo((2, 3, 32), "int64")) + _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorType((2, 3, 32), "float16")) + _check_inference(bb, relax.op.nn.avg_pool1d(x1), relax.TensorType((2, 3, 32), "int8")) + _check_inference(bb, relax.op.nn.avg_pool1d(x2), relax.TensorType((2, 3, 32), "int64")) def test_avg_pool1d_stride_padding_dilation_int64(): @@ -896,7 +856,7 @@ def test_avg_pool1d_wrong_pool_size_strides_padding_dilation_length(): relax.op.nn.avg_pool1d(x, dilation=(1, 2)) -def test_avg_pool1d_infer_struct_info_wrong_layout_string(): +def test_avg_pool1d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) with pytest.raises(ValueError): @@ -915,10 +875,10 @@ def test_avg_pool1d_wrong_input_ndim(): bb.normalize(relax.op.nn.avg_pool1d(x1)) -def test_avg_pool1d_infer_struct_info_wrong_input_type(): +def test_avg_pool1d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.avg_pool1d(x0)) @@ -926,7 +886,7 @@ def test_avg_pool1d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.avg_pool1d(x1)) -def test_avg_pool2d_infer_struct_info(): +def test_avg_pool2d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) @@ -938,66 +898,60 @@ def test_avg_pool2d_infer_struct_info(): x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0)) + _check_inference(bb, relax.op.nn.avg_pool2d(x0), relax.TensorType((2, 3, 32, 32), "float32")) _check_inference( - bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") - ) - _check_inference( - bb, relax.op.nn.avg_pool2d(x7), relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0) + bb, relax.op.nn.avg_pool2d(x7), relax.TensorType((2, 3, 32, 32), "float32", vdev0) ) _check_inference( bb, relax.op.nn.avg_pool2d(x0, pool_size=3), - relax.TensorStructInfo((2, 3, 30, 30), "float32"), + relax.TensorType((2, 3, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool2d(x0, pool_size=(5, 3)), - relax.TensorStructInfo((2, 3, 28, 30), "float32"), + relax.TensorType((2, 3, 28, 30), "float32"), ) _check_inference( - bb, relax.op.nn.avg_pool2d(x0, padding=1), relax.TensorStructInfo((2, 3, 34, 34), "float32") + bb, relax.op.nn.avg_pool2d(x0, padding=1), relax.TensorType((2, 3, 34, 34), "float32") ) _check_inference( bb, relax.op.nn.avg_pool2d(x0, padding=[1, 2]), - relax.TensorStructInfo((2, 3, 34, 36), "float32"), + relax.TensorType((2, 3, 34, 36), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool2d(x0, strides=2), - relax.TensorStructInfo((2, 3, 16, 16), "float32"), + relax.TensorType((2, 3, 16, 16), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool2d(x0, dilation=2), - relax.TensorStructInfo((2, 3, 32, 32), "float32"), + relax.TensorType((2, 3, 32, 32), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool2d(x1, layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.TensorType((2, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool2d(x0, out_layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.TensorType((2, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), - relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), - ) - _check_inference( - bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference( - bb, relax.op.nn.avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + relax.TensorType((2, 32, 32, 4, 16), "float32"), ) - _check_inference(bb, relax.op.nn.avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4)) - _check_inference(bb, relax.op.nn.avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.avg_pool2d(x2), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.avg_pool2d(x3), relax.TensorType(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.nn.avg_pool2d(x4), relax.TensorType(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.avg_pool2d(x5), relax.TensorType(dtype="", ndim=4)) -def test_avg_pool2d_infer_struct_info_shape_symbolic(): +def test_avg_pool2d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1012,7 +966,7 @@ def test_avg_pool2d_infer_struct_info_shape_symbolic(): relax.op.nn.avg_pool2d( x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 2) ), - relax.TensorStructInfo( + relax.TensorType( ( n, c, @@ -1025,51 +979,49 @@ def test_avg_pool2d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), - relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + relax.TensorType((n, ih, iw, c * 16), "float32"), ) -def test_avg_pool2d_infer_struct_info_shape_var(): +def test_avg_pool2d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType(ndim=5)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference( - bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo(dtype="float32", ndim=4) - ) + _check_inference(bb, relax.op.nn.avg_pool2d(x0), relax.TensorType(dtype="float32", ndim=4)) _check_inference( bb, relax.op.nn.avg_pool2d(x1, layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.nn.avg_pool2d(x2), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) -def test_avg_pool2d_infer_struct_info_ceil_mode(): +def test_avg_pool2d_infer_ty_ceil_mode(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) _check_inference( bb, relax.op.nn.avg_pool2d(x, pool_size=3, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 16, 16), "float32"), + relax.TensorType((2, 3, 16, 16), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 15, 16), "float32"), + relax.TensorType((2, 3, 15, 16), "float32"), ) -def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic(): +def test_avg_pool2d_infer_ty_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1082,24 +1034,18 @@ def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic(): relax.op.nn.avg_pool2d( x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True ), - relax.TensorStructInfo( - (n, c, tvm.tirx.floordiv(ih, 2), tvm.tirx.floordiv(iw, 2)), "float32" - ), + relax.TensorType((n, c, tvm.tirx.floordiv(ih, 2), tvm.tirx.floordiv(iw, 2)), "float32"), ) -def test_avg_pool2d_infer_struct_info_more_input_dtype(): +def test_avg_pool2d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) - _check_inference( - bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") - ) - _check_inference(bb, relax.op.nn.avg_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8")) - _check_inference( - bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") - ) + _check_inference(bb, relax.op.nn.avg_pool2d(x0), relax.TensorType((2, 3, 32, 32), "float16")) + _check_inference(bb, relax.op.nn.avg_pool2d(x1), relax.TensorType((2, 3, 32, 32), "int8")) + _check_inference(bb, relax.op.nn.avg_pool2d(x2), relax.TensorType((2, 3, 32, 32), "int64")) def test_avg_pool2d_stride_padding_dilation_int64(): @@ -1128,7 +1074,7 @@ def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length(): relax.op.nn.avg_pool2d(x, dilation=(1, 2, 3)) -def test_avg_pool2d_infer_struct_info_wrong_layout_string(): +def test_avg_pool2d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) with pytest.raises(ValueError): @@ -1147,10 +1093,10 @@ def test_avg_pool2d_wrong_input_ndim(): bb.normalize(relax.op.nn.avg_pool2d(x1)) -def test_avg_pool2d_infer_struct_info_wrong_input_type(): +def test_avg_pool2d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28, 28), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.avg_pool2d(x0)) @@ -1158,7 +1104,7 @@ def test_avg_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.avg_pool2d(x1)) -def test_avg_pool3d_infer_struct_info(): +def test_avg_pool3d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") @@ -1172,67 +1118,63 @@ def test_avg_pool3d_infer_struct_info(): x7 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32", vdev0)) _check_inference( - bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float32") + bb, relax.op.nn.avg_pool3d(x0), relax.TensorType((2, 3, 32, 32, 32), "float32") ) _check_inference( - bb, relax.op.nn.avg_pool3d(x7), relax.TensorStructInfo((2, 3, 32, 32, 32), "float32", vdev0) + bb, relax.op.nn.avg_pool3d(x7), relax.TensorType((2, 3, 32, 32, 32), "float32", vdev0) ) _check_inference( bb, relax.op.nn.avg_pool3d(x0, pool_size=3), - relax.TensorStructInfo((2, 3, 30, 30, 30), "float32"), + relax.TensorType((2, 3, 30, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x0, pool_size=(5, 3, 3)), - relax.TensorStructInfo((2, 3, 28, 30, 30), "float32"), + relax.TensorType((2, 3, 28, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x0, padding=1), - relax.TensorStructInfo((2, 3, 34, 34, 34), "float32"), + relax.TensorType((2, 3, 34, 34, 34), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x0, padding=[1, 2, 3]), - relax.TensorStructInfo((2, 3, 34, 36, 38), "float32"), + relax.TensorType((2, 3, 34, 36, 38), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x0, strides=2), - relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), + relax.TensorType((2, 3, 16, 16, 16), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x0, dilation=2), - relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + relax.TensorType((2, 3, 32, 32, 32), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x1, layout="NCDHW"), - relax.TensorStructInfo((2, 32, 32, 32, 3), "float32"), + relax.TensorType((2, 32, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x0, out_layout="NCDHW"), - relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + relax.TensorType((2, 3, 32, 32, 32), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), - relax.TensorStructInfo((2, 32, 32, 32, 4, 16), "float32"), - ) - _check_inference( - bb, relax.op.nn.avg_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.nn.avg_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) + relax.TensorType((2, 32, 32, 32, 4, 16), "float32"), ) - _check_inference(bb, relax.op.nn.avg_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference(bb, relax.op.nn.avg_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.avg_pool3d(x2), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.avg_pool3d(x3), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.nn.avg_pool3d(x4), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.avg_pool3d(x5), relax.TensorType(dtype="", ndim=5)) -def test_avg_pool3d_infer_struct_info_shape_symbolic(): +def test_avg_pool3d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1248,7 +1190,7 @@ def test_avg_pool3d_infer_struct_info_shape_symbolic(): relax.op.nn.avg_pool3d( x0, pool_size=(3, 3, 3), strides=(3, 3, 3), padding=(2, 2, 2), dilation=(2, 2, 2) ), - relax.TensorStructInfo( + relax.TensorType( ( n, c, @@ -1262,51 +1204,49 @@ def test_avg_pool3d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.avg_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), - relax.TensorStructInfo((n, id_, ih, iw, c * 16), "float32"), + relax.TensorType((n, id_, ih, iw, c * 16), "float32"), ) -def test_avg_pool3d_infer_struct_info_shape_var(): +def test_avg_pool3d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=5)) + s1 = relax.Var("s", relax.ShapeType(ndim=6)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference( - bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo(dtype="float32", ndim=5) - ) + _check_inference(bb, relax.op.nn.avg_pool3d(x0), relax.TensorType(dtype="float32", ndim=5)) _check_inference( bb, relax.op.nn.avg_pool3d(x1, layout="NCDHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=6), + relax.TensorType(dtype="float32", ndim=6), ) _check_inference( bb, relax.op.nn.avg_pool3d(x2), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) -def test_avg_pool3d_infer_struct_info_ceil_mode(): +def test_avg_pool3d_infer_ty_ceil_mode(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) _check_inference( bb, relax.op.nn.avg_pool3d(x, pool_size=3, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), + relax.TensorType((2, 3, 16, 16, 16), "float32"), ) _check_inference( bb, relax.op.nn.avg_pool3d(x, pool_size=(5, 3, 3), strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 15, 16, 16), "float32"), + relax.TensorType((2, 3, 15, 16, 16), "float32"), ) -def test_avg_pool3d_infer_struct_info_ceil_mode_symbolic(): +def test_avg_pool3d_infer_ty_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1325,7 +1265,7 @@ def test_avg_pool3d_infer_struct_info_ceil_mode_symbolic(): dilation=(2, 2, 2), ceil_mode=True, ), - relax.TensorStructInfo( + relax.TensorType( ( n, c, @@ -1338,21 +1278,17 @@ def test_avg_pool3d_infer_struct_info_ceil_mode_symbolic(): ) -def test_avg_pool3d_infer_struct_info_more_input_dtype(): +def test_avg_pool3d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64")) _check_inference( - bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float16") - ) - _check_inference( - bb, relax.op.nn.avg_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") - ) - _check_inference( - bb, relax.op.nn.avg_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") + bb, relax.op.nn.avg_pool3d(x0), relax.TensorType((2, 3, 32, 32, 32), "float16") ) + _check_inference(bb, relax.op.nn.avg_pool3d(x1), relax.TensorType((2, 3, 32, 32, 32), "int8")) + _check_inference(bb, relax.op.nn.avg_pool3d(x2), relax.TensorType((2, 3, 32, 32, 32), "int64")) def test_avg_pool3d_stride_padding_dilation_int64(): @@ -1384,7 +1320,7 @@ def test_avg_pool3d_wrong_pool_size_strides_padding_dilation_length(): relax.op.nn.avg_pool3d(x, dilation=(1, 2, 3, 4)) -def test_avg_pool3d_infer_struct_info_wrong_layout_string(): +def test_avg_pool3d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) with pytest.raises(ValueError): @@ -1403,10 +1339,10 @@ def test_avg_pool3d_wrong_input_ndim(): bb.normalize(relax.op.nn.avg_pool3d(x1)) -def test_avg_pool3d_infer_struct_info_wrong_input_type(): +def test_avg_pool3d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28, 28, 28), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.avg_pool3d(x0)) @@ -1414,7 +1350,7 @@ def test_avg_pool3d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.avg_pool3d(x1)) -def test_adaptive_avg_pool1d_infer_struct_info(): +def test_adaptive_avg_pool1d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") @@ -1429,41 +1365,41 @@ def test_adaptive_avg_pool1d_infer_struct_info(): _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x0), - relax.TensorStructInfo((2, 3, 32), "float32"), + relax.TensorType((2, 3, 32), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x5), - relax.TensorStructInfo((2, 3, 32), "float32", vdev0), + relax.TensorType((2, 3, 32), "float32", vdev0), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x0, output_size=16), - relax.TensorStructInfo((2, 3, 16), "float32"), + relax.TensorType((2, 3, 16), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x1), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x2), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x3), - relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x4), - relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), ) -def test_adaptive_avg_pool1d_infer_struct_info_shape_symbolic(): +def test_adaptive_avg_pool1d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1474,55 +1410,51 @@ def test_adaptive_avg_pool1d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x0), - relax.TensorStructInfo((n, c, l), "float32"), + relax.TensorType((n, c, l), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x0, output_size=64), - relax.TensorStructInfo((n, c, 64), "float32"), + relax.TensorType((n, c, 64), "float32"), ) -def test_adaptive_avg_pool1d_infer_struct_info_shape_var(): +def test_adaptive_avg_pool1d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s1 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType(ndim=3)) + s1 = relax.Var("s", relax.ShapeType()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x0), - relax.TensorStructInfo(s0, "float32"), + relax.TensorType(s0, "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x0, output_size=20), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool1d(x1), - relax.TensorStructInfo(s1, dtype="float32"), + relax.TensorType(s1, dtype="float32"), ) -def test_adaptive_avg_pool1d_infer_struct_info_more_input_dtype(): +def test_adaptive_avg_pool1d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 64), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 64), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 64), "int64")) _check_inference( - bb, relax.op.nn.adaptive_avg_pool1d(x0), relax.TensorStructInfo((2, 3, 64), "float16") - ) - _check_inference( - bb, relax.op.nn.adaptive_avg_pool1d(x1), relax.TensorStructInfo((2, 3, 64), "int8") - ) - _check_inference( - bb, relax.op.nn.adaptive_avg_pool1d(x2), relax.TensorStructInfo((2, 3, 64), "int64") + bb, relax.op.nn.adaptive_avg_pool1d(x0), relax.TensorType((2, 3, 64), "float16") ) + _check_inference(bb, relax.op.nn.adaptive_avg_pool1d(x1), relax.TensorType((2, 3, 64), "int8")) + _check_inference(bb, relax.op.nn.adaptive_avg_pool1d(x2), relax.TensorType((2, 3, 64), "int64")) def test_adaptive_avg_pool1d_wrong_output_size_ndim(): @@ -1531,7 +1463,7 @@ def test_adaptive_avg_pool1d_wrong_output_size_ndim(): relax.op.nn.adaptive_avg_pool1d(x, output_size=(32, 32)) -def test_adaptive_avg_pool1d_infer_struct_info_wrong_layout_string(): +def test_adaptive_avg_pool1d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 64), "float32")) with pytest.raises(ValueError): @@ -1551,10 +1483,10 @@ def test_adaptive_avg_pool1d_wrong_input_ndim(): bb.normalize(relax.op.nn.adaptive_avg_pool1d(x1)) -def test_adaptive_avg_pool1d_infer_struct_info_wrong_input_type(): +def test_adaptive_avg_pool1d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 64))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 64), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 64))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 64), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.adaptive_avg_pool1d(x0)) @@ -1562,7 +1494,7 @@ def test_adaptive_avg_pool1d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.adaptive_avg_pool1d(x1)) -def test_adaptive_avg_pool2d_infer_struct_info(): +def test_adaptive_avg_pool2d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) @@ -1575,53 +1507,49 @@ def test_adaptive_avg_pool2d_infer_struct_info(): x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0)) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorType((2, 3, 32, 32), "float32") ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x7), - relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0), + relax.TensorType((2, 3, 32, 32), "float32", vdev0), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x0, output_size=30), - relax.TensorStructInfo((2, 3, 30, 30), "float32"), + relax.TensorType((2, 3, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x0, output_size=(28, 30)), - relax.TensorStructInfo((2, 3, 28, 30), "float32"), + relax.TensorType((2, 3, 28, 30), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x1, layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.TensorType((2, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.TensorType((2, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), - relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), - ) - _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) - ) - _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + relax.TensorType((2, 32, 32, 4, 16), "float32"), ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4) + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorType(dtype="float32", ndim=4) ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4) + bb, relax.op.nn.adaptive_avg_pool2d(x3), relax.TensorType(dtype="float32", ndim=4) ) + _check_inference(bb, relax.op.nn.adaptive_avg_pool2d(x4), relax.TensorType(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.adaptive_avg_pool2d(x5), relax.TensorType(dtype="", ndim=4)) -def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic(): +def test_adaptive_avg_pool2d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1632,70 +1560,70 @@ def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic(): x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((n, c, ih, iw), "float32") + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorType((n, c, ih, iw), "float32") ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x0, output_size=256), - relax.TensorStructInfo((n, c, 256, 256), "float32"), + relax.TensorType((n, c, 256, 256), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x0, output_size=(256, 128)), - relax.TensorStructInfo((n, c, 256, 128), "float32"), + relax.TensorType((n, c, 256, 128), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), - relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + relax.TensorType((n, ih, iw, c * 16), "float32"), ) -def test_adaptive_avg_pool2d_infer_struct_info_shape_var(): +def test_adaptive_avg_pool2d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType(ndim=5)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorType(s0, "float32")) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x0, output_size=32), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c"), - relax.TensorStructInfo(s1, "float32"), + relax.TensorType(s1, "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x2, out_layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) -def test_adaptive_avg_pool2d_infer_struct_info_more_input_dtype(): +def test_adaptive_avg_pool2d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorType((2, 3, 32, 32), "float16") ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8") + bb, relax.op.nn.adaptive_avg_pool2d(x1), relax.TensorType((2, 3, 32, 32), "int8") ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorType((2, 3, 32, 32), "int64") ) @@ -1705,7 +1633,7 @@ def test_adaptive_avg_pool2d_wrong_output_size_ndim(): relax.op.nn.adaptive_avg_pool2d(x, (32, 32, 32)) -def test_adaptive_avg_pool2d_infer_struct_info_wrong_layout_string(): +def test_adaptive_avg_pool2d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) with pytest.raises(ValueError): @@ -1724,10 +1652,10 @@ def test_adaptive_avg_pool2d_wrong_input_ndim(): bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) -def test_adaptive_avg_pool2d_infer_struct_info_wrong_input_type(): +def test_adaptive_avg_pool2d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28, 28), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0)) @@ -1735,7 +1663,7 @@ def test_adaptive_avg_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) -def test_adaptive_avg_pool3d_infer_struct_info(): +def test_adaptive_avg_pool3d_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") @@ -1751,55 +1679,51 @@ def test_adaptive_avg_pool3d_infer_struct_info(): _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0), - relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + relax.TensorType((2, 3, 32, 32, 32), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x7), - relax.TensorStructInfo((2, 3, 32, 32, 32), "float32", vdev0), + relax.TensorType((2, 3, 32, 32, 32), "float32", vdev0), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0, output_size=30), - relax.TensorStructInfo((2, 3, 30, 30, 30), "float32"), + relax.TensorType((2, 3, 30, 30, 30), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0, output_size=(28, 30, 32)), - relax.TensorStructInfo((2, 3, 28, 30, 32), "float32"), + relax.TensorType((2, 3, 28, 30, 32), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW"), - relax.TensorStructInfo((2, 32, 32, 32, 3), "float32"), + relax.TensorType((2, 32, 32, 32, 3), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0, out_layout="NCDHW"), - relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + relax.TensorType((2, 3, 32, 32, 32), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), - relax.TensorStructInfo((2, 32, 32, 32, 4, 16), "float32"), + relax.TensorType((2, 32, 32, 32, 4, 16), "float32"), ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.nn.adaptive_avg_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.nn.adaptive_avg_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5) + bb, relax.op.nn.adaptive_avg_pool3d(x2), relax.TensorType(dtype="float32", ndim=5) ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5) + bb, relax.op.nn.adaptive_avg_pool3d(x3), relax.TensorType(dtype="float32", ndim=5) ) + _check_inference(bb, relax.op.nn.adaptive_avg_pool3d(x4), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.adaptive_avg_pool3d(x5), relax.TensorType(dtype="", ndim=5)) -def test_adaptive_avg_pool3d_infer_struct_info_shape_symbolic(): +def test_adaptive_avg_pool3d_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") @@ -1815,60 +1739,60 @@ def test_adaptive_avg_pool3d_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0), - relax.TensorStructInfo((n, c, d, ih, iw), "float32"), + relax.TensorType((n, c, d, ih, iw), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0, output_size=256), - relax.TensorStructInfo((n, c, 256, 256, 256), "float32"), + relax.TensorType((n, c, 256, 256, 256), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0, output_size=(256, 128, 64)), - relax.TensorStructInfo((n, c, 256, 128, 64), "float32"), + relax.TensorType((n, c, 256, 128, 64), "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), - relax.TensorStructInfo((n, d, ih, iw, c * 16), "float32"), + relax.TensorType((n, d, ih, iw, c * 16), "float32"), ) -def test_adaptive_avg_pool3d_infer_struct_info_shape_var(): +def test_adaptive_avg_pool3d_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeType(ndim=5)) + s1 = relax.Var("s", relax.ShapeType(ndim=6)) + s2 = relax.Var("s", relax.ShapeType()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) - _check_inference(bb, relax.op.nn.adaptive_avg_pool3d(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.adaptive_avg_pool3d(x0), relax.TensorType(s0, "float32")) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0, output_size=32), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.TensorType(dtype="float32", ndim=5), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW16c"), - relax.TensorStructInfo(s1, "float32"), + relax.TensorType(s1, "float32"), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0, out_layout="NCDHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=6), + relax.TensorType(dtype="float32", ndim=6), ) _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x2, out_layout="NCDHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=6), + relax.TensorType(dtype="float32", ndim=6), ) -def test_adaptive_avg_pool3d_infer_struct_info_more_input_dtype(): +def test_adaptive_avg_pool3d_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) @@ -1877,13 +1801,13 @@ def test_adaptive_avg_pool3d_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.nn.adaptive_avg_pool3d(x0), - relax.TensorStructInfo((2, 3, 32, 32, 32), "float16"), + relax.TensorType((2, 3, 32, 32, 32), "float16"), ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") + bb, relax.op.nn.adaptive_avg_pool3d(x1), relax.TensorType((2, 3, 32, 32, 32), "int8") ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") + bb, relax.op.nn.adaptive_avg_pool3d(x2), relax.TensorType((2, 3, 32, 32, 32), "int64") ) @@ -1894,7 +1818,7 @@ def test_adaptive_avg_pool3d_wrong_output_size_ndim(): relax.op.nn.adaptive_avg_pool3d(x, (32, 32, 32, 32)) -def test_adaptive_avg_pool3d_infer_struct_info_wrong_layout_string(): +def test_adaptive_avg_pool3d_infer_ty_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) @@ -1916,10 +1840,10 @@ def test_adaptive_avg_pool3d_wrong_input_ndim(): bb.normalize(relax.op.nn.adaptive_avg_pool3d(x1)) -def test_adaptive_avg_pool3d_infer_struct_info_wrong_input_type(): +def test_adaptive_avg_pool3d_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 28, 28, 28), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nn.adaptive_avg_pool3d(x0)) diff --git a/tests/python/relax/test_op_qdq.py b/tests/python/relax/test_op_qdq.py index 2c876eb4a34b..22724dadaa3a 100644 --- a/tests/python/relax/test_op_qdq.py +++ b/tests/python/relax/test_op_qdq.py @@ -30,45 +30,41 @@ def test_op_correctness(): assert relax.op.dequantize(dx, s, zp, 1, "float32").op == Op.get("relax.dequantize") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_qdq_op_infer_struct_info(): +def test_qdq_op_infer_ty(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) dx = relax.Var("dx", R.Tensor((2, 3), "uint8")) s = relax.Var("s", R.Tensor([3], "float32")) zp = relax.Var("zp", R.Tensor([3], "int8")) - _check_inference( - bb, relax.op.quantize(x, s, zp, 1, "int8"), relax.TensorStructInfo((2, 3), "int8") - ) + _check_inference(bb, relax.op.quantize(x, s, zp, 1, "int8"), relax.TensorType((2, 3), "int8")) _check_inference( bb, relax.op.dequantize(dx, s, zp, 1, "float32"), - relax.TensorStructInfo((2, 3), "float32"), + relax.TensorType((2, 3), "float32"), ) -def test_qdq_op_infer_struct_info_symbolic(): +def test_qdq_op_infer_ty_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((n, 3), "float32")) dx = relax.Var("dx", R.Tensor((n, 3), "int8")) s = relax.Var("s", R.Tensor([3], "float32")) zp = relax.Var("zp", R.Tensor([3], "int8")) - _check_inference( - bb, relax.op.quantize(x, s, zp, 1, "int8"), relax.TensorStructInfo((n, 3), "int8") - ) + _check_inference(bb, relax.op.quantize(x, s, zp, 1, "int8"), relax.TensorType((n, 3), "int8")) _check_inference( bb, relax.op.dequantize(dx, s, zp, 1, "float32"), - relax.TensorStructInfo((n, 3), "float32"), + relax.TensorType((n, 3), "float32"), ) -def test_qdq_float8_e4m3fn_op_infer_struct_info_symbolic(): +def test_qdq_float8_e4m3fn_op_infer_ty_symbolic(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((n, 3), "float32")) @@ -78,16 +74,16 @@ def test_qdq_float8_e4m3fn_op_infer_struct_info_symbolic(): _check_inference( bb, relax.op.quantize(x, s, zp, 1, "float8_e4m3fn"), - relax.TensorStructInfo((n, 3), "float8_e4m3fn"), + relax.TensorType((n, 3), "float8_e4m3fn"), ) _check_inference( bb, relax.op.dequantize(dx, s, zp, 1, "float32"), - relax.TensorStructInfo((n, 3), "float32"), + relax.TensorType((n, 3), "float32"), ) -def test_qdq_float8_e5m2_op_infer_struct_info_symbolic(): +def test_qdq_float8_e5m2_op_infer_ty_symbolic(): dtype = "float8_e5m2" bb = relax.BlockBuilder() n = tirx.Var("n", "int64") @@ -95,13 +91,11 @@ def test_qdq_float8_e5m2_op_infer_struct_info_symbolic(): dx = relax.Var("dx", R.Tensor((n, 3), dtype)) s = relax.Var("s", R.Tensor([3], "float32")) zp = relax.Var("zp", R.Tensor([3], "float16")) - _check_inference( - bb, relax.op.quantize(x, s, zp, 1, dtype), relax.TensorStructInfo((n, 3), dtype) - ) + _check_inference(bb, relax.op.quantize(x, s, zp, 1, dtype), relax.TensorType((n, 3), dtype)) _check_inference( bb, relax.op.dequantize(dx, s, zp, 1, "float32"), - relax.TensorStructInfo((n, 3), "float32"), + relax.TensorType((n, 3), "float32"), ) diff --git a/tests/python/relax/test_op_sampling.py b/tests/python/relax/test_op_sampling.py index d8806cf62500..2b1de74976d4 100644 --- a/tests/python/relax/test_op_sampling.py +++ b/tests/python/relax/test_op_sampling.py @@ -20,9 +20,9 @@ from tvm.script import relax as R -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) def test_multinomial_from_uniform(): diff --git a/tests/python/relax/test_op_search.py b/tests/python/relax/test_op_search.py index 252f5db24178..6da0aed702ea 100644 --- a/tests/python/relax/test_op_search.py +++ b/tests/python/relax/test_op_search.py @@ -34,12 +34,12 @@ def test_op_correctness(): assert relax.op.argmin(x).op == Op.get("relax.argmin") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_where_infer_struct_info(): +def test_where_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) @@ -62,47 +62,37 @@ def test_where_infer_struct_info(): y6 = relax.Var("y", R.Tensor((4, 3, 1), "float32", vdev0)) _check_inference( - bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32") - ) - _check_inference( - bb, relax.op.where(cond3, x6, y6), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32", vdev0) - ) - _check_inference( - bb, relax.op.where(cond0, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference(bb, relax.op.where(cond0, x2, y0), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.where(cond0, x3, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), dtype="") - ) - _check_inference(bb, relax.op.where(cond0, x4, y0), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference(bb, relax.op.where(cond0, x5, y0), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.where(cond0, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference(bb, relax.op.where(cond0, x2, y1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond0, x3, y1), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference(bb, relax.op.where(cond0, x4, y1), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference(bb, relax.op.where(cond0, x5, y1), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond0, x3, y2), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.where(cond0, x4, y2), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.where(cond0, x5, y2), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.where(cond0, x3, y3), relax.TensorStructInfo((6, 5, 4, 3, 2), dtype="") - ) - _check_inference(bb, relax.op.where(cond0, x4, y3), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference(bb, relax.op.where(cond0, x5, y3), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.where(cond0, x4, y4), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference(bb, relax.op.where(cond0, x5, y4), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.where(cond0, x5, y5), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.where(cond1, x0, y0), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference(bb, relax.op.where(cond1, x2, y0), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond2, x0, y0), relax.TensorStructInfo(dtype="float32")) - - -def test_where_infer_struct_info_shape_symbolic(): + bb, relax.op.where(cond0, x0, y0), relax.TensorType((6, 5, 4, 3, 2), "float32") + ) + _check_inference( + bb, relax.op.where(cond3, x6, y6), relax.TensorType((6, 5, 4, 3, 2), "float32", vdev0) + ) + _check_inference(bb, relax.op.where(cond0, x1, y0), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x2, y0), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y0), relax.TensorType((6, 5, 4, 3, 2), dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y0), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y0), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.where(cond0, x1, y1), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x2, y1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y1), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x4, y1), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y2), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y2), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.where(cond0, x5, y2), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.where(cond0, x3, y3), relax.TensorType((6, 5, 4, 3, 2), dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y3), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y3), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y4), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y4), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.where(cond0, x5, y5), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.where(cond1, x0, y0), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond1, x2, y0), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond2, x0, y0), relax.TensorType(dtype="float32")) + + +def test_where_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -118,92 +108,70 @@ def test_where_infer_struct_info_shape_symbolic(): y1 = relax.Var("y", R.Tensor((c, d0, 1))) _check_inference( - bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((a, b, c, d0, e), "float32") - ) - _check_inference( - bb, relax.op.where(cond, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.where(cond, x2, y0), relax.TensorStructInfo((a, b, c, d0, e), dtype="") - ) - _check_inference( - bb, relax.op.where(cond, x0, y1), relax.TensorStructInfo((a, b, c, d0, e), dtype="") - ) - _check_inference(bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo(dtype="", ndim=5)) - _check_inference( - bb, relax.op.where(cond, x2, y1), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + bb, relax.op.where(cond, x0, y0), relax.TensorType((a, b, c, d0, e), "float32") ) + _check_inference(bb, relax.op.where(cond, x1, y0), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond, x2, y0), relax.TensorType((a, b, c, d0, e), dtype="")) + _check_inference(bb, relax.op.where(cond, x0, y1), relax.TensorType((a, b, c, d0, e), dtype="")) + _check_inference(bb, relax.op.where(cond, x1, y1), relax.TensorType(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond, x2, y1), relax.TensorType((a, b, c, d0, e), dtype="")) -def test_where_infer_struct_info_shape_var(): +def test_where_infer_ty_shape_var(): bb = relax.BlockBuilder() - scond0 = relax.Var("scond", relax.ShapeStructInfo((6, 5, 1, 3, 1))) - scond1 = relax.Var("scond", relax.ShapeStructInfo(ndim=5)) - scond2 = relax.Var("scond", relax.ShapeStructInfo()) - sx0 = relax.Var("sx", relax.ShapeStructInfo((5, 1, 3, 2))) - sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=4)) - sx2 = relax.Var("sx", relax.ShapeStructInfo()) - sy0 = relax.Var("sy", relax.ShapeStructInfo((4, 3, 1))) - sy1 = relax.Var("sy", relax.ShapeStructInfo(ndim=3)) - sy2 = relax.Var("sy", relax.ShapeStructInfo()) - s0 = relax.Var("s", relax.ShapeStructInfo((6, 5, 4, 3, 2))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - cond0 = relax.Var("cond", relax.TensorStructInfo(scond0, "bool")) - cond1 = relax.Var("cond", relax.TensorStructInfo(scond1, "bool")) - cond2 = relax.Var("cond", relax.TensorStructInfo(scond2, "bool")) - cond3 = relax.Var("cond", relax.TensorStructInfo(s0, "bool")) - cond4 = relax.Var("cond", relax.TensorStructInfo(s1, "bool")) - cond5 = relax.Var("cond", relax.TensorStructInfo(s2, "bool")) - x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) - y1 = relax.Var("y", relax.TensorStructInfo(sy1, "float32")) - y2 = relax.Var("y", relax.TensorStructInfo(sy2, "float32")) - y3 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) - y4 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) - y5 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) - - _check_inference( - bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.where(cond0, x0, y1), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference(bb, relax.op.where(cond0, x0, y2), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.where(cond0, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference(bb, relax.op.where(cond0, x1, y2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorStructInfo(dtype="float32")) - _check_inference( - bb, relax.op.where(cond1, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference(bb, relax.op.where(cond1, x1, y2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond1, x2, y2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond2, x2, y2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond3, x3, y3), relax.TensorStructInfo(s0, "float32")) - _check_inference( - bb, relax.op.where(cond3, x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.where(cond3, x4, y3), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference( - bb, relax.op.where(cond4, x3, y3), relax.TensorStructInfo(dtype="float32", ndim=5) - ) - _check_inference(bb, relax.op.where(cond4, x4, y4), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.where(cond4, x4, y5), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond4, x5, y4), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond5, x4, y4), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.where(cond5, x5, y5), relax.TensorStructInfo(s2, "float32")) - - -def test_where_infer_struct_info_more_input_dtype(): + scond0 = relax.Var("scond", relax.ShapeType((6, 5, 1, 3, 1))) + scond1 = relax.Var("scond", relax.ShapeType(ndim=5)) + scond2 = relax.Var("scond", relax.ShapeType()) + sx0 = relax.Var("sx", relax.ShapeType((5, 1, 3, 2))) + sx1 = relax.Var("sx", relax.ShapeType(ndim=4)) + sx2 = relax.Var("sx", relax.ShapeType()) + sy0 = relax.Var("sy", relax.ShapeType((4, 3, 1))) + sy1 = relax.Var("sy", relax.ShapeType(ndim=3)) + sy2 = relax.Var("sy", relax.ShapeType()) + s0 = relax.Var("s", relax.ShapeType((6, 5, 4, 3, 2))) + s1 = relax.Var("s", relax.ShapeType(ndim=5)) + s2 = relax.Var("s", relax.ShapeType()) + cond0 = relax.Var("cond", relax.TensorType(scond0, "bool")) + cond1 = relax.Var("cond", relax.TensorType(scond1, "bool")) + cond2 = relax.Var("cond", relax.TensorType(scond2, "bool")) + cond3 = relax.Var("cond", relax.TensorType(s0, "bool")) + cond4 = relax.Var("cond", relax.TensorType(s1, "bool")) + cond5 = relax.Var("cond", relax.TensorType(s2, "bool")) + x0 = relax.Var("x", relax.TensorType(sx0, "float32")) + x1 = relax.Var("x", relax.TensorType(sx1, "float32")) + x2 = relax.Var("x", relax.TensorType(sx2, "float32")) + x3 = relax.Var("x", relax.TensorType(s0, "float32")) + x4 = relax.Var("x", relax.TensorType(s1, "float32")) + x5 = relax.Var("x", relax.TensorType(s2, "float32")) + y0 = relax.Var("y", relax.TensorType(sy0, "float32")) + y1 = relax.Var("y", relax.TensorType(sy1, "float32")) + y2 = relax.Var("y", relax.TensorType(sy2, "float32")) + y3 = relax.Var("y", relax.TensorType(s0, "float32")) + y4 = relax.Var("y", relax.TensorType(s1, "float32")) + y5 = relax.Var("y", relax.TensorType(s2, "float32")) + + _check_inference(bb, relax.op.where(cond0, x0, y0), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x0, y1), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x0, y2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x1, y1), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x1, y2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond1, x1, y1), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond1, x1, y2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond1, x2, y2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond2, x2, y2), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond3, x3, y3), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.where(cond3, x3, y4), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond3, x4, y3), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond4, x3, y3), relax.TensorType(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.where(cond4, x4, y4), relax.TensorType(s1, "float32")) + _check_inference(bb, relax.op.where(cond4, x4, y5), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond4, x5, y4), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond5, x4, y4), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.where(cond5, x5, y5), relax.TensorType(s2, "float32")) + + +def test_where_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() cond = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float16")) @@ -213,18 +181,12 @@ def test_where_infer_struct_info_more_input_dtype(): x2 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int32")) y2 = relax.Var("y", R.Tensor((4, 3, 1), "int32")) - _check_inference( - bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float16") - ) - _check_inference( - bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo((6, 5, 4, 3, 2), "int8") - ) - _check_inference( - bb, relax.op.where(cond, x2, y2), relax.TensorStructInfo((6, 5, 4, 3, 2), "int32") - ) + _check_inference(bb, relax.op.where(cond, x0, y0), relax.TensorType((6, 5, 4, 3, 2), "float16")) + _check_inference(bb, relax.op.where(cond, x1, y1), relax.TensorType((6, 5, 4, 3, 2), "int8")) + _check_inference(bb, relax.op.where(cond, x2, y2), relax.TensorType((6, 5, 4, 3, 2), "int32")) -def test_where_infer_struct_info_cond_not_boolean(): +def test_where_infer_ty_cond_not_boolean(): bb = relax.BlockBuilder() cond0 = relax.Var("cond", R.Tensor((2, 3), "float32")) cond1 = relax.Var("cond", R.Tensor((2, 3))) @@ -237,7 +199,7 @@ def test_where_infer_struct_info_cond_not_boolean(): bb.normalize(relax.op.where(cond1, x, y)) -def test_where_infer_struct_info_shape_unequal_const_int(): +def test_where_infer_ty_shape_unequal_const_int(): bb = relax.BlockBuilder() cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 4, 1), "bool")) cond1 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) @@ -254,7 +216,7 @@ def test_where_infer_struct_info_shape_unequal_const_int(): bb.normalize(relax.op.where(cond1, x1, y0)) -def test_where_infer_struct_info_dtype_mismatch(): +def test_where_infer_ty_dtype_mismatch(): bb = relax.BlockBuilder() cond = relax.Var("cond", R.Tensor((2, 3), "bool")) x0 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -268,13 +230,13 @@ def test_where_infer_struct_info_dtype_mismatch(): bb.normalize(relax.op.where(cond, x1, y1)) -def test_where_infer_struct_info_wrong_input_type(): +def test_where_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - cond0 = relax.Var("cond", relax.ShapeStructInfo((2, 3))) + cond0 = relax.Var("cond", relax.ShapeType((2, 3))) cond1 = relax.Var("cond", R.Tensor((2, 3), "bool")) - x0 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) x1 = relax.Var("x", R.Tensor((2, 3), "float32")) - y0 = relax.Var("y", relax.TupleStructInfo([R.Tensor((2, 3), "float32")])) + y0 = relax.Var("y", relax.TupleType([R.Tensor((2, 3), "float32")])) y1 = relax.Var("y", R.Tensor((2, 3), "float32")) with pytest.raises(TypeError): @@ -289,7 +251,7 @@ def test_where_infer_struct_info_wrong_input_type(): @pytest.mark.parametrize("argmax_argmin_op", argmax_argmin_ops) -def test_argmax_argmin_infer_struct_info(argmax_argmin_op: Callable): +def test_argmax_argmin_infer_ty(argmax_argmin_op: Callable): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -298,71 +260,65 @@ def test_argmax_argmin_infer_struct_info(argmax_argmin_op: Callable): x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0)) - _check_inference(bb, argmax_argmin_op(x0, axis=1), relax.TensorStructInfo((2, 4, 5), "int64")) - _check_inference( - bb, argmax_argmin_op(x4, axis=1), relax.TensorStructInfo((2, 4, 5), "int64", vdev0) - ) + _check_inference(bb, argmax_argmin_op(x0, axis=1), relax.TensorType((2, 4, 5), "int64")) + _check_inference(bb, argmax_argmin_op(x4, axis=1), relax.TensorType((2, 4, 5), "int64", vdev0)) _check_inference( bb, argmax_argmin_op(x0, axis=1, keepdims=True), - relax.TensorStructInfo((2, 1, 4, 5), "int64"), + relax.TensorType((2, 1, 4, 5), "int64"), ) - _check_inference(bb, argmax_argmin_op(x0, axis=None), relax.TensorStructInfo((), "int64")) + _check_inference(bb, argmax_argmin_op(x0, axis=None), relax.TensorType((), "int64")) _check_inference( bb, argmax_argmin_op(x0, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), "int64"), - ) - _check_inference( - bb, argmax_argmin_op(x1, axis=1), relax.TensorStructInfo(dtype="int64", ndim=3) + relax.TensorType((1, 1, 1, 1), "int64"), ) + _check_inference(bb, argmax_argmin_op(x1, axis=1), relax.TensorType(dtype="int64", ndim=3)) _check_inference( bb, argmax_argmin_op(x1, axis=1, keepdims=True), - relax.TensorStructInfo(dtype="int64", ndim=4), + relax.TensorType(dtype="int64", ndim=4), ) - _check_inference(bb, argmax_argmin_op(x1, axis=None), relax.TensorStructInfo((), "int64")) + _check_inference(bb, argmax_argmin_op(x1, axis=None), relax.TensorType((), "int64")) _check_inference( bb, argmax_argmin_op(x1, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), "int64"), + relax.TensorType((1, 1, 1, 1), "int64"), ) - _check_inference(bb, argmax_argmin_op(x2, axis=1), relax.TensorStructInfo(dtype="int64")) + _check_inference(bb, argmax_argmin_op(x2, axis=1), relax.TensorType(dtype="int64")) _check_inference( bb, argmax_argmin_op(x2, axis=1, keepdims=True), - relax.TensorStructInfo(dtype="int64"), + relax.TensorType(dtype="int64"), ) - _check_inference(bb, argmax_argmin_op(x2, axis=None), relax.TensorStructInfo((), "int64")) + _check_inference(bb, argmax_argmin_op(x2, axis=None), relax.TensorType((), "int64")) _check_inference( bb, argmax_argmin_op(x2, axis=None, keepdims=True), - relax.TensorStructInfo(dtype="int64"), - ) - _check_inference( - bb, argmax_argmin_op(x3, axis=1), relax.TensorStructInfo((2, 4, 5), dtype="int64") + relax.TensorType(dtype="int64"), ) + _check_inference(bb, argmax_argmin_op(x3, axis=1), relax.TensorType((2, 4, 5), dtype="int64")) _check_inference( bb, argmax_argmin_op(x3, axis=1, keepdims=True), - relax.TensorStructInfo((2, 1, 4, 5), dtype="int64"), + relax.TensorType((2, 1, 4, 5), dtype="int64"), ) - _check_inference(bb, argmax_argmin_op(x3, axis=None), relax.TensorStructInfo((), dtype="int64")) + _check_inference(bb, argmax_argmin_op(x3, axis=None), relax.TensorType((), dtype="int64")) _check_inference( bb, argmax_argmin_op(x3, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), dtype="int64"), + relax.TensorType((1, 1, 1, 1), dtype="int64"), ) _check_inference( bb, argmax_argmin_op(x0, axis=1, keepdims=True), - relax.TensorStructInfo((2, 1, 4, 5), "int64"), + relax.TensorType((2, 1, 4, 5), "int64"), ) - _check_inference(bb, argmax_argmin_op(x0, axis=-1), relax.TensorStructInfo((2, 3, 4), "int64")) + _check_inference(bb, argmax_argmin_op(x0, axis=-1), relax.TensorType((2, 3, 4), "int64")) @pytest.mark.parametrize("argmax_argmin_op", argmax_argmin_ops) -def test_argmax_argmin_infer_struct_info_shape_symbolic(argmax_argmin_op: Callable): +def test_argmax_argmin_infer_ty_shape_symbolic(argmax_argmin_op: Callable): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -370,60 +326,58 @@ def test_argmax_argmin_infer_struct_info_shape_symbolic(argmax_argmin_op: Callab d = tirx.Var("d", "int64") x = relax.Var("x", R.Tensor((a, b, c, d), "int64")) - _check_inference(bb, argmax_argmin_op(x, axis=1), relax.TensorStructInfo((a, c, d), "int64")) + _check_inference(bb, argmax_argmin_op(x, axis=1), relax.TensorType((a, c, d), "int64")) _check_inference( bb, argmax_argmin_op(x, axis=1, keepdims=True), - relax.TensorStructInfo((a, 1, c, d), "int64"), + relax.TensorType((a, 1, c, d), "int64"), ) - _check_inference(bb, argmax_argmin_op(x, axis=None), relax.TensorStructInfo((), "int64")) + _check_inference(bb, argmax_argmin_op(x, axis=None), relax.TensorType((), "int64")) _check_inference( bb, argmax_argmin_op(x, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), "int64"), + relax.TensorType((1, 1, 1, 1), "int64"), ) @pytest.mark.parametrize("argmax_argmin_op", argmax_argmin_ops) -def test_argmax_argmin_infer_struct_info_shape_var(argmax_argmin_op: Callable): +def test_argmax_argmin_infer_ty_shape_var(argmax_argmin_op: Callable): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "int64")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "int64")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "int64")) + x1 = relax.Var("x", relax.TensorType(s1, "int64")) - _check_inference(bb, argmax_argmin_op(x0), relax.TensorStructInfo((), dtype="int64")) - _check_inference( - bb, argmax_argmin_op(x0, keepdims=True), relax.TensorStructInfo((1, 1, 1, 1), dtype="int64") - ) + _check_inference(bb, argmax_argmin_op(x0), relax.TensorType((), dtype="int64")) _check_inference( - bb, argmax_argmin_op(x0, axis=2), relax.TensorStructInfo(dtype="int64", ndim=3) + bb, argmax_argmin_op(x0, keepdims=True), relax.TensorType((1, 1, 1, 1), dtype="int64") ) + _check_inference(bb, argmax_argmin_op(x0, axis=2), relax.TensorType(dtype="int64", ndim=3)) _check_inference( bb, argmax_argmin_op(x0, axis=2, keepdims=True), - relax.TensorStructInfo(dtype="int64", ndim=4), + relax.TensorType(dtype="int64", ndim=4), ) - _check_inference(bb, argmax_argmin_op(x1), relax.TensorStructInfo((), dtype="int64")) - _check_inference(bb, argmax_argmin_op(x1, keepdims=True), relax.TensorStructInfo(dtype="int64")) - _check_inference(bb, argmax_argmin_op(x1, axis=2), relax.TensorStructInfo(dtype="int64")) + _check_inference(bb, argmax_argmin_op(x1), relax.TensorType((), dtype="int64")) + _check_inference(bb, argmax_argmin_op(x1, keepdims=True), relax.TensorType(dtype="int64")) + _check_inference(bb, argmax_argmin_op(x1, axis=2), relax.TensorType(dtype="int64")) _check_inference( - bb, argmax_argmin_op(x1, axis=2, keepdims=True), relax.TensorStructInfo(dtype="int64") + bb, argmax_argmin_op(x1, axis=2, keepdims=True), relax.TensorType(dtype="int64") ) @pytest.mark.parametrize("argmax_argmin_op", argmax_argmin_ops) -def test_argmax_argmin_infer_struct_info_more_input_dtype(argmax_argmin_op: Callable): +def test_argmax_argmin_infer_ty_more_input_dtype(argmax_argmin_op: Callable): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) - _check_inference(bb, argmax_argmin_op(x0), relax.TensorStructInfo((), "int64")) - _check_inference(bb, argmax_argmin_op(x1), relax.TensorStructInfo((), "int64")) + _check_inference(bb, argmax_argmin_op(x0), relax.TensorType((), "int64")) + _check_inference(bb, argmax_argmin_op(x1), relax.TensorType((), "int64")) @pytest.mark.parametrize("argmax_argmin_op", argmax_argmin_ops) -def test_argmax_argmin_infer_struct_info_axis_out_of_range(argmax_argmin_op: Callable): +def test_argmax_argmin_infer_ty_axis_out_of_range(argmax_argmin_op: Callable): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int64")) x1 = relax.Var("x", R.Tensor("int64", ndim=4)) @@ -439,10 +393,10 @@ def test_argmax_argmin_infer_struct_info_axis_out_of_range(argmax_argmin_op: Cal @pytest.mark.parametrize("argmax_argmin_op", argmax_argmin_ops) -def test_argmax_argmin_infer_struct_info_wrong_input_type(argmax_argmin_op: Callable): +def test_argmax_argmin_infer_ty_wrong_input_type(argmax_argmin_op: Callable): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "int64"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "int64"))) with pytest.raises(TypeError): bb.normalize(argmax_argmin_op(x0)) diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py index 83769c627f4f..8239669103d3 100644 --- a/tests/python/relax/test_op_set.py +++ b/tests/python/relax/test_op_set.py @@ -28,12 +28,12 @@ def test_op_correctness(): assert relax.op.unique(x).op == Op.get("relax.unique") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_unique_infer_struct_info(): +def test_unique_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -47,39 +47,39 @@ def test_unique_infer_struct_info(): relax.op.unique( x0, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( bb, relax.op.unique( x4, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo(dtype="float32", ndim=1, vdevice=vdev0), + relax.TensorType(dtype="float32", ndim=1, vdevice=vdev0), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.unique( x0, return_index=False, return_inverse=False, return_counts=True, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), ] ), ) @@ -88,42 +88,42 @@ def test_unique_infer_struct_info(): relax.op.unique( x0, return_index=False, return_inverse=True, return_counts=False, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=False, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) @@ -132,100 +132,100 @@ def test_unique_infer_struct_info(): relax.op.unique( x0, return_index=True, return_inverse=False, return_counts=False, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=False, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=False, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=False, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=-2), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) @@ -234,12 +234,12 @@ def test_unique_infer_struct_info(): relax.op.unique( x0, sorted=True, return_index=True, return_inverse=True, return_counts=True, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) @@ -248,12 +248,12 @@ def test_unique_infer_struct_info(): relax.op.unique( x0, sorted=True, return_index=True, return_inverse=True, return_counts=True, axis=1 ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) @@ -262,78 +262,78 @@ def test_unique_infer_struct_info(): relax.op.unique( x1, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( bb, relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=False, axis=1), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.unique( x1, return_index=False, return_inverse=True, return_counts=False, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=False, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=True, return_inverse=False, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=True, return_inverse=False, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) @@ -342,75 +342,75 @@ def test_unique_infer_struct_info(): relax.op.unique( x2, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( bb, relax.op.unique(x2, return_index=False, return_inverse=False, return_counts=False, axis=1), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.unique( x2, return_index=True, return_inverse=False, return_counts=False, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x2, return_index=True, return_inverse=False, return_counts=False, axis=1), - relax.TupleStructInfo( - [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int64", ndim=1)] + relax.TupleType( + [relax.TensorType(dtype="float32"), relax.TensorType(dtype="int64", ndim=1)] ), ) _check_inference( bb, relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=False, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=False, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) @@ -419,84 +419,84 @@ def test_unique_infer_struct_info(): relax.op.unique( x3, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorType(dtype="", ndim=1), ) _check_inference( bb, relax.op.unique(x3, return_index=False, return_inverse=False, return_counts=False, axis=1), - relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorType(dtype="", ndim=3), ) _check_inference( bb, relax.op.unique( x3, return_index=False, return_inverse=False, return_counts=True, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x3, return_index=False, return_inverse=False, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="", ndim=3), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x3, return_index=False, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x3, return_index=False, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x3, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x3, return_index=True, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) -def test_unique_infer_struct_info_shape_symbolic(): +def test_unique_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -508,165 +508,165 @@ def test_unique_infer_struct_info_shape_symbolic(): relax.op.unique( x, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( bb, relax.op.unique(x, return_index=False, return_inverse=False, return_counts=False, axis=1), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.unique(x, return_index=False, return_inverse=False, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x, return_index=False, return_inverse=False, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x, return_index=False, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x, return_index=False, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x, return_index=True, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) -def test_unique_infer_struct_info_shape_var(): +def test_unique_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) _check_inference( bb, relax.op.unique( x0, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1), - relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), ) _check_inference( bb, relax.op.unique( x0, return_index=False, return_inverse=False, return_counts=True, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) @@ -675,81 +675,81 @@ def test_unique_infer_struct_info_shape_var(): relax.op.unique( x1, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorType(dtype="float32", ndim=1), ) _check_inference( bb, relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=False, axis=1), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) _check_inference( bb, relax.op.unique( x1, return_index=False, return_inverse=False, return_counts=True, axis=None ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=True, axis=1), - relax.TupleStructInfo( - [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int64", ndim=1)] + relax.TupleType( + [relax.TensorType(dtype="float32"), relax.TensorType(dtype="int64", ndim=1)] ), ) _check_inference( bb, relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) -def test_unique_infer_struct_info_more_input_dtype(): +def test_unique_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) @@ -758,70 +758,70 @@ def test_unique_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float16", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="float16", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="int8", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="int8", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) _check_inference( bb, relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="int32", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), - relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorType(dtype="int32", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), + relax.TensorType(dtype="int64", ndim=1), ] ), ) -def test_unique_infer_struct_info_input_zero_rank(): +def test_unique_infer_ty_input_zero_rank(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(())) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + s0 = relax.Var("s", relax.ShapeType(())) + s1 = relax.Var("s", relax.ShapeType(ndim=0)) x0 = relax.Var("x", R.Tensor((), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=0)) - x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s0, "float32")) + x3 = relax.Var("x", relax.TensorType(s1, "float32")) _check_inference( bb, relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((1,), "float32"), - relax.TensorStructInfo((1,), "int64"), - relax.TensorStructInfo((1,), "int64"), - relax.TensorStructInfo((1,), "int64"), + relax.TensorType((1,), "float32"), + relax.TensorType((1,), "int64"), + relax.TensorType((1,), "int64"), + relax.TensorType((1,), "int64"), ] ), ) _check_inference( bb, relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=False, axis=None), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((1,), "float32"), - relax.TensorStructInfo((1,), "int64"), - relax.TensorStructInfo((1,), "int64"), + relax.TensorType((1,), "float32"), + relax.TensorType((1,), "int64"), + relax.TensorType((1,), "int64"), ] ), ) @@ -830,20 +830,18 @@ def test_unique_infer_struct_info_input_zero_rank(): relax.op.unique( x2, return_index=True, return_inverse=False, return_counts=False, axis=None ), - relax.TupleStructInfo( - [relax.TensorStructInfo((1,), "float32"), relax.TensorStructInfo((1,), "int64")] - ), + relax.TupleType([relax.TensorType((1,), "float32"), relax.TensorType((1,), "int64")]), ) _check_inference( bb, relax.op.unique( x3, return_index=False, return_inverse=False, return_counts=False, axis=None ), - relax.TensorStructInfo((1,), "float32"), + relax.TensorType((1,), "float32"), ) -def test_unique_infer_struct_info_axis_out_of_range(): +def test_unique_infer_ty_axis_out_of_range(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor((), "float32")) @@ -856,10 +854,10 @@ def test_unique_infer_struct_info_axis_out_of_range(): bb.normalize(relax.op.unique(x1, axis=0)) -def test_unique_infer_struct_info_wrong_input_dtype(): +def test_unique_infer_ty_wrong_input_dtype(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.unique(x0)) @@ -868,32 +866,32 @@ def test_unique_infer_struct_info_wrong_input_dtype(): @pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)]) -def test_nonzero_infer_struct_info(shape): +def test_nonzero_infer_ty(shape): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor(shape, "bool")) _check_inference( bb, relax.op.nonzero(x0), - relax.TensorStructInfo(ndim=2, dtype="int64"), + relax.TensorType(ndim=2, dtype="int64"), ) -def test_nonzero_infer_struct_info_ndim_zero(): +def test_nonzero_infer_ty_ndim_zero(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((), "bool")) _check_inference( bb, relax.op.nonzero(x), - relax.TensorStructInfo(ndim=2, dtype="int64"), + relax.TensorType(ndim=2, dtype="int64"), ) -def test_nonzero_infer_struct_info_wrong_input_dtype(): +def test_nonzero_infer_ty_wrong_input_dtype(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.nonzero(x0)) diff --git a/tests/python/relax/test_op_sort.py b/tests/python/relax/test_op_sort.py index 6a3f8cb437f9..c40b58b89941 100644 --- a/tests/python/relax/test_op_sort.py +++ b/tests/python/relax/test_op_sort.py @@ -30,12 +30,12 @@ def test_op_correctness(): assert relax.op.topk(x, k=1, axis=1).op == Op.get("relax.topk") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_sort_infer_struct_info(): +def test_sort_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -46,47 +46,45 @@ def test_sort_infer_struct_info(): x5 = relax.Var("x", R.Tensor()) x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0)) - _check_inference(bb, relax.op.sort(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32")) - _check_inference( - bb, relax.op.sort(x6, axis=1), relax.TensorStructInfo((2, 10, 4), "float32", vdev0) - ) - _check_inference(bb, relax.op.sort(x1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, relax.op.sort(x2, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.sort(x3, axis=1), relax.TensorStructInfo((2, 10, 4), dtype="")) - _check_inference(bb, relax.op.sort(x4, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.sort(x5, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.sort(x0), relax.TensorStructInfo((2, 10, 4), "float32")) + _check_inference(bb, relax.op.sort(x0, axis=1), relax.TensorType((2, 10, 4), "float32")) + _check_inference(bb, relax.op.sort(x6, axis=1), relax.TensorType((2, 10, 4), "float32", vdev0)) + _check_inference(bb, relax.op.sort(x1, axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.sort(x2, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.sort(x3, axis=1), relax.TensorType((2, 10, 4), dtype="")) + _check_inference(bb, relax.op.sort(x4, axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, relax.op.sort(x5, axis=1), relax.TensorType(dtype="")) + _check_inference(bb, relax.op.sort(x0), relax.TensorType((2, 10, 4), "float32")) _check_inference( bb, relax.op.sort(x0, axis=1, descending=False), - relax.TensorStructInfo((2, 10, 4), "float32"), + relax.TensorType((2, 10, 4), "float32"), ) -def test_sort_infer_struct_info_shape_symbolic(): +def test_sort_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) - _check_inference(bb, relax.op.sort(x, axis=1), relax.TensorStructInfo((a, b, c), "float32")) - _check_inference(bb, relax.op.sort(x), relax.TensorStructInfo((a, b, c), "float32")) + _check_inference(bb, relax.op.sort(x, axis=1), relax.TensorType((a, b, c), "float32")) + _check_inference(bb, relax.op.sort(x), relax.TensorType((a, b, c), "float32")) -def test_sort_infer_struct_info_more_input_dtype(): +def test_sort_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) - _check_inference(bb, relax.op.sort(x0, axis=1), relax.TensorStructInfo((2, 3, 4), "float16")) - _check_inference(bb, relax.op.sort(x1, axis=1), relax.TensorStructInfo((2, 3, 4), "int8")) + _check_inference(bb, relax.op.sort(x0, axis=1), relax.TensorType((2, 3, 4), "float16")) + _check_inference(bb, relax.op.sort(x1, axis=1), relax.TensorType((2, 3, 4), "int8")) def test_sort_wrong_input(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32"))) x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) @@ -100,7 +98,7 @@ def test_sort_wrong_input(): bb.normalize(relax.op.sort(x1, axis=1)) -def test_argsort_infer_struct_info(): +def test_argsort_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -114,44 +112,36 @@ def test_argsort_infer_struct_info(): _check_inference( bb, relax.op.argsort(x0, axis=1, descending=False, dtype="int64"), - relax.TensorStructInfo((2, 10, 4), "int64"), - ) - _check_inference( - bb, relax.op.argsort(x6, axis=1), relax.TensorStructInfo((2, 10, 4), "int32", vdev0) - ) - _check_inference( - bb, relax.op.argsort(x1, axis=1), relax.TensorStructInfo(dtype="int32", ndim=3) + relax.TensorType((2, 10, 4), "int64"), ) + _check_inference(bb, relax.op.argsort(x6, axis=1), relax.TensorType((2, 10, 4), "int32", vdev0)) + _check_inference(bb, relax.op.argsort(x1, axis=1), relax.TensorType(dtype="int32", ndim=3)) _check_inference( - bb, relax.op.argsort(x2, axis=1, dtype="float16"), relax.TensorStructInfo(dtype="float16") + bb, relax.op.argsort(x2, axis=1, dtype="float16"), relax.TensorType(dtype="float16") ) - _check_inference( - bb, relax.op.argsort(x3, axis=1), relax.TensorStructInfo((2, 10, 4), dtype="int32") - ) - _check_inference( - bb, relax.op.argsort(x4, axis=1), relax.TensorStructInfo(dtype="int32", ndim=3) - ) - _check_inference(bb, relax.op.argsort(x5, axis=1), relax.TensorStructInfo(dtype="int32")) - _check_inference(bb, relax.op.argsort(x0), relax.TensorStructInfo((2, 10, 4), "int32")) + _check_inference(bb, relax.op.argsort(x3, axis=1), relax.TensorType((2, 10, 4), dtype="int32")) + _check_inference(bb, relax.op.argsort(x4, axis=1), relax.TensorType(dtype="int32", ndim=3)) + _check_inference(bb, relax.op.argsort(x5, axis=1), relax.TensorType(dtype="int32")) + _check_inference(bb, relax.op.argsort(x0), relax.TensorType((2, 10, 4), "int32")) _check_inference( bb, relax.op.argsort(x0, axis=1, descending=False), - relax.TensorStructInfo((2, 10, 4), "int32"), + relax.TensorType((2, 10, 4), "int32"), ) -def test_argsort_infer_struct_info_shape_symbolic(): +def test_argsort_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) - _check_inference(bb, relax.op.argsort(x, axis=1), relax.TensorStructInfo((a, b, c), "int32")) - _check_inference(bb, relax.op.argsort(x), relax.TensorStructInfo((a, b, c), "int32")) + _check_inference(bb, relax.op.argsort(x, axis=1), relax.TensorType((a, b, c), "int32")) + _check_inference(bb, relax.op.argsort(x), relax.TensorType((a, b, c), "int32")) -def test_topk_infer_struct_info(): +def test_topk_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -165,103 +155,101 @@ def test_topk_infer_struct_info(): _check_inference( bb, relax.op.topk(x0, k=5, axis=1, ret_type="both", largest=False, dtype="int64"), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 5, 4), "float32"), - relax.TensorStructInfo((2, 5, 4), "int64"), + relax.TensorType((2, 5, 4), "float32"), + relax.TensorType((2, 5, 4), "int64"), ] ), ) _check_inference( bb, relax.op.topk(x6), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 10, 1), "float32", vdev0), - relax.TensorStructInfo((2, 10, 1), "int32", vdev0), + relax.TensorType((2, 10, 1), "float32", vdev0), + relax.TensorType((2, 10, 1), "int32", vdev0), ] ), ) _check_inference( bb, relax.op.topk(x1, k=3, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int32", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int32", ndim=3), ] ), ) _check_inference( bb, relax.op.topk(x2), - relax.TupleStructInfo( - [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int32")] - ), + relax.TupleType([relax.TensorType(dtype="float32"), relax.TensorType(dtype="int32")]), ) _check_inference( bb, relax.op.topk(x3, axis=0), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((1, 10, 4), None), - relax.TensorStructInfo((1, 10, 4), dtype="int32"), + relax.TensorType((1, 10, 4), None), + relax.TensorType((1, 10, 4), dtype="int32"), ] ), ) _check_inference( bb, relax.op.topk(x4, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(ndim=3, dtype=None), - relax.TensorStructInfo(dtype="int32", ndim=3), + relax.TensorType(ndim=3, dtype=None), + relax.TensorType(dtype="int32", ndim=3), ] ), ) _check_inference( bb, relax.op.topk(x5, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype=None), - relax.TensorStructInfo(dtype="int32"), + relax.TensorType(dtype=None), + relax.TensorType(dtype="int32"), ] ), ) _check_inference( bb, relax.op.topk(x0), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 10, 1), "float32"), - relax.TensorStructInfo((2, 10, 1), "int32"), + relax.TensorType((2, 10, 1), "float32"), + relax.TensorType((2, 10, 1), "int32"), ] ), ) _check_inference( bb, relax.op.topk(x0, k=-1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 10, 4), "float32"), - relax.TensorStructInfo((2, 10, 4), "int32"), + relax.TensorType((2, 10, 4), "float32"), + relax.TensorType((2, 10, 4), "int32"), ] ), ) _check_inference( bb, relax.op.topk(x0, k=6), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 10, 4), "float32"), - relax.TensorStructInfo((2, 10, 4), "int32"), + relax.TensorType((2, 10, 4), "float32"), + relax.TensorType((2, 10, 4), "int32"), ] ), ) -def test_topk_infer_struct_info_shape_symbolic(): +def test_topk_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -271,20 +259,20 @@ def test_topk_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.topk(x, axis=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((a, 1, c), "float32"), - relax.TensorStructInfo((a, 1, c), "int32"), + relax.TensorType((a, 1, c), "float32"), + relax.TensorType((a, 1, c), "int32"), ] ), ) _check_inference( bb, relax.op.topk(x, k=3), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((a, b, 3), "float32"), - relax.TensorStructInfo((a, b, 3), "int32"), + relax.TensorType((a, b, 3), "float32"), + relax.TensorType((a, b, 3), "int32"), ] ), ) diff --git a/tests/python/relax/test_op_statistical.py b/tests/python/relax/test_op_statistical.py index 590584326739..06c2dbe05535 100644 --- a/tests/python/relax/test_op_statistical.py +++ b/tests/python/relax/test_op_statistical.py @@ -37,12 +37,12 @@ def test_op_correctness(): assert relax.op.median(x).op == Op.get("relax.median") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_statistical_infer_struct_info(): +def test_statistical_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -51,78 +51,72 @@ def test_statistical_infer_struct_info(): x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0)) - _check_inference(bb, relax.op.sum(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) - _check_inference( - bb, relax.op.sum(x4, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32", vdev0) - ) + _check_inference(bb, relax.op.sum(x0, axis=[1, 2]), relax.TensorType((2, 5), "float32")) + _check_inference(bb, relax.op.sum(x4, axis=[1, 2]), relax.TensorType((2, 5), "float32", vdev0)) _check_inference( bb, relax.op.sum(x0, axis=[1, 2], keepdims=True), - relax.TensorStructInfo((2, 1, 1, 5), "float32"), + relax.TensorType((2, 1, 1, 5), "float32"), ) - _check_inference(bb, relax.op.sum(x0, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.sum(x0, axis=None), relax.TensorType((), "float32")) _check_inference( bb, relax.op.sum(x0, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), "float32"), - ) - _check_inference( - bb, relax.op.mean(x1, axis=[1, 2]), relax.TensorStructInfo(dtype="float32", ndim=2) + relax.TensorType((1, 1, 1, 1), "float32"), ) + _check_inference(bb, relax.op.mean(x1, axis=[1, 2]), relax.TensorType(dtype="float32", ndim=2)) _check_inference( bb, relax.op.mean(x1, axis=[1, 2], keepdims=True), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) - _check_inference(bb, relax.op.mean(x1, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.mean(x1, axis=None), relax.TensorType((), "float32")) _check_inference( bb, relax.op.mean(x1, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), "float32"), - ) - _check_inference( - bb, relax.op.variance(x2, axis=[1, 2]), relax.TensorStructInfo(dtype="float32") + relax.TensorType((1, 1, 1, 1), "float32"), ) + _check_inference(bb, relax.op.variance(x2, axis=[1, 2]), relax.TensorType(dtype="float32")) _check_inference( bb, relax.op.variance(x2, axis=[1, 2], keepdims=True), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) - _check_inference(bb, relax.op.variance(x2, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.variance(x2, axis=None), relax.TensorType((), "float32")) _check_inference( bb, relax.op.variance(x2, axis=None, keepdims=True), - relax.TensorStructInfo(dtype="float32"), + relax.TensorType(dtype="float32"), ) - _check_inference(bb, relax.op.max(x3, axis=[1, 2]), relax.TensorStructInfo((2, 5), dtype="")) + _check_inference(bb, relax.op.max(x3, axis=[1, 2]), relax.TensorType((2, 5), dtype="")) _check_inference( bb, relax.op.max(x3, axis=[1, 2], keepdims=True), - relax.TensorStructInfo((2, 1, 1, 5), dtype=""), + relax.TensorType((2, 1, 1, 5), dtype=""), ) - _check_inference(bb, relax.op.max(x3, axis=None), relax.TensorStructInfo((), dtype="")) + _check_inference(bb, relax.op.max(x3, axis=None), relax.TensorType((), dtype="")) _check_inference( bb, relax.op.max(x3, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), dtype=""), + relax.TensorType((1, 1, 1, 1), dtype=""), ) - _check_inference(bb, relax.op.prod(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference(bb, relax.op.prod(x0, axis=[1, 2]), relax.TensorType((2, 5), "float32")) _check_inference( bb, relax.op.prod(x0, axis=[1, 2], keepdims=True), - relax.TensorStructInfo((2, 1, 1, 5), "float32"), + relax.TensorType((2, 1, 1, 5), "float32"), ) - _check_inference(bb, relax.op.std(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference(bb, relax.op.std(x0, axis=[1, 2]), relax.TensorType((2, 5), "float32")) _check_inference( bb, relax.op.std(x0, axis=[1, 2], keepdims=True), - relax.TensorStructInfo((2, 1, 1, 5), "float32"), + relax.TensorType((2, 1, 1, 5), "float32"), ) - _check_inference(bb, relax.op.sum(x0, axis=[-1, -4]), relax.TensorStructInfo((3, 4), "float32")) - _check_inference(bb, relax.op.sum(x0, axis=[]), relax.TensorStructInfo((2, 3, 4, 5), "float32")) + _check_inference(bb, relax.op.sum(x0, axis=[-1, -4]), relax.TensorType((3, 4), "float32")) + _check_inference(bb, relax.op.sum(x0, axis=[]), relax.TensorType((2, 3, 4, 5), "float32")) -def test_statistical_infer_struct_info_shape_symbolic(): +def test_statistical_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -130,57 +124,55 @@ def test_statistical_infer_struct_info_shape_symbolic(): d = tirx.Var("d", "int64") x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) - _check_inference(bb, relax.op.min(x, axis=[1, 2]), relax.TensorStructInfo((a, d), "float32")) + _check_inference(bb, relax.op.min(x, axis=[1, 2]), relax.TensorType((a, d), "float32")) _check_inference( bb, relax.op.min(x, axis=[1, 2], keepdims=True), - relax.TensorStructInfo((a, 1, 1, d), "float32"), + relax.TensorType((a, 1, 1, d), "float32"), ) - _check_inference(bb, relax.op.min(x, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.min(x, axis=None), relax.TensorType((), "float32")) _check_inference( bb, relax.op.min(x, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), "float32"), + relax.TensorType((1, 1, 1, 1), "float32"), ) -def test_statistical_infer_struct_info_shape_var(): +def test_statistical_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) - _check_inference(bb, relax.op.max(x0), relax.TensorStructInfo((), dtype="float32")) - _check_inference( - bb, relax.op.max(x0, keepdims=True), relax.TensorStructInfo((1, 1, 1, 1), dtype="float32") - ) + _check_inference(bb, relax.op.max(x0), relax.TensorType((), dtype="float32")) _check_inference( - bb, relax.op.max(x0, axis=[2, 3]), relax.TensorStructInfo(dtype="float32", ndim=2) + bb, relax.op.max(x0, keepdims=True), relax.TensorType((1, 1, 1, 1), dtype="float32") ) + _check_inference(bb, relax.op.max(x0, axis=[2, 3]), relax.TensorType(dtype="float32", ndim=2)) _check_inference( bb, relax.op.max(x0, axis=[2, 3], keepdims=True), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorType(dtype="float32", ndim=4), ) - _check_inference(bb, relax.op.max(x1), relax.TensorStructInfo((), dtype="float32")) - _check_inference(bb, relax.op.max(x1, keepdims=True), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.max(x1, axis=[2, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.max(x1), relax.TensorType((), dtype="float32")) + _check_inference(bb, relax.op.max(x1, keepdims=True), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.max(x1, axis=[2, 3]), relax.TensorType(dtype="float32")) _check_inference( - bb, relax.op.max(x1, axis=[2, 3], keepdims=True), relax.TensorStructInfo(dtype="float32") + bb, relax.op.max(x1, axis=[2, 3], keepdims=True), relax.TensorType(dtype="float32") ) -def test_statistical_infer_struct_info_more_input_dtype(): +def test_statistical_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) - _check_inference(bb, relax.op.sum(x0), relax.TensorStructInfo((), "float16")) - _check_inference(bb, relax.op.sum(x1), relax.TensorStructInfo((), "int8")) + _check_inference(bb, relax.op.sum(x0), relax.TensorType((), "float16")) + _check_inference(bb, relax.op.sum(x1), relax.TensorType((), "int8")) -def test_statistical_infer_struct_info_axis_out_of_range_repetitive(): +def test_statistical_infer_ty_axis_out_of_range_repetitive(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -197,10 +189,10 @@ def test_statistical_infer_struct_info_axis_out_of_range_repetitive(): bb.normalize(relax.op.mean(x0, axis=[-5])) -def test_statistical_infer_struct_info_wrong_input_type(): +def test_statistical_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.variance(x0)) @@ -215,7 +207,7 @@ def test_statistical_infer_struct_info_wrong_input_type(): @pytest.mark.parametrize("scan_op", scan_ops) -def test_scan_op_infer_struct_info(scan_op: Callable): +def test_scan_op_infer_ty(scan_op: Callable): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) @@ -226,39 +218,37 @@ def test_scan_op_infer_struct_info(scan_op: Callable): x5 = relax.Var("x", R.Tensor()) x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0)) - _check_inference(bb, scan_op(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32")) - _check_inference(bb, scan_op(x6, axis=1), relax.TensorStructInfo((2, 10, 4), "float32", vdev0)) - _check_inference(bb, scan_op(x1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, scan_op(x2, axis=1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, scan_op(x3, axis=1), relax.TensorStructInfo((2, 10, 4), dtype="")) - _check_inference(bb, scan_op(x4, axis=1), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, scan_op(x5, axis=1), relax.TensorStructInfo(dtype="")) - _check_inference(bb, scan_op(x0), relax.TensorStructInfo((80,), "float32")) - _check_inference( - bb, scan_op(x0, axis=1, dtype="int32"), relax.TensorStructInfo((2, 10, 4), "int32") - ) + _check_inference(bb, scan_op(x0, axis=1), relax.TensorType((2, 10, 4), "float32")) + _check_inference(bb, scan_op(x6, axis=1), relax.TensorType((2, 10, 4), "float32", vdev0)) + _check_inference(bb, scan_op(x1, axis=1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, scan_op(x2, axis=1), relax.TensorType(dtype="float32")) + _check_inference(bb, scan_op(x3, axis=1), relax.TensorType((2, 10, 4), dtype="")) + _check_inference(bb, scan_op(x4, axis=1), relax.TensorType(dtype="", ndim=3)) + _check_inference(bb, scan_op(x5, axis=1), relax.TensorType(dtype="")) + _check_inference(bb, scan_op(x0), relax.TensorType((80,), "float32")) + _check_inference(bb, scan_op(x0, axis=1, dtype="int32"), relax.TensorType((2, 10, 4), "int32")) @pytest.mark.parametrize("scan_op", scan_ops) -def test_scan_op_infer_struct_info_shape_symbolic(scan_op: Callable): +def test_scan_op_infer_ty_shape_symbolic(scan_op: Callable): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) - _check_inference(bb, scan_op(x, axis=1), relax.TensorStructInfo((a, b, c), "float32")) - _check_inference(bb, scan_op(x), relax.TensorStructInfo((a * b * c,), "float32")) + _check_inference(bb, scan_op(x, axis=1), relax.TensorType((a, b, c), "float32")) + _check_inference(bb, scan_op(x), relax.TensorType((a * b * c,), "float32")) @pytest.mark.parametrize("scan_op", scan_ops) -def test_scan_op_infer_struct_info_more_input_dtype(scan_op: Callable): +def test_scan_op_infer_ty_more_input_dtype(scan_op: Callable): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) - _check_inference(bb, scan_op(x0, axis=1), relax.TensorStructInfo((2, 3, 4), "float16")) - _check_inference(bb, scan_op(x1, axis=1), relax.TensorStructInfo((2, 3, 4), "int8")) + _check_inference(bb, scan_op(x0, axis=1), relax.TensorType((2, 3, 4), "float16")) + _check_inference(bb, scan_op(x1, axis=1), relax.TensorType((2, 3, 4), "int8")) @pytest.mark.parametrize("scan_op", scan_ops) @@ -271,10 +261,10 @@ def test_scan_op_wrong_input_number(scan_op: Callable): @pytest.mark.parametrize("scan_op", scan_ops) -def test_scan_opinfer_struct_info_wrong_input_type(scan_op: Callable): +def test_scan_opinfer_ty_wrong_input_type(scan_op: Callable): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32"))) with pytest.raises(TypeError): bb.normalize(scan_op(x0, axis=1)) @@ -282,7 +272,7 @@ def test_scan_opinfer_struct_info_wrong_input_type(scan_op: Callable): bb.normalize(scan_op(x1, axis=1)) -def test_statistical_ext_infer_struct_info(): +def test_statistical_ext_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) @@ -294,98 +284,98 @@ def test_statistical_ext_infer_struct_info(): _check_inference( bb, relax.op.median(x0, axis=[1]), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 4, 5), "float32"), - relax.TensorStructInfo((2, 4, 5), "int64"), + relax.TensorType((2, 4, 5), "float32"), + relax.TensorType((2, 4, 5), "int64"), ] ), ) _check_inference( bb, relax.op.median(x0, axis=[1], keepdims=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 1, 4, 5), "float32"), - relax.TensorStructInfo((2, 1, 4, 5), "int64"), + relax.TensorType((2, 1, 4, 5), "float32"), + relax.TensorType((2, 1, 4, 5), "int64"), ] ), ) _check_inference( bb, relax.op.median(x1, axis=[1]), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=3), ] ), ) _check_inference( bb, relax.op.median(x1, axis=[1], keepdims=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=4), - relax.TensorStructInfo(dtype="int64", ndim=4), + relax.TensorType(dtype="float32", ndim=4), + relax.TensorType(dtype="int64", ndim=4), ] ), ) _check_inference( bb, relax.op.median(x1, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), "float32"), + relax.TensorType((1, 1, 1, 1), "float32"), ) _check_inference( bb, relax.op.median(x2, axis=[1]), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64"), ] ), ) _check_inference( bb, relax.op.median(x2, axis=[1], keepdims=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64"), ] ), ) - _check_inference(bb, relax.op.median(x2, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.median(x2, axis=None), relax.TensorType((), "float32")) _check_inference( bb, relax.op.median(x3, axis=[1], keepdims=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 1, 4, 5), dtype=""), - relax.TensorStructInfo((2, 1, 4, 5), dtype="int64"), + relax.TensorType((2, 1, 4, 5), dtype=""), + relax.TensorType((2, 1, 4, 5), dtype="int64"), ] ), ) - _check_inference(bb, relax.op.median(x3, axis=None), relax.TensorStructInfo((), dtype="")) + _check_inference(bb, relax.op.median(x3, axis=None), relax.TensorType((), dtype="")) _check_inference( bb, relax.op.median(x3, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), dtype=""), + relax.TensorType((1, 1, 1, 1), dtype=""), ) _check_inference( bb, relax.op.median(x4, axis=[1]), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 4, 5), "float32", vdev0), - relax.TensorStructInfo((2, 4, 5), "int64", vdev0), + relax.TensorType((2, 4, 5), "float32", vdev0), + relax.TensorType((2, 4, 5), "int64", vdev0), ] ), ) -def test_statistical_ext_infer_struct_info_shape_symbolic(): +def test_statistical_ext_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() a = tirx.Var("a", "int64") b = tirx.Var("b", "int64") @@ -396,110 +386,110 @@ def test_statistical_ext_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.median(x, axis=[1]), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((a, c, d), "float32"), - relax.TensorStructInfo((a, c, d), "int64"), + relax.TensorType((a, c, d), "float32"), + relax.TensorType((a, c, d), "int64"), ] ), ) _check_inference( bb, relax.op.median(x, axis=[1], keepdims=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((a, 1, c, d), "float32"), - relax.TensorStructInfo((a, 1, c, d), "int64"), + relax.TensorType((a, 1, c, d), "float32"), + relax.TensorType((a, 1, c, d), "int64"), ] ), ) - _check_inference(bb, relax.op.median(x, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.median(x, axis=None), relax.TensorType((), "float32")) _check_inference( bb, relax.op.median(x, axis=None, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), "float32"), + relax.TensorType((1, 1, 1, 1), "float32"), ) -def test_statistical_ext_infer_struct_info_shape_var(): +def test_statistical_ext_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=4)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) - _check_inference(bb, relax.op.median(x0), relax.TensorStructInfo((), dtype="float32")) + _check_inference(bb, relax.op.median(x0), relax.TensorType((), dtype="float32")) _check_inference( bb, relax.op.median(x0, keepdims=True), - relax.TensorStructInfo((1, 1, 1, 1), dtype="float32"), + relax.TensorType((1, 1, 1, 1), dtype="float32"), ) _check_inference( bb, relax.op.median(x0, axis=[2]), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=3), - relax.TensorStructInfo(dtype="int64", ndim=3), + relax.TensorType(dtype="float32", ndim=3), + relax.TensorType(dtype="int64", ndim=3), ] ), ) _check_inference( bb, relax.op.median(x0, axis=[2], keepdims=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32", ndim=4), - relax.TensorStructInfo(dtype="int64", ndim=4), + relax.TensorType(dtype="float32", ndim=4), + relax.TensorType(dtype="int64", ndim=4), ] ), ) - _check_inference(bb, relax.op.median(x1), relax.TensorStructInfo((), dtype="float32")) + _check_inference(bb, relax.op.median(x1), relax.TensorType((), dtype="float32")) _check_inference( bb, relax.op.median(x1, keepdims=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64"), ] ), ) _check_inference( bb, relax.op.median(x1, axis=[2]), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64"), ] ), ) _check_inference( bb, relax.op.median(x1, axis=[2], keepdims=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo(dtype="float32"), - relax.TensorStructInfo(dtype="int64"), + relax.TensorType(dtype="float32"), + relax.TensorType(dtype="int64"), ] ), ) -def test_statistical_ext_infer_struct_info_more_input_dtype(): +def test_statistical_ext_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) - _check_inference(bb, relax.op.median(x0), relax.TensorStructInfo((), "float16")) - _check_inference(bb, relax.op.median(x1), relax.TensorStructInfo((), "int8")) + _check_inference(bb, relax.op.median(x0), relax.TensorType((), "float16")) + _check_inference(bb, relax.op.median(x1), relax.TensorType((), "int8")) -def test_statistical_ext_infer_struct_info_wrong_input_type(): +def test_statistical_ext_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3, 4, 5), "float32"))) with pytest.raises(TypeError): bb.normalize(relax.op.median(x0)) diff --git a/tests/python/relax/test_op_ternary.py b/tests/python/relax/test_op_ternary.py index ba7f5e18c402..82b6df52f13e 100644 --- a/tests/python/relax/test_op_ternary.py +++ b/tests/python/relax/test_op_ternary.py @@ -30,12 +30,12 @@ def test_op_correctness(): assert relax.op.ewise_fma(x, y, z).op == Op.get("relax.ewise_fma") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) -def test_ewise_fma_infer_struct_info(): +def test_ewise_fma_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -48,20 +48,14 @@ def test_ewise_fma_infer_struct_info(): z1 = relax.Var("z", R.Tensor("float32")) z2 = relax.Var("z", R.Tensor((2, 3), "float32", vdev0)) - _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorStructInfo((2, 3), "float32", vdev0) - ) - _check_inference( - bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.ewise_fma(x0, y1, z1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.ewise_fma(x1, y0, z0), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorType((2, 3), "float32", vdev0)) + _check_inference(bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.ewise_fma(x0, y1, z1), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.ewise_fma(x1, y0, z0), relax.TensorType((2, 3), dtype="")) -def test_ewise_fma_infer_struct_info_shape_symbolic(): +def test_ewise_fma_infer_ty_shape_symbolic(): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") @@ -70,33 +64,27 @@ def test_ewise_fma_infer_struct_info_shape_symbolic(): y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2)) z0 = relax.Var("z", R.Tensor((m, n), "float32")) - _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((m, n), "float32")) - _check_inference( - bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2) - ) + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorType(dtype="float32", ndim=2)) -def test_ewise_fma_infer_struct_info_shape_var(): +def test_ewise_fma_infer_ty_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - y = relax.Var("y", relax.TensorStructInfo(s0, "float32")) - z = relax.Var("z", relax.TensorStructInfo(s0, "float32")) - - _check_inference(bb, relax.op.ewise_fma(x0, y, z), relax.TensorStructInfo(s0, "float32")) - _check_inference( - bb, relax.op.ewise_fma(x1, y, z), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.ewise_fma(x2, y, z), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - - -def test_ewise_fma_infer_struct_info_more_input_dtype(): + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType(ndim=2)) + s2 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) + x2 = relax.Var("x", relax.TensorType(s2, "float32")) + y = relax.Var("y", relax.TensorType(s0, "float32")) + z = relax.Var("z", relax.TensorType(s0, "float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y, z), relax.TensorType(s0, "float32")) + _check_inference(bb, relax.op.ewise_fma(x1, y, z), relax.TensorType(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.ewise_fma(x2, y, z), relax.TensorType(dtype="float32", ndim=2)) + + +def test_ewise_fma_infer_ty_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float64")) y0 = relax.Var("y", R.Tensor((2, 3), "float64")) @@ -108,12 +96,12 @@ def test_ewise_fma_infer_struct_info_more_input_dtype(): y2 = relax.Var("y", R.Tensor((2, 3), "int64")) z2 = relax.Var("z", R.Tensor((2, 3), "int64")) - _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float64")) - _check_inference(bb, relax.op.ewise_fma(x1, y1, z1), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorStructInfo((2, 3), "int64")) + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorType((2, 3), "float64")) + _check_inference(bb, relax.op.ewise_fma(x1, y1, z1), relax.TensorType((2, 3), "int8")) + _check_inference(bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorType((2, 3), "int64")) -def test_ewise_fma_infer_struct_info_dtype_mismatch(): +def test_ewise_fma_infer_ty_dtype_mismatch(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) y0 = relax.Var("y", R.Tensor((2, 3), "int32")) @@ -127,7 +115,7 @@ def test_ewise_fma_infer_struct_info_dtype_mismatch(): bb.normalize(relax.op.ewise_fma(x, y1, z1)) -def test_ewise_fma_infer_struct_info_ndim_mismatch(): +def test_ewise_fma_infer_ty_ndim_mismatch(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) y0 = relax.Var("y", R.Tensor((2, 3), "float32")) @@ -152,11 +140,11 @@ def test_ewise_fma_wrong_input_number(): relax.op.ewise_fma(x, x, x, x) -def test_ewise_fma_infer_struct_info_wrong_input_type(): +def test_ewise_fma_infer_ty_wrong_input_type(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) - y0 = relax.Var("y", relax.ShapeStructInfo((2, 3))) - y1 = relax.Var("y", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y0 = relax.Var("y", relax.ShapeType((2, 3))) + y1 = relax.Var("y", relax.FuncType([], R.Tensor((2, 3), "float32"))) z = relax.Var("z", R.Tensor((2, 3), "float32")) with pytest.raises(TypeError): diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py index 0527c2e9de44..9d2658b32e9a 100644 --- a/tests/python/relax/test_op_unary.py +++ b/tests/python/relax/test_op_unary.py @@ -63,9 +63,9 @@ def test_op_correctness(): assert relax.op.logical_not(x).op == Op.get("relax.logical_not") -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) unary_arith_ops = [ @@ -97,7 +97,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r @pytest.mark.parametrize("unary_arith_op", [row[0] for row in unary_arith_ops]) -def test_unary_arith_infer_struct_info(unary_arith_op: Callable): +def test_unary_arith_infer_ty(unary_arith_op: Callable): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -107,42 +107,40 @@ def test_unary_arith_infer_struct_info(unary_arith_op: Callable): x4 = relax.Var("x", R.Tensor()) x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0)) - _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, unary_arith_op(x5), relax.TensorStructInfo((2, 3), "float32", vdev0)) - _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, unary_arith_op(x3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, unary_arith_op(x4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, unary_arith_op(x0), relax.TensorType((2, 3), "float32")) + _check_inference(bb, unary_arith_op(x5), relax.TensorType((2, 3), "float32", vdev0)) + _check_inference(bb, unary_arith_op(x1), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, unary_arith_op(x2), relax.TensorType(dtype="float32")) + _check_inference(bb, unary_arith_op(x3), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, unary_arith_op(x4), relax.TensorType(dtype="")) @pytest.mark.parametrize("unary_arith_op", [row[0] for row in unary_arith_ops]) -def test_unary_arith_infer_struct_info_shape_symbolic(unary_arith_op: Callable): +def test_unary_arith_infer_ty_shape_symbolic(unary_arith_op: Callable): bb = relax.BlockBuilder() m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) - _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, unary_arith_op(x0), relax.TensorType((m, n), "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorType((4, n), "float32")) @pytest.mark.parametrize("unary_arith_op", [row[0] for row in unary_arith_ops]) -def test_unary_arith_infer_struct_info_shape_var(unary_arith_op: Callable): +def test_unary_arith_infer_ty_shape_var(unary_arith_op: Callable): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s1 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + s0 = relax.Var("s", relax.ShapeType(ndim=2)) + s1 = relax.Var("s", relax.ShapeType()) + x0 = relax.Var("x", relax.TensorType(s0, "float32")) + x1 = relax.Var("x", relax.TensorType(s1, "float32")) - _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, unary_arith_op(x0), relax.TensorType(s0, "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorType(s1, "float32")) @pytest.mark.parametrize("unary_arith_op,require_float_dtype", unary_arith_ops) -def test_unary_arith_infer_struct_info_more_input_dtype( - unary_arith_op: Callable, require_float_dtype: bool -): +def test_unary_arith_infer_ty_more_input_dtype(unary_arith_op: Callable, require_float_dtype: bool): if require_float_dtype: return @@ -151,13 +149,13 @@ def test_unary_arith_infer_struct_info_more_input_dtype( x1 = relax.Var("x", R.Tensor((2, 3), "int8")) x2 = relax.Var("x", R.Tensor((2, 3), "int64")) - _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float64")) - _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo((2, 3), "int64")) + _check_inference(bb, unary_arith_op(x0), relax.TensorType((2, 3), "float64")) + _check_inference(bb, unary_arith_op(x1), relax.TensorType((2, 3), "int8")) + _check_inference(bb, unary_arith_op(x2), relax.TensorType((2, 3), "int64")) @pytest.mark.parametrize("unary_arith_op,require_float_dtype", unary_arith_ops) -def test_unary_arith_infer_struct_info_invalid_input_dtype( +def test_unary_arith_infer_ty_invalid_input_dtype( unary_arith_op: Callable, require_float_dtype: bool ): if not require_float_dtype: @@ -184,10 +182,10 @@ def test_unary_arith_wrong_input_number(unary_arith_op: Callable): @pytest.mark.parametrize("unary_arith_op", [row[0] for row in unary_arith_ops]) -def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op: Callable): +def test_unary_arith_infer_ty_wrong_input_type(unary_arith_op: Callable): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x0 = relax.Var("x", relax.ShapeType((2, 3))) + x1 = relax.Var("x", relax.FuncType([], R.Tensor((2, 3), "float32"))) with pytest.raises(TypeError): bb.normalize(unary_arith_op(x0)) @@ -195,7 +193,7 @@ def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op: Callable bb.normalize(unary_arith_op(x1)) -def test_clip_infer_struct_info(): +def test_clip_infer_ty(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3), "float32")) @@ -205,12 +203,12 @@ def test_clip_infer_struct_info(): x4 = relax.Var("x", R.Tensor()) x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0)) - _check_inference(bb, relax.op.clip(x0, 0, 6), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorStructInfo((2, 3), "float32", vdev0)) - _check_inference(bb, relax.op.clip(x1, 0, 6), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, relax.op.clip(x2, 0, 6), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.clip(x3, 0, 6), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.clip(x4, 0, 6), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.clip(x0, 0, 6), relax.TensorType((2, 3), "float32")) + _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorType((2, 3), "float32", vdev0)) + _check_inference(bb, relax.op.clip(x1, 0, 6), relax.TensorType(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.clip(x2, 0, 6), relax.TensorType(dtype="float32")) + _check_inference(bb, relax.op.clip(x3, 0, 6), relax.TensorType((2, 3), dtype="")) + _check_inference(bb, relax.op.clip(x4, 0, 6), relax.TensorType(dtype="")) # Symbolic m = tirx.Var("m", "int64") @@ -218,8 +216,8 @@ def test_clip_infer_struct_info(): x5 = relax.Var("x", R.Tensor((m, n), "float32")) x6 = relax.Var("x", R.Tensor((4, n), "float32")) - _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.clip(x6, 0, 6), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorType((m, n), "float32")) + _check_inference(bb, relax.op.clip(x6, 0, 6), relax.TensorType((4, n), "float32")) if __name__ == "__main__": diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 454db01e352e..6654990c6f88 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -27,30 +27,30 @@ def test_infer_shape_of_1d_static_view(): @R.function(private=True) - def explicit_sinfo(A: R.Tensor) -> R.Tensor([4096]): + def explicit_ty(A: R.Tensor) -> R.Tensor([4096]): B: R.Tensor([4096]) = R.memory.view(A, R.shape([4096])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor): + def inferred_ty(A: R.Tensor): B = R.memory.view(A, R.shape([4096])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_infer_shape_of_2d_static_view(): @R.function(private=True) - def explicit_sinfo(A: R.Tensor) -> R.Tensor([64, 64]): + def explicit_ty(A: R.Tensor) -> R.Tensor([64, 64]): B: R.Tensor([64, 64]) = R.memory.view(A, R.shape([64, 64])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor): + def inferred_ty(A: R.Tensor): B = R.memory.view(A, R.shape([64, 64])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_error_if_shape_argument_is_not_shape(): @@ -64,44 +64,44 @@ def func(A: R.Tensor([16])): def test_infer_shape_of_1d_static_view_smaller_than_1d_source(): @R.function(private=True) - def explicit_sinfo(A: R.Tensor([4096])) -> R.Tensor([16]): + def explicit_ty(A: R.Tensor([4096])) -> R.Tensor([16]): B: R.Tensor([16]) = R.memory.view(A, R.shape([16])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor([4096])): + def inferred_ty(A: R.Tensor([4096])): B = R.memory.view(A, R.shape([16])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_infer_shape_of_2d_static_view_smaller_than_1d_source(): @R.function(private=True) - def explicit_sinfo(A: R.Tensor([4096])) -> R.Tensor([4, 4]): + def explicit_ty(A: R.Tensor([4096])) -> R.Tensor([4, 4]): B: R.Tensor([4, 4]) = R.memory.view(A, R.shape([4, 4])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor([4096])): + def inferred_ty(A: R.Tensor([4096])): B = R.memory.view(A, R.shape([4, 4])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_infer_shape_of_2d_static_view_same_size_as_2d_source(): @R.function(private=True) - def explicit_sinfo(A: R.Tensor([64, 64])) -> R.Tensor([16, 256]): + def explicit_ty(A: R.Tensor([64, 64])) -> R.Tensor([16, 256]): B: R.Tensor([16, 256]) = R.memory.view(A, R.shape([16, 256])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor([64, 64])): + def inferred_ty(A: R.Tensor([64, 64])): B = R.memory.view(A, R.shape([16, 256])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_error_if_1d_static_view_larger_than_1d_source(): @@ -124,50 +124,50 @@ def func(A: R.Tensor([16])): def test_infer_shape_of_1d_dynamic_view(): @R.function(private=True) - def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 2"]): + def explicit_ty(A: R.Tensor(["N"])) -> R.Tensor(["N // 2"]): N = T.int64() B: R.Tensor([N // 2]) = R.memory.view(A, R.shape([N // 2])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor(["N"])): + def inferred_ty(A: R.Tensor(["N"])): N = T.int64() B = R.memory.view(A, R.shape([N // 2])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_infer_shape_of_2d_dynamic_view_of_1d_source(): @R.function(private=True) - def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 8", 8]): + def explicit_ty(A: R.Tensor(["N"])) -> R.Tensor(["N // 8", 8]): N = T.int64() B: R.Tensor([N // 8, 8]) = R.memory.view(A, R.shape([N // 8, 8])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor(["N"])): + def inferred_ty(A: R.Tensor(["N"])): N = T.int64() B = R.memory.view(A, R.shape([N // 8, 8])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_infer_shape_of_2d_dynamic_view(): @R.function(private=True) - def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 2"]): + def explicit_ty(A: R.Tensor(["N"])) -> R.Tensor(["N // 2"]): N = T.int64() B: R.Tensor([N // 2]) = R.memory.view(A, R.shape([N // 2])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor(["N"])): + def inferred_ty(A: R.Tensor(["N"])): N = T.int64() B = R.memory.view(A, R.shape([N // 2])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_error_if_1d_dynamic_view_larger_than_1d_source(): @@ -232,32 +232,32 @@ def test_infer_dtype_of_float32_view(): """ @R.function(private=True) - def explicit_sinfo(A: R.Tensor) -> R.Tensor("float32"): + def explicit_ty(A: R.Tensor) -> R.Tensor("float32"): B: R.Tensor("float32") = R.memory.view(A, dtype=R.dtype("float32")) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor): + def inferred_ty(A: R.Tensor): B = R.memory.view(A, dtype=R.dtype("float32")) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_view_without_explicit_dtype_keeps_input_dtype(): """If R.memory.view only specifies the shape, the dtype is unchanged""" @R.function(private=True) - def explicit_sinfo(A: R.Tensor([16], "float32")) -> R.Tensor([4, 4], "float32"): + def explicit_ty(A: R.Tensor([16], "float32")) -> R.Tensor([4, 4], "float32"): B: R.Tensor([4, 4], "float32") = R.memory.view(A, R.shape([4, 4])) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor([16], "float32")): + def inferred_ty(A: R.Tensor([16], "float32")): B = R.memory.view(A, R.shape([4, 4])) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_infer_dtype_of_float32_view_from_relax_var(): @@ -270,18 +270,18 @@ def test_infer_dtype_of_float32_view_from_relax_var(): """ @R.function(private=True) - def explicit_sinfo(A: R.Tensor) -> R.Tensor("float32"): + def explicit_ty(A: R.Tensor) -> R.Tensor("float32"): dtype = R.dtype("float32") B: R.Tensor("float32") = R.memory.view(A, dtype=dtype) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor): + def inferred_ty(A: R.Tensor): dtype = R.dtype("float32") B = R.memory.view(A, dtype=dtype) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_infer_dtype_of_view_with_unknown_dtype(): @@ -293,16 +293,16 @@ def test_infer_dtype_of_view_with_unknown_dtype(): """ @R.function(private=True) - def explicit_sinfo(A: R.Tensor("float32"), dtype: R.Object) -> R.Tensor: + def explicit_ty(A: R.Tensor("float32"), dtype: R.Object) -> R.Tensor: B: R.Tensor = R.memory.view(A, dtype=dtype) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor("float32"), dtype: R.Object): + def inferred_ty(A: R.Tensor("float32"), dtype: R.Object): B = R.memory.view(A, dtype=dtype) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_view_dtype_may_be_smaller_than_input_dtype(): @@ -315,16 +315,16 @@ def test_view_dtype_may_be_smaller_than_input_dtype(): """ @R.function(private=True) - def explicit_sinfo(A: R.Tensor("uint32")) -> R.Tensor("float8"): + def explicit_ty(A: R.Tensor("uint32")) -> R.Tensor("float8"): B: R.Tensor("float8") = R.memory.view(A, dtype=R.dtype("float8")) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor("uint32")): + def inferred_ty(A: R.Tensor("uint32")): B = R.memory.view(A, dtype=R.dtype("float8")) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_error_if_view_dtype_is_larger_than_input_dtype(): @@ -348,32 +348,32 @@ def test_increase_dtype_size_while_decreasing_number_of_elements(): """ @R.function(private=True) - def explicit_sinfo(A: R.Tensor([16], "uint8")) -> R.Tensor([8], "float16"): + def explicit_ty(A: R.Tensor([16], "uint8")) -> R.Tensor([8], "float16"): B: R.Tensor([8], "float16") = R.memory.view(A, shape=R.shape([8]), dtype=R.dtype("float16")) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor([16], "uint8")): + def inferred_ty(A: R.Tensor([16], "uint8")): B = R.memory.view(A, shape=R.shape([8]), dtype=R.dtype("float16")) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_decrease_dtype_size_while_increasing_number_of_elements(): """R.memory.view may update both dtype and shape simultaneously""" @R.function(private=True) - def explicit_sinfo(A: R.Tensor([8], "float16")) -> R.Tensor([16], "uint8"): + def explicit_ty(A: R.Tensor([8], "float16")) -> R.Tensor([16], "uint8"): B: R.Tensor([16], "uint8") = R.memory.view(A, shape=R.shape([16]), dtype=R.dtype("uint8")) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor([8], "float16")): + def inferred_ty(A: R.Tensor([8], "float16")): B = R.memory.view(A, shape=R.shape([16]), dtype=R.dtype("uint8")) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_error_if_number_of_bytes_of_view_is_larger_than_original(): @@ -419,16 +419,16 @@ def test_applying_relative_byte_offset_of_zero_is_legal(): """ @R.function(private=True) - def explicit_sinfo(A: R.Tensor) -> R.Tensor: + def explicit_ty(A: R.Tensor) -> R.Tensor: B: R.Tensor = R.memory.view(A, relative_byte_offset=R.prim_value(0)) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor): + def inferred_ty(A: R.Tensor): B = R.memory.view(A, relative_byte_offset=R.prim_value(0)) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_applying_unknown_relative_byte_offset_is_legal(): @@ -442,16 +442,16 @@ def test_applying_unknown_relative_byte_offset_is_legal(): """ @R.function(private=True) - def explicit_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")) -> R.Tensor: + def explicit_ty(A: R.Tensor, relative_byte_offset: R.Prim("int64")) -> R.Tensor: B: R.Tensor = R.memory.view(A, relative_byte_offset=relative_byte_offset) return B @R.function(private=True) - def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")): + def inferred_ty(A: R.Tensor, relative_byte_offset: R.Prim("int64")): B = R.memory.view(A, relative_byte_offset=relative_byte_offset) return B - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_legalize_is_no_op(): @@ -485,7 +485,7 @@ def main(A: R.Tensor([4096], "float32")): B = R.ExternFunc( "runtime.TVMTensorCreateView", R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", + derive_func="tvm.relax.type.infer_view_ty", purity=True, ), )( @@ -517,7 +517,7 @@ def main(A: R.Tensor(dtype="float32")): B = R.ExternFunc( "runtime.TVMTensorCreateView", R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", + derive_func="tvm.relax.type.infer_view_ty", purity=True, ), )( @@ -547,7 +547,7 @@ def main(A: R.Tensor([4096], "float32")): B = R.ExternFunc( "runtime.TVMTensorCreateView", R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", + derive_func="tvm.relax.type.infer_view_ty", purity=True, ), )( @@ -577,7 +577,7 @@ def main(A: R.Tensor([4096], "float32")): B = R.ExternFunc( "runtime.TVMTensorCreateView", R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", + derive_func="tvm.relax.type.infer_view_ty", purity=True, ), )( @@ -626,7 +626,7 @@ def main(A: R.Tensor([4096], "uint8")): B = R.ExternFunc( "runtime.TVMTensorCreateView", R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", + derive_func="tvm.relax.type.infer_view_ty", purity=True, ), )( @@ -638,7 +638,7 @@ def main(A: R.Tensor([4096], "uint8")): C = R.ExternFunc( "runtime.TVMTensorCreateView", R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", + derive_func="tvm.relax.type.infer_view_ty", purity=True, ), )( diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index 075f49e9ca3e..9010f7460684 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -31,9 +31,9 @@ from tvm.script import relax as R -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_ty: relax.Type): ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(ret.ty, expected_ty) def _assert_relax_op_legalized(mod: tvm.IRModule, op_name: str) -> None: @@ -59,7 +59,7 @@ def test_roi_align_op_correctness(): assert relax.op.vision.roi_align(x, rois, (7, 7), 1.0).op == Op.get("relax.vision.roi_align") -def test_roi_align_infer_struct_info(): +def test_roi_align_infer_ty(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) @@ -68,16 +68,16 @@ def test_roi_align_infer_struct_info(): _check_inference( bb, relax.op.vision.roi_align(x0, rois, (7, 7), 0.25), - relax.TensorStructInfo((5, 3, 7, 7), "float32"), + relax.TensorType((5, 3, 7, 7), "float32"), ) _check_inference( bb, relax.op.vision.roi_align(x1, rois, (5, 7), 1.0, layout="NHWC"), - relax.TensorStructInfo((5, 5, 7, 3), "float32"), + relax.TensorType((5, 5, 7, 3), "float32"), ) -def test_roi_align_infer_struct_info_aligned(): +def test_roi_align_infer_ty_aligned(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) rois = relax.Var("rois", R.Tensor((5, 5), "float32")) @@ -85,11 +85,11 @@ def test_roi_align_infer_struct_info_aligned(): _check_inference( bb, relax.op.vision.roi_align(x, rois, (7, 7), 1.0, aligned=True), - relax.TensorStructInfo((5, 3, 7, 7), "float32"), + relax.TensorType((5, 3, 7, 7), "float32"), ) -def test_roi_align_infer_struct_info_shape_var(): +def test_roi_align_infer_ty_shape_var(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -103,7 +103,7 @@ def test_roi_align_infer_struct_info_shape_var(): _check_inference( bb, relax.op.vision.roi_align(x, rois, (7, 7), 0.5), - relax.TensorStructInfo((num_roi, c, 7, 7), "float32"), + relax.TensorType((num_roi, c, 7, 7), "float32"), ) @@ -160,8 +160,8 @@ def main( mod = LegalizeOps()(ROIAlign) assert "call_tir" in str(mod) tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TensorStructInfo((2, 2, 3, 3), "float32"), + mod["main"].ret_ty, + relax.TensorType((2, 2, 3, 3), "float32"), ) @@ -188,8 +188,8 @@ def main( mod = LegalizeOps()(ROIAlign) assert "call_tir" in str(mod) tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TensorStructInfo((1, 1, 1, 1), "float32"), + mod["main"].ret_ty, + relax.TensorType((1, 1, 1, 1), "float32"), ) @@ -215,8 +215,8 @@ def main( mod = LegalizeOps()(ROIAlign) assert "call_tir" in str(mod) tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TensorStructInfo((1, 2, 2, 2), "float32"), + mod["main"].ret_ty, + relax.TensorType((1, 2, 2, 2), "float32"), ) @@ -225,23 +225,23 @@ def test_get_valid_counts_op_correctness(): assert relax.op.vision.get_valid_counts(data, 0.5).op == Op.get("relax.vision.get_valid_counts") -def test_get_valid_counts_infer_struct_info(): +def test_get_valid_counts_infer_ty(): bb = relax.BlockBuilder() data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) _check_inference( bb, relax.op.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, score_index=1), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2,), "int32"), - relax.TensorStructInfo((2, 10, 6), "float32"), - relax.TensorStructInfo((2, 10), "int32"), + relax.TensorType((2,), "int32"), + relax.TensorType((2, 10, 6), "float32"), + relax.TensorType((2, 10), "int32"), ] ), ) -def test_get_valid_counts_infer_struct_info_shape_var(): +def test_get_valid_counts_infer_ty_shape_var(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") m = tirx.Var("m", "int64") @@ -250,11 +250,11 @@ def test_get_valid_counts_infer_struct_info_shape_var(): _check_inference( bb, relax.op.vision.get_valid_counts(data, score_threshold=0.0), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((n,), "int32"), - relax.TensorStructInfo((n, m, k), "float32"), - relax.TensorStructInfo((n, m), "int32"), + relax.TensorType((n,), "int32"), + relax.TensorType((n, m, k), "float32"), + relax.TensorType((n, m), "int32"), ] ), ) @@ -287,7 +287,7 @@ def test_nms_op_correctness(): ) -def test_nms_infer_struct_info_return_indices(): +def test_nms_infer_ty_return_indices(): bb = relax.BlockBuilder() data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) @@ -295,16 +295,16 @@ def test_nms_infer_struct_info_return_indices(): _check_inference( bb, relax.op.vision.non_max_suppression(data, valid_count, indices, return_indices=True), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 10), "int32"), - relax.TensorStructInfo((2, 1), "int32"), + relax.TensorType((2, 10), "int32"), + relax.TensorType((2, 1), "int32"), ] ), ) -def test_nms_infer_struct_info_return_indices_soft_nms(): +def test_nms_infer_ty_return_indices_soft_nms(): bb = relax.BlockBuilder() data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) @@ -314,17 +314,17 @@ def test_nms_infer_struct_info_return_indices_soft_nms(): relax.op.vision.non_max_suppression( data, valid_count, indices, return_indices=True, soft_nms_sigma=0.5 ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 10, 6), "float32"), - relax.TensorStructInfo((2, 10), "int32"), - relax.TensorStructInfo((2, 1), "int32"), + relax.TensorType((2, 10, 6), "float32"), + relax.TensorType((2, 10), "int32"), + relax.TensorType((2, 1), "int32"), ] ), ) -def test_nms_infer_struct_info_return_data(): +def test_nms_infer_ty_return_data(): bb = relax.BlockBuilder() data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) @@ -332,11 +332,11 @@ def test_nms_infer_struct_info_return_data(): _check_inference( bb, relax.op.vision.non_max_suppression(data, valid_count, indices, return_indices=False), - relax.TensorStructInfo((2, 10, 6), "float32"), + relax.TensorType((2, 10, 6), "float32"), ) -def test_nms_infer_struct_info_return_data_shape_var(): +def test_nms_infer_ty_return_data_shape_var(): bb = relax.BlockBuilder() batch_size = tirx.Var("batch_size", "int64") num_anchors = tirx.Var("num_anchors", "int64") @@ -347,7 +347,7 @@ def test_nms_infer_struct_info_return_data_shape_var(): _check_inference( bb, relax.op.vision.non_max_suppression(data, valid_count, indices, return_indices=False), - relax.TensorStructInfo((batch_size, num_anchors, elem_length), "float32"), + relax.TensorType((batch_size, num_anchors, elem_length), "float32"), ) @@ -440,12 +440,12 @@ def main( mod = LegalizeOps()(GVC) _assert_relax_op_legalized(mod, "relax.vision.get_valid_counts") tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TupleStructInfo( + mod["main"].ret_ty, + relax.TupleType( [ - relax.TensorStructInfo((1,), "int32"), - relax.TensorStructInfo((1, 5, 6), "float32"), - relax.TensorStructInfo((1, 5), "int32"), + relax.TensorType((1,), "int32"), + relax.TensorType((1, 5, 6), "float32"), + relax.TensorType((1, 5), "int32"), ] ), ) @@ -481,11 +481,11 @@ def main( mod = LegalizeOps()(NMS) _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression") tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TupleStructInfo( + mod["main"].ret_ty, + relax.TupleType( [ - relax.TensorStructInfo((1, 5), "int32"), - relax.TensorStructInfo((1, 1), "int32"), + relax.TensorType((1, 5), "int32"), + relax.TensorType((1, 1), "int32"), ] ), ) @@ -525,12 +525,12 @@ def main( mod = LegalizeOps()(NMS) _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression") tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TupleStructInfo( + mod["main"].ret_ty, + relax.TupleType( [ - relax.TensorStructInfo((1, 5, 6), "float32"), - relax.TensorStructInfo((1, 5), "int32"), - relax.TensorStructInfo((1, 1), "int32"), + relax.TensorType((1, 5, 6), "float32"), + relax.TensorType((1, 5), "int32"), + relax.TensorType((1, 1), "int32"), ] ), ) @@ -566,8 +566,8 @@ def main( mod = LegalizeOps()(NMS) _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression") tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TensorStructInfo((1, 5, 6), "float32"), + mod["main"].ret_ty, + relax.TensorType((1, 5, 6), "float32"), ) @@ -650,9 +650,9 @@ def _run_nms_e2e( data_shape = tuple(int(dim) for dim in data_np.shape) valid_count_shape = tuple(int(dim) for dim in valid_count_np.shape) indices_shape = tuple(int(dim) for dim in indices_np.shape) - data = relax.Var("data", relax.TensorStructInfo(data_shape, "float32")) - valid_count = relax.Var("valid_count", relax.TensorStructInfo(valid_count_shape, "int32")) - indices = relax.Var("indices", relax.TensorStructInfo(indices_shape, "int32")) + data = relax.Var("data", relax.TensorType(data_shape, "float32")) + valid_count = relax.Var("valid_count", relax.TensorType(valid_count_shape, "int32")) + indices = relax.Var("indices", relax.TensorType(indices_shape, "int32")) bb = relax.BlockBuilder() with bb.function("main", (data, valid_count, indices)): @@ -1176,7 +1176,7 @@ def test_roi_pool_op_correctness(): assert relax.op.vision.roi_pool(x, rois, (7, 7), 1.0).op == Op.get("relax.vision.roi_pool") -def test_roi_pool_infer_struct_info(): +def test_roi_pool_infer_ty(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) rois = relax.Var("rois", R.Tensor((5, 5), "float32")) @@ -1184,11 +1184,11 @@ def test_roi_pool_infer_struct_info(): _check_inference( bb, relax.op.vision.roi_pool(x, rois, (7, 5), 0.25), - relax.TensorStructInfo((5, 3, 7, 5), "float32"), + relax.TensorType((5, 3, 7, 5), "float32"), ) -def test_roi_pool_infer_struct_info_shape_var(): +def test_roi_pool_infer_ty_shape_var(): bb = relax.BlockBuilder() n = tirx.Var("n", "int64") c = tirx.Var("c", "int64") @@ -1202,7 +1202,7 @@ def test_roi_pool_infer_struct_info_shape_var(): _check_inference( bb, relax.op.vision.roi_pool(x, rois, (7, 7), 0.5), - relax.TensorStructInfo((num_roi, c, 7, 7), "float32"), + relax.TensorType((num_roi, c, 7, 7), "float32"), ) @@ -1257,12 +1257,12 @@ def main( mod = LegalizeOps()(ROIPool) assert "call_tir" in str(mod) tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TensorStructInfo((2, 2, 3, 2), "float32"), + mod["main"].ret_ty, + relax.TensorType((2, 2, 3, 2), "float32"), ) -def test_all_class_non_max_suppression_infer_struct_info(): +def test_all_class_non_max_suppression_infer_ty(): bb = relax.BlockBuilder() batch_size, num_classes, num_boxes = 10, 8, 5 boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32")) @@ -1276,10 +1276,10 @@ def test_all_class_non_max_suppression_infer_struct_info(): relax.op.vision.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((batch_size * num_classes * num_boxes, 3), "int64"), - relax.TensorStructInfo((1,), "int64"), + relax.TensorType((batch_size * num_classes * num_boxes, 3), "int64"), + relax.TensorType((1,), "int64"), ] ), ) @@ -1293,7 +1293,7 @@ def test_all_class_non_max_suppression_wrong_input_number(): relax.op.vision.all_class_non_max_suppression(boxes, scores) -def test_all_class_non_max_suppression_infer_struct_info_shape_var(): +def test_all_class_non_max_suppression_infer_ty_shape_var(): bb = relax.BlockBuilder() batch_size = tirx.Var("batch_size", "int64") num_classes = tirx.Var("num_classes", "int64") @@ -1309,10 +1309,10 @@ def test_all_class_non_max_suppression_infer_struct_info_shape_var(): relax.op.vision.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((batch_size * num_classes * num_boxes, 3), "int64"), - relax.TensorStructInfo((1,), "int64"), + relax.TensorType((batch_size * num_classes * num_boxes, 3), "int64"), + relax.TensorType((1,), "int64"), ] ), ) @@ -1338,13 +1338,13 @@ def main( # Check legalized function has dynamic output (uses dynamic_strided_slice) assert "dynamic_strided_slice" in str(mod) - ret_sinfo = mod["main"].ret_struct_info + ret_ty = mod["main"].ret_ty tvm.ir.assert_structural_equal( - ret_sinfo, - relax.TupleStructInfo( + ret_ty, + relax.TupleType( [ - relax.TensorStructInfo(ndim=2, dtype="int64"), - relax.TensorStructInfo((1,), "int64"), + relax.TensorType(ndim=2, dtype="int64"), + relax.TensorType((1,), "int64"), ] ), ) @@ -1385,13 +1385,13 @@ def main( mod = LegalizeOps()(NMSModule) - # Check struct info + # Check type tvm.ir.assert_structural_equal( - mod["main"].ret_struct_info, - relax.TupleStructInfo( + mod["main"].ret_ty, + relax.TupleType( [ - relax.TensorStructInfo(ndim=2, dtype="int64"), - relax.TensorStructInfo((1,), "int64"), + relax.TensorType(ndim=2, dtype="int64"), + relax.TensorType((1,), "int64"), ] ), ) @@ -1418,7 +1418,7 @@ def test_multibox_transform_loc_op_correctness(): ).op == Op.get("relax.vision.multibox_transform_loc") -def test_multibox_transform_loc_infer_struct_info(): +def test_multibox_transform_loc_infer_ty(): bb = relax.BlockBuilder() cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) loc = relax.Var("loc", R.Tensor((2, 20), "float32")) @@ -1428,10 +1428,10 @@ def test_multibox_transform_loc_infer_struct_info(): relax.op.vision.multibox_transform_loc( cls, loc, anc, False, 0.0, (0.1, 0.1, 0.2, 0.2), True ), - relax.TupleStructInfo( + relax.TupleType( [ - relax.TensorStructInfo((2, 5, 4), "float32"), - relax.TensorStructInfo((2, 3, 5), "float32"), + relax.TensorType((2, 5, 4), "float32"), + relax.TensorType((2, 3, 5), "float32"), ] ), ) diff --git a/tests/python/relax/test_optimize_layout_transform.py b/tests/python/relax/test_optimize_layout_transform.py index 2303afe89bb0..0ef21c2daab2 100644 --- a/tests/python/relax/test_optimize_layout_transform.py +++ b/tests/python/relax/test_optimize_layout_transform.py @@ -73,7 +73,7 @@ def main( lv2 = R.call_tir( Before.relax_add_replacement, (lv, lv1), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) lv0: R.Tensor((16,), dtype="float32") = R.layout_transform( lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None @@ -87,7 +87,7 @@ def main( lv5 = R.call_tir( Before.relax_add_replacement, (lv4, lv3), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform( lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None @@ -127,12 +127,12 @@ def main( lv2 = R.call_tir( Expected.relax_add_replacement, (lv, lv1), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) lv5 = R.call_tir( Expected.relax_add_replacement, (lv1, lv2), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) gv: R.Tensor((16,), dtype="float32") = R.layout_transform( lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None @@ -180,12 +180,12 @@ def main( lv3 = R.call_tir( Before.relax_add_replacement, (lv, lv1), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) lv4 = R.call_tir( Before.relax_add_replacement, (lv, lv2), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) lv5: R.Tensor((16,), dtype="float32") = R.layout_transform( lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None @@ -202,7 +202,7 @@ def main( lv9 = R.call_tir( Before.relax_add_replacement, (lv7, lv8), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) lv10: R.Tensor((16,), dtype="float32") = R.layout_transform( lv9, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None @@ -247,17 +247,17 @@ def main( lv3 = R.call_tir( Expected.relax_add_replacement, (lv, lv1), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) lv4 = R.call_tir( Expected.relax_add_replacement, (lv, lv2), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) lv5 = R.call_tir( Expected.relax_add_replacement, (lv3, lv4), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) gv: R.Tensor((16,), dtype="float32") = R.layout_transform( lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None @@ -311,7 +311,7 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32" lv1 = R.call_tir( Before.relax_relu_replacement, (lv,), - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), ) lv2: R.Tensor((16,), dtype="float32") = R.layout_transform( lv1, @@ -320,7 +320,7 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32" axis_separators=[], ) lv_1 = R.call_tir( - Before.remove_pad, (lv2,), out_sinfo=R.Tensor((14,), dtype="float32") + Before.remove_pad, (lv2,), out_ty=R.Tensor((14,), dtype="float32") ) lv3: R.Tensor((16,), dtype="float32") = R.layout_transform( lv_1, @@ -331,7 +331,7 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32" lv4 = R.call_tir( Before.relax_relu_replacement, (lv3,), - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), ) lv5: R.Tensor((16,), dtype="float32") = R.layout_transform( lv4, @@ -340,7 +340,7 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32" axis_separators=[], ) lv_2 = R.call_tir( - Before.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32") + Before.remove_pad, (lv5,), out_ty=R.Tensor((14,), dtype="float32") ) gv: R.Tensor((14,), dtype="float32") = lv_2 R.output(gv) @@ -388,12 +388,12 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32" lv1 = R.call_tir( Expected.relax_relu_replacement, (lv,), - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), ) lv4 = R.call_tir( Expected.relax_relu_replacement, (lv1,), - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), ) lv5: R.Tensor((16,), dtype="float32") = R.layout_transform( lv4, @@ -402,7 +402,7 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32" axis_separators=[], ) gv = R.call_tir( - Expected.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32") + Expected.remove_pad, (lv5,), out_ty=R.Tensor((14,), dtype="float32") ) R.output(gv) return gv diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py index a85cbc43563a..4229a89291f4 100644 --- a/tests/python/relax/test_pipeline.py +++ b/tests/python/relax/test_pipeline.py @@ -67,7 +67,7 @@ def create_kv_cache(reserve_slots: R.Shape(["m"])): init_data, R.shape([m, 4]), 0, - sinfo_args=[R.Object()], + ty_args=[R.Object()], ) return kv_cache @@ -83,14 +83,14 @@ def main( curr_value = R.add(x, y) # update cache kv_cache = R.call_packed( - "vm.builtin.attention_kv_cache_append", kv_cache, curr_value, sinfo_args=[R.Object] + "vm.builtin.attention_kv_cache_append", kv_cache, curr_value, ty_args=[R.Object] ) # return the updated cache view kv = R.call_packed( "vm.builtin.attention_kv_cache_view", kv_cache, shape, - sinfo_args=[R.Tensor((L, 4), "float32")], + ty_args=[R.Tensor((L, 4), "float32")], ) return (kv, kv_cache) diff --git a/tests/python/relax/test_pytorch_integration.py b/tests/python/relax/test_pytorch_integration.py index f8255ed96306..8ea98306feaa 100644 --- a/tests/python/relax/test_pytorch_integration.py +++ b/tests/python/relax/test_pytorch_integration.py @@ -49,13 +49,13 @@ def main(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: n = x.shape[0] # Call TIR function - lv = self.call_tir(self.matmul, [x, w], out_sinfo=R.Tensor((n, 20), "float32")) + lv = self.call_tir(self.matmul, [x, w], out_ty=R.Tensor((n, 20), "float32")) # Apply ReLU lv1 = F.relu(lv) # Call packed function (will be added dynamically) - lv2 = self.call_dps_packed("my_softmax", [lv1, 1], out_sinfo=R.Tensor((n, 20), "float32")) + lv2 = self.call_dps_packed("my_softmax", [lv1, 1], out_ty=R.Tensor((n, 20), "float32")) # Call Python function lv3 = self.my_identity_func(lv2) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 739243547688..ea9953d2d822 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -279,11 +279,11 @@ def symbolic_shape(shape: R.Shape(("m", "n"))) -> R.Tensor(ndim=-1): def test_op_shape_to_tensor(exec_mode): - # Check struct info - isinstance(ShapeToTensorTest["const_shape"].body.struct_info, tvm.relax.TensorStructInfo) - assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1 - isinstance(ShapeToTensorTest["symbolic_shape"].body.struct_info, tvm.relax.TensorStructInfo) - assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1 + # Check type + isinstance(ShapeToTensorTest["const_shape"].body.ty, tvm.relax.TensorType) + assert ShapeToTensorTest["const_shape"].body.ty.ndim == 1 + isinstance(ShapeToTensorTest["symbolic_shape"].body.ty, tvm.relax.TensorType) + assert ShapeToTensorTest["symbolic_shape"].body.ty.ndim == 1 # Check its functionality out2d = run_cpu(ShapeToTensorTest, "const_shape", tvm_ffi.Shape([3, 2]), exec_mode=exec_mode) @@ -311,7 +311,7 @@ class CallPureTest: @R.function def pure_copy(x: R.Tensor((3, 4), "float32")): z = R.call_pure_packed( - "vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32")) + "vm.builtin.copy", x, ty_args=(R.Tensor((3, 4), dtype="float32")) ) return z @@ -331,7 +331,7 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): "vm.builtin.copy", x, inplace_indices=0, - sinfo_args=(R.Tensor((3, 4), dtype="float32")), + ty_args=(R.Tensor((3, 4), dtype="float32")), ) return z @@ -354,7 +354,7 @@ def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): x, y, inplace_indices=0, - sinfo_args=(R.Tensor((3, 4), dtype="float32")), + ty_args=(R.Tensor((3, 4), dtype="float32")), ) return z @@ -389,7 +389,7 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") x, y, inplace_indices=[0, -1], - sinfo_args=(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), dtype="float32")), + ty_args=(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), dtype="float32")), ) return z @@ -454,13 +454,13 @@ def torch_sigmoid(x): class CallPyFuncTest: @R.function def simple_call(x: R.Tensor((3,), "float32")): - result = R.call_py_func(R.str("torch_relu"), (x,), out_sinfo=R.Tensor((3,), "float32")) + result = R.call_py_func(R.str("torch_relu"), (x,), out_ty=R.Tensor((3,), "float32")) return result @R.function def multiple_calls(x: R.Tensor((2,), "float32")): - y = R.call_py_func(R.str("torch_relu"), (x,), out_sinfo=R.Tensor((2,), "float32")) - z = R.call_py_func(R.str("torch_sigmoid"), (y,), out_sinfo=R.Tensor((2,), "float32")) + y = R.call_py_func(R.str("torch_relu"), (x,), out_ty=R.Tensor((2,), "float32")) + z = R.call_py_func(R.str("torch_sigmoid"), (y,), out_ty=R.Tensor((2,), "float32")) return z np.random.seed(0) @@ -492,7 +492,7 @@ def to_dev(x: R.Tensor((3, 4), "float32")): x, 1, 0, - sinfo_args=(R.Tensor((3, 4), dtype="float32")), + ty_args=(R.Tensor((3, 4), dtype="float32")), ) return z diff --git a/tests/python/relax/test_relax_to_pyfunc_converter.py b/tests/python/relax/test_relax_to_pyfunc_converter.py index 0f41ec93eb8b..cbefedabbe31 100644 --- a/tests/python/relax/test_relax_to_pyfunc_converter.py +++ b/tests/python/relax/test_relax_to_pyfunc_converter.py @@ -71,12 +71,12 @@ def with_call_tir(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> (5,), "float32" ): cls = ComprehensiveTestModule - return R.call_tir(cls.add_tir, (x, y), out_sinfo=R.Tensor((5,), "float32")) + return R.call_tir(cls.add_tir, (x, y), out_ty=R.Tensor((5,), "float32")) @R.function def with_call_dps_packed(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): return R.call_dps_packed( - "my_softmax", (x, R.prim_value(1)), out_sinfo=R.Tensor((5,), "float32") + "my_softmax", (x, R.prim_value(1)), out_ty=R.Tensor((5,), "float32") ) @R.function @@ -86,7 +86,7 @@ def complex_function(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) added = R.add(x, y) relued = R.nn.relu(added) cls = ComprehensiveTestModule - tir_result = R.call_tir(cls.add_tir, (relued, y), out_sinfo=R.Tensor((5,), "float32")) + tir_result = R.call_tir(cls.add_tir, (relued, y), out_ty=R.Tensor((5,), "float32")) return R.nn.relu(tir_result) @R.function @@ -882,7 +882,7 @@ def test_func(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")) -> R.T (4,), "float32" ): return R.call_tir( - DLPackTestModule.test_tir, (x, y), out_sinfo=R.Tensor((4,), "float32") + DLPackTestModule.test_tir, (x, y), out_ty=R.Tensor((4,), "float32") ) converter = RelaxToPyFuncConverter(DLPackTestModule) @@ -935,7 +935,7 @@ def test_func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")) -> R.T (3,), "float32" ): return R.call_tir( - RuntimeAPITestModule.test_tir, (x, y), out_sinfo=R.Tensor((3,), "float32") + RuntimeAPITestModule.test_tir, (x, y), out_ty=R.Tensor((3,), "float32") ) converter = RelaxToPyFuncConverter(RuntimeAPITestModule) @@ -963,7 +963,7 @@ class PackedFuncTestModule: @R.function def test_dps(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"): return R.call_dps_packed( - "test_packed_func", (x, R.const(0)), out_sinfo=R.Tensor((4,), "float32") + "test_packed_func", (x, R.const(0)), out_ty=R.Tensor((4,), "float32") ) converter = RelaxToPyFuncConverter(PackedFuncTestModule) @@ -994,7 +994,7 @@ def test_mixed(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")) -> R. ): # TIR operation tir_result = R.call_tir( - MixedOpsTestModule.add_tir, (x, y), out_sinfo=R.Tensor((4,), "float32") + MixedOpsTestModule.add_tir, (x, y), out_ty=R.Tensor((4,), "float32") ) # Relax operations relued = R.nn.relu(tir_result) diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py index 9e406dec7258..13cd34531191 100644 --- a/tests/python/relax/test_tir_call_source_kernel.py +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -66,7 +66,7 @@ def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle) -> None def main(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")): m = T.int64() with R.dataflow(): - output = R.call_tir(Module.add, [x, y], relax.TensorStructInfo((m,), "float32")) + output = R.call_tir(Module.add, [x, y], relax.TensorType((m,), "float32")) R.output(output) return output diff --git a/tests/python/relax/test_training_append_loss.py b/tests/python/relax/test_training_append_loss.py index abf5e3ac8633..4ef2ce6e3025 100644 --- a/tests/python/relax/test_training_append_loss.py +++ b/tests/python/relax/test_training_append_loss.py @@ -160,7 +160,7 @@ def main_loss(x: R.Tensor((3, 3), "float32"), arg3: R.Tensor((3, 3), "float32")) def test_error_return_value_vs_parameter(): - # StructInfo not match + # Type not match # fmt: off @I.ir_module class Module1: diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py index c651607436c1..7505cb796397 100644 --- a/tests/python/relax/test_training_loss.py +++ b/tests/python/relax/test_training_loss.py @@ -39,8 +39,8 @@ def forward( def test_l1_loss(): N = 3 C = 5 - predictions = relax.TensorStructInfo((N, C), "float32") - targets = relax.TensorStructInfo((N, C), "float32") + predictions = relax.TensorType((N, C), "float32") + targets = relax.TensorType((N, C), "float32") l1_loss = relax.training.loss.L1Loss() @R.function @@ -59,7 +59,7 @@ def expected( def test_l1_loss_append(): - s = Module["forward"].ret_struct_info + s = Module["forward"].ret_ty l1_loss = relax.training.loss.L1Loss(reduction="sum") After = relax.training.AppendLoss("forward", l1_loss(s, s), l1_loss.num_backbone_outputs)( Module @@ -88,8 +88,8 @@ def expected( def test_mse_loss(): N = 3 C = 5 - predictions = relax.TensorStructInfo((N, C), "float32") - targets = relax.TensorStructInfo((N, C), "float32") + predictions = relax.TensorType((N, C), "float32") + targets = relax.TensorType((N, C), "float32") mse_loss = relax.training.loss.MSELoss() @R.function @@ -108,7 +108,7 @@ def expected( def test_mse_loss_append(): - s = Module["forward"].ret_struct_info + s = Module["forward"].ret_ty mse_loss = relax.training.loss.MSELoss(reduction="sum") After = relax.training.AppendLoss("forward", mse_loss(s, s), mse_loss.num_backbone_outputs)( Module @@ -137,9 +137,9 @@ def expected( def test_cross_entropy_loss(): N = 3 C = 5 - predictions = relax.TensorStructInfo((N, C), "float32") - targets = relax.TensorStructInfo((N,), "int64") - weights = relax.TensorStructInfo((C,), "float32") + predictions = relax.TensorType((N, C), "float32") + targets = relax.TensorType((N,), "int64") + weights = relax.TensorType((C,), "float32") cross_entropy_loss = relax.training.loss.CrossEntropyLoss(reduction="sum", ignore_index=1) @R.function @@ -163,8 +163,8 @@ def expected( def test_cross_entropy_loss_without_weights(): N = 3 C = 5 - predictions = relax.TensorStructInfo((N, C), "float32") - targets = relax.TensorStructInfo((N,), "int64") + predictions = relax.TensorType((N, C), "float32") + targets = relax.TensorType((N,), "int64") cross_entropy_loss = relax.training.loss.CrossEntropyLoss() @R.function @@ -184,11 +184,11 @@ def expected( def test_cross_entropy_loss_append(): - s = Module["forward"].ret_struct_info + s = Module["forward"].ret_ty N = s.shape[0] C = s.shape[1] - targets = relax.TensorStructInfo((N,), "int64") - weights = relax.TensorStructInfo((C,), "float32") + targets = relax.TensorType((N,), "int64") + weights = relax.TensorType((C,), "float32") cross_entropy_loss = relax.training.loss.CrossEntropyLoss(reduction="sum", ignore_index=1) After = relax.training.AppendLoss( "forward", cross_entropy_loss(s, targets, weights), cross_entropy_loss.num_backbone_outputs @@ -219,9 +219,9 @@ def expected( def test_categorical_cross_entropy_loss(): N = 3 C = 5 - predictions = relax.TensorStructInfo((N, C), "float32") - targets = relax.TensorStructInfo((N, C), "int64") - weights = relax.TensorStructInfo((C,), "float32") + predictions = relax.TensorType((N, C), "float32") + targets = relax.TensorType((N, C), "int64") + weights = relax.TensorType((C,), "float32") categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss( reduction="sum" ) @@ -246,8 +246,8 @@ def expected( def test_categorical_cross_entropy_loss_without_weights(): N = 3 C = 5 - predictions = relax.TensorStructInfo((N, C), "float32") - targets = relax.TensorStructInfo((N, C), "int64") + predictions = relax.TensorType((N, C), "float32") + targets = relax.TensorType((N, C), "int64") categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss() @R.function @@ -267,9 +267,9 @@ def expected( def test_categorical_cross_entropy_loss_with_ignore_index(): N = 3 C = 5 - predictions = relax.TensorStructInfo((N, C), "float32") - targets = relax.TensorStructInfo((N, C), "int64") - weights = relax.TensorStructInfo((C,), "float32") + predictions = relax.TensorType((N, C), "float32") + targets = relax.TensorType((N, C), "int64") + weights = relax.TensorType((C,), "float32") categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss( reduction="sum", ignore_index=1 ) @@ -284,7 +284,7 @@ def expected( with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) targets = relax.op.reshape( - relax.op.argmax(targets, axis=1), shape=(targets.struct_info.shape[0],) + relax.op.argmax(targets, axis=1), shape=(targets.ty.shape[0],) ) gv: R.Tensor((), "float32") = R.nn.nll_loss( lv, targets, weights, reduction="sum", ignore_index=1 diff --git a/tests/python/relax/test_training_setup_trainer.py b/tests/python/relax/test_training_setup_trainer.py index 441dd7011dcb..4b8fde27e539 100644 --- a/tests/python/relax/test_training_setup_trainer.py +++ b/tests/python/relax/test_training_setup_trainer.py @@ -93,8 +93,8 @@ def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.T return (params_new, optim_states_new) # fmt: on - sinfo = relax.TensorStructInfo((2, 2), "float64") - setup_trainer = SetupTrainer(MSELoss(reduction="sum"), SGD(0.1), [sinfo, sinfo], legalize=False) + ty = relax.TensorType((2, 2), "float64") + setup_trainer = SetupTrainer(MSELoss(reduction="sum"), SGD(0.1), [ty, ty], legalize=False) train_mod = setup_trainer(Backbone) assert_structural_equal(train_mod.without_attr("optim_state"), Expected) @@ -172,9 +172,9 @@ def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.T # fmt: on - sinfo = relax.TensorStructInfo((2, 2), "float64") + ty = relax.TensorType((2, 2), "float64") setup_trainer = SetupTrainer( - MSELoss(reduction="sum"), MomentumSGD(0.1, 0.1), [sinfo, sinfo], legalize=False + MSELoss(reduction="sum"), MomentumSGD(0.1, 0.1), [ty, ty], legalize=False ) train_mod = setup_trainer(Backbone) assert_structural_equal(train_mod.without_attr("optim_state"), Expected) @@ -196,18 +196,18 @@ def backbone( R.output(gv, out) return gv, out - pred_sinfo = relax.TensorStructInfo((1, 5), "float32") + pred_ty = relax.TensorType((1, 5), "float32") setup_trainer = SetupTrainer( MSELoss(reduction="sum"), SGD(0.001), - [pred_sinfo, pred_sinfo], + [pred_ty, pred_ty], ) with pytest.raises((RuntimeError, ValueError)): SetupTrainer( MSELoss(reduction="sum"), SGD(0.001), - [pred_sinfo, pred_sinfo], + [pred_ty, pred_ty], )(NoAttr) @I.ir_module diff --git a/tests/python/relax/test_training_trainer_numeric.py b/tests/python/relax/test_training_trainer_numeric.py index fb3f06e67bfa..1164e704f09c 100644 --- a/tests/python/relax/test_training_trainer_numeric.py +++ b/tests/python/relax/test_training_trainer_numeric.py @@ -58,12 +58,12 @@ def test_execute(): target = "llvm" dev = tvm.device(target) backbone = _get_backbone() - pred_sinfo = relax.TensorStructInfo((1, 5), "float32") + pred_ty = relax.TensorType((1, 5), "float32") setup_trainer = SetupTrainer( MSELoss(reduction="sum"), Adam(0.01), - [pred_sinfo, pred_sinfo], + [pred_ty, pred_ty], ) train_mod = setup_trainer(backbone) @@ -84,12 +84,12 @@ def test_execute_numeric(): target = "llvm" dev = tvm.device(target) backbone = _get_backbone() - pred_sinfo = relax.TensorStructInfo((1, 5), "float32") + pred_ty = relax.TensorType((1, 5), "float32") setup_trainer = SetupTrainer( MSELoss(reduction="sum"), SGD(0.01), - [pred_sinfo, pred_sinfo], + [pred_ty, pred_ty], ) train_mod = setup_trainer(backbone) @@ -115,12 +115,12 @@ def test_load_export_params(): target = "llvm" dev = tvm.device(target) backbone = _get_backbone() - pred_sinfo = relax.TensorStructInfo((1, 5), "float32") + pred_ty = relax.TensorType((1, 5), "float32") setup_trainer = SetupTrainer( MSELoss(reduction="sum"), SGD(0.01), - [pred_sinfo, pred_sinfo], + [pred_ty, pred_ty], ) train_mod = setup_trainer(backbone) @@ -152,12 +152,12 @@ def test_setting_error(): target = "llvm" dev = tvm.device(target) backbone = _get_backbone() - pred_sinfo = relax.TensorStructInfo((1, 5), "float32") + pred_ty = relax.TensorType((1, 5), "float32") setup_trainer = SetupTrainer( MSELoss(reduction="sum"), SGD(0.01), - [pred_sinfo, pred_sinfo], + [pred_ty, pred_ty], ) train_mod = setup_trainer(backbone) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index a3358c770eb4..ef52c895d7b8 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -124,7 +124,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert isinstance(s1, relax.Call) assert s1.op.name == "relax.builtin.alloc_tensor" assert isinstance(s1.args[0], relax.ShapeExpr) - tvm.ir.assert_structural_equal(s1.args[0], s0.sinfo_args[0].shape) + tvm.ir.assert_structural_equal(s1.args[0], s0.ty_args[0].shape) s2 = block.bindings[1].value tvm.ir.expr.GlobalVar assert s2.op.name_hint == "exp" @@ -142,13 +142,13 @@ def base(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): @R.function def use_call_pure_packed(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): y = R.add(x, x) - z = R.call_pure_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32"))) + z = R.call_pure_packed("vm.builtin.copy", y, ty_args=(R.Tensor((), dtype="int32"))) return z @R.function def use_invoke_pure_closure(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): closure = R.make_closure(Before.base, ()) - res = R.invoke_pure_closure(closure, (x,), sinfo_args=R.Tensor((), "int32")) + res = R.invoke_pure_closure(closure, (x,), ty_args=R.Tensor((), "int32")) return res @R.function(pure=False) @@ -161,9 +161,7 @@ def nested_pure_func() -> R.Tensor((), "int32"): @R.function def nested(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): y = R.add(x, x) - q = R.call_pure_packed( - "vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32")) - ) + q = R.call_pure_packed("vm.builtin.copy", y, ty_args=(R.Tensor((), dtype="int32"))) return q z = R.const(1, dtype="int32") @@ -194,14 +192,14 @@ def base(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): def use_call_pure_packed(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): R.func_attr({"relax.force_pure": True}) y = R.add(x, x) - z = R.call_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32"))) + z = R.call_packed("vm.builtin.copy", y, ty_args=(R.Tensor((), dtype="int32"))) return z @R.function def use_invoke_pure_closure(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): R.func_attr({"relax.force_pure": True}) closure = R.make_closure(Expected.base, ()) - res = R.invoke_closure(closure, (x,), sinfo_args=R.Tensor((), "int32")) + res = R.invoke_closure(closure, (x,), ty_args=R.Tensor((), "int32")) return res @R.function(pure=False) @@ -217,7 +215,7 @@ def nested_pure_func() -> R.Tensor((), "int32"): def nested(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): R.func_attr({"relax.force_pure": True}) y = R.add(x, x) - q = R.call_packed("vm.builtin.copy", y, sinfo_args=(R.Tensor((), dtype="int32"))) + q = R.call_packed("vm.builtin.copy", y, ty_args=(R.Tensor((), dtype="int32"))) return q z = R.const(1, dtype="int32") @@ -269,7 +267,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert isinstance(s1, relax.Call) assert s1.op.name == "relax.builtin.alloc_tensor" assert isinstance(s1.args[0], relax.ShapeExpr) - tvm.ir.assert_structural_equal(s1.args[0], s0.sinfo_args[0].shape) + tvm.ir.assert_structural_equal(s1.args[0], s0.ty_args[0].shape) s2 = block.bindings[1].value assert s2.op.global_symbol == "test.op.identity" @@ -444,12 +442,12 @@ def foo( R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), dtype="int32") ): R.func_attr({"relax.force_pure": True}) - gv0: R.Tensor((2, 3), dtype="int32") = R.emit_with_sinfo( + gv0: R.Tensor((2, 3), dtype="int32") = R.emit_with_ty( "relax.builtin.alloc_tensor", (R.shape([2, 3]), R.dtype("int32"), R.prim_value(0), R.str("global")), (R.Tensor((2, 3), dtype="int32"),), ) - gv1: R.Tensor((2, 3), dtype="int32") = R.emit_with_sinfo( + gv1: R.Tensor((2, 3), dtype="int32") = R.emit_with_ty( "relax.builtin.alloc_tensor", (R.shape([2, 3]), R.dtype("int32"), R.prim_value(0), R.str("global")), (R.Tensor((2, 3), dtype="int32"),), @@ -532,7 +530,7 @@ def main(A: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32" gv1 = R.call_tir_inplace( cls.multiply_by_two, [[A]], - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), inplace_indices=[0], ) return gv1 @@ -563,7 +561,7 @@ def main(A: R.Object): gv1 = R.call_tir_inplace( Module.multiply_by_two, [A], - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), inplace_indices=[0], ) return gv1 @@ -592,7 +590,7 @@ def main(A: R.Tensor([32], dtype="float32")): gv1 = R.call_tir_inplace( Module.multiply_by_two, [A], - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), inplace_indices=[0], ) return gv1 @@ -621,7 +619,7 @@ def main(A: R.Tensor([16], dtype="int32")): gv1 = R.call_tir_inplace( Module.multiply_by_two, [A], - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), inplace_indices=[0], ) return gv1 diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py b/tests/python/relax/test_transform_adjust_matmul_order.py index 9600c97bdaac..1c00ef362dc3 100644 --- a/tests/python/relax/test_transform_adjust_matmul_order.py +++ b/tests/python/relax/test_transform_adjust_matmul_order.py @@ -811,11 +811,11 @@ class TestAdjustMatmulOrderAttentionBlock: def _build_attention_module(self, batch, seq, dim): """Minimal batched attention block exercising ND permute_dims + matmul.""" bb = relax.BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo((batch, seq, dim), "float32")) - wq = relax.Var("wq", relax.TensorStructInfo((dim, dim), "float32")) - wk = relax.Var("wk", relax.TensorStructInfo((dim, dim), "float32")) - wv = relax.Var("wv", relax.TensorStructInfo((dim, dim), "float32")) - wo = relax.Var("wo", relax.TensorStructInfo((dim, dim), "float32")) + x = relax.Var("x", relax.TensorType((batch, seq, dim), "float32")) + wq = relax.Var("wq", relax.TensorType((dim, dim), "float32")) + wk = relax.Var("wk", relax.TensorType((dim, dim), "float32")) + wv = relax.Var("wv", relax.TensorType((dim, dim), "float32")) + wo = relax.Var("wo", relax.TensorType((dim, dim), "float32")) with bb.function("main", [x, wq, wk, wv, wo]): with bb.dataflow(): q = bb.emit(relax.op.matmul(x, wq)) diff --git a/tests/python/relax/test_transform_alter_op_impl.py b/tests/python/relax/test_transform_alter_op_impl.py index 3e5d4889d3a1..cd433d0a30c3 100644 --- a/tests/python/relax/test_transform_alter_op_impl.py +++ b/tests/python/relax/test_transform_alter_op_impl.py @@ -63,7 +63,7 @@ def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), outp @R.function def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): with R.dataflow(): - lv = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((16,), dtype="float32")) + lv = R.call_tir(Before.add, (x, y), out_ty=R.Tensor((16,), dtype="float32")) gv: R.Tensor((16,), dtype="float32") = lv R.output(gv) return gv @@ -84,7 +84,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" with R.dataflow(): lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None) lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None) - lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32")) + lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), out_ty=R.Tensor((4, 4), dtype="float32")) lv_1: R.Tensor((16,), dtype="float32") = R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) gv: R.Tensor((16,), dtype="float32") = lv_1 R.output(gv) @@ -126,7 +126,7 @@ def mul_by_2(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32" @R.function def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): with R.dataflow(): - lv = R.call_tir(Before.mul_by_2, (x,), out_sinfo=R.Tensor((16,), dtype="float32")) + lv = R.call_tir(Before.mul_by_2, (x,), out_ty=R.Tensor((16,), dtype="float32")) gv: R.Tensor((16,), dtype="float32") = lv R.output(gv) return gv @@ -145,7 +145,7 @@ def relax_mul_by_2_replacement(arg0: T.Buffer((16,), "float32"), output: T.Buffe @R.function def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): with R.dataflow(): - lv = R.call_tir(Expected.relax_mul_by_2_replacement, (x,), out_sinfo=R.Tensor((16,), dtype="float32")) + lv = R.call_tir(Expected.relax_mul_by_2_replacement, (x,), out_ty=R.Tensor((16,), dtype="float32")) gv: R.Tensor((16,), dtype="float32") = lv R.output(gv) return gv @@ -187,7 +187,7 @@ def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), @R.function def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")): with R.dataflow(): - gv = R.call_tir(Before.some_op, (x, y), out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")]) + gv = R.call_tir(Before.some_op, (x, y), out_ty=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")]) R.output(gv) return gv @@ -209,7 +209,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" with R.dataflow(): lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None) lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None) - lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")]) + lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, lv1), out_ty=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")]) lv3: R.Tensor((4, 4), dtype="float32") = lv2[0] lv4: R.Tensor((16,), dtype="float32") = R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) lv5: R.Tensor((4, 4), dtype="float32") = lv2[1] @@ -257,7 +257,7 @@ def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), @R.function def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")): with R.dataflow(): - gv = R.call_tir(Before.some_op, (x, y), out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")]) + gv = R.call_tir(Before.some_op, (x, y), out_ty=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")]) R.output(gv) return gv @@ -279,7 +279,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" with R.dataflow(): lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1]) lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1]) - lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")]) + lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, lv1), out_ty=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")]) lv3: R.Tensor((4, 4), dtype="float32") = lv2[0] lv4: R.Tensor((16,), dtype="float32") = R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None, axis_separators=[1]) lv5: R.Tensor((4, 4), dtype="float32") = lv2[1] @@ -318,7 +318,7 @@ class Before: @R.function def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"): with R.dataflow(): - lv = R.call_tir(Before.relu, (x,), out_sinfo=R.Tensor((14,), dtype="float32")) + lv = R.call_tir(Before.relu, (x,), out_ty=R.Tensor((14,), dtype="float32")) gv: R.Tensor((14,), dtype="float32") = lv R.output(gv) return gv @@ -347,7 +347,7 @@ def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32") lv1 = R.call_tir( Expected.relax_relu_replacement, (lv,), - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), ) lv2: R.Tensor((16,), dtype="float32") = R.layout_transform( lv1, @@ -356,7 +356,7 @@ def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32") axis_separators=[], ) lv_1 = R.call_tir( - Expected.remove_pad, (lv2,), out_sinfo=R.Tensor((14,), dtype="float32") + Expected.remove_pad, (lv2,), out_ty=R.Tensor((14,), dtype="float32") ) gv: R.Tensor((14,), dtype="float32") = lv_1 R.output(gv) @@ -428,9 +428,9 @@ def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), outp @R.function def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"): with R.dataflow(): - lv0 = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((16,), dtype="float32")) + lv0 = R.call_tir(Before.add, (x, y), out_ty=R.Tensor((16,), dtype="float32")) lv1 = R.nn.relu(lv0) - lv2 = R.call_tir(Before.add, (lv0, lv1), out_sinfo=R.Tensor((16,), dtype="float32")) + lv2 = R.call_tir(Before.add, (lv0, lv1), out_ty=R.Tensor((16,), dtype="float32")) gv: R.Tensor((16,), dtype="float32") = lv2 R.output(gv) return gv @@ -452,12 +452,12 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" with R.dataflow(): lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None) lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None) - lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32")) + lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), out_ty=R.Tensor((4, 4), dtype="float32")) lv0: R.Tensor((16,), dtype="float32") = R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) lv1_1: R.Tensor((16,), dtype="float32") = R.nn.relu(lv0) lv3: R.Tensor((4, 4), dtype="float32") = R.layout_transform(lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None) lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(lv1_1, index_map=lambda i: (i // 4, i % 4), pad_value=None) - lv5 = R.call_tir(Expected.relax_add_replacement, (lv3, lv4), out_sinfo=R.Tensor((4, 4), dtype="float32")) + lv5 = R.call_tir(Expected.relax_add_replacement, (lv3, lv4), out_ty=R.Tensor((4, 4), dtype="float32")) lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None) gv: R.Tensor((16,), dtype="float32") = lv2_1 R.output(gv) @@ -511,9 +511,7 @@ def main(x: R.Tensor((850, 2048), dtype="float16")) -> R.Tensor( ): cls = Before with R.dataflow(): - lv = R.call_tir( - cls.reshape, (x,), out_sinfo=R.Tensor((850, 1, 2048), dtype="float16") - ) + lv = R.call_tir(cls.reshape, (x,), out_ty=R.Tensor((850, 1, 2048), dtype="float16")) gv: R.Tensor((850, 1, 2048), dtype="float16") = lv R.output(gv) return gv @@ -550,7 +548,7 @@ def main(x: R.Tensor((850, 2048), dtype="float16")) -> R.Tensor( lv_1 = R.call_tir( cls.relax_reshape_replacement, (lv,), - out_sinfo=R.Tensor((850, 1, 2048), dtype="float16"), + out_ty=R.Tensor((850, 1, 2048), dtype="float16"), ) gv: R.Tensor((850, 1, 2048), dtype="float16") = lv_1 R.output(gv) @@ -599,7 +597,7 @@ def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), @R.function def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")): with R.dataflow(): - gv = R.call_tir(Before.some_op, (x, y), out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")]) + gv = R.call_tir(Before.some_op, (x, y), out_ty=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")]) R.output(gv) return gv @@ -619,7 +617,7 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32" with R.dataflow(): lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1]) lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1]) - lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")]) + lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, lv1), out_ty=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")]) lv3: R.Tensor((4, 4), dtype="float32") = lv2[0] lv4: R.Tensor((16,), dtype="float32") = R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None, axis_separators=[], input_axis_separators=[1]) lv5: R.Tensor((4, 4), dtype="float32") = lv2[1] diff --git a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py index 3690af03d6c0..c21b9a647534 100644 --- a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py +++ b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py @@ -47,7 +47,7 @@ def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), "float32")): R.func_attr({"num_input": 1}) cls = Before with R.dataflow(): - gv = R.call_tir(cls.matmul, (x, y), out_sinfo=R.Tensor((32, 32), "float32")) + gv = R.call_tir(cls.matmul, (x, y), out_ty=R.Tensor((32, 32), "float32")) R.output(gv) return gv @@ -71,7 +71,7 @@ def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), "float32")): R.func_attr({"num_input": 1}) cls = Expected with R.dataflow(): - gv = R.call_tir(cls.matmul1, (x, y), out_sinfo=R.Tensor((32, 32), "float32")) + gv = R.call_tir(cls.matmul1, (x, y), out_ty=R.Tensor((32, 32), "float32")) R.output(gv) return gv @@ -104,7 +104,7 @@ def main(x: R.Tensor((32, 32), "float32")): gv = R.call_tir( cls.matmul, (x, relax.const(const_value)), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) R.output(gv) return gv @@ -132,7 +132,7 @@ def main(x: R.Tensor((32, 32), "float32")): gv = R.call_tir( cls.matmul1, (x, relax.const(const_value)), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) R.output(gv) return gv @@ -168,12 +168,12 @@ def main( lv1 = R.call_tir( cls.matmul, (x, w1), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) gv = R.call_tir( cls.matmul, (lv1, w2), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) R.output(gv) return gv @@ -205,12 +205,12 @@ def main( lv1 = R.call_tir( cls.matmul1, (x, w1), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) gv = R.call_tir( cls.matmul1, (lv1, w2), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) R.output(gv) return gv @@ -246,12 +246,12 @@ def main( lv1 = R.call_tir( cls.matmul, (x, w1), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) gv = R.call_tir( cls.matmul, (w2, lv1), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) R.output(gv) return gv @@ -296,12 +296,12 @@ def main( lv1 = R.call_tir( cls.matmul1, (x, w1), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) gv = R.call_tir( cls.matmul2, (w2, lv1), - out_sinfo=R.Tensor((32, 32), "float32"), + out_ty=R.Tensor((32, 32), "float32"), ) R.output(gv) return gv diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index 59c4a60087e0..c7711d68d362 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -90,10 +90,10 @@ def main( n = T.Var("n", "int64") with R.dataflow(): lv0 = R.call_dps_packed( - "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), dtype="float32") + "linear0", (x, w0, b0), out_ty=R.Tensor((batch, n), dtype="float32") ) out = R.call_dps_packed( - "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), dtype="float32") + "linear1", (lv0, w1, b1), out_ty=R.Tensor((batch, k), dtype="float32") ) R.output(out) return out @@ -109,20 +109,12 @@ def main( # Since it contains ConstantNode, it's hard to check with structural equality. func = mod["main"] assert len(func.params) == 1 - batch = func.params[0].struct_info.shape[0] - tvm.ir.assert_structural_equal( - func.params[0].struct_info, relax.TensorStructInfo((batch, 4), "float32") - ) - tvm.ir.assert_structural_equal( - func.ret_struct_info, relax.TensorStructInfo((batch, 8), "float32") - ) + batch = func.params[0].ty.shape[0] + tvm.ir.assert_structural_equal(func.params[0].ty, relax.TensorType((batch, 4), "float32")) + tvm.ir.assert_structural_equal(func.ret_ty, relax.TensorType((batch, 8), "float32")) bindings = func.body.blocks[0].bindings - tvm.ir.assert_structural_equal( - bindings[0].var.struct_info, relax.TensorStructInfo((batch, 6), "float32") - ) - tvm.ir.assert_structural_equal( - bindings[1].var.struct_info, relax.TensorStructInfo((batch, 8), "float32") - ) + tvm.ir.assert_structural_equal(bindings[0].var.ty, relax.TensorType((batch, 6), "float32")) + tvm.ir.assert_structural_equal(bindings[1].var.ty, relax.TensorType((batch, 8), "float32")) param_specification = tvm.testing.parameter("by_string", "by_var") diff --git a/tests/python/relax/test_transform_bind_symbolic_vars.py b/tests/python/relax/test_transform_bind_symbolic_vars.py index 6b0f3a075a79..643b72b42408 100644 --- a/tests/python/relax/test_transform_bind_symbolic_vars.py +++ b/tests/python/relax/test_transform_bind_symbolic_vars.py @@ -42,10 +42,10 @@ def main( k = T.Var("k", "int64") with R.dataflow(): lv0 = R.call_dps_packed( - "test0", (x, w0), out_sinfo=R.Tensor((batch, n), dtype="float32") + "test0", (x, w0), out_ty=R.Tensor((batch, n), dtype="float32") ) out = R.call_dps_packed( - "test1", (lv0, w1), out_sinfo=R.Tensor((batch, k), dtype="float32") + "test1", (lv0, w1), out_ty=R.Tensor((batch, k), dtype="float32") ) R.output(out) return out @@ -64,11 +64,9 @@ def main( ) -> R.Tensor((1, 3), dtype="float32"): n = T.int64() with R.dataflow(): - lv0 = R.call_dps_packed( - "test0", (x, w0), out_sinfo=R.Tensor((1, n), dtype="float32") - ) + lv0 = R.call_dps_packed("test0", (x, w0), out_ty=R.Tensor((1, n), dtype="float32")) out = R.call_dps_packed( - "test1", (lv0, w1), out_sinfo=R.Tensor((1, 3), dtype="float32") + "test1", (lv0, w1), out_ty=R.Tensor((1, 3), dtype="float32") ) R.output(out) return out @@ -91,8 +89,8 @@ def main( n = T.Var("n", "int64") k = T.Var("k", "int64") with R.dataflow(): - lv0 = R.call_dps_packed("test0", (x, w0), out_sinfo=R.Tensor((batch, n))) - out = R.call_dps_packed("test1", (lv0, w1), out_sinfo=R.Tensor((batch, k))) + lv0 = R.call_dps_packed("test0", (x, w0), out_ty=R.Tensor((batch, n))) + out = R.call_dps_packed("test1", (lv0, w1), out_ty=R.Tensor((batch, k))) R.output(out) return out @@ -108,8 +106,8 @@ def main(x: R.Shape([1, "m"]), w0: R.Shape(["m", "n"]), w1: R.Shape([3, 10])) -> ): n = T.int64() with R.dataflow(): - lv0 = R.call_dps_packed("test0", (x, w0), out_sinfo=R.Tensor((1, n))) - out = R.call_dps_packed("test1", (lv0, w1), out_sinfo=R.Tensor((1, 3))) + lv0 = R.call_dps_packed("test0", (x, w0), out_ty=R.Tensor((1, n))) + out = R.call_dps_packed("test1", (lv0, w1), out_ty=R.Tensor((1, 3))) R.output(out) return out @@ -135,12 +133,12 @@ def main( lv0 = R.call_dps_packed( "test0", (x, w0), - out_sinfo=R.Tensor((batch, m + n), dtype="float32"), + out_ty=R.Tensor((batch, m + n), dtype="float32"), ) out = R.call_dps_packed( "test1", (lv0, w1), - out_sinfo=R.Tensor((batch, k + n), dtype="float32"), + out_ty=R.Tensor((batch, k + n), dtype="float32"), ) R.output(out) return out @@ -160,10 +158,10 @@ def main( n = T.int64() with R.dataflow(): lv0 = R.call_dps_packed( - "test0", (x, w0), out_sinfo=R.Tensor((1, n + 3), dtype="float32") + "test0", (x, w0), out_ty=R.Tensor((1, n + 3), dtype="float32") ) out = R.call_dps_packed( - "test1", (lv0, w1), out_sinfo=R.Tensor((1, n + 2), dtype="float32") + "test1", (lv0, w1), out_ty=R.Tensor((1, n + 2), dtype="float32") ) R.output(out) return out @@ -221,7 +219,7 @@ def main_1(x: R.Tensor(("m", 16), dtype="float32")): def main_2(x: R.Tensor(("m", "n"), dtype="float32")): return x - main_1_n = Before["main_1"].params[0].struct_info.shape[1] + main_1_n = Before["main_1"].params[0].ty.shape[1] After = relax.transform.BindSymbolicVars({main_1_n: 16})(Before) tvm.ir.assert_structural_equal(Expected, After) diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index ccbb011bb61b..f70e41d9805e 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -128,7 +128,7 @@ def main(x: R.Tensor, y: R.Tensor): verify(TestOps, Expected) -@pytest.mark.xfail(reason="The lhs and rhs of an assignment should have the same struct info.") +@pytest.mark.xfail(reason="The lhs and rhs of an assignment should have the same type.") def test_casting(): @I.ir_module class TestCasting: @@ -165,7 +165,7 @@ def main(x: R.Tensor): class Expected: @R.function def main(x: R.Tensor): - # can't get rid of z because its struct_info is different from x's + # can't get rid of z because its ty is different from x's m, n = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((m, n))) return z @@ -221,7 +221,7 @@ class Expected: def main(x: R.Tensor(ndim=2)): o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) - # the struct_info field on q will need to be updated + # the ty field on q will need to be updated q = R.add(z, x) return R.add(q, z) @@ -850,13 +850,13 @@ def main(x: R.Tensor): assert_structural_equal(Expected, after) -def test_canonicalize_with_updated_struct_info(): +def test_canonicalize_with_updated_ty(): """CanonicalizeBindings and Normalizer may both replace a Var If the CanonicalizeBindings pass has no replacements to make for a variable, it must still delegate to the ExprMutator. This is because a variable replacement may have occurred as part of the IRNormalizer, - in order to provide better struct info. + in order to provide better type. """ @I.ir_module @@ -1164,20 +1164,20 @@ def main(x: R.Tensor, y: R.Tensor) -> R.Tensor: assert_structural_equal(Expected, after) -def test_canonicalization_causes_struct_info_update(): +def test_canonicalization_causes_ty_update(): """Regression test for failure mode causing undefined variable - The ExprMutator is only allowed to update a variable's struct info - if the value bound to it has new struct info. When + The ExprMutator is only allowed to update a variable's type + if the value bound to it has new type. When CanonicalizeBindings replaces a trivial binding, this may provide - better struct info as a result. If this happens, the + better type as a result. If this happens, the In previous implementations, ExprMutator::ReEmitBinding defined a remap for `binding->var->vid`, even if the derived class defined a replacement by overriding `VisitVarDef`. If the derived class defines a new variable binding by overriding `VisitVarDef`, and also causes a variable replacement by overriding `VisitExpr` and - returning a type with different struct info, then `ExprMutator` + returning a type with different type, then `ExprMutator` must check for both `binding->var->vid` *AND* `new_var->vid`. The former may be present in the unmodified graph, and the latter may be produced by the derived class before delegating to the base @@ -1199,7 +1199,7 @@ def transform_params( # RHS contains `(A,C)`, which CanonicalizeBindings # replaces with `(A,B)`. Because this changes the - # RHS, a new LHS (and new struct info!) will be + # RHS, a new LHS (and new type!) will be # generated. D: R.Tuple( R.Tensor(dtype="float16", ndim=2), diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index c690eac5603b..306185937c34 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -253,12 +253,12 @@ def main( lv = R.call_dps_packed( "fused_relax_nn_conv2d_tensorrt", (data, weight1), - out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), + out_ty=R.Tensor((16, 32, 32, 16), dtype="float16"), ) gv = R.call_dps_packed( "fused_relax_nn_conv2d_tensorrt", (lv, weight2), - out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), + out_ty=R.Tensor((16, 32, 32, 16), dtype="float16"), ) R.output(gv) return gv @@ -349,12 +349,12 @@ def main( lv = R.call_dps_packed( "fused_relax_matmul_cublas", (x, w1), - out_sinfo=R.Tensor((1, r1), dtype="float16"), + out_ty=R.Tensor((1, r1), dtype="float16"), ) lv1 = R.call_dps_packed( "fused_relax_matmul_cublas", (x, w2), - out_sinfo=R.Tensor((1, r2), dtype="float16"), + out_ty=R.Tensor((1, r2), dtype="float16"), ) gv: R.Tuple( R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16") diff --git a/tests/python/relax/test_transform_compute_prim_value.py b/tests/python/relax/test_transform_compute_prim_value.py index 6be87a357c98..733dbc295a9a 100644 --- a/tests/python/relax/test_transform_compute_prim_value.py +++ b/tests/python/relax/test_transform_compute_prim_value.py @@ -56,9 +56,9 @@ class Before: def main(A: R.Tensor(["N"])): N = T.int64() if R.prim_value(N % 16 == 0): - out = R.call_packed("fast_vectorized_impl", A, sinfo_args=[A.struct_info]) + out = R.call_packed("fast_vectorized_impl", A, ty_args=[A.ty]) else: - out = R.call_packed("slow_non_vectorized_impl", A, sinfo_args=[A.struct_info]) + out = R.call_packed("slow_non_vectorized_impl", A, ty_args=[A.ty]) return out @I.ir_module @@ -68,9 +68,9 @@ def main(A: R.Tensor(["N"])): N = T.int64() condition: R.Prim("bool") = Expected.compute_symbolic_expr(R.prim_value(N)) if condition: - out = R.call_packed("fast_vectorized_impl", A, sinfo_args=[A.struct_info]) + out = R.call_packed("fast_vectorized_impl", A, ty_args=[A.ty]) else: - out = R.call_packed("slow_non_vectorized_impl", A, sinfo_args=[A.struct_info]) + out = R.call_packed("slow_non_vectorized_impl", A, ty_args=[A.ty]) return out @T.prim_func(private=True, s_tir=True) diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index e9a2cb767f9c..4526146fe7c7 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -380,8 +380,8 @@ def test_do_not_eliminate_extern_func(): class Before: @R.function(pure=False) def foo(x: R.Tensor((2, 3), dtype="float32")): - y = R.call_packed("extern_func_name", x, sinfo_args=R.Tensor([2, 3])) - z = R.call_packed("extern_func_name", y, sinfo_args=R.Tensor([2, 3])) + y = R.call_packed("extern_func_name", x, ty_args=R.Tensor([2, 3])) + z = R.call_packed("extern_func_name", y, ty_args=R.Tensor([2, 3])) return z Expected = Before @@ -395,8 +395,8 @@ class Before: @R.function def main(A: R.Tensor([16, 16], "int32"), B: R.Tensor([16, 16], "int32")): cls = Before - Prod = R.call_tir(cls.product, [A, B], out_sinfo=R.Tensor([16, 16], "int32")) - Sum = R.call_tir(cls.sum, [A, B], out_sinfo=R.Tensor([16, 16], "int32")) + Prod = R.call_tir(cls.product, [A, B], out_ty=R.Tensor([16, 16], "int32")) + Sum = R.call_tir(cls.sum, [A, B], out_ty=R.Tensor([16, 16], "int32")) return (Prod, Sum) @T.prim_func(private=True, s_tir=True) diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 87366137b1e7..4963a7f1dc05 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -367,7 +367,7 @@ def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> gv0 = R.call_tir( InputModule.tir_add_tensors, [x, w], - out_sinfo=R.Tensor((16, 16), "float32"), + out_ty=R.Tensor((16, 16), "float32"), ) return gv0 @@ -595,7 +595,7 @@ def while_loop(i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")) -> R.Te (2, 3), "float32" ): cond = R.call_pure_packed( - "test.vm.less", i, R.const(10), sinfo_args=R.Tensor((), dtype="bool") + "test.vm.less", i, R.const(10), ty_args=R.Tensor((), dtype="bool") ) c = R.const(1, dtype="int32") if cond: @@ -634,7 +634,7 @@ def while_loop(i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")) -> R.Te (2, 3), "float32" ): cond = R.call_pure_packed( - "test.vm.less", i, threshold, sinfo_args=R.Tensor((), dtype="bool") + "test.vm.less", i, threshold, ty_args=R.Tensor((), dtype="bool") ) c = R.const(1, dtype="int32") if cond: diff --git a/tests/python/relax/test_transform_decompose_ops.py b/tests/python/relax/test_transform_decompose_ops.py index 137ea0750205..bf5f84fc71f8 100644 --- a/tests/python/relax/test_transform_decompose_ops.py +++ b/tests/python/relax/test_transform_decompose_ops.py @@ -375,7 +375,7 @@ def main(t: R.Tensor([3], dtype="int64")) -> R.Shape(ndim=3): x_1 = T.int64() x_2 = T.int64() gv: R.Shape(ndim=3) = R.call_pure_packed( - "vm.builtin.tensor_to_shape", t, sinfo_args=(R.Shape(ndim=3),) + "vm.builtin.tensor_to_shape", t, ty_args=(R.Shape(ndim=3),) ) y: R.Shape([x, x_1, x_2]) = R.match_cast(gv, R.Shape([x, x_1, x_2])) gv_1: R.Shape([x, x_1, x_2]) = R.shape([x, x_1, x_2]) diff --git a/tests/python/relax/test_transform_error_enrichment.py b/tests/python/relax/test_transform_error_enrichment.py index 621d38059a63..4ec19368236c 100644 --- a/tests/python/relax/test_transform_error_enrichment.py +++ b/tests/python/relax/test_transform_error_enrichment.py @@ -32,15 +32,13 @@ def _bad_matmul_module(): """Build (programmatically, no TVMScript parse) a module whose `main` binds a matmul of incompatible shapes [3, 4] x [5, 6]. The function carries a - placeholder return struct info so it constructs; Normalize re-infers and the + placeholder return type so it constructs; Normalize re-infers and the matmul validator fires during the pass.""" - x = relax.Var("x", relax.TensorStructInfo([3, 4], "float32")) - y = relax.Var("y", relax.TensorStructInfo([5, 6], "float32")) + x = relax.Var("x", relax.TensorType([3, 4], "float32")) + y = relax.Var("y", relax.TensorType([5, 6], "float32")) lv = relax.Var("lv") body = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(lv, relax.op.matmul(x, y))])], lv) - func = relax.Function( - [x, y], body, ret_struct_info=relax.TensorStructInfo([3, 6], "float32"), is_pure=True - ) + func = relax.Function([x, y], body, ret_ty=relax.TensorType([3, 6], "float32"), is_pure=True) func = func.with_attr("global_symbol", "main") return IRModule({relax.GlobalVar("main"): func}) diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index 175d54f57213..0867a9d63058 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -48,7 +48,7 @@ def gen_mod(mod, name, binding): if k.name_hint == name: # rename to main gv = tvm.ir.GlobalVar("main") - funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr( + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_ty).with_attr( "global_symbol", "main" ) else: @@ -469,7 +469,7 @@ def before(c0: R.Tensor((4, 4), "float32")): lv0 = relax.call_tir( cls.split, (c0,), - out_sinfo=[ + out_ty=[ R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32"), ], diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index d8173c9ed24e..8c0785334fec 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -861,9 +861,9 @@ class Module: def main(x: R.Tensor((2, 3), "float32")): cls = Module with R.dataflow(): - a = R.call_tir(cls.exp, (x,), out_sinfo=R.Tensor((2, 3), "float32")) - b = R.call_tir(cls.exp, (a,), out_sinfo=R.Tensor((2, 3), "float32")) - c = R.call_dps_packed("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), "float32")) + a = R.call_tir(cls.exp, (x,), out_ty=R.Tensor((2, 3), "float32")) + b = R.call_tir(cls.exp, (a,), out_ty=R.Tensor((2, 3), "float32")) + c = R.call_dps_packed("packed_dps", (a,), out_ty=R.Tensor((2, 3), "float32")) R.output(b, c) return R.tuple(b, c) @@ -883,8 +883,8 @@ class Module: def main(x: R.Tensor((1, 512, 64, 64), "float32"), mean: R.Tensor((64, 64), "float32"), var: R.Tensor((64, 64), "float32")): cls = Module with R.dataflow(): - gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64), 'float32')) - gv1 = R.call_tir(cls.relu, gv0, out_sinfo=R.Tensor((1, 512, 64, 64), "float32")) + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_ty=R.Tensor((1, 512, 64, 64), 'float32')) + gv1 = R.call_tir(cls.relu, gv0, out_ty=R.Tensor((1, 512, 64, 64), "float32")) R.output(gv1) return gv1 @@ -963,8 +963,8 @@ def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.func_attr({"Primitive": True}) cls = Expected with R.dataflow(): - gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64), 'float32')) - gv = R.call_tir(cls.relu, (gv0,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_ty=R.Tensor((1, 512, 64, 64), 'float32')) + gv = R.call_tir(cls.relu, (gv0,), out_ty=R.Tensor((1, 512, 64, 64), dtype="float32")) R.output(gv) return gv @@ -1105,9 +1105,9 @@ def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), w1 R.func_attr({"Primitive": True}) cls = Expected with R.dataflow(): - lv27 = R.call_tir(cls.conv2d, (inp_0, w1), out_sinfo=R.Tensor((2, 320, 64, 64), dtype="float32")) - lv29 = R.call_tir(cls.add, (lv27, lv28), out_sinfo=R.Tensor((2, 320, 64, 64), dtype="float32")) - gv = R.call_tir(cls.add2, (lv29, lv35), out_sinfo=R.Tensor((2, 320, 64, 64), dtype="float32")) + lv27 = R.call_tir(cls.conv2d, (inp_0, w1), out_ty=R.Tensor((2, 320, 64, 64), dtype="float32")) + lv29 = R.call_tir(cls.add, (lv27, lv28), out_ty=R.Tensor((2, 320, 64, 64), dtype="float32")) + gv = R.call_tir(cls.add2, (lv29, lv35), out_ty=R.Tensor((2, 320, 64, 64), dtype="float32")) R.output(gv) return gv @@ -1116,8 +1116,8 @@ def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"), lv31: R.Tenso cls = Expected R.func_attr({"Primitive": True}) with R.dataflow(): - lv32 = R.call_tir(cls.matmul, (inp_1, lv31), out_sinfo=R.Tensor((2, 320), dtype="float32")) - gv = R.call_tir(cls.add1, (lv32, b2), out_sinfo=R.Tensor((2, 320), dtype="float32")) + lv32 = R.call_tir(cls.matmul, (inp_1, lv31), out_ty=R.Tensor((2, 320), dtype="float32")) + gv = R.call_tir(cls.add1, (lv32, b2), out_ty=R.Tensor((2, 320), dtype="float32")) R.output(gv) return gv @@ -1126,10 +1126,10 @@ def main(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), inp_1: R.Tensor((2, R.func_attr({"num_input": 2}) cls = Expected with R.dataflow(): - lv28 = R.call_tir(cls.reshape, (b1,), out_sinfo=R.Tensor((1, 320, 1, 1), dtype="float32")) - lv31 = R.call_tir(cls.transpose, (w2,), out_sinfo=R.Tensor((1280, 320), dtype="float32")) + lv28 = R.call_tir(cls.reshape, (b1,), out_ty=R.Tensor((1, 320, 1, 1), dtype="float32")) + lv31 = R.call_tir(cls.transpose, (w2,), out_ty=R.Tensor((1280, 320), dtype="float32")) lv: R.Tensor((2, 320), dtype="float32") = cls.fused_matmul_add1(inp_1, lv31, b2) - lv35 = R.call_tir(cls.reshape1, (lv,), out_sinfo=R.Tensor((2, 320, 1, 1), dtype="float32")) + lv35 = R.call_tir(cls.reshape1, (lv,), out_ty=R.Tensor((2, 320, 1, 1), dtype="float32")) lv1: R.Tensor((2, 320, 64, 64), dtype="float32") = cls.fused_conv2d_add_add2(inp_0, w1, lv28, lv35) gv: R.Tensor((2, 320, 64, 64), dtype="float32") = lv1 R.output(gv) @@ -1250,8 +1250,8 @@ def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"), lv4: R.Tensor R.func_attr({"Primitive": True}) cls = Expected with R.dataflow(): - lv5 = R.call_tir(cls.matmul1, (inp_1, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32")) - gv = R.call_tir(cls.add1, (lv5, linear2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32")) + lv5 = R.call_tir(cls.matmul1, (inp_1, lv4), out_ty=R.Tensor((1, 10), dtype="float32")) + gv = R.call_tir(cls.add1, (lv5, linear2_bias), out_ty=R.Tensor((1, 10), dtype="float32")) R.output(gv) return gv @@ -1260,8 +1260,8 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d R.func_attr({"num_input": 1}) cls = Expected with R.dataflow(): - lv = R.call_tir(cls.transpose, (linear1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32")) - lv4 = R.call_tir(cls.transpose1, (linear2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32")) + lv = R.call_tir(cls.transpose, (linear1_weight,), out_ty=R.Tensor((784, 128), dtype="float32")) + lv4 = R.call_tir(cls.transpose1, (linear2_weight,), out_ty=R.Tensor((128, 10), dtype="float32")) lv_1: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul1_add1(inp_1, lv4, linear2_bias) gv: R.Tensor((1, 10), dtype="float32") = lv_1 R.output(gv) @@ -1366,7 +1366,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): "vm.builtin.attention_kv_cache_view", kv_cache, R.shape([1 + n, 32, 128]), - sinfo_args=(R.Tensor((1 + n, 32, 128), dtype="float32"),), + ty_args=(R.Tensor((1 + n, 32, 128), dtype="float32"),), ) R.output(gv, lv2) return gv, lv2 @@ -1398,7 +1398,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): "vm.builtin.attention_kv_cache_view", kv_cache, R.shape([1 + n, 32, 128]), - sinfo_args=(R.Tensor((1 + n, 32, 128), dtype="float32"),), + ty_args=(R.Tensor((1 + n, 32, 128), dtype="float32"),), ) R.output(gv, lv) return gv, lv @@ -1431,16 +1431,16 @@ class Module: def main(inp: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): with R.dataflow(): lv = R.call_pure_packed( - "my_func1", inp, R.prim_value(0), sinfo_args=[R.Tensor((2, 2), dtype="float32")] + "my_func1", inp, R.prim_value(0), ty_args=[R.Tensor((2, 2), dtype="float32")] ) lv1 = R.call_pure_packed( - "my_func2", lv, R.str("str"), sinfo_args=[R.Tensor((2, 2), dtype="float32")] + "my_func2", lv, R.str("str"), ty_args=[R.Tensor((2, 2), dtype="float32")] ) gv = R.call_pure_packed( "my_func3", lv1, R.dtype("float32"), - sinfo_args=[R.Tensor((2, 2), dtype="float32")], + ty_args=[R.Tensor((2, 2), dtype="float32")], ) R.output(gv) return gv @@ -1555,19 +1555,19 @@ def main( lv = R.call_tir( cls.add, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) lv1 = R.call_tir_inplace( cls.exp_inplace, (lv,), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) gv = R.call_tir_inplace( cls.squeeze_inplace, (lv1,), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) R.output(gv) return gv @@ -1618,19 +1618,19 @@ def fused_add_exp_inplace_squeeze_inplace( lv = R.call_tir( cls.add, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) lv1 = R.call_tir_inplace( cls.exp_inplace, (lv,), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) gv = R.call_tir_inplace( cls.squeeze_inplace, (lv1,), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) R.output(gv) return gv @@ -1685,10 +1685,10 @@ def main(x: R.Tensor((16, 16), dtype="float32"), packed_params: R.Tuple(R.Tensor with R.dataflow(): lv: R.Tensor((16, 16), dtype="float16") = packed_params[0] lv1: R.Tensor((16, 16), dtype="float16") = packed_params[1] - lv2 = R.call_tir(cls.cast, (lv,), out_sinfo=R.Tensor((16, 16), dtype="float32")) - lv3 = R.call_tir(cls.matmul, (x, lv2), out_sinfo=R.Tensor((16, 16), dtype="float32")) - lv4 = R.call_tir(cls.cast, (lv1,), out_sinfo=R.Tensor((16, 16), dtype="float32")) - lv5 = R.call_tir(cls.matmul, (lv3, lv4), out_sinfo=R.Tensor((16, 16), dtype="float32")) + lv2 = R.call_tir(cls.cast, (lv,), out_ty=R.Tensor((16, 16), dtype="float32")) + lv3 = R.call_tir(cls.matmul, (x, lv2), out_ty=R.Tensor((16, 16), dtype="float32")) + lv4 = R.call_tir(cls.cast, (lv1,), out_ty=R.Tensor((16, 16), dtype="float32")) + lv5 = R.call_tir(cls.matmul, (lv3, lv4), out_ty=R.Tensor((16, 16), dtype="float32")) gv: R.Tensor((16, 16), dtype="float32") = lv5 R.output(gv) return gv diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 1362f9375847..16c102321ff3 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -591,8 +591,8 @@ def inner_func(B: R.Tensor([16, 16], dtype="float16")): def test_compare_with_merge_composite_path(): - x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) - y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32")) + x = relax.Var("x", relax.TensorType([10, 10], "float32")) + y = relax.Var("y", relax.TensorType([10, 10], "float32")) bb = relax.BlockBuilder() with bb.function("main", [x, y]): with bb.dataflow(): @@ -769,7 +769,7 @@ def main( relu1 = R.call_tir( cls.relu, (lv,), - out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"), + out_ty=R.Tensor((1, 64, 56, 56), dtype="float32"), ) R.output(relu1) return relu1 diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 57b91b9448fd..8fc34070db89 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -297,13 +297,13 @@ def fused_exp_add(x1, x2): def test_fuse_with_nested_tuple_as_param(): - tuple_struct_info = R.Tuple( + tuple_ty = R.Tuple( [R.Tensor([10], "float32"), R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])] ) def before(): bb = relax.BlockBuilder() - x = relax.Var("x", tuple_struct_info) + x = relax.Var("x", tuple_ty) with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}, private=True): with bb.dataflow(): lv0 = bb.emit(relax.TupleGetItem(x, 0)) @@ -317,7 +317,7 @@ def before(): mod = bb.get() func_gv = mod.get_global_var("fused_exp_add_add") - x = relax.Var("x", tuple_struct_info) + x = relax.Var("x", tuple_ty) with bb.function("main", [x]): with bb.dataflow(): gv = bb.emit_output(relax.Call(func_gv, [x])) @@ -331,7 +331,7 @@ def fused_exp_add_add(x1, x2, x3): return topi.add(exp, add) bb = relax.BlockBuilder() - x = relax.Var("x", tuple_struct_info) + x = relax.Var("x", tuple_ty) with bb.function("main", [x]): with bb.dataflow(): lv0 = bb.emit(relax.TupleGetItem(x, 0)) @@ -615,7 +615,7 @@ def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 20), dtype="f gv2 = R.call_tir( Expected.fused_add_exp_squeeze, (x, R.const(1, "float32")), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) R.output(gv2) return gv2 @@ -626,7 +626,7 @@ def func2(x: R.Tensor((20, 10), dtype="float32")) -> R.Tensor((20, 10), dtype="f gv3 = R.call_tir( Expected.fused_add1_exp1_squeeze1, (x, R.const(1, "float32")), - out_sinfo=R.Tensor((20, 10), dtype="float32"), + out_ty=R.Tensor((20, 10), dtype="float32"), ) R.output(gv3) return gv3 @@ -762,8 +762,8 @@ def fused_function(x: R.Tensor([16, 32], "float32")) -> R.Tensor([16, 32], dtype R.func_attr({"Primitive": True}) cls = Before with R.dataflow(): - y = R.call_tir(cls.dynamic_tir_kernel, [x], out_sinfo=R.Tensor([16, 32], "float32")) - z = R.call_tir(cls.dynamic_tir_kernel, [y], out_sinfo=R.Tensor([16, 32], "float32")) + y = R.call_tir(cls.dynamic_tir_kernel, [x], out_ty=R.Tensor([16, 32], "float32")) + z = R.call_tir(cls.dynamic_tir_kernel, [y], out_ty=R.Tensor([16, 32], "float32")) R.output(z) return z @@ -798,7 +798,7 @@ def fused_function( def main(x: R.Tensor([16, 32], "float32")) -> R.Tensor([16, 32], dtype="float32"): cls = Expected with R.dataflow(): - gv = R.call_tir(cls.fused_function, [x], out_sinfo=R.Tensor([16, 32], "float32")) + gv = R.call_tir(cls.fused_function, [x], out_ty=R.Tensor([16, 32], "float32")) R.output(gv) return gv @@ -837,10 +837,10 @@ def fused_function( cls = Before with R.dataflow(): y = R.call_tir( - cls.dynamic_tir_kernel, [x, B, C], out_sinfo=R.Tensor([16 * 32], "float32") + cls.dynamic_tir_kernel, [x, B, C], out_ty=R.Tensor([16 * 32], "float32") ) z = R.call_tir( - cls.dynamic_tir_kernel, [y, B, C], out_sinfo=R.Tensor([16 * 32], "float32") + cls.dynamic_tir_kernel, [y, B, C], out_ty=R.Tensor([16 * 32], "float32") ) R.output(z) return z @@ -887,7 +887,7 @@ def main( cls = Expected with R.dataflow(): gv = R.call_tir( - cls.fused_function, (x, B, C), out_sinfo=R.Tensor((512,), dtype="float32") + cls.fused_function, (x, B, C), out_ty=R.Tensor((512,), dtype="float32") ) R.output(gv) return gv @@ -981,7 +981,7 @@ def fused( gv = R.call_tir( cls.foo, [lv1, y], - out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float32"), + out_ty=R.Tensor((1, 1, 32, 128), dtype="float32"), tir_vars=R.shape([m]), ) R.output(gv) @@ -1033,7 +1033,7 @@ def main( gv = R.call_tir( cls.fused, (x, y), - out_sinfo=R.Tensor([1, 1, 32, 128], "float32"), + out_ty=R.Tensor([1, 1, 32, 128], "float32"), tir_vars=R.shape([m]), ) R.output(gv) @@ -1093,10 +1093,10 @@ def fused_concatenate_transpose2( lv = R.call_tir( cls.concatenate, (inp_0, inp_0), - out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"), + out_ty=R.Tensor((2, 4, 64, 64), dtype="float32"), ) gv = R.call_tir( - cls.transpose2, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32") + cls.transpose2, (lv,), out_ty=R.Tensor((2, 64, 64, 4), dtype="float32") ) R.output(gv) return gv @@ -1154,7 +1154,7 @@ def main(inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")) -> R.Tensor( lv = R.call_tir( cls.fused_concatenate_transpose2, (inp_0,), - out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"), + out_ty=R.Tensor((2, 64, 64, 4), dtype="float32"), ) R.output(lv) return lv @@ -1229,7 +1229,7 @@ def main( lv = R.call_tir( cls.fused_transpose_matmul, (x, y), - out_sinfo=R.Tensor((n - 1, 3), dtype="float32"), + out_ty=R.Tensor((n - 1, 3), dtype="float32"), tir_vars=R.shape([n]), ) R.output(lv) @@ -1285,7 +1285,7 @@ def fused_reshape( with R.dataflow(): lv1: R.Tensor((4, 8, 2048), dtype="float32") = lv[0] gv = R.call_tir( - cls.reshape, (lv1,), out_sinfo=R.Tensor((4, 8, 32, 64), dtype="float32") + cls.reshape, (lv1,), out_ty=R.Tensor((4, 8, 32, 64), dtype="float32") ) R.output(gv) return gv @@ -1349,7 +1349,7 @@ def main( with R.dataflow(): lv: R.Tensor((4, 8, 2048), dtype="float32") = tup[0] lv_1 = R.call_tir( - cls.fused_reshape, (lv,), out_sinfo=R.Tensor((4, 8, 32, 64), dtype="float32") + cls.fused_reshape, (lv,), out_ty=R.Tensor((4, 8, 32, 64), dtype="float32") ) R.output(lv_1) return lv_1 @@ -1398,9 +1398,9 @@ def fused_func( cls = Module with R.dataflow(): lv = R.call_tir( - cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + cls.add, (input_embeds,), out_ty=R.Tensor((4096, 4096), dtype="float16") ) - gv = R.call_tir(cls.add1, (lv,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")) + gv = R.call_tir(cls.add1, (lv,), out_ty=R.Tensor((4096, 4096), dtype="float16")) R.output(gv) return gv @@ -1435,7 +1435,7 @@ def main(input_embeds: R.Tensor((4096, 4096), dtype="float16")) -> R.Tensor( gv = R.call_tir( cls.fused_func, (input_embeds,), - out_sinfo=R.Tensor((4096, 4096), dtype="float16"), + out_ty=R.Tensor((4096, 4096), dtype="float16"), ) R.output(gv) return gv @@ -1498,7 +1498,7 @@ def fused( gv = R.call_tir( cls.foo, [lv1, y], - out_sinfo=R.Tensor((1, sequence_length, 32, 128), dtype="float32"), + out_ty=R.Tensor((1, sequence_length, 32, 128), dtype="float32"), tir_vars=R.shape([m]), ) R.output(gv) @@ -1563,7 +1563,7 @@ def main( gv = R.call_tir( cls.fused, (x, y), - out_sinfo=R.Tensor([1, sequence_length, 32, 128], "float32"), + out_ty=R.Tensor([1, sequence_length, 32, 128], "float32"), tir_vars=R.shape([m]), ) R.output(gv) @@ -1603,7 +1603,7 @@ def fused( gv = R.call_tir( cls.sum_1d, [x], - out_sinfo=R.Tensor([1], dtype="float32"), + out_ty=R.Tensor([1], dtype="float32"), ) R.output(gv) return gv @@ -1640,7 +1640,7 @@ def main( ) -> R.Tensor([1], dtype="float32"): cls = Expected with R.dataflow(): - gv = R.call_tir(cls.fused, (x,), out_sinfo=R.Tensor((1,), dtype="float32")) + gv = R.call_tir(cls.fused, (x,), out_ty=R.Tensor((1,), dtype="float32")) R.output(gv) return gv @@ -1690,17 +1690,17 @@ def fused( x_sum = R.call_tir( cls.sum_1d, [x], - out_sinfo=R.Tensor([1], dtype="float32"), + out_ty=R.Tensor([1], dtype="float32"), ) y_sum = R.call_tir( cls.sum_1d, [y], - out_sinfo=R.Tensor([1], dtype="float32"), + out_ty=R.Tensor([1], dtype="float32"), ) gv = R.call_tir( cls.sum_scalar, [x_sum, y_sum], - out_sinfo=R.Tensor([1], dtype="float32"), + out_ty=R.Tensor([1], dtype="float32"), ) R.output(gv) return gv @@ -1755,7 +1755,7 @@ def main( ) -> R.Tensor([1], dtype="float32"): cls = Expected with R.dataflow(): - gv = R.call_tir(cls.fused, (x, y), out_sinfo=R.Tensor((1,), dtype="float32")) + gv = R.call_tir(cls.fused, (x, y), out_ty=R.Tensor((1,), dtype="float32")) R.output(gv) return gv @@ -1802,7 +1802,7 @@ def fused( gv = R.call_tir( cls.sum_1d, [x], - out_sinfo=R.Tensor([1], dtype="float32"), + out_ty=R.Tensor([1], dtype="float32"), tir_vars=R.shape([64]), ) R.output(gv) @@ -1840,7 +1840,7 @@ def main( ) -> R.Tensor([1], dtype="float32"): cls = Expected with R.dataflow(): - gv = R.call_tir(cls.fused, (x,), out_sinfo=R.Tensor((1,), dtype="float32")) + gv = R.call_tir(cls.fused, (x,), out_ty=R.Tensor((1,), dtype="float32")) R.output(gv) return gv @@ -1891,10 +1891,10 @@ def fused_func( cls = Before with R.dataflow(): lv = R.call_tir( - cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + cls.add, (input_embeds,), out_ty=R.Tensor((4096, 4096), dtype="float16") ) gv = R.call_tir( - cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16") + cls.take, (lv, input_ids), out_ty=R.Tensor((1, 4096), dtype="float16") ) R.output(gv) return gv @@ -1930,7 +1930,7 @@ def main( gv = R.call_tir( cls.fused_func, (input_ids, input_embeds), - out_sinfo=R.Tensor((1, 4096), dtype="float16"), + out_ty=R.Tensor((1, 4096), dtype="float16"), ) R.output(gv) return gv @@ -1991,19 +1991,19 @@ def fused_add_exp_squeeze( cls.add_inplace, (x, p0), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) lv1 = R.call_tir_inplace( cls.exp_inplace, (lv,), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) gv = R.call_tir_inplace( cls.squeeze_inplace, (lv1,), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) R.output(gv) return gv @@ -2050,7 +2050,7 @@ def main( gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir_inplace( cls.fused_add_exp_squeeze, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), inplace_indices=[0], ) R.output(gv1) @@ -2102,19 +2102,19 @@ def fused_add_exp_squeeze( lv = R.call_tir( cls.add, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) lv1 = R.call_tir_inplace( cls.exp_inplace, (lv,), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) gv = R.call_tir_inplace( cls.squeeze_inplace, (lv1,), inplace_indices=[0], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) R.output(gv) return gv @@ -2162,7 +2162,7 @@ def main( gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir( cls.fused_add_exp_squeeze, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) R.output(gv1) return gv1 @@ -2196,19 +2196,19 @@ def fused_sums( lv = R.call_tir( cls.add, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) lv1 = R.call_tir_inplace( cls.add, (x, p0, lv), inplace_indices=[2], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) lv2 = R.call_tir_inplace( cls.add, (x, p0, lv1), inplace_indices=[2], - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) R.output(lv2) return lv2 @@ -2254,7 +2254,7 @@ def main( gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir( cls.fused_sums, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_ty=R.Tensor((10, 20), dtype="float32"), ) R.output(gv1) return gv1 @@ -2290,10 +2290,10 @@ def fused_func( cls = Before with R.dataflow(): lv = R.call_tir( - cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + cls.add, (input_embeds,), out_ty=R.Tensor((4096, 4096), dtype="float16") ) gv = R.call_tir( - cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16") + cls.take, (lv, input_ids), out_ty=R.Tensor((1, 4096), dtype="float16") ) R.output(gv) return gv @@ -2346,10 +2346,10 @@ def fused_function( cls = Before with R.dataflow(): w = R.call_tir( - cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32") + cls.add, [x, y], out_ty=R.Tensor([T.int64(16), T.int64(32)], "float32") ) out = R.call_tir( - cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32") + cls.add, [w, z], out_ty=R.Tensor([T.int64(16), T.int64(32)], "float32") ) R.output(out) return out @@ -2397,7 +2397,7 @@ def main( gv = R.call_tir( cls.fused_function, [x, y, z], - out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"), + out_ty=R.Tensor([T.int64(16), T.int64(32)], "float32"), ) R.output(gv) return gv @@ -2427,7 +2427,7 @@ def fused_function( cls = Before with R.dataflow(): out = R.call_tir( - cls.mul, [x, x], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32") + cls.mul, [x, x], out_ty=R.Tensor([T.int64(16), T.int64(32)], "float32") ) R.output(out) return out @@ -2472,8 +2472,8 @@ def fused_add_mul(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), dtype="float R.func_attr({"Primitive": True}) cls = Before with R.dataflow(): - lv1 = R.call_tir(cls.add1, (x,), out_sinfo=R.Tensor((10,), dtype="float32")) - lv2 = R.call_tir(cls.mul1, (lv1,), out_sinfo=R.Tensor((10,), dtype="float32")) + lv1 = R.call_tir(cls.add1, (x,), out_ty=R.Tensor((10,), dtype="float32")) + lv2 = R.call_tir(cls.mul1, (lv1,), out_ty=R.Tensor((10,), dtype="float32")) R.output(lv2) return lv2 @@ -2513,7 +2513,7 @@ def fused_add_mul(p_x: T.handle, p_output0: T.handle): def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): cls = Expected with R.dataflow(): - gv = R.call_tir(cls.fused_add_mul, (x,), out_sinfo=R.Tensor((10,), dtype="float32")) + gv = R.call_tir(cls.fused_add_mul, (x,), out_ty=R.Tensor((10,), dtype="float32")) R.output(gv) return gv diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py b/tests/python/relax/test_transform_fuse_transpose_matmul.py index 9382c4892496..43ff415a1946 100644 --- a/tests/python/relax/test_transform_fuse_transpose_matmul.py +++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py @@ -65,9 +65,7 @@ def main( ) -> R.Tensor((128, 128), dtype="float32"): cls = Expected with R.dataflow(): - gv = R.call_tir( - cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128), dtype="float32") - ) + gv = R.call_tir(cls.NT_matmul, (x, w), out_ty=R.Tensor((128, 128), dtype="float32")) R.output(gv) return gv @@ -118,9 +116,7 @@ def NT_matmul( def main(x: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"): cls = Expected with R.dataflow(): - gv = R.call_tir( - cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128), dtype="float32") - ) + gv = R.call_tir(cls.NT_matmul, (x, w), out_ty=R.Tensor((128, 128), dtype="float32")) R.output(gv) return gv diff --git a/tests/python/relax/test_transform_gradient_checkpoint.py b/tests/python/relax/test_transform_gradient_checkpoint.py index 7797f6d8f679..5ab76f8b410b 100644 --- a/tests/python/relax/test_transform_gradient_checkpoint.py +++ b/tests/python/relax/test_transform_gradient_checkpoint.py @@ -479,7 +479,7 @@ def func2(x): return relax.op.sum(y) bb = BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + x = relax.Var("x", relax.TensorType((3, 3), "float32")) with bb.function("main", [x]): with bb.dataflow(): lv1 = bb.emit(nn.checkpoint(func1, x)) @@ -516,11 +516,11 @@ def func(x, y, z, w): return x * y, z * w bb = BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) - y = relax.Var("y", relax.TensorStructInfo((3, 3), "float32")) - z = relax.Var("z", relax.TensorStructInfo((3, 3), "float32")) - u = relax.Var("u", relax.TensorStructInfo((3, 3), "float32")) - v = relax.Var("v", relax.TensorStructInfo((3, 3), "float32")) + x = relax.Var("x", relax.TensorType((3, 3), "float32")) + y = relax.Var("y", relax.TensorType((3, 3), "float32")) + z = relax.Var("z", relax.TensorType((3, 3), "float32")) + u = relax.Var("u", relax.TensorType((3, 3), "float32")) + v = relax.Var("v", relax.TensorType((3, 3), "float32")) with bb.function("main", [x, y, z, u, v]): with bb.dataflow(): lv1 = bb.emit(x * y) @@ -564,7 +564,7 @@ def func(x): return x * relax.const(2, "float32") * relax.const(2, "float32") bb = BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + x = relax.Var("x", relax.TensorType((3, 3), "float32")) with bb.function("main", [x]): with bb.dataflow(): lv1 = bb.emit(nn.checkpoint(func, x)) @@ -610,7 +610,7 @@ def func(x): return x + x bb = BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + x = relax.Var("x", relax.TensorType((3, 3), "float32")) with bb.function("main", [x]): with bb.dataflow(): lv1 = nn.emit_checkpoint_sequential([func] * 5, 2, x) @@ -652,7 +652,7 @@ def func(x): return x + x bb = BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + x = relax.Var("x", relax.TensorType((3, 3), "float32")) with bb.function("main", [x]): with bb.dataflow(): lv1 = nn.emit_checkpoint_sequential([func] * 5, 2, x, checkpoint_last=True) @@ -702,7 +702,7 @@ def func(x): return x * relax.const(2, "float32") * relax.const(2, "float32") bb = BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + x = relax.Var("x", relax.TensorType((3, 3), "float32")) with bb.function("main", [x]): with bb.dataflow(): lv1 = bb.emit(nn.checkpoint(func, x)) @@ -720,7 +720,7 @@ def func(x): return x * x * x bb = BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + x = relax.Var("x", relax.TensorType((3, 3), "float32")) with bb.function("main", [x]): with bb.dataflow(): lv1 = nn.emit_checkpoint(func, x) diff --git a/tests/python/relax/test_transform_gradient_numeric.py b/tests/python/relax/test_transform_gradient_numeric.py index 1865528e3b83..f3d9b357b24b 100644 --- a/tests/python/relax/test_transform_gradient_numeric.py +++ b/tests/python/relax/test_transform_gradient_numeric.py @@ -125,8 +125,8 @@ def test_mlp_blockbuilder(): # Check numerical gradients equal args = [] for arg in After["MLP_adjoint"].params: - shape = [int(l) for l in arg.struct_info.shape] - if arg.struct_info.dtype == "int64": + shape = [int(l) for l in arg.ty.shape] + if arg.ty.dtype == "int64": args.append( tvm.runtime.tensor(np.random.randint(0, out_size, size=shape).astype(np.int64)) ) @@ -188,7 +188,7 @@ def main(x: R.Tensor((6,), "float32"), y: R.Tensor((6, 3, 4), "float32")): After = relax.transform.Gradient("main")(Before) args = [] for arg in After["main_adjoint"].params: - shape = [int(l) for l in arg.struct_info.shape] + shape = [int(l) for l in arg.ty.shape] args.append(rand("float32", *shape)) vm_before = _legalize_and_build(Before, target, dev) @@ -228,7 +228,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")): After = relax.transform.Gradient("main")(Before) args = [] for arg in After["main_adjoint"].params: - shape = [int(l) for l in arg.struct_info.shape] + shape = [int(l) for l in arg.ty.shape] args.append(rand("float32", *shape)) vm_before = _legalize_and_build(Before, target, dev) diff --git a/tests/python/relax/test_transform_gradient_te_register.py b/tests/python/relax/test_transform_gradient_te_register.py index 4706d7e8cbbc..47fe713d29bf 100644 --- a/tests/python/relax/test_transform_gradient_te_register.py +++ b/tests/python/relax/test_transform_gradient_te_register.py @@ -94,11 +94,11 @@ def f_mul_grad(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T. def main_adjoint(a: R.Tensor((5, 5), dtype="float32"), b: R.Tensor((5, 5), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((5, 5), dtype="float32"), R.Tensor((5, 5), dtype="float32"))): cls = Expected with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32")) + lv = R.call_tir(cls.f_mul, (a, b), out_ty=R.Tensor((5, 5), dtype="float32")) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") lv_adjoint: R.Tensor((5, 5), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([5, 5])) - lv_1 = R.call_tir(cls.f_mul_grad, (lv_adjoint, a, b), out_sinfo=[R.Tensor((5, 5), dtype="float32"), R.Tensor((5, 5), dtype="float32")]) + lv_1 = R.call_tir(cls.f_mul_grad, (lv_adjoint, a, b), out_ty=[R.Tensor((5, 5), dtype="float32"), R.Tensor((5, 5), dtype="float32")]) a_adjoint: R.Tensor((5, 5), dtype="float32") = lv_1[0] b_adjoint: R.Tensor((5, 5), dtype="float32") = lv_1[1] a_adjoint_out: R.Tensor((5, 5), dtype="float32") = a_adjoint @@ -110,7 +110,7 @@ def main_adjoint(a: R.Tensor((5, 5), dtype="float32"), b: R.Tensor((5, 5), dtype def main(a: R.Tensor((5, 5), dtype="float32"), b: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Expected with R.dataflow(): - lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") + lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_ty=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -126,8 +126,8 @@ def mul(*idx): return tvm.te.compute(src1.shape, mul, name="f_mul") - a = relax.Var("a", relax.TensorStructInfo([5, 5], "float32")) - b = relax.Var("b", relax.TensorStructInfo([5, 5], "float32")) + a = relax.Var("a", relax.TensorType([5, 5], "float32")) + b = relax.Var("b", relax.TensorType([5, 5], "float32")) bb = relax.BlockBuilder() with bb.function("main", [a, b]): @@ -164,7 +164,7 @@ def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64 def main(a: R.Tensor((5, 5), dtype="float32"), b: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Before with R.dataflow(): - lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") + lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_ty=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mul_grad") gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -204,11 +204,11 @@ def f_mulk_grad(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T def main_adjoint(a: R.Tensor((5, 5), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((5, 5), dtype="float32"))): cls = Expected with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32")) + lv = R.call_tir(cls.f_mul, (a,), out_ty=R.Tensor((5, 5), dtype="float32")) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") lv_adjoint: R.Tensor((5, 5), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([5, 5])) - lv_1 = R.call_tir(cls.f_mulk_grad, (lv_adjoint, a), out_sinfo=R.Tensor((5, 5), dtype="float32")) + lv_1 = R.call_tir(cls.f_mulk_grad, (lv_adjoint, a), out_ty=R.Tensor((5, 5), dtype="float32")) a_adjoint: R.Tensor((5, 5), dtype="float32") = lv_1 a_adjoint_out: R.Tensor((5, 5), dtype="float32") = a_adjoint R.output(gv, a_adjoint_out) @@ -218,7 +218,7 @@ def main_adjoint(a: R.Tensor((5, 5), dtype="float32")) -> R.Tuple(R.Tensor((), d def main(a: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Expected with R.dataflow(): - lv = R.call_tir_with_grad(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) + lv = R.call_tir_with_grad(cls.f_mul, (a,), out_ty=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -231,7 +231,7 @@ def test_emit_te_kwargs(register_te_grads): def f_mul2(src): return tvm.te.compute(src.shape, lambda *idx: src[idx] * T.float32(2), name="f_mul2") - a = relax.Var("a", relax.TensorStructInfo([5, 5], "float32")) + a = relax.Var("a", relax.TensorType([5, 5], "float32")) bb = relax.BlockBuilder() with bb.function("main", [a]): @@ -273,7 +273,7 @@ def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul2: T.Buffer((T. def main(a: R.Tensor((5, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Before with R.dataflow(): - lv = R.call_tir_with_grad(cls.f_mul, (a,), out_sinfo=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) + lv = R.call_tir_with_grad(cls.f_mul, (a,), out_ty=R.Tensor((5, 5), dtype="float32"), te_grad_name="f_mulk_grad", te_grad_kwargs={"k": T.float32(2)}) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -330,11 +330,11 @@ def main_adjoint(a: R.Tensor(("n", "n"), dtype="float32"), b: R.Tensor(("n", "n" n = T.int64() cls = Expected with R.dataflow(): - lv = R.call_tir(cls.f_mul, (a, b), out_sinfo=R.Tensor((n, n), dtype="float32")) + lv = R.call_tir(cls.f_mul, (a, b), out_ty=R.Tensor((n, n), dtype="float32")) gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) gv_adjoint: R.Tensor((), dtype="float32") = R.ones(R.shape([]), dtype="float32") lv_adjoint: R.Tensor((n, n), dtype="float32") = R.broadcast_to(gv_adjoint, R.shape([n, n])) - lv_1 = R.call_tir(cls.f_mul_grad, (lv_adjoint, a, b), out_sinfo=[R.Tensor((n, n), dtype="float32"), R.Tensor((n, n), dtype="float32")]) + lv_1 = R.call_tir(cls.f_mul_grad, (lv_adjoint, a, b), out_ty=[R.Tensor((n, n), dtype="float32"), R.Tensor((n, n), dtype="float32")]) a_adjoint: R.Tensor((n, n), dtype="float32") = lv_1[0] b_adjoint: R.Tensor((n, n), dtype="float32") = lv_1[1] a_adjoint_out: R.Tensor((n, n), dtype="float32") = a_adjoint @@ -347,7 +347,7 @@ def main(a: R.Tensor(("n", "n"), dtype="float32"), b: R.Tensor(("n", "n"), dtype n = T.int64() cls = Expected with R.dataflow(): - lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_sinfo=R.Tensor((n, n), dtype="float32"), te_grad_name="f_mul_grad") + lv = R.call_tir_with_grad(cls.f_mul, (a, b), out_ty=R.Tensor((n, n), dtype="float32"), te_grad_name="f_mul_grad") gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) R.output(gv) return gv @@ -363,8 +363,8 @@ def mul(*idx): return tvm.te.compute(src1.shape, mul, name="f_mul") n = tirx.Var("n", "int64") - a = relax.Var("a", relax.TensorStructInfo([n, n], "float32")) - b = relax.Var("b", relax.TensorStructInfo([n, n], "float32")) + a = relax.Var("a", relax.TensorType([n, n], "float32")) + b = relax.Var("b", relax.TensorType([n, n], "float32")) bb = relax.BlockBuilder() with bb.function("main", [a, b]): diff --git a/tests/python/relax/test_transform_inline_private_functions.py b/tests/python/relax/test_transform_inline_private_functions.py index 61d09cb3af3e..349c776b4022 100644 --- a/tests/python/relax/test_transform_inline_private_functions.py +++ b/tests/python/relax/test_transform_inline_private_functions.py @@ -88,7 +88,7 @@ def main(): @R.function(private=True) def subroutine() -> R.Tensor([], "int64"): R.func_attr({"relax.force_pure": True}) - cond = R.call_packed("dummy_function", sinfo_args=R.Tensor([], "bool")) + cond = R.call_packed("dummy_function", ty_args=R.Tensor([], "bool")) if cond: Out = Before.subroutine() else: diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index 70a7b4a79143..44d333783de5 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -91,8 +91,8 @@ def inner( def test_input_module_is_unmodified(): """The input module may not be modified - If the output requires new StructInfo, it must create a new relax - variable. It must not update the struct info of an existing relax + If the output requires new Type, it must create a new relax + variable. It must not update the type of an existing relax variable, as that variable may be used by another IRModule. """ @@ -136,9 +136,7 @@ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Te (2, 3), "float32" ): in_call = Expected.main_outer_func(x) - res = R.invoke_pure_closure( - in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32")) - ) + res = R.invoke_pure_closure(in_call, (y,), ty_args=(R.Tensor((2, 3), dtype="float32"))) return res @R.function(private=True) @@ -191,7 +189,7 @@ def main_while_loop( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): cond: R.Tensor((), "bool") = R.call_pure_packed( - "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) + "test.vm.less", i, R.const(10), ty_args=(R.Tensor((), dtype="bool")) ) c: R.Tensor((), "int32") = R.const(1, dtype="int32") if cond: @@ -209,7 +207,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), dtype="float32"): gv: R.Tensor((2, 3), dtype="float32") = R.invoke_pure_closure( while_loop, (R.const(0), x), - sinfo_args=(R.Tensor((2, 3), dtype="float32")), + ty_args=(R.Tensor((2, 3), dtype="float32")), ) return gv @@ -223,7 +221,7 @@ def while_loop(i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")) -> R.Te (2, 3), "float32" ): cond: R.Tensor((), "bool") = R.call_pure_packed( - "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) + "test.vm.less", i, R.const(10), ty_args=(R.Tensor((), dtype="bool")) ) c: R.Tensor((), "int32") = R.const(1, dtype="int32") if cond: diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 4a8d91df0990..5642f72a09e2 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -56,7 +56,7 @@ def main_transform_params( lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) gv: R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), @@ -81,14 +81,14 @@ def transform_layout_IOHW_to_OIHW( @R.function(pure=False) def main_transform_params() -> R.Tuple: cls = Expected - lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) + lv: R.Object = R.call_packed("get_item", R.prim_value(1), ty_args=(R.Object,)) gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast( lv, R.Tensor((16, 16, 3, 3), dtype="float32") ) lv_m: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 - _: R.Object = R.call_packed("set_item", R.prim_value(0), lv_m, sinfo_args=(R.Object,)) + _: R.Object = R.call_packed("set_item", R.prim_value(0), lv_m, ty_args=(R.Object,)) _1: R.Tuple = R.vm.kill_object(lv_m) - lv1: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + lv1: R.Object = R.call_packed("get_item", R.prim_value(0), ty_args=(R.Object,)) gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast( lv1, R.Tensor((3, 16, 3, 3), dtype="float32") ) @@ -96,10 +96,10 @@ def main_transform_params() -> R.Tuple: lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1_m,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) _2: R.Tuple = R.vm.kill_object(lv1_m) - _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,)) + _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, ty_args=(R.Object,)) gv: R.Tuple = R.tuple() return gv @@ -137,7 +137,7 @@ def main_transform_params( lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) lv3 = R.add(lv2, R.const(1, "float32")) gv: R.Tuple( @@ -165,12 +165,12 @@ def main_transform_params() -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): cls = Expected - gv: R.Object = R.call_packed("get_item_0", R.prim_value(1), sinfo_args=(R.Object,)) + gv: R.Object = R.call_packed("get_item_0", R.prim_value(1), ty_args=(R.Object,)) gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast( gv, R.Tensor((16, 16, 3, 3), dtype="float32") ) lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 - gv2: R.Object = R.call_packed("get_item_0", R.prim_value(0), sinfo_args=(R.Object,)) + gv2: R.Object = R.call_packed("get_item_0", R.prim_value(0), ty_args=(R.Object,)) gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast( gv2, R.Tensor((3, 16, 3, 3), dtype="float32") ) @@ -178,7 +178,7 @@ def main_transform_params() -> R.Tuple( lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32")) gv_1: R.Tuple( @@ -220,7 +220,7 @@ def main_transform_params( lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) lv3 = R.add(lv2, R.const(1, "float32")) gv: R.Tuple( @@ -246,18 +246,14 @@ def transform_layout_IOHW_to_OIHW( @R.function(pure=False) def main_transform_params(loader: R.Object) -> R.Tuple: cls = Expected - gv: R.Object = R.call_packed( - "get_item", loader, R.prim_value(1), sinfo_args=(R.Object,) - ) + gv: R.Object = R.call_packed("get_item", loader, R.prim_value(1), ty_args=(R.Object,)) gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast( gv, R.Tensor((16, 16, 3, 3), dtype="float32") ) lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 - _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,)) + _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, ty_args=(R.Object,)) _1: R.Tuple = R.vm.kill_object(lv) - gv2: R.Object = R.call_packed( - "get_item", loader, R.prim_value(0), sinfo_args=(R.Object,) - ) + gv2: R.Object = R.call_packed("get_item", loader, R.prim_value(0), ty_args=(R.Object,)) gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast( gv2, R.Tensor((3, 16, 3, 3), dtype="float32") ) @@ -265,17 +261,17 @@ def main_transform_params(loader: R.Object) -> R.Tuple: lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) _2: R.Tuple = R.vm.kill_object(lv1) lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32")) - _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv3, sinfo_args=(R.Object,)) + _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv3, ty_args=(R.Object,)) gv_1: R.Tuple = R.tuple() return gv_1 - after = LazyTransformParams( - extra_get_item_params=[relax.Var("loader", relax.ObjectStructInfo())] - )(Before) + after = LazyTransformParams(extra_get_item_params=[relax.Var("loader", relax.ObjectType())])( + Before + ) tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True) @@ -309,7 +305,7 @@ def main_transform_params( lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) lv3 = R.add(lv2, R.const(1, "float32")) gv: R.Tuple( @@ -335,16 +331,16 @@ def transform_layout_IOHW_to_OIHW( @R.function(pure=False) def main_transform_params(setter: R.Object) -> R.Tuple: cls = Expected - gv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) + gv: R.Object = R.call_packed("get_item", R.prim_value(1), ty_args=(R.Object,)) gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast( gv, R.Tensor((16, 16, 3, 3), dtype="float32") ) lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 _: R.Object = R.call_packed( - "set_item", setter, R.prim_value(0), lv, sinfo_args=(R.Object,) + "set_item", setter, R.prim_value(0), lv, ty_args=(R.Object,) ) _1: R.Tuple = R.vm.kill_object(lv) - gv2: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + gv2: R.Object = R.call_packed("get_item", R.prim_value(0), ty_args=(R.Object,)) gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast( gv2, R.Tensor((3, 16, 3, 3), dtype="float32") ) @@ -352,19 +348,19 @@ def main_transform_params(setter: R.Object) -> R.Tuple: lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) _2: R.Tuple = R.vm.kill_object(lv1) lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32")) _3: R.Object = R.call_packed( - "set_item", setter, R.prim_value(1), lv3, sinfo_args=(R.Object,) + "set_item", setter, R.prim_value(1), lv3, ty_args=(R.Object,) ) gv_1: R.Tuple = R.tuple() return gv_1 - after = LazyTransformParams( - extra_set_item_params=[relax.Var("setter", relax.ObjectStructInfo())] - )(Before) + after = LazyTransformParams(extra_set_item_params=[relax.Var("setter", relax.ObjectType())])( + Before + ) tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True) @@ -392,20 +388,20 @@ def main_transform_params(setter: R.Object) -> R.Tuple: setter, R.prim_value(0), R.const(np.array([1, 2]).astype("float32")), - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) _ = R.call_packed( "set_item", setter, R.prim_value(1), R.const(np.array([3, 4]).astype("float32")), - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) return output - after = LazyTransformParams( - extra_set_item_params=[relax.Var("setter", relax.ObjectStructInfo())] - )(Before) + after = LazyTransformParams(extra_set_item_params=[relax.Var("setter", relax.ObjectType())])( + Before + ) tvm.ir.assert_structural_equal(after, Expected) @@ -432,7 +428,7 @@ def main_transform_params( cls.slice_buffer, (param,), tir_vars=[slice_index], - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), ) output = (transformed,) return output @@ -456,7 +452,7 @@ def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])): slice_index = T.int64() - param = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + param = R.call_packed("get_item", R.prim_value(0), ty_args=(R.Object,)) gv: R.Tensor((16, 16), dtype="float32") = R.match_cast( param, R.Tensor((16, 16), dtype="float32") ) @@ -465,12 +461,10 @@ def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])): cls.slice_buffer, (param_m,), tir_vars=[slice_index], - out_sinfo=R.Tensor((16,), dtype="float32"), + out_ty=R.Tensor((16,), dtype="float32"), ) unused_1_ = R.vm.kill_object(param_m) - unused_2_ = R.call_packed( - "set_item", R.prim_value(0), transformed, sinfo_args=(R.Object,) - ) + unused_2_ = R.call_packed("set_item", R.prim_value(0), transformed, ty_args=(R.Object,)) output = R.tuple() return output @@ -523,7 +517,7 @@ def main_transform_params( lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((ic, 3, 3, 3), dtype="float32"), ) gv: R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), @@ -549,14 +543,14 @@ def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle): def main_transform_params() -> R.Tuple: ic = T.int64() cls = Expected - gv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) + gv: R.Object = R.call_packed("get_item", R.prim_value(1), ty_args=(R.Object,)) gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast( gv, R.Tensor((16, 16, 3, 3), dtype="float32") ) lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 - _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,)) + _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, ty_args=(R.Object,)) _1: R.Tuple = R.vm.kill_object(lv) - gv2: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + gv2: R.Object = R.call_packed("get_item", R.prim_value(0), ty_args=(R.Object,)) gv3: R.Tensor((3, ic, 3, 3), dtype="float32") = R.match_cast( gv2, R.Tensor((3, ic, 3, 3), dtype="float32") ) @@ -564,10 +558,10 @@ def main_transform_params() -> R.Tuple: lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((ic, 3, 3, 3), dtype="float32"), ) _2: R.Tuple = R.vm.kill_object(lv1) - _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,)) + _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, ty_args=(R.Object,)) gv4: R.Tuple = R.tuple() return gv4 @@ -593,8 +587,8 @@ def main_transform_params(params: R.Tuple(R.Tensor((), dtype="float32"))) -> R.T R.func_attr({"relax.force_pure": True}) cls = Module x: R.Tensor((), dtype="float32") = params[0] - y = R.call_tir(cls.copy, (x,), out_sinfo=R.Tensor((), dtype="float32")) - z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((), dtype="float32")) + y = R.call_tir(cls.copy, (x,), out_ty=R.Tensor((), dtype="float32")) + z = R.call_tir(cls.copy, (y,), out_ty=R.Tensor((), dtype="float32")) gv: R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")) = (y, z) return gv @@ -610,14 +604,14 @@ def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")): @R.function(pure=False) def main_transform_params() -> R.Tuple: cls = Expected - x: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + x: R.Object = R.call_packed("get_item", R.prim_value(0), ty_args=(R.Object,)) gv: R.Tensor((), dtype="float32") = R.match_cast(x, R.Tensor((), dtype="float32")) x_m: R.Tensor((), dtype="float32") = gv - y = R.call_tir(cls.copy, (x_m,), out_sinfo=R.Tensor((), dtype="float32")) + y = R.call_tir(cls.copy, (x_m,), out_ty=R.Tensor((), dtype="float32")) _: R.Tuple = R.vm.kill_object(x_m) - z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((), dtype="float32")) - _1: R.Object = R.call_packed("set_item", R.prim_value(0), y, sinfo_args=(R.Object,)) - _2: R.Object = R.call_packed("set_item", R.prim_value(1), z, sinfo_args=(R.Object,)) + z = R.call_tir(cls.copy, (y,), out_ty=R.Tensor((), dtype="float32")) + _1: R.Object = R.call_packed("set_item", R.prim_value(0), y, ty_args=(R.Object,)) + _2: R.Object = R.call_packed("set_item", R.prim_value(1), z, ty_args=(R.Object,)) gv: R.Tuple = R.tuple() return gv @@ -704,26 +698,26 @@ def main_transform_params( class Expected: @R.function(pure=False) def main_transform_params() -> R.Tuple: - gv: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + gv: R.Object = R.call_packed("get_item", R.prim_value(0), ty_args=(R.Object,)) gv1: R.Tensor((16,), dtype="int32") = R.match_cast(gv, R.Tensor((16,), dtype="int32")) param0: R.Tensor((16,), dtype="int32") = gv1 - gv2: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) + gv2: R.Object = R.call_packed("get_item", R.prim_value(1), ty_args=(R.Object,)) gv3: R.Tensor((16,), dtype="int32") = R.match_cast(gv2, R.Tensor((16,), dtype="int32")) param1: R.Tensor((16,), dtype="int32") = gv3 transformed0: R.Tensor((16,), dtype="int32") = R.add(param0, R.const(1, "int32")) _: R.Tuple = R.vm.kill_object(param0) _: R.Object = R.call_packed( - "set_item", R.prim_value(0), transformed0, sinfo_args=(R.Object,) + "set_item", R.prim_value(0), transformed0, ty_args=(R.Object,) ) _: R.Object = R.call_packed( - "set_item", R.prim_value(2), transformed0, sinfo_args=(R.Object,) + "set_item", R.prim_value(2), transformed0, ty_args=(R.Object,) ) transformed1: R.Tensor((16,), dtype="int32") = R.add(param1, R.const(2, "int32")) _ = R.vm.kill_object(param1) - _ = R.call_packed("set_item", R.prim_value(1), transformed1, sinfo_args=(R.Object,)) + _ = R.call_packed("set_item", R.prim_value(1), transformed1, ty_args=(R.Object,)) output = R.tuple() return output @@ -744,11 +738,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl class Expected: @R.function(pure=False) def transform_params(): - A = R.call_packed("get_item", R.prim_value(0), sinfo_args=[R.Object]) + A = R.call_packed("get_item", R.prim_value(0), ty_args=[R.Object]) A = R.match_cast(A, R.Tensor([16, 16], "float32")) C = R.multiply(A, R.const(2, "float32")) - B = R.call_packed("get_item", R.prim_value(1), sinfo_args=[R.Object]) + B = R.call_packed("get_item", R.prim_value(1), ty_args=[R.Object]) B = R.match_cast(B, R.Tensor([16, 16], "float32")) D = R.add(C, B) return (D, B) @@ -785,13 +779,13 @@ def transform_params(relax_rank: R.Prim(value="rank")): R.func_attr({"num_input": 1}) rank = T.int64() - A = R.call_packed("get_item", R.prim_value(0), sinfo_args=[R.Object]) + A = R.call_packed("get_item", R.prim_value(0), ty_args=[R.Object]) A = R.match_cast(A, R.Tensor([16, 16], "float32")) A_sharded = R.strided_slice( A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8], assume_inbound=True ) - B = R.call_packed("get_item", R.prim_value(1), sinfo_args=[R.Object]) + B = R.call_packed("get_item", R.prim_value(1), ty_args=[R.Object]) B = R.match_cast(B, R.Tensor([16, 16], "float32")) B_sharded = R.strided_slice( B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8], assume_inbound=True @@ -814,7 +808,7 @@ def transform_params(A: R.Object): class Expected: @R.function(pure=False) def transform_params(): - A = R.call_packed("get_item", R.prim_value(0), sinfo_args=[R.Object]) + A = R.call_packed("get_item", R.prim_value(0), ty_args=[R.Object]) A = R.match_cast(A, R.Object) return (A,) diff --git a/tests/python/relax/test_transform_legalize_ops.py b/tests/python/relax/test_transform_legalize_ops.py index 63a9b4cbac79..17e03f76706b 100644 --- a/tests/python/relax/test_transform_legalize_ops.py +++ b/tests/python/relax/test_transform_legalize_ops.py @@ -152,9 +152,9 @@ def main(x: R.Tensor((3, 3), "float32")): register_legalize("relax.add", add_legalize) # case 2: don't know all shape - s = relax.Var("s", relax.ShapeStructInfo((3, 3))) - x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) - y = relax.Var("y", relax.TensorStructInfo(s, "float32")) + s = relax.Var("s", relax.ShapeType((3, 3))) + x = relax.Var("x", relax.TensorType((3, 3), "float32")) + y = relax.Var("y", relax.TensorType(s, "float32")) bb = relax.BlockBuilder() with bb.function("main", [x, y]): with bb.dataflow(): @@ -209,7 +209,7 @@ def multiply( @R.function def main(x: R.Tensor((3, 3), dtype="float16")) -> R.Tensor((3, 3), dtype="float16"): cls = Expected0 - gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="float16")) + gv = R.call_tir(cls.multiply, (x,), out_ty=R.Tensor((3, 3), dtype="float16")) return gv @tvm.script.ir_module @@ -231,7 +231,7 @@ def multiply( @R.function def main(x: R.Tensor((3, 3), dtype="uint8")) -> R.Tensor((3, 3), dtype="uint8"): cls = Expected1 - gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="uint8")) + gv = R.call_tir(cls.multiply, (x,), out_ty=R.Tensor((3, 3), dtype="uint8")) return gv @tvm.script.ir_module @@ -253,7 +253,7 @@ def equal( @R.function def main(x: R.Tensor((3, 3), dtype="bool")) -> R.Tensor((3, 3), dtype="bool"): cls = Expected2 - gv = R.call_tir(cls.equal, (x,), out_sinfo=R.Tensor((3, 3), dtype="bool")) + gv = R.call_tir(cls.equal, (x,), out_ty=R.Tensor((3, 3), dtype="bool")) return gv # fmt: on @@ -297,19 +297,17 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 8]): def custom_op(emit_legalization_through_builder): op_name = "custom_op.matmul_bias_add" - def infer_struct_info(call: relax.Call, context): + def infer_ty(call: relax.Call, context): activations, weight, bias = call.args matmul_call = relax.op.matmul(activations, weight) - matmul_sinfo = tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")( - matmul_call, context - ) + matmul_ty = tvm.ir.Op.get("relax.matmul").get_attr("FInferType")(matmul_call, context) - matmul_var = relax.Var("dummy_var", matmul_sinfo) + matmul_var = relax.Var("dummy_var", matmul_ty) add_call = matmul_var + bias - add_sinfo = tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context) + add_ty = tvm.ir.Op.get("relax.add").get_attr("FInferType")(add_call, context) - return add_sinfo + return add_ty def legalize(bb: relax.BlockBuilder, call: relax.Call): activations, weight, bias = call.args @@ -319,7 +317,7 @@ def legalize(bb: relax.BlockBuilder, call: relax.Call): return legalized op_attrs = { - "FInferStructInfo": infer_struct_info, + "FInferType": infer_ty, "FLegalize": legalize, "FPurity": True, } @@ -392,7 +390,7 @@ def func_cuda( B: R.Tensor((32, 32), dtype="float32"), ): cls = Expected - C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32")) + C = R.call_tir(cls.add, (A, B), out_ty=R.Tensor((32, 32), dtype="float32")) return C @T.prim_func(private=True, s_tir=True) @@ -416,7 +414,7 @@ def func_llvm( C = R.call_tir( cls.add_llvm, (A, B), - out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm"), + out_ty=R.Tensor((32, 32), dtype="float32", vdevice="llvm"), ) return C diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py b/tests/python/relax/test_transform_legalize_ops_binary.py index 964c704d4cbf..0161eafc2c82 100644 --- a/tests/python/relax/test_transform_legalize_ops_binary.py +++ b/tests/python/relax/test_transform_legalize_ops_binary.py @@ -697,7 +697,7 @@ def power(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32 @R.function def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3, 2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"): - gv = R.call_tir(Expected.power, (x, y), out_sinfo=R.Tensor((4, 3, 2, 3), dtype="float32")) + gv = R.call_tir(Expected.power, (x, y), out_ty=R.Tensor((4, 3, 2, 3), dtype="float32")) return gv # fmt: on @@ -745,7 +745,7 @@ def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y: R.Tensor(("a", "b", "c" b = T.int64() c = T.int64() d = T.int64() - gv = R.call_tir(Expected.power, (x, y), out_sinfo=R.Tensor((a, b, c, d), dtype="float32")) + gv = R.call_tir(Expected.power, (x, y), out_ty=R.Tensor((a, b, c, d), dtype="float32")) return gv # fmt: on @@ -815,7 +815,7 @@ def atan2(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32 @R.function def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3, 2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"): - gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((4, 3, 2, 3), dtype="float32")) + gv = R.call_tir(Expected.atan2, (x, y), out_ty=R.Tensor((4, 3, 2, 3), dtype="float32")) return gv # fmt: on @@ -862,7 +862,7 @@ def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y: R.Tensor(("a", "b", "c" b = T.int64() c = T.int64() d = T.int64() - gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((a, b, c, d), dtype="float32")) + gv = R.call_tir(Expected.atan2, (x, y), out_ty=R.Tensor((a, b, c, d), dtype="float32")) return gv # fmt: on diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py index 2ab48b64cf43..b6459e3ae694 100644 --- a/tests/python/relax/test_transform_legalize_ops_ccl.py +++ b/tests/python/relax/test_transform_legalize_ops_ccl.py @@ -41,11 +41,11 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv1: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv2: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv3: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) - gv4: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4]), True], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([0]), True], out_ty=R.Tensor((10, 10), dtype="float32")) + gv1: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([1]), True], out_ty=R.Tensor((10, 10), dtype="float32")) + gv2: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([2]), True], out_ty=R.Tensor((10, 10), dtype="float32")) + gv3: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([3]), True], out_ty=R.Tensor((10, 10), dtype="float32")) + gv4: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allreduce", [x, R.shape([4]), True], out_ty=R.Tensor((10, 10), dtype="float32")) return x # fmt: on @@ -67,8 +67,8 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20, 10), dtype="float32")) - gv1: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x, True], out_sinfo=R.Tensor((20, 10), dtype="float32")) + gv0: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x, True], out_ty=R.Tensor((20, 10), dtype="float32")) + gv1: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x, True], out_ty=R.Tensor((20, 10), dtype="float32")) return x # fmt: on @@ -89,7 +89,7 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"): class Expected: @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): - gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.broadcast_from_worker0", [x, False], out_sinfo=R.Tensor((10, 10), dtype="float32")) + gv0: R.Tensor((10, 10), dtype="float32") = R.call_dps_packed("runtime.disco.broadcast_from_worker0", [x, False], out_ty=R.Tensor((10, 10), dtype="float32")) return x # fmt: on @@ -133,9 +133,9 @@ def transpose(A: T.Buffer((T.int64(10), T.int64(2), T.int64(5)), "float32"), T_t @R.function def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"): cls = Expected - gv = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((10, 2, 5), dtype="float32")) - gv1 = R.call_tir(cls.transpose, (gv,), out_sinfo=R.Tensor((2, 10, 5), dtype="float32")) - gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", (gv1, False), out_sinfo=R.Tensor((10, 5), dtype="float32")) + gv = R.call_tir(cls.reshape, (x,), out_ty=R.Tensor((10, 2, 5), dtype="float32")) + gv1 = R.call_tir(cls.transpose, (gv,), out_ty=R.Tensor((2, 10, 5), dtype="float32")) + gv0 = R.call_dps_packed("runtime.disco.scatter_from_worker0", (gv1, False), out_ty=R.Tensor((10, 5), dtype="float32")) return gv0 # fmt: on diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py index 55f9ac799eee..dda27def445d 100644 --- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -600,7 +600,7 @@ class Expected: def main(x: R.Tensor(["n"], "float32")): cls = Expected n = T.int64() - gv = R.call_tir(cls.arange, R.tuple(), out_sinfo=R.Tensor((n // 2,), dtype="int64"), tir_vars=R.shape([n])) + gv = R.call_tir(cls.arange, R.tuple(), out_ty=R.Tensor((n // 2,), dtype="int64"), tir_vars=R.shape([n])) return gv @T.prim_func(private=True, s_tir=True) diff --git a/tests/python/relax/test_transform_legalize_ops_distributed.py b/tests/python/relax/test_transform_legalize_ops_distributed.py index 30b1adb2f7a3..83338570c306 100644 --- a/tests/python/relax/test_transform_legalize_ops_distributed.py +++ b/tests/python/relax/test_transform_legalize_ops_distributed.py @@ -51,9 +51,9 @@ def strided_slice(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), redistribu def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"): worker_id = T.int64() cls = Expected - gv: R.Shape(ndim=-1) = R.call_pure_packed("runtime.disco.worker_id", sinfo_args=(R.Shape(ndim=-1),)) + gv: R.Shape(ndim=-1) = R.call_pure_packed("runtime.disco.worker_id", ty_args=(R.Shape(ndim=-1),)) gv1: R.Shape([worker_id]) = R.match_cast(gv, R.Shape([worker_id])) - gv0 = R.call_tir(cls.strided_slice, (x,), out_sinfo=R.Tensor((10, 5), dtype="float32"), tir_vars=R.shape([worker_id])) + gv0 = R.call_tir(cls.strided_slice, (x,), out_ty=R.Tensor((10, 5), dtype="float32"), tir_vars=R.shape([worker_id])) return gv0 # fmt: on diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py b/tests/python/relax/test_transform_legalize_ops_grad.py index be2603cdc3ac..8f12bb0b58f6 100644 --- a/tests/python/relax/test_transform_legalize_ops_grad.py +++ b/tests/python/relax/test_transform_legalize_ops_grad.py @@ -79,7 +79,7 @@ def nll_loss_backward(rxplaceholder: T.Buffer((), "float32"), rxplaceholder_1: T @R.function def main(output_grad: R.Tensor((), dtype="float32"), predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), targets: R.Tensor((2, 4, 5), dtype="int64"), weights: R.Tensor((4,), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"): cls = Expected - gv = R.call_tir(cls.nll_loss_backward, (output_grad, predictions, targets, weights), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32")) + gv = R.call_tir(cls.nll_loss_backward, (output_grad, predictions, targets, weights), out_ty=R.Tensor((2, 3, 4, 5), dtype="float32")) return gv # fmt: on @@ -149,7 +149,7 @@ def te_nll_loss_backward_no_weight(rxplaceholder: T.Buffer((), "float32"), rxpla @R.function def main(output_grad: R.Tensor((), dtype="float32"), predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), targets: R.Tensor((2, 4, 5), dtype="int64")) -> R.Tensor((2, 3, 4, 5), dtype="float32"): cls = Expected - gv = R.call_tir(cls.te_nll_loss_backward_no_weight, (output_grad, predictions, targets), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32")) + gv = R.call_tir(cls.te_nll_loss_backward_no_weight, (output_grad, predictions, targets), out_ty=R.Tensor((2, 3, 4, 5), dtype="float32")) return gv # fmt: on @@ -171,7 +171,7 @@ class Expected: @R.function def main(output_grad: R.Tensor((), dtype="float32"), predictions: R.Tensor((4,), dtype="float32"), targets: R.Tensor((), dtype="int64"), weights: R.Tensor((4,), dtype="float32")) -> R.Tensor((4,), dtype="float32"): cls = Expected - gv = R.call_tir(cls.nll_loss_backward, (output_grad, predictions, targets, weights), out_sinfo=R.Tensor((4,), dtype="float32")) + gv = R.call_tir(cls.nll_loss_backward, (output_grad, predictions, targets, weights), out_ty=R.Tensor((4,), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -256,7 +256,7 @@ def max_pool2d_backward(A: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64 @R.function def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), dtype="float32"): cls = Expected - gv = R.call_tir(cls.max_pool2d_backward, (output_grad, data), out_sinfo=R.Tensor((3, 2, 10, 10), dtype="float32")) + gv = R.call_tir(cls.max_pool2d_backward, (output_grad, data), out_ty=R.Tensor((3, 2, 10, 10), dtype="float32")) return gv # fmt: on @@ -291,7 +291,7 @@ def avg_pool2d_backward(output_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(6 @R.function def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), dtype="float32"): cls = Expected - gv = R.call_tir(cls.avg_pool2d_backward, (output_grad, data), out_sinfo=R.Tensor((3, 2, 10, 10), dtype="float32")) + gv = R.call_tir(cls.avg_pool2d_backward, (output_grad, data), out_ty=R.Tensor((3, 2, 10, 10), dtype="float32")) return gv # fmt: on @@ -326,7 +326,7 @@ def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, va @R.function def main(output_grad: R.Tensor((3, 2, 5), dtype="float32"), x: R.Tensor((3, 4, 5), dtype="float32"), indices: R.Tensor((2,), dtype="int32")) -> R.Tensor((3, 4, 5), dtype="float32"): cls = Expected - gv = R.call_tir(cls.take_backward, (output_grad, x, indices), out_sinfo=R.Tensor((3, 4, 5), dtype="float32")) + gv = R.call_tir(cls.take_backward, (output_grad, x, indices), out_ty=R.Tensor((3, 4, 5), dtype="float32")) return gv # fmt: on @@ -369,7 +369,7 @@ def main(output_grad: R.Tensor(("m", "i"), dtype="float32"), x: R.Tensor(("m", " n = T.int64() i = T.int64() cls = Expected - gv = R.call_tir(cls.take_backward, (output_grad, x, indices), out_sinfo=R.Tensor((m, n), dtype="float32")) + gv = R.call_tir(cls.take_backward, (output_grad, x, indices), out_ty=R.Tensor((m, n), dtype="float32")) return gv # fmt: on diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index 9b905dd3da30..db4440242cac 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -240,7 +240,7 @@ def main(x: R.Tensor((8, 9, 10, 10), "float32")) : class Expected: @R.function def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")): - gv = R.call_tir(Expected.strided_slice, (x,), out_sinfo=R.Tensor((7, 9, 10, 2), dtype="float32")) + gv = R.call_tir(Expected.strided_slice, (x,), out_ty=R.Tensor((7, 9, 10, 2), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -272,7 +272,7 @@ def main(x: R.Tensor((8, 9, 10), "float32")) -> R.Tensor((8, 9, 3), "float32"): class Expected: @R.function def main(x: R.Tensor((8, 9, 10), dtype="float32")) -> R.Tensor((8, 9, 3), dtype="float32"): - gv = R.call_tir(Expected.strided_slice, (x,), out_sinfo=R.Tensor((8, 9, 3), dtype="float32")) + gv = R.call_tir(Expected.strided_slice, (x,), out_ty=R.Tensor((8, 9, 3), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -321,7 +321,7 @@ def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor((3, "n"), dtype=" n = T.int64() m = T.int64() cls = Expected - gv = R.call_tir(cls.strided_slice, (x,), out_sinfo=R.Tensor((3, n), dtype="float32")) + gv = R.call_tir(cls.strided_slice, (x,), out_ty=R.Tensor((3, n), dtype="float32")) return gv # fmt: on @@ -701,7 +701,7 @@ def main( gv = R.call_tir( Expected.shape_func, (x, begin, end, strides), - out_sinfo=R.Tensor((4,), dtype="int64"), + out_ty=R.Tensor((4,), dtype="int64"), ) gv1: R.Shape(ndim=4) = R.tensor_to_shape(gv) gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast( @@ -710,7 +710,7 @@ def main( gv_1 = R.call_tir( Expected.dynamic_strided_slice, (x, begin, end, strides), - out_sinfo=R.Tensor((s, s_1, s_2, s_3), dtype="float32"), + out_ty=R.Tensor((s, s_1, s_2, s_3), dtype="float32"), ) return gv_1 # fmt: on @@ -898,14 +898,14 @@ def main( gv = R.call_tir( Expected.shape_func, (x, begin, end, strides), - out_sinfo=R.Tensor((2,), dtype="int64"), + out_ty=R.Tensor((2,), dtype="int64"), ) gv1: R.Shape(ndim=2) = R.tensor_to_shape(gv) gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1])) gv_1 = R.call_tir( Expected.dynamic_strided_slice, (x, begin, end, strides), - out_sinfo=R.Tensor((s, s_1), dtype="float32"), + out_ty=R.Tensor((s, s_1), dtype="float32"), ) return gv_1 # fmt: on @@ -1128,7 +1128,7 @@ def matmul(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(5)), "float3 @R.function def main(x: R.Tensor((1, 1, 4, 5), dtype="float32"), y: R.Tensor((1, 1, 5, 7), dtype="float32")) -> R.Tensor((1, 1, 4, 7), dtype="float32"): cls = Expected - gv = R.call_tir(cls.matmul, (x, y), out_sinfo=R.Tensor((1, 1, 4, 7), dtype="float32")) + gv = R.call_tir(cls.matmul, (x, y), out_ty=R.Tensor((1, 1, 4, 7), dtype="float32")) return gv # fmt: on @@ -1168,7 +1168,7 @@ def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((3, 4), dtype="float32") ) -> R.Tensor((2, 4), dtype="float32"): cls = Expected - gv = R.call_tir(cls.einsum, (x, y), out_sinfo=R.Tensor((2, 4), dtype="float32")) + gv = R.call_tir(cls.einsum, (x, y), out_ty=R.Tensor((2, 4), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -1215,7 +1215,7 @@ def main( c = T.int64() b = T.int64() cls = Expected - gv = R.call_tir(cls.einsum, (x, y), out_sinfo=R.Tensor((a, c), dtype="float32")) + gv = R.call_tir(cls.einsum, (x, y), out_ty=R.Tensor((a, c), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index c3686213c398..92668ab60e20 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -543,7 +543,7 @@ def reshape( @R.function def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((8, 3), dtype="float32"): lv: R.Shape((8, 3)) = R.shape((8, 3)) - gv = R.call_tir(Expected2.reshape, (x,), out_sinfo=R.Tensor((8, 3), dtype="float32")) + gv = R.call_tir(Expected2.reshape, (x,), out_ty=R.Tensor((8, 3), dtype="float32")) return gv # fmt: on @@ -679,9 +679,7 @@ def main( ) -> R.Tensor((5, "b * 2"), dtype="float32"): b = T.int64() lv: R.Shape([5, b * 2]) = R.shape([5, b * 2]) - gv = R.call_tir( - Expected3.reshape, (x,), out_sinfo=R.Tensor((5, b * 2), dtype="float32") - ) + gv = R.call_tir(Expected3.reshape, (x,), out_ty=R.Tensor((5, b * 2), dtype="float32")) return gv mod3 = LegalizeOps()(Reshape3) @@ -716,10 +714,10 @@ def main( ) -> R.Tensor(ndim=2, dtype="float32"): M = T.int64() N = T.int64() - gv = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape(ndim=2),)) + gv = R.call_pure_packed("vm.builtin.tensor_to_shape", x, ty_args=(R.Shape(ndim=2),)) _ = R.match_cast(gv, R.Shape([M,N])) _ = R.shape([M,N]) - gv_1 = R.call_tir(Expected.reshape, (y,), out_sinfo=R.Tensor([M,N], dtype="float32")) + gv_1 = R.call_tir(Expected.reshape, (y,), out_ty=R.Tensor([M,N], dtype="float32")) return gv_1 @T.prim_func(private=True, s_tir=True) @@ -1101,7 +1099,7 @@ def main(x: R.Tensor((3, 2, 3), "float32")): class Expected: @R.function def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((6, 2, 3), dtype="float32"): - gv = R.call_tir(Expected.repeat, (x,), out_sinfo=R.Tensor((6, 2, 3), dtype="float32")) + gv = R.call_tir(Expected.repeat, (x,), out_ty=R.Tensor((6, 2, 3), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -1135,7 +1133,7 @@ class Expected: def main( x: R.Tensor((3, 2, 3), dtype="float32") ) -> R.Tensor((36,), dtype="float32"): - gv = R.call_tir(Expected.repeat, (x,), out_sinfo=R.Tensor((36,), dtype="float32")) + gv = R.call_tir(Expected.repeat, (x,), out_ty=R.Tensor((36,), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -1206,7 +1204,7 @@ def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("2 * a", "b a = T.Var("a", "int64") b = T.Var("b", "int64") c = T.Var("c", "int64") - gv = R.call_tir(Expected.repeat, (x,), out_sinfo=R.Tensor((2 * a, b, c), dtype="float32")) + gv = R.call_tir(Expected.repeat, (x,), out_ty=R.Tensor((2 * a, b, c), dtype="float32")) return gv # fmt: on @@ -1238,7 +1236,7 @@ def tile(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32" @R.function def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((2, 3, 4, 9), dtype="float32"): - gv = R.call_tir(Expected.tile, (x,), out_sinfo=R.Tensor((2, 3, 4, 9), dtype="float32")) + gv = R.call_tir(Expected.tile, (x,), out_ty=R.Tensor((2, 3, 4, 9), dtype="float32")) return gv # fmt: on @@ -1278,7 +1276,7 @@ def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor((2, "a", "b a = T.Var("a", "int64") b = T.Var("b", "int64") c = T.Var("c", "int64") - gv = R.call_tir(Expected.tile, (x,), out_sinfo=R.Tensor((2, a, b * 2, c * 3), dtype="float32")) + gv = R.call_tir(Expected.tile, (x,), out_ty=R.Tensor((2, a, b * 2, c * 3), dtype="float32")) return gv # fmt: on mod = LegalizeOps()(Tile) @@ -1299,7 +1297,7 @@ class Expected: @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): cls = Expected - gv = R.call_tir(cls.flip, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) + gv = R.call_tir(cls.flip, (x,), out_ty=R.Tensor((2, 3), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -1341,7 +1339,7 @@ def main( a = T.int64() b = T.int64() cls = Expected - gv = R.call_tir(cls.flip, (x,), out_sinfo=R.Tensor((a, b), dtype="float32")) + gv = R.call_tir(cls.flip, (x,), out_ty=R.Tensor((a, b), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -1453,7 +1451,7 @@ def main( gv = R.call_tir( Expected.scatter_elements, (x, indices, updates), - out_sinfo=R.Tensor((4, 4), dtype="float32"), + out_ty=R.Tensor((4, 4), dtype="float32"), ) return gv @@ -1544,7 +1542,7 @@ def main( gv = R.call_tir( Expected.scatter_elements, (x, indices, updates), - out_sinfo=R.Tensor((a, b), dtype="float32"), + out_ty=R.Tensor((a, b), dtype="float32"), ) return gv # fmt: on @@ -1608,7 +1606,7 @@ def te_layout_transform(A: T.Buffer((T.int64(10), T.int64(21), T.int64(30)), "fl @R.function def main(x: R.Tensor((10, 21, 30), dtype="float32")) -> R.Tensor((10, 30, 7, 3), dtype="float32"): cls = Expected - gv = R.call_tir(cls.te_layout_transform, (x,), out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32")) + gv = R.call_tir(cls.te_layout_transform, (x,), out_ty=R.Tensor((10, 30, 7, 3), dtype="float32")) return gv # fmt: on @@ -1646,7 +1644,7 @@ def te_layout_transform_with_pad(A: T.Buffer((T.int64(10), T.int64(20), T.int64( @R.function def main(x: R.Tensor((10, 20, 30), dtype="float32")) -> R.Tensor((10, 30, 7, 3), dtype="float32"): cls = Expected - gv = R.call_tir(cls.te_layout_transform_with_pad, (x,), out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32")) + gv = R.call_tir(cls.te_layout_transform_with_pad, (x,), out_ty=R.Tensor((10, 30, 7, 3), dtype="float32")) return gv # fmt: on @@ -1690,7 +1688,7 @@ def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "c", " c = T.int64() b = T.int64() cls = Expected - gv = R.call_tir(cls.te_layout_transform_with_pad, (x,), out_sinfo=R.Tensor((a, c, (b - b % -3) // 3, 3), dtype="float32")) + gv = R.call_tir(cls.te_layout_transform_with_pad, (x,), out_ty=R.Tensor((a, c, (b - b % -3) // 3, 3), dtype="float32")) return gv # fmt: on @@ -1730,7 +1728,7 @@ def te_layout_transform_with_pad_axis_separator(A: T.Buffer((T.int64(10), T.int6 @R.function def main(x: R.Tensor((10, 20, 30), dtype="float32")) -> R.Tensor((10, 30, 7, 3), dtype="float32"): cls = Expected - gv = R.call_tir(cls.te_layout_transform_with_pad_axis_separator, (x,), out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32")) + gv = R.call_tir(cls.te_layout_transform_with_pad_axis_separator, (x,), out_ty=R.Tensor((10, 30, 7, 3), dtype="float32")) return gv # fmt: on @@ -1738,13 +1736,13 @@ def main(x: R.Tensor((10, 20, 30), dtype="float32")) -> R.Tensor((10, 30, 7, 3), tvm.ir.assert_structural_equal(mod, Expected) -def test_func_struct_info_of_legalized_layout_transform(): +def test_func_ty_of_legalized_layout_transform(): """PrimFunc shape information must be correct This is a regression test. Previously, the legalization of - `R.layout_transform` produced a PrimFunc with `FuncStructInfo` + `R.layout_transform` produced a PrimFunc with `FuncType` different than its actual signature. This resulted in errors - when later passes attempted to infer the StructInfo. + when later passes attempted to infer the Type. """ @I.ir_module(s_tir=True) @@ -1780,7 +1778,7 @@ def main( ): R.func_attr({"relax.force_pure": True}) cls = Expected - alloc: R.Tensor((4, 4), dtype="float32") = R.emit_with_sinfo( + alloc: R.Tensor((4, 4), dtype="float32") = R.emit_with_ty( "relax.builtin.alloc_tensor", (R.shape([4, 4]), R.dtype("float32"), R.prim_value(0), R.str("global")), (R.Tensor((4, 4), dtype="float32"),), diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 8136997cf66c..81648e91b4f9 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -42,7 +42,7 @@ def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16, 3), "float32 class Expected: @R.function def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dtype="float32")) -> R.Tensor((2, 64, 13), dtype="float32"): - gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 64, 13), dtype="float32")) + gv = R.call_tir(Expected.conv1d, (x, w), out_ty=R.Tensor((2, 64, 13), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -82,7 +82,7 @@ def main(x: R.Tensor((2, 3, 28), "float32"), w: R.Tensor((4, 3, 3), "float32")) class Expected: @R.function def main(x: R.Tensor((2, 3, 28), dtype="float32"), w: R.Tensor((4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 26), dtype="float16"): - gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 4, 26), dtype="float16")) + gv = R.call_tir(Expected.conv1d, (x, w), out_ty=R.Tensor((2, 4, 26), dtype="float16")) return gv @T.prim_func(private=True, s_tir=True) @@ -123,7 +123,7 @@ def main(x: R.Tensor((2, 28, 128), "float32"), w: R.Tensor((64, 128, 3), "float3 class Expected: @R.function def main(x: R.Tensor((2, 28, 128), dtype="float32"), w: R.Tensor((64, 128, 3), dtype="float32")) -> R.Tensor((2, 26, 64), dtype="float32"): - gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 26, 64), dtype="float32")) + gv = R.call_tir(Expected.conv1d, (x, w), out_ty=R.Tensor((2, 26, 64), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -173,7 +173,7 @@ def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", " w = T.int64() kw = T.int64() c = T.int64() - gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w + 1 - kw), dtype="float32")) + gv = R.call_tir(Expected.conv1d, (x, kernel), out_ty=R.Tensor((n, f, w + 1 - kw), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -245,7 +245,7 @@ def conv1d_transpose(x: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float @R.function def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((128, 16, 3), dtype="float32")) -> R.Tensor((2, 128, 56), dtype="float32"): cls = Expected - gv = R.call_tir(cls.conv1d_transpose, (x, w), out_sinfo=R.Tensor((2, 128, 56), dtype="float32")) + gv = R.call_tir(cls.conv1d_transpose, (x, w), out_ty=R.Tensor((2, 128, 56), dtype="float32")) return gv # fmt: on @@ -448,7 +448,7 @@ def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((128, 16, 3, 3), class Expected: @R.function def main(x: R.Tensor((2, 128, 28, 28), dtype="float32"), w: R.Tensor((128, 16, 3, 3), dtype="float32")) -> R.Tensor((2, 128, 56, 84), dtype="float32"): - gv = R.call_tir(Expected.conv2d_transpose, (x, w), out_sinfo=R.Tensor((2, 128, 56, 84), dtype="float32")) + gv = R.call_tir(Expected.conv2d_transpose, (x, w), out_ty=R.Tensor((2, 128, 56, 84), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -503,7 +503,7 @@ def main(x: R.Tensor((2, 3, 4, 4, 4), "float32"), w: R.Tensor((3, 4, 3, 3, 3), " class Expected: @R.function def main(x: R.Tensor((2, 3, 4, 4, 4), dtype="float32"), w: R.Tensor((3, 4, 3, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 6, 6, 6), dtype="float32"): - gv = R.call_tir(Expected.conv3d_transpose, (x, w), out_sinfo=R.Tensor((2, 4, 6, 6, 6), dtype="float32")) + gv = R.call_tir(Expected.conv3d_transpose, (x, w), out_ty=R.Tensor((2, 4, 6, 6, 6), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -557,7 +557,7 @@ def main(x: R.Tensor((2, 3, 4, 4, 4), "float32"), w: R.Tensor((3, 4, 3, 3, 3), " class Expected: @R.function def main(x: R.Tensor((2, 3, 4, 4, 4), dtype="float32"), w: R.Tensor((3, 4, 3, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 6, 6, 6), dtype="float16"): - gv = R.call_tir(Expected.conv3d_transpose, (x, w), out_sinfo=R.Tensor((2, 4, 6, 6, 6), dtype="float16")) + gv = R.call_tir(Expected.conv3d_transpose, (x, w), out_ty=R.Tensor((2, 4, 6, 6, 6), dtype="float16")) return gv @T.prim_func(private=True, s_tir=True) @@ -611,7 +611,7 @@ def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 4, 3, 3), "floa class Expected: @R.function def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 30, 30), dtype="float16"): - gv = R.call_tir(Expected.conv2d_transpose, (x, w), out_sinfo=R.Tensor((2, 4, 30, 30), dtype="float16")) + gv = R.call_tir(Expected.conv2d_transpose, (x, w), out_ty=R.Tensor((2, 4, 30, 30), dtype="float16")) return gv @T.prim_func(private=True, s_tir=True) @@ -673,7 +673,7 @@ def main(x: R.Tensor(("n", "c", "h", "w"), dtype="float32"), kernel: R.Tensor((" w = T.int64() kw = T.int64() f = T.int64() - gv = R.call_tir(Expected.conv2d_transpose, (x, kernel), out_sinfo=R.Tensor((n, c, h * 3 + kh - 3, w * 3 + kw - 3), dtype="float32")) + gv = R.call_tir(Expected.conv2d_transpose, (x, kernel), out_ty=R.Tensor((n, c, h * 3 + kh - 3, w * 3 + kw - 3), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -904,7 +904,7 @@ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), @R.function def main(x: R.Tensor((4, 112, 112, 6), dtype="float32")) -> R.Tensor((4, 56, 56, 6), dtype="float32"): - gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 56, 56, 6), dtype="float32")) + gv = R.call_tir(Expected.avg_pool2d, (x,), out_ty=R.Tensor((4, 56, 56, 6), dtype="float32")) return gv # fmt: on @@ -945,7 +945,7 @@ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T. pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32", T.max((T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1) - T.max(T.int64(0) - v_ax2, T.int64(0))) * (T.min(T.int64(2), T.int64(111) - v_ax3) + T.int64(1) - T.max(T.int64(0) - v_ax3, T.int64(0))), T.int64(1))) @R.function def main(x: R.Tensor((4, 4, 112, 112, 16), dtype="float32")) -> R.Tensor((4, 4, 110, 110, 16), dtype="float32"): - gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 4, 110, 110, 16), dtype="float32")) + gv = R.call_tir(Expected.avg_pool2d, (x,), out_ty=R.Tensor((4, 4, 110, 110, 16), dtype="float32")) return gv # fmt: on @@ -994,7 +994,7 @@ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T. @R.function def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> R.Tensor((4, 6, 38, 38), dtype="float32"): - gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 6, 38, 38), dtype="float32")) + gv = R.call_tir(Expected.avg_pool2d, (x,), out_ty=R.Tensor((4, 6, 38, 38), dtype="float32")) return gv # fmt: on @@ -1279,7 +1279,7 @@ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1,), "float32")) -> R.Tens class Expected: @R.function def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((2, 3), dtype="float32")) + gv = R.call_tir(Expected.prelu, (x, y), out_ty=R.Tensor((2, 3), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -1320,7 +1320,7 @@ class Expected: @R.function def main(x: R.Tensor(("m", 7), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor(("m", 7), dtype="float32"): m = T.int64() - gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((m, 7), dtype="float32")) + gv = R.call_tir(Expected.prelu, (x, y), out_ty=R.Tensor((m, 7), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -1487,7 +1487,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): class Expected: @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): - gv = R.call_tir(Expected.gelu_tanh, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32")) + gv = R.call_tir(Expected.gelu_tanh, (x,), out_ty=R.Tensor((2, 3), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -1577,7 +1577,7 @@ class Expected: def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): m = T.int64() n = T.int64() - gv = R.call_tir(Expected.gelu_tanh, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + gv = R.call_tir(Expected.gelu_tanh, (x,), out_ty=R.Tensor((m, n), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -2412,7 +2412,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov @R.function def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")): cls = Expected - gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) + gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_ty=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) return gv # fmt: on @@ -2710,7 +2710,7 @@ def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c w = T.int64() c = T.int64() cls = Expected - gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) + gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_ty=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) return gv mod = LegalizeOps()(BatchNorm) @@ -2828,7 +2828,7 @@ def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,) R.func_attr({"num_input": 1}) cls = LayerNorm_1D_Expected with R.dataflow(): - layer_norm = R.call_tir(cls.layer_norm, (x, layer_norm_weight, layer_norm_bias), out_sinfo=R.Tensor((3,), dtype="float32")) + layer_norm = R.call_tir(cls.layer_norm, (x, layer_norm_weight, layer_norm_bias), out_ty=R.Tensor((3,), dtype="float32")) gv: R.Tensor((3,), dtype="float32") = layer_norm R.output(gv) return gv @@ -2891,7 +2891,7 @@ def layer_norm( @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), gamma: R.Tensor((4, 5), dtype="float16"), beta: R.Tensor((4, 5), dtype="float16")) -> R.Tensor((2, 3, 4, 5), dtype="float16"): - gv = R.call_tir(Expected.layer_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16")) + gv = R.call_tir(Expected.layer_norm, (x, gamma, beta), out_ty=R.Tensor((2, 3, 4, 5), dtype="float16")) return gv # fmt: on mod = LegalizeOps()(LayerNorm) @@ -3030,7 +3030,7 @@ def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.in @R.function def main(x: R.Tensor((2, 4, 4, 5), dtype="float32"), gamma: R.Tensor((4,), dtype="float32"), beta: R.Tensor((4,), dtype="float32")) -> R.Tensor((2, 4, 4, 5), dtype="float32"): - gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float32")) + gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_ty=R.Tensor((2, 4, 4, 5), dtype="float32")) return gv # fmt: on mod = LegalizeOps()(GroupNorm) @@ -3050,7 +3050,7 @@ def main(x: R.Tensor((2, 4, 4, 5), "float16"), gamma: R.Tensor((4,), "float16"), class Expected: @R.function def main(x: R.Tensor((2, 4, 4, 5), dtype="float16"), gamma: R.Tensor((4,), dtype="float16"), beta: R.Tensor((4,), dtype="float16")) -> R.Tensor((2, 4, 4, 5), dtype="float16"): - gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float16")) + gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_ty=R.Tensor((2, 4, 4, 5), dtype="float16")) return gv @T.prim_func(private=True, s_tir=True) @@ -3199,7 +3199,7 @@ def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32 c = T.int64() h = T.int64() w = T.int64() - gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_sinfo=R.Tensor((n, 4 * c, h, w), dtype="float32"), tir_vars=R.shape([c])) + gv = R.call_tir(Expected.group_norm, (x, gamma, beta), out_ty=R.Tensor((n, 4 * c, h, w), dtype="float32"), tir_vars=R.shape([c])) return gv # fmt: on mod = LegalizeOps()(GroupNorm) @@ -3275,7 +3275,7 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: R.Tensor((4, 5), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"): cls = Expected - gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32")) + gv = R.call_tir(cls.rms_norm, (x, weight), out_ty=R.Tensor((2, 3, 4, 5), dtype="float32")) return gv # fmt: on mod = LegalizeOps()(RMSNorm) @@ -3351,7 +3351,7 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), weight: R.Tensor((4, 5), dtype="float16")) -> R.Tensor((2, 3, 4, 5), dtype="float16"): cls = Expected - gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16")) + gv = R.call_tir(cls.rms_norm, (x, weight), out_ty=R.Tensor((2, 3, 4, 5), dtype="float16")) return gv # fmt: on mod = LegalizeOps()(RMSNorm) @@ -3437,7 +3437,7 @@ def main(x: R.Tensor(("n", "s", "f"), dtype="float32"), weight: R.Tensor(("s", " s = T.int64() f = T.int64() cls = Expected - gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((n, s, f), dtype="float32")) + gv = R.call_tir(cls.rms_norm, (x, weight), out_ty=R.Tensor((n, s, f), dtype="float32")) return gv # fmt: on mod = LegalizeOps()(RMSNorm) @@ -3513,7 +3513,7 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: R.Tensor((4, 5), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"): cls = Expected - gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32")) + gv = R.call_tir(cls.rms_norm, (x, weight), out_ty=R.Tensor((2, 3, 4, 5), dtype="float32")) return gv # fmt: on mod = LegalizeOps()(RMSNorm) @@ -3696,7 +3696,7 @@ def attention_bias(q: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) @R.function def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 8, 32, 8), dtype="float32"), v: R.Tensor((4, 8, 32, 16), dtype="float32"), bias: R.Tensor((4, 32, 16, 8), dtype="float32")) -> R.Tensor((4, 16, 32, 16), dtype="float32"): cls = Expected - gv = R.call_tir(cls.attention_bias, (q, k, v, bias), out_sinfo=R.Tensor((4, 16, 32, 16), dtype="float32")) + gv = R.call_tir(cls.attention_bias, (q, k, v, bias), out_ty=R.Tensor((4, 16, 32, 16), dtype="float32")) return gv # fmt: on mod = LegalizeOps()(Attention) @@ -3926,7 +3926,7 @@ class Expected: @R.function def main(predictions: R.Tensor(("C",), dtype="float32"), targets: R.Tensor((), dtype="int64"), weights: R.Tensor(("C",), dtype="float32")) -> R.Tensor((), dtype="float32"): C = T.int64() - gv = R.call_tir(Expected.nll_loss, (predictions, targets, weights), out_sinfo=R.Tensor((), dtype="float32")) + gv = R.call_tir(Expected.nll_loss, (predictions, targets, weights), out_ty=R.Tensor((), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -4045,7 +4045,7 @@ class Expected: def main( x: R.Tensor((2, 128, 28), dtype="float32"), ) -> R.Tensor((2, 130, 30), dtype="float32"): - gv = R.call_tir(Expected.pad, (x), out_sinfo=R.Tensor((2, 130, 30), dtype="float32")) + gv = R.call_tir(Expected.pad, (x), out_ty=R.Tensor((2, 130, 30), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) @@ -4086,7 +4086,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 60), "float32"): class Expected: @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2, 60), dtype="float32"): - gv = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((2, 60), dtype="float32")) + gv = R.call_tir(Expected.reshape, (x,), out_ty=R.Tensor((2, 60), dtype="float32")) return gv @T.prim_func(private=True, s_tir=True) diff --git a/tests/python/relax/test_transform_legalize_ops_qdq.py b/tests/python/relax/test_transform_legalize_ops_qdq.py index 51d18017ff6a..4caa942ce2af 100644 --- a/tests/python/relax/test_transform_legalize_ops_qdq.py +++ b/tests/python/relax/test_transform_legalize_ops_qdq.py @@ -68,7 +68,7 @@ def main( zp: R.Tensor((2,), dtype="int8"), ) -> R.Tensor((2, 4), dtype="int8"): out = R.call_tir( - Expected.quantize, (data, scale, zp), out_sinfo=R.Tensor((2, 4), dtype="int8") + Expected.quantize, (data, scale, zp), out_ty=R.Tensor((2, 4), dtype="int8") ) return out @@ -122,7 +122,7 @@ def main( zp: R.Tensor((2,), dtype="int8"), ) -> R.Tensor((2, 4), dtype="uint8"): out = R.call_tir( - Expected.quantize, (data, scale, zp), out_sinfo=R.Tensor((2, 4), dtype="uint8") + Expected.quantize, (data, scale, zp), out_ty=R.Tensor((2, 4), dtype="uint8") ) return out @@ -176,9 +176,7 @@ def main( zp: R.Tensor(("n",), dtype="int8"), ) -> R.Tensor((4, "n"), dtype="int8"): n = T.int64() - out = R.call_tir( - Expected.quantize, (data, scale, zp), out_sinfo=R.Tensor((4, n), "int8") - ) + out = R.call_tir(Expected.quantize, (data, scale, zp), out_ty=R.Tensor((4, n), "int8")) return out mod = LegalizeOps()(Quantize) @@ -222,7 +220,7 @@ def quantize( @R.function def main(data: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="int8"): - out = R.call_tir(Expected.quantize, (data,), out_sinfo=R.Tensor((2, 4), dtype="int8")) + out = R.call_tir(Expected.quantize, (data,), out_ty=R.Tensor((2, 4), dtype="int8")) return out mod = LegalizeOps()(Quantize) @@ -276,7 +274,7 @@ def main(data: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="int out = R.call_tir( cls.quantize, (data, R.const([2.0, 1.0], "float32"), R.const([4, 5], "int8")), - out_sinfo=R.Tensor((2, 4), dtype="int8"), + out_ty=R.Tensor((2, 4), dtype="int8"), ) return out @@ -321,7 +319,7 @@ def quantize( @R.function def main(data: R.Tensor((2, 4), dtype="float16")) -> R.Tensor((2, 4), dtype="int8"): - out = R.call_tir(Expected.quantize, (data,), out_sinfo=R.Tensor((2, 4), dtype="int8")) + out = R.call_tir(Expected.quantize, (data,), out_ty=R.Tensor((2, 4), dtype="int8")) return out mod = LegalizeOps()(Quantize) @@ -368,7 +366,7 @@ def main( zp: R.Tensor((2,), dtype="int8"), ) -> R.Tensor((2, 4), dtype="float32"): out = R.call_tir( - Expected.dequantize, (data, scale, zp), out_sinfo=R.Tensor((2, 4), dtype="float32") + Expected.dequantize, (data, scale, zp), out_ty=R.Tensor((2, 4), dtype="float32") ) return out @@ -407,7 +405,7 @@ def dequantize( @R.function def main(data: R.Tensor((2, 4), dtype="int8")) -> R.Tensor((2, 4), dtype="float32"): cls = Expected - out = R.call_tir(cls.dequantize, (data,), out_sinfo=R.Tensor((2, 4), dtype="float32")) + out = R.call_tir(cls.dequantize, (data,), out_ty=R.Tensor((2, 4), dtype="float32")) return out mod = LegalizeOps()(Dequantize) @@ -457,7 +455,7 @@ def main( ) -> R.Tensor((2, "n"), dtype="float32"): n = T.int64() out = R.call_tir( - Expected.dequantize, (data, scale, zp), out_sinfo=R.Tensor((2, n), dtype="float32") + Expected.dequantize, (data, scale, zp), out_ty=R.Tensor((2, n), dtype="float32") ) return out @@ -515,7 +513,7 @@ def main( zp: R.Tensor((2,), dtype="int8"), ) -> R.Tensor((2, 4), dtype="float16"): out = R.call_tir( - Expected.dequantize, (data, scale, zp), out_sinfo=R.Tensor((2, 4), dtype="float16") + Expected.dequantize, (data, scale, zp), out_ty=R.Tensor((2, 4), dtype="float16") ) return out @@ -562,7 +560,7 @@ def dequantize( @R.function def main(data: R.Tensor((2, 4), dtype="int8")) -> R.Tensor((2, 4), dtype="float16"): cls = Expected - out = R.call_tir(cls.dequantize, (data,), out_sinfo=R.Tensor((2, 4), dtype="float16")) + out = R.call_tir(cls.dequantize, (data,), out_ty=R.Tensor((2, 4), dtype="float16")) return out mod = LegalizeOps()(Dequantize) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 4a707352d352..a24c259cf44e 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -115,7 +115,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 4, 5), "int64"): class Expected: @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2, 4, 5), dtype="int64"): - gv = R.call_tir(Expected.argmax, (x,), out_sinfo=R.Tensor((2, 4, 5), dtype="int64")) + gv = R.call_tir(Expected.argmax, (x,), out_ty=R.Tensor((2, 4, 5), dtype="int64")) return gv @T.prim_func(private=True, s_tir=True) @@ -166,7 +166,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("a", 1 a = T.int64() c = T.int64() d = T.int64() - gv = R.call_tir(Expected.argmax, (x,), out_sinfo=R.Tensor((a, 1, c, d), dtype="int64")) + gv = R.call_tir(Expected.argmax, (x,), out_ty=R.Tensor((a, 1, c, d), dtype="int64")) return gv @T.prim_func(private=True, s_tir=True) @@ -241,7 +241,7 @@ def argmin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64( @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((), dtype="int64"): - gv = R.call_tir(Expected.argmin, (x,), out_sinfo=R.Tensor((), dtype="int64")) + gv = R.call_tir(Expected.argmin, (x,), out_ty=R.Tensor((), dtype="int64")) return gv # fmt: on @@ -291,7 +291,7 @@ def argmin(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), @R.function def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor((1, 1, 1, 1), dtype="int64"): - gv = R.call_tir(Expected.argmin, (x,), out_sinfo=R.Tensor((1, 1, 1, 1), dtype="int64")) + gv = R.call_tir(Expected.argmin, (x,), out_ty=R.Tensor((1, 1, 1, 1), dtype="int64")) return gv # fmt: on @@ -796,7 +796,7 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tuple(R.Tensor((3, 4, class Expected: @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")): - gv = R.call_tir(Expected.median, (x,), out_sinfo=[R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")]) + gv = R.call_tir(Expected.median, (x,), out_ty=[R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")]) return gv @T.prim_func(private=True, s_tir=True) @@ -930,7 +930,7 @@ def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)) @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): cls = Expected - gv = R.call_tir(cls.std, (x,), out_sinfo=R.Tensor((), dtype="float32")) + gv = R.call_tir(cls.std, (x,), out_ty=R.Tensor((), dtype="float32")) return gv # fmt: on @@ -1013,7 +1013,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor((), dty c = T.int64() d = T.int64() cls = Expected - gv = R.call_tir(cls.std, (x,), out_sinfo=R.Tensor((), dtype="float32")) + gv = R.call_tir(cls.std, (x,), out_ty=R.Tensor((), dtype="float32")) return gv # fmt: on @@ -1235,7 +1235,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype="float32"): cls = Expected - gv = R.call_tir(cls.variance, (x,), out_sinfo=R.Tensor((3, 4), dtype="float32")) + gv = R.call_tir(cls.variance, (x,), out_ty=R.Tensor((3, 4), dtype="float32")) return gv # fmt: on diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 56299fc4fe65..caae267aef93 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -128,7 +128,7 @@ def main_transform_params( lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] gv: R.Tuple( @@ -202,19 +202,19 @@ def main_transform_params( "vm.builtin.tuple_reset_item", params, R.prim_value(T.int32(0)), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) lv2 = R.call_tir( cls.transform_layout_IOHW_to_OIHW, (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + out_ty=R.Tensor((16, 3, 3, 3), dtype="float32"), ) lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] _2: R.Tuple = R.call_pure_packed( "vm.builtin.tuple_reset_item", params, R.prim_value(T.int32(1)), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) gv: R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), @@ -1457,9 +1457,7 @@ def main(shape: R.Shape(["n"])) -> R.Shape(["n"]): n = T.int64() cls = Before with R.dataflow(): - zeros = R.call_tir( - cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), dtype="float32") - ) + zeros = R.call_tir(cls.zeros, R.tuple(), out_ty=R.Tensor((n, n), dtype="float32")) R.output() return shape @@ -1489,9 +1487,7 @@ def main(shape: R.Shape(["n"])) -> R.Shape(["n"]): n = T.int64() cls = Expected with R.dataflow(): - zeros = R.call_tir( - cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), dtype="float32") - ) + zeros = R.call_tir(cls.zeros, R.tuple(), out_ty=R.Tensor((n, n), dtype="float32")) R.output() return shape @@ -1517,13 +1513,13 @@ def main( cls.slice, [B], tir_vars=R.ShapeExpr([slice_index]), - out_sinfo=R.Tensor([16], dtype="int32"), + out_ty=R.Tensor([16], dtype="int32"), ) A_slice = R.call_tir( cls.slice, [A], tir_vars=R.ShapeExpr([slice_index]), - out_sinfo=R.Tensor([16], dtype="int32"), + out_ty=R.Tensor([16], dtype="int32"), ) A_scale = R.multiply(A_slice, B_slice) R.output(A_scale) @@ -1557,7 +1553,7 @@ def main( cls.slice, [A], tir_vars=R.ShapeExpr([slice_index]), - out_sinfo=R.Tensor([16], dtype="int32"), + out_ty=R.Tensor([16], dtype="int32"), ) A_scale = R.multiply(A_slice, B_slice) R.output(A_scale) @@ -1577,7 +1573,7 @@ def main_transform_params( cls.slice, [B], tir_vars=R.ShapeExpr([slice_index]), - out_sinfo=R.Tensor([16], dtype="int32"), + out_ty=R.Tensor([16], dtype="int32"), ) output = (R.ShapeExpr([slice_index]), B_slice) R.output(output) diff --git a/tests/python/relax/test_transform_lower_gpu_ipc_alloc_storage.py b/tests/python/relax/test_transform_lower_gpu_ipc_alloc_storage.py index 379d52f262e5..03aa2f01abf6 100644 --- a/tests/python/relax/test_transform_lower_gpu_ipc_alloc_storage.py +++ b/tests/python/relax/test_transform_lower_gpu_ipc_alloc_storage.py @@ -48,7 +48,7 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore "runtime.disco.cuda_ipc.alloc_storage", R.shape([m, n]), R.dtype("float16"), - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) alloc: R.Tensor((m, n), dtype="float16") = R.memory.alloc_tensor( # type: ignore storage, R.prim_value(0), R.shape([m, n]), R.dtype("float16") @@ -81,7 +81,7 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore "runtime.disco.cuda_ipc.alloc_storage", R.shape([m, n]), R.dtype("float16"), - sinfo_args=(R.Object,), + ty_args=(R.Object,), ) tensor: R.Tensor((m, n), dtype="float16") = R.memory.alloc_tensor( # type: ignore gv, R.prim_value(0), R.shape([m, n]), R.dtype("float16") diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index 00ff74bbaac0..32567cba6366 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -1120,7 +1120,7 @@ def main(A: R.Tensor([10], dtype="float32")) -> R.Tensor([10], dtype="float32"): cls = Before with R.dataflow(): B = cls.fused_relax_nn_relu(A) - C = R.call_tir(cls.relu, (B,), out_sinfo=R.Tensor([10], dtype="float32")) + C = R.call_tir(cls.relu, (B,), out_ty=R.Tensor([10], dtype="float32")) D = cls.fused_relax_nn_gelu(C) R.output(D) return D @@ -1163,7 +1163,7 @@ def main(A: R.Tensor([10], dtype="float32")) -> R.Tensor([10], dtype="float32"): cls = Expected with R.dataflow(): B = cls.fused_relax_nn_relu1_compiler_A(A) - C = R.call_tir(cls.relu, (B,), out_sinfo=R.Tensor([10], dtype="float32")) + C = R.call_tir(cls.relu, (B,), out_ty=R.Tensor([10], dtype="float32")) D = cls.fused_relax_nn_gelu1_compiler_A(C) R.output(D) return D diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py index 65f04f2dc755..9863f0e7027e 100644 --- a/tests/python/relax/test_transform_meta_schedule_tuning.py +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -210,12 +210,12 @@ def main( lv0 = R.call_tir( DefaultScheduledModule.tir_matmul, (x, w), - out_sinfo=R.Tensor((32, 32), dtype="float32"), + out_ty=R.Tensor((32, 32), dtype="float32"), ) lv1 = R.call_tir( DefaultScheduledModule.tir_relu, (lv0,), - out_sinfo=R.Tensor((32, 32), dtype="float32"), + out_ty=R.Tensor((32, 32), dtype="float32"), ) R.output(lv1) return lv1 diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index 468d0c3381a4..d136c80df38e 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -36,7 +36,7 @@ def test_normalize_function(): mul_add = relax.Function( [x], relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), - ret_struct_info=R.Tensor("float16", ndim=2), + ret_ty=R.Tensor("float16", ndim=2), ) # Note: from_expr api names private function (function without global_symbol) as "main" @@ -80,7 +80,7 @@ def test_normalize_if(): ], y, ), - ret_struct_info=R.Tensor("float32", ndim=1), + ret_ty=R.Tensor("float32", ndim=1), ) before_mod = tvm.IRModule.from_expr(f) @@ -145,7 +145,7 @@ def test_normalize_seq_body(): f = relax.Function( [x, y], seq, - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) before_mod = tvm.IRModule.from_expr(f) @@ -169,7 +169,7 @@ def test_normalize_func_body(): f = relax.Function( [x, y], relax.op.add(x, y), - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) before_mod = tvm.IRModule.from_expr(f) @@ -201,7 +201,7 @@ def test_normalize_if_branches(): f = relax.Function( [cond, x, y], seq, - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) before_mod = tvm.IRModule.from_expr(f) @@ -251,7 +251,7 @@ def test_normalize_if_condition(): ], y, ), - ret_struct_info=R.Tensor("float32", ndim=1), + ret_ty=R.Tensor("float32", ndim=1), ) before_mod = tvm.IRModule.from_expr(f) @@ -284,7 +284,7 @@ def test_normalize_tuple_get_item(): ), 0, ), - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) before_mod = tvm.IRModule.from_expr(f) @@ -310,7 +310,7 @@ def test_normalize_tuple_get_item(): ], ret_var, ), - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) expected_mod = tvm.IRModule.from_expr(expected_f) # apply normalization to fill in type and shape annotations (tedious otherwise) @@ -336,7 +336,7 @@ def test_normalize_combine_nearby_blocks(): ], v3, ), - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) @@ -379,7 +379,7 @@ def test_normalize_nested_seq(): f = relax.Function( [], seq, - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) @@ -430,7 +430,7 @@ def test_normalize_nested_seq_dataflow(): f = relax.Function( [], seq, - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) @@ -503,7 +503,7 @@ def test_normalize_deeply_nested_seq(): f = relax.Function( [], seq, - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) @@ -546,7 +546,7 @@ def test_nesting_non_dataflow_in_dataflow_error(): f = relax.Function( [], seq, - ret_struct_info=R.Tensor([], "int32"), + ret_ty=R.Tensor([], "int32"), ) relax.transform.Normalize()(tvm.IRModule.from_expr(f)) # should fail due to a normal binding block being inside a dataflowblock @@ -571,7 +571,7 @@ def test_remove_usage_of_void_type_variables(): relax.VarBinding(x, R.assert_op(R.const(True, "bool"))), ] seq = relax.SeqExpr([relax.BindingBlock(bindings)], x) - before = relax.Function([], seq, ret_struct_info=R.Tuple([]), is_pure=False) + before = relax.Function([], seq, ret_ty=R.Tuple([]), is_pure=False) after = relax.transform.Normalize()(tvm.IRModule({"main": before}))["main"] diff --git a/tests/python/relax/test_transform_normalize_global_var.py b/tests/python/relax/test_transform_normalize_global_var.py index c5d4e6083f15..13c547fc0717 100644 --- a/tests/python/relax/test_transform_normalize_global_var.py +++ b/tests/python/relax/test_transform_normalize_global_var.py @@ -85,7 +85,7 @@ def f1(x: T.Buffer((1,), "int32")): @R.function def f() -> R.Tensor((1,), dtype="int32"): cls = Expected - gv = R.call_tir(cls.f1, R.tuple(), out_sinfo=R.Tensor((1,), dtype="int32")) + gv = R.call_tir(cls.f1, R.tuple(), out_ty=R.Tensor((1,), dtype="int32")) return gv After = relax.transform.NormalizeGlobalVar()(Before) diff --git a/tests/python/relax/test_transform_operator_specific_normalization.py b/tests/python/relax/test_transform_operator_specific_normalization.py index daff1fa846bc..f777f1d5b564 100644 --- a/tests/python/relax/test_transform_operator_specific_normalization.py +++ b/tests/python/relax/test_transform_operator_specific_normalization.py @@ -43,8 +43,8 @@ def custom_op(define_normalization): op_name = "custom_op.ignore_second_argument" - def infer_struct_info(call: relax.Call, context: relax.BlockBuilder): - return call.args[0].struct_info + def infer_ty(call: relax.Call, context: relax.BlockBuilder): + return call.args[0].ty def normalize(context: relax.BlockBuilder, call: relax.Call): if len(call.args) == 1: @@ -56,7 +56,7 @@ def legalize(context: relax.BlockBuilder, call: relax.Call): return call.args[0] op_attrs = { - "FInferStructInfo": infer_struct_info, + "FInferType": infer_ty, "FLegalize": legalize, "FPurity": True, } @@ -184,7 +184,7 @@ def main(A: R.Tensor([16], "float32")): return relax.Call( tvm.ir.Op.get("relax.call_tir"), [cls.multiply_by_two, args], - sinfo_args=[A.struct_info], + ty_args=[A.ty], ) @T.prim_func(private=True, s_tir=True) @@ -201,7 +201,7 @@ def main(A: R.Tensor([16], "float32")): return relax.Call( tvm.ir.Op.get("relax.call_tir"), [cls.multiply_by_two, relax.Tuple([A])], - sinfo_args=[A.struct_info], + ty_args=[A.ty], ) @T.prim_func(private=True, s_tir=True) @@ -231,7 +231,7 @@ def main(args: R.Tuple([R.Tensor([16], "float32")])): return relax.Call( tvm.ir.Op.get("relax.call_tir"), [cls.multiply_by_two, args], - sinfo_args=[args[0].struct_info], + ty_args=[args[0].ty], ) @T.prim_func(private=True, s_tir=True) @@ -247,7 +247,7 @@ def main(args: R.Tuple([R.Tensor([16], "float32")])): return relax.Call( tvm.ir.Op.get("relax.call_tir"), [cls.multiply_by_two, relax.Tuple([args[0]])], - sinfo_args=[args[0].struct_info], + ty_args=[args[0].ty], ) @T.prim_func(private=True, s_tir=True) @@ -278,7 +278,7 @@ def main(A: R.Tensor([16], "float32")): cls.multiply_by_two, A, inplace_indices=[0], - out_sinfo=[A.struct_info], + out_ty=[A.ty], ) @T.prim_func(private=True, s_tir=True) @@ -298,7 +298,7 @@ def main(A: R.Tensor([16], "float32")): tvm.ir.Op.get("relax.call_tir_inplace"), [cls.multiply_by_two, args], attrs=inplace_attrs, - sinfo_args=[A.struct_info], + ty_args=[A.ty], ) @T.prim_func(private=True, s_tir=True) @@ -328,7 +328,7 @@ def main(A: R.Tensor([16], "float32")): return R.call_tir_with_grad( cls.multiply_by_two, A, - out_sinfo=[A.struct_info], + out_ty=[A.ty], te_grad_name="f_grad", ) @@ -356,7 +356,7 @@ def main(A: R.Tensor([16], "float32")): tvm.ir.Op.get("relax.call_tir_with_grad"), [cls.multiply_by_two, args], attrs=with_grad_attrs, - sinfo_args=[A.struct_info], + ty_args=[A.ty], ) @T.prim_func(private=True, s_tir=True) diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 3897e444bc93..8d0d895813e1 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -124,14 +124,14 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4), dtype="float32 # this comes after RemovePurityChecking, so we expect purity to be forced R.func_attr({"relax.force_pure": True}) cls = Expected - gv: R.Tuple(R.Object, R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object, R.Object),)) + gv: R.Tuple(R.Object, R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), ty_args=(R.Tuple(R.Object, R.Object, R.Object),)) storage: R.Object = gv[0] alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) _1: R.Tuple = cls.exp(x, alloc) storage1: R.Object = gv[1] alloc1: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) storage2: R.Object = gv[2] - gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) + gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)), ty_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) alloc3: R.Tensor((2, 4), dtype="float32") = gv1[0] alloc4: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0)) _6: R.Tuple = cls.exp(alloc3, alloc4) @@ -234,13 +234,13 @@ def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R. def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): R.func_attr({"relax.force_pure": True}) cls = Expected - gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),)) + gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), ty_args=(R.Tuple(R.Object, R.Object),)) storage: R.Object = gv[0] alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) _: R.Tuple = cls.exp(x, alloc) storage1: R.Object = gv[1] alloc1: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) - gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) + gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), ty_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0] alloc3: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0)) _4: R.Tuple = cls.exp(alloc2, alloc3) @@ -284,7 +284,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4), dtype="float32 _3: R.Tuple = R.memory.kill_tensor(alloc) alloc2: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32") lv: R.Tensor((2, 4), dtype="float32") = alloc2 - _4: R.Tuple = R.call_packed("vm.builtin.dummy", (x, lv), sinfo_args=R.Tuple()) + _4: R.Tuple = R.call_packed("vm.builtin.dummy", (x, lv), ty_args=R.Tuple()) _5: R.Tuple = R.memory.kill_tensor(alloc1) alloc3: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0) _6 = cls.exp(alloc2, alloc3) @@ -330,16 +330,16 @@ def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R. def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4), dtype="float32"): R.func_attr({"relax.force_pure": True}) cls = Expected - gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),)) + gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), ty_args=(R.Tuple(R.Object, R.Object),)) storage: R.Object = gv[0] alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) _1: R.Tuple = cls.exp(x, alloc) storage1: R.Object = gv[1] alloc1: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) - gv1: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")),)) + gv1: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), ty_args=(R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")),)) alloc2: R.Tensor((2, 4), dtype="float32") = gv1[1] lv: R.Tensor((2, 4), dtype="float32") = gv1[0] - _4: R.Tuple = R.call_packed("vm.builtin.dummy", (x, lv), sinfo_args=(R.Tuple,)) + _4: R.Tuple = R.call_packed("vm.builtin.dummy", (x, lv), ty_args=(R.Tuple,)) _5: R.Tuple = R.memory.kill_tensor(alloc1) alloc3: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0)) _6: R.Tuple = cls.exp(alloc2, alloc3) @@ -613,7 +613,7 @@ def main( gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), - sinfo_args=(R.Tuple(R.Object, R.Object),), + ty_args=(R.Tuple(R.Object, R.Object),), ) storage: R.Object = gv[0] alloc: R.Tensor((16, 32, 32, 16), dtype="float16") = R.memory.alloc_tensor( @@ -637,7 +637,7 @@ def main( (lv_1, lv1, alloc1, alloc, params, storage), R.prim_value(0), ), - sinfo_args=( + ty_args=( R.Tuple( R.Tensor((16, 32, 32, 16), dtype="float16"), R.Tensor((16, 3, 3, 16), dtype="float16"), @@ -742,7 +742,7 @@ def main() -> R.Tuple: gv: R.Tuple(R.Object) = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), - sinfo_args=(R.Tuple(R.Object),), + ty_args=(R.Tuple(R.Object),), ) storage0: R.Object = gv[0] alloc0: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( @@ -751,7 +751,7 @@ def main() -> R.Tuple: gv1: R.Tuple = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc0,), R.prim_value(0)), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) return R.tuple() @@ -848,7 +848,7 @@ def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float3 gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), - sinfo_args=(R.Tuple(R.Object, R.Object),), + ty_args=(R.Tuple(R.Object, R.Object),), ) storage: R.Object = gv[0] alloc1: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor( @@ -867,7 +867,7 @@ def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float3 R.prim_value(0), R.shape([m]), ), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) alloc3: R.Tensor((m,), dtype="float32") = R.builtin.alloc_tensor( R.shape([m]), R.dtype("float32"), R.prim_value(0), R.str("global") @@ -891,7 +891,7 @@ def func1(): alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([128]), "float32") alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([256]), "float32") alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([512]), "float32") - R.call_packed("dummy", alloc1, alloc2, alloc3, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", alloc1, alloc2, alloc3, ty_args=(R.Tuple,)) return R.tuple() @R.function @@ -905,7 +905,7 @@ def func2(): alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([64]), "float32") alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([1024]), "float32") alloc4 = R.memory.alloc_tensor(storage4, 0, R.shape([512]), "float32") - R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4, ty_args=(R.Tuple,)) return R.tuple() @I.ir_module(s_tir=True) @@ -940,7 +940,7 @@ def func1() -> R.Tuple: gv: R.Tuple(R.Object, R.Object, R.Object, R.Object) = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), - sinfo_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),), + ty_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),), ) storage1: R.Object = gv[1] storage2: R.Object = gv[0] @@ -957,7 +957,7 @@ def func1() -> R.Tuple: R.call_builtin_with_ctx( "vm.builtin.cuda_graph.run_or_capture", (cls.func1_cuda_graph_capture, (alloc1, alloc2, alloc3), R.prim_value(0)), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) return R.tuple() @@ -968,7 +968,7 @@ def func1_cuda_graph_capture( alloc3: R.Tensor((512,), dtype="float32"), ) -> R.Tuple: R.func_attr({"relax.force_pure": True}) - R.call_packed("dummy", alloc1, alloc2, alloc3, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", alloc1, alloc2, alloc3, ty_args=(R.Tuple,)) R.tuple() return R.tuple() @@ -979,7 +979,7 @@ def func2() -> R.Tuple: gv2: R.Tuple(R.Object, R.Object, R.Object, R.Object) = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), - sinfo_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),), + ty_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),), ) storage11: R.Object = gv2[1] storage21: R.Object = gv2[2] @@ -1000,7 +1000,7 @@ def func2() -> R.Tuple: R.call_builtin_with_ctx( "vm.builtin.cuda_graph.run_or_capture", (cls.func2_cuda_graph_capture, (alloc1, alloc2, alloc3, alloc4), R.prim_value(1)), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) return R.tuple() @@ -1012,7 +1012,7 @@ def func2_cuda_graph_capture( alloc4: R.Tensor((512,), dtype="float32"), ) -> R.Tuple: R.func_attr({"relax.force_pure": True}) - R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4, ty_args=(R.Tuple,)) R.tuple() return R.tuple() @@ -1028,13 +1028,13 @@ def main(x: R.Tensor((8,), "float32")) -> R.Tuple(R.Tensor((8,), "float32")): R.func_attr({"relax.force_pure": True}) storage1 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([8]), "float32") - _ = R.call_packed("dummy", x, alloc1, sinfo_args=(R.Tuple,)) + _ = R.call_packed("dummy", x, alloc1, ty_args=(R.Tuple,)) storage2 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([8]), "float32") - _1 = R.call_packed("dummy", alloc1, alloc2, sinfo_args=(R.Tuple,)) + _1 = R.call_packed("dummy", alloc1, alloc2, ty_args=(R.Tuple,)) storage3 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([8]), "float32") - _2 = R.call_packed("dummy", alloc2, alloc3, sinfo_args=(R.Tuple,)) + _2 = R.call_packed("dummy", alloc2, alloc3, ty_args=(R.Tuple,)) gv = (alloc3,) return gv @@ -1057,7 +1057,7 @@ def main_cuda_graph_capture( alloc1: R.Tensor((8,), dtype="float32"), alloc2: R.Tensor((8,), dtype="float32") ) -> R.Tuple: R.func_attr({"relax.force_pure": True}) - R.call_packed("dummy", alloc1, alloc2, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", alloc1, alloc2, ty_args=(R.Tuple,)) R.tuple() return R.tuple() @@ -1068,13 +1068,13 @@ def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), dtype="fl gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), - sinfo_args=(R.Tuple(R.Object, R.Object),), + ty_args=(R.Tuple(R.Object, R.Object),), ) storage1: R.Object = gv[0] alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( storage1, R.prim_value(0), R.shape([8]), R.dtype("float32") ) - R.call_packed("dummy", x, alloc1, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", x, alloc1, ty_args=(R.Tuple,)) storage2: R.Object = gv[1] alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( storage2, R.prim_value(0), R.shape([8]), R.dtype("float32") @@ -1082,7 +1082,7 @@ def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), dtype="fl R.call_builtin_with_ctx( "vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc1, alloc2), R.prim_value(0)), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) storage3: R.Object = R.memory.alloc_storage( R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32") @@ -1090,7 +1090,7 @@ def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), dtype="fl alloc3: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( storage3, R.prim_value(0), R.shape([8]), R.dtype("float32") ) - R.call_packed("dummy", alloc2, alloc3, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", alloc2, alloc3, ty_args=(R.Tuple,)) gv = (alloc3,) return gv @@ -1107,13 +1107,13 @@ def main(x: R.Tensor((8,), "float16"), w: R.Tensor(("m",))): R.func_attr({"relax.force_pure": True, "num_input": 1}) storage1 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16") alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([8]), "float16") - _ = R.call_packed("dummy", x, w, alloc1, sinfo_args=(R.Tuple,)) + _ = R.call_packed("dummy", x, w, alloc1, ty_args=(R.Tuple,)) storage2 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16") alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([8]), "float16") - _1 = R.call_packed("dummy", alloc1, w, alloc2, sinfo_args=(R.Tuple,)) + _1 = R.call_packed("dummy", alloc1, w, alloc2, ty_args=(R.Tuple,)) storage3 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16") alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([8]), "float16") - _2 = R.call_packed("dummy", alloc2, w, alloc3, sinfo_args=(R.Tuple,)) + _2 = R.call_packed("dummy", alloc2, w, alloc3, ty_args=(R.Tuple,)) gv = (alloc3,) return gv @@ -1140,7 +1140,7 @@ def main_cuda_graph_capture( ) -> R.Tuple: m = T.int64() R.func_attr({"relax.force_pure": True}) - R.call_packed("dummy", alloc1, w, alloc2, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", alloc1, w, alloc2, ty_args=(R.Tuple,)) R.tuple() return R.tuple() @@ -1154,13 +1154,13 @@ def main(x: R.Tensor((8,), dtype="float16"), w: R.Tensor(("m",))) -> R.Tuple( gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), - sinfo_args=(R.Tuple(R.Object, R.Object),), + ty_args=(R.Tuple(R.Object, R.Object),), ) storage1: R.Object = gv[0] alloc1: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor( storage1, R.prim_value(0), R.shape([8]), R.dtype("float16") ) - R.call_packed("dummy", x, w, alloc1, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", x, w, alloc1, ty_args=(R.Tuple,)) storage2: R.Object = gv[1] alloc2: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor( storage2, R.prim_value(0), R.shape([8]), R.dtype("float16") @@ -1173,7 +1173,7 @@ def main(x: R.Tensor((8,), dtype="float16"), w: R.Tensor(("m",))) -> R.Tuple( R.prim_value(0), R.shape([m]), ), - sinfo_args=(R.Tuple,), + ty_args=(R.Tuple,), ) storage3: R.Object = R.memory.alloc_storage( R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16") @@ -1181,7 +1181,7 @@ def main(x: R.Tensor((8,), dtype="float16"), w: R.Tensor(("m",))) -> R.Tuple( alloc3: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor( storage3, R.prim_value(0), R.shape([8]), R.dtype("float16") ) - R.call_packed("dummy", alloc2, w, alloc3, sinfo_args=(R.Tuple,)) + R.call_packed("dummy", alloc2, w, alloc3, ty_args=(R.Tuple,)) gv_1: R.Tuple(R.Tensor((8,), dtype="float16")) = (alloc3,) return gv_1 diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index c96eec052f06..df774656e14b 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -69,10 +69,8 @@ def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor( ): cls = Module with R.dataflow(): - y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) - z = R.call_tir( - cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), "float32") - ) + y = R.call_tir(cls.reshape, (x,), out_ty=R.Tensor((2, 4, 3), dtype="float32")) + z = R.call_tir(cls.expand_dims, (y,), out_ty=R.Tensor((2, 1, 4, 1, 3), "float32")) R.output(z) return z @@ -124,7 +122,7 @@ def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor( # Note: `z` is the output var of the dataflow block, and is thus # not expected to be rewritten. z = R.call_tir( - cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), dtype="float32") + cls.expand_dims, (y,), out_ty=R.Tensor((2, 1, 4, 1, 3), dtype="float32") ) R.output(z) return z @@ -175,9 +173,9 @@ def main( ) -> R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32"): cls = Module with R.dataflow(): - y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4096, 5, 64), dtype="float32")) + y = R.call_tir(cls.reshape, (x,), out_ty=R.Tensor((2, 4096, 5, 64), dtype="float32")) z = R.call_tir( - cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4096, 1, 5, 64), "float32") + cls.expand_dims, (y,), out_ty=R.Tensor((2, 1, 4096, 1, 5, 64), "float32") ) R.output(z) return z @@ -214,7 +212,7 @@ def main(x: R.Tensor((2, 4096, 320), dtype="float32")) -> R.Tensor((2, 1, 4096, cls = Expected with R.dataflow(): y: R.Tensor((2, 4096, 5, 64), dtype="float32") = R.reshape(x, R.shape([2, 4096, 5, 64])) - z = R.call_tir(cls.expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32")) + z = R.call_tir(cls.expand_dims, (y,), out_ty=R.Tensor((2, 1, 4096, 1, 5, 64), dtype="float32")) R.output(z) return z # fmt: on @@ -260,9 +258,7 @@ def main(x: R.Tensor((8, 16, 128), dtype="float16")) -> R.Tensor( ): cls = Module with R.dataflow(): - y = R.call_tir( - cls.reshape, (x,), out_sinfo=R.Tensor((1, 8, 16, 128), dtype="float16") - ) + y = R.call_tir(cls.reshape, (x,), out_ty=R.Tensor((1, 8, 16, 128), dtype="float16")) z = R.add(y, R.const(1, "float16")) R.output(z) return z @@ -339,7 +335,7 @@ def reshape( @R.function def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): cls = Module - y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) + y = R.call_tir(cls.reshape, (x,), out_ty=R.Tensor((2, 4, 3), dtype="float32")) return y assert relax.analysis.has_reshape_pattern(Module["reshape"]) @@ -404,7 +400,7 @@ def main( lv645 = R.call_tir( cls.fused_reshape5, (lv, lv1, lv2), - out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float16"), + out_ty=R.Tensor((2, 4096, 8, 40), dtype="float16"), ) out: R.Tensor((2, 4096, 8, 40), dtype="float16") = R.add(lv645, lv645) R.output(out) @@ -507,10 +503,8 @@ def add_one( def main(A: R.Tensor((1, 1024), dtype="int32")) -> R.Tensor((1, 1000), dtype="int32"): with R.dataflow(): cls = Module - S = R.call_tir( - cls.strided_slice, (A,), out_sinfo=R.Tensor((1, 1000), dtype="int32") - ) - A = R.call_tir(cls.add_one, (S,), out_sinfo=R.Tensor((1, 1000), dtype="int32")) + S = R.call_tir(cls.strided_slice, (A,), out_ty=R.Tensor((1, 1000), dtype="int32")) + A = R.call_tir(cls.add_one, (S,), out_ty=R.Tensor((1, 1000), dtype="int32")) R.output(A) return A @@ -525,11 +519,9 @@ class Module: @R.function def main(x: R.Tensor((8, 8), dtype="float16")) -> R.Tensor((8, 8), dtype="float16"): with R.dataflow(): - gv = R.call_pure_packed( - "foo", x, x, sinfo_args=(R.Tensor((8, 8), dtype="float16"),) - ) + gv = R.call_pure_packed("foo", x, x, ty_args=(R.Tensor((8, 8), dtype="float16"),)) out = R.call_pure_packed( - "foo", gv, gv, sinfo_args=(R.Tensor((8, 8), dtype="float16"),) + "foo", gv, gv, ty_args=(R.Tensor((8, 8), dtype="float16"),) ) R.output(out) return out @@ -582,7 +574,7 @@ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"): cls = Expected with R.dataflow(): lv1: R.Tensor((1,), dtype="float32") = R.reshape(x, R.shape([1])) - lv2 = R.call_tir(cls.add, (lv1, lv1), out_sinfo=R.Tensor((1,), dtype="float32")) + lv2 = R.call_tir(cls.add, (lv1, lv1), out_ty=R.Tensor((1,), dtype="float32")) R.output(lv2) return lv2 @@ -611,7 +603,7 @@ def main(x: R.Tensor((256,), dtype="float32")): with R.dataflow(): y = R.reshape(x, R.shape([64, 4])) - z = R.call_tir(cls.add, (y, y), out_sinfo=R.Tensor((64, 4), dtype="float32")) + z = R.call_tir(cls.add, (y, y), out_ty=R.Tensor((64, 4), dtype="float32")) R.output(z) return z @@ -669,7 +661,7 @@ def add( # cls.add, # (y, y), # tir_vars=[N], -# out_sinfo=R.Tensor((N // 4, 4), dtype="float32"), +# out_ty=R.Tensor((N // 4, 4), dtype="float32"), # ) # R.output(z) # return z @@ -734,7 +726,7 @@ def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")): cls.add, (y, y), tir_vars=[N], - out_sinfo=R.Tensor((N * 4, 4), dtype="float32"), + out_ty=R.Tensor((N * 4, 4), dtype="float32"), ) R.output(z) return z diff --git a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py index d61bf465d7f1..f99c41503492 100644 --- a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py +++ b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py @@ -50,34 +50,30 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re ) else: for idx, arg in enumerate(call.args[1]): - arg_sinfo = arg.struct_info - assert isinstance(arg_sinfo, relax.TensorStructInfo), ( - f"Expected TensorStructInfo but git {type(arg_sinfo)}" + arg_ty = arg.ty + assert isinstance(arg_ty, relax.TensorType), ( + f"Expected TensorType but git {type(arg_ty)}" ) buf = pfunc.buffer_map[pfunc.params[idx]] - assert ( - arg_sinfo.vdevice.memory_scope == buf.data.type_annotation.storage_scope - ), ( - f"scope mismatched after specialization {arg_sinfo.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + assert arg_ty.vdevice.memory_scope == buf.data.type_annotation.storage_scope, ( + f"scope mismatched after specialization {arg_ty.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" ) - if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + if isinstance(call.ty_args[0], relax.TensorType): buf = pfunc.buffer_map[pfunc.params[-1]] assert ( - call.sinfo_args[0].vdevice.memory_scope + call.ty_args[0].vdevice.memory_scope == buf.data.type_annotation.storage_scope ), ( - f"scope mismatched after specialization {call.sinfo_args[0].vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + f"scope mismatched after specialization {call.ty_args[0].vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" ) else: - assert isinstance(call.sinfo_args[0], relax.TupleStructInfo), ( - f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + assert isinstance(call.ty_args[0], relax.TupleType), ( + f"Expected TupleType but git {type(call.ty_args[0])}" ) - for idx, sinfo in enumerate(call.sinfo_args[0].fields): + for idx, ty in enumerate(call.ty_args[0].fields): buf = pfunc.buffer_map[pfunc.params[len(call.args[1]) + idx]] - assert ( - sinfo.vdevice.memory_scope == buf.data.type_annotation.storage_scope - ), ( - f"scope mismatched after specialization {sinfo.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + assert ty.vdevice.memory_scope == buf.data.type_annotation.storage_scope, ( + f"scope mismatched after specialization {ty.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" ) @@ -191,14 +187,14 @@ def main( lv = R.call_tir( cls.te_layout_transform, (x,), - out_sinfo=R.Tensor( + out_ty=R.Tensor( (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" ), ) lv2 = R.call_tir( cls.max_pool2d_opencl, (lv,), - out_sinfo=R.Tensor( + out_ty=R.Tensor( (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" ), ) @@ -208,7 +204,7 @@ def main( gv2 = R.call_tir( cls.te_layout_transform2, (lv5,), - out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + out_ty=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), ) R.output(gv2) return gv2 @@ -293,28 +289,28 @@ def main( lv = R.call_tir( cls.te_layout_transform, (x,), - out_sinfo=R.Tensor( + out_ty=R.Tensor( (2, 4, 28, 28, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" ), ) lv1 = R.call_tir( cls.te_layout_transform1, (w,), - out_sinfo=R.Tensor( + out_ty=R.Tensor( (1, 16, 3, 3, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" ), ) gv = R.call_tir( cls.conv2d_NCHWc_OIHWo_opencl, (lv, lv1), - out_sinfo=R.Tensor( + out_ty=R.Tensor( (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" ), ) lv_1 = R.call_tir( cls.fused_relu_concatenate_split, (gv,), - out_sinfo=[ + out_ty=[ R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global"), R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global"), ], @@ -325,7 +321,7 @@ def main( lv4 = R.call_tir( cls.te_layout_transform2, (lv3,), - out_sinfo=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), + out_ty=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), ) lv5: R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global") = lv_1[ 1 @@ -333,7 +329,7 @@ def main( lv6 = R.call_tir( cls.te_layout_transform2, (lv5,), - out_sinfo=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), + out_ty=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), ) gv4: R.Tuple( R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py index 5325ee2b1e81..f52d74d36cf1 100644 --- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -52,9 +52,7 @@ def forward( R.func_attr({"num_input": 1}) cls = Before with R.dataflow(): - gv = R.call_tir( - cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32") - ) + gv = R.call_tir(cls.tir_func, (x, w), out_ty=R.Tensor((224, 224), dtype="float32")) R.output(gv) return gv @@ -91,10 +89,10 @@ def forward( cls = After with R.dataflow(): lv = R.call_tir( - cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32") + cls.tir_func_weight_prepack, (w,), out_ty=R.Tensor((4, 4, 56, 56), "float32") ) lv1 = R.call_tir( - cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32") + cls.tir_func_prepacked, (x, lv), out_ty=R.Tensor((224, 224), "float32") ) gv: R.Tensor((224, 224), dtype="float32") = lv1 R.output(gv) @@ -146,7 +144,7 @@ def forward( cls = Before with R.dataflow(): gv = R.call_tir( - cls.tir_func, (x, w1, w2), out_sinfo=R.Tensor((224, 224), dtype="float32") + cls.tir_func, (x, w1, w2), out_ty=R.Tensor((224, 224), dtype="float32") ) R.output(gv) return gv @@ -198,7 +196,7 @@ def forward( lv0 = R.call_tir( cls.tir_func_weight_prepack, (w1, w2), - out_sinfo=[ + out_ty=[ R.Tensor((4, 4, 56, 56), "float32"), R.Tensor((4, 4, 56, 56), "float32"), ], @@ -206,7 +204,7 @@ def forward( lv1 = R.call_tir( cls.tir_func_prepacked, (x, lv0[0], lv0[1]), - out_sinfo=R.Tensor((224, 224), "float32"), + out_ty=R.Tensor((224, 224), "float32"), ) gv: R.Tensor((224, 224), dtype="float32") = lv1 R.output(gv) @@ -246,9 +244,7 @@ def forward( R.func_attr({"num_input": 1}) cls = Before with R.dataflow(): - gv = R.call_tir( - cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32") - ) + gv = R.call_tir(cls.tir_func, (x, w), out_ty=R.Tensor((224, 224), dtype="float32")) R.output(gv) return gv @@ -287,10 +283,10 @@ def forward( cls = After with R.dataflow(): lv = R.call_tir( - cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32") + cls.tir_func_weight_prepack, (w,), out_ty=R.Tensor((4, 4, 56, 56), "float32") ) lv1 = R.call_tir( - cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32") + cls.tir_func_prepacked, (x, lv), out_ty=R.Tensor((224, 224), "float32") ) gv: R.Tensor((224, 224), dtype="float32") = lv1 R.output(gv) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 7a658c0a355f..e181f715d7f4 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -160,7 +160,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32 storage: R.Object = R.vm.alloc_storage(R.shape([32]), R.prim_value(0), R.dtype("uint8")) alloc: R.Tensor((2, 4), dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) _: R.Tuple = cls.exp(x, alloc) - lv1: R.Tensor((8,), dtype="float32") = R.call_packed("vm.builtin.reshape", alloc, R.shape([8]), sinfo_args=(R.Tensor((8,), dtype="float32"),)) + lv1: R.Tensor((8,), dtype="float32") = R.call_packed("vm.builtin.reshape", alloc, R.shape([8]), ty_args=(R.Tensor((8,), dtype="float32"),)) _ = R.vm.kill_object(alloc) storage1: R.Object = R.vm.alloc_storage(R.shape([40]), R.prim_value(0), R.dtype("uint8")) alloc1: R.Tensor((8,), dtype="float32") = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape([8]), R.dtype("float32")) @@ -647,12 +647,12 @@ def main(x: R.Tensor((2, 3), "float32")): alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) - _ = R.call_packed("extern_func", x, alloc, sinfo_args=[R.Tuple()]) + _ = R.call_packed("extern_func", x, alloc, ty_args=[R.Tuple()]) y: R.Tensor((2, 3), dtype="float32") = alloc alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) - _1 = R.call_packed("extern_func", y, alloc1, sinfo_args=[R.Tuple()]) + _1 = R.call_packed("extern_func", y, alloc1, ty_args=[R.Tuple()]) z: R.Tensor((2, 3), dtype="float32") = alloc1 return z @@ -666,12 +666,12 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3 alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( storage, R.prim_value(0), R.shape([2, 3]), R.dtype("float32") ) - _: R.Tuple = R.call_packed("extern_func", x, alloc, sinfo_args=(R.Tuple(),)) + _: R.Tuple = R.call_packed("extern_func", x, alloc, ty_args=(R.Tuple(),)) y: R.Tensor((2, 3), dtype="float32") = alloc alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( R.shape([2, 3]), R.dtype("float32"), R.prim_value(0) ) - _1: R.Tuple = R.call_packed("extern_func", y, alloc1, sinfo_args=(R.Tuple(),)) + _1: R.Tuple = R.call_packed("extern_func", y, alloc1, ty_args=(R.Tuple(),)) z: R.Tensor((2, 3), dtype="float32") = alloc1 return z @@ -1618,7 +1618,7 @@ def main(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Te "vm.builtin.reshape", cumsum, R.shape([batch_size, vocab_size]), - sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),), + ty_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),), ) return lv1_1 @@ -1669,7 +1669,7 @@ def main(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Te "vm.builtin.reshape", cumsum, R.shape([batch_size, vocab_size]), - sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),), + ty_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),), ) return lv1_1 diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index f2480d103150..226cbe048f04 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -1046,8 +1046,8 @@ class Before: def main(A: R.Tensor([64], "float16")): cls = Before with R.dataflow(): - B = R.call_tir(cls.tir_identity, [A], out_sinfo=R.Tensor([64], "float16")) - C = R.call_tir(cls.tir_identity, [B], out_sinfo=R.Tensor([64], "float16")) + B = R.call_tir(cls.tir_identity, [A], out_ty=R.Tensor([64], "float16")) + C = R.call_tir(cls.tir_identity, [B], out_ty=R.Tensor([64], "float16")) R.output(C) return C diff --git a/tests/python/relax/test_transform_update_param_struct_info.py b/tests/python/relax/test_transform_update_param_type.py similarity index 85% rename from tests/python/relax/test_transform_update_param_struct_info.py rename to tests/python/relax/test_transform_update_param_type.py index 30200955322d..0c7de2d5252c 100644 --- a/tests/python/relax/test_transform_update_param_struct_info.py +++ b/tests/python/relax/test_transform_update_param_type.py @@ -27,7 +27,7 @@ class Base: def test_compare(self): - transform = relax.transform.UpdateParamStructInfo(self.update_sinfo) + transform = relax.transform.UpdateParamType(self.update_ty) if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception): with pytest.raises(self.Expected): @@ -36,15 +36,15 @@ def test_compare(self): after = transform(self.Before) tvm.ir.assert_structural_equal(self.Expected, after) - def update_sinfo(self, var: relax.Var) -> relax.StructInfo | None: - """The struct info update function provided to the transform""" + def update_ty(self, var: relax.Var) -> relax.Type | None: + """The parameter type update function provided to the transform""" raise NotImplementedError("Should be implemented in derived class") class TestSimple(Base): - def update_sinfo(self, var: relax.Var) -> relax.StructInfo | None: + def update_ty(self, var: relax.Var) -> relax.Type | None: if var.name_hint == "weight": - return relax.TensorStructInfo([64, 16], "float32") + return relax.TensorType([64, 16], "float32") @I.ir_module class Before: diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index ade2378b7937..ca3cabf445d1 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -34,33 +34,29 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) with R.function(): R.func_name("foo") R.func_attr({"Primitive": True}) - x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) - R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", ndim=2)) + x = R.arg("x", relax.TensorType((128, 128), "float32")) + R.func_ret_ty(relax.TensorType(dtype="float32", ndim=2)) y = R.emit( - R.call_dps_packed( - "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") - ) + R.call_dps_packed("extern_func", x, relax.TensorType((128, 128), dtype="float32")) ) out = R.emit( R.call_dps_packed( - "extern_dps_func", y, relax.TensorStructInfo((128, 128), dtype="float32") + "extern_dps_func", y, relax.TensorType((128, 128), dtype="float32") ) ) IRBuilder.name("out", out) R.func_ret_value(out) func = ir_builder.get() # create with BlockBuilder - x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) + x = relax.Var("x", relax.TensorType((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,), attrs={"Primitive": True}): y = bb.emit( - relax.call_dps_packed( - "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") - ) + relax.call_dps_packed("extern_func", x, relax.TensorType((128, 128), dtype="float32")) ) out = bb.emit( relax.call_dps_packed( - "extern_dps_func", y, relax.TensorStructInfo((128, 128), dtype="float32") + "extern_dps_func", y, relax.TensorType((128, 128), dtype="float32") ) ) bb.emit_func_output(out) @@ -88,13 +84,13 @@ def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(n with IRBuilder() as ir_builder: with R.function(): R.func_name("foo") - x = R.arg("x", relax.TensorStructInfo(ndim=-1, dtype="float32")) - y = R.arg("y", relax.TensorStructInfo(ndim=-1, dtype="float32")) + x = R.arg("x", relax.TensorType(ndim=-1, dtype="float32")) + y = R.arg("y", relax.TensorType(ndim=-1, dtype="float32")) m = tirx.Var("m", dtype="int64") n = tirx.Var("n", dtype="int64") - _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32")) - y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32")) - v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) + _ = R.emit_match_cast(x, relax.TensorType((m,), "float32")) + y1 = R.emit_match_cast(y, relax.TensorType((n,), "float32")) + v = relax.Var("v", relax.TensorType((n,), "float32")) vb = relax.VarBinding(v, y1) v = R.emit_var_binding(vb) R.emit(v) @@ -106,13 +102,13 @@ def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(n # create with BlockBuilder m = tirx.Var("m", dtype="int64") n = tirx.Var("n", dtype="int64") - x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1)) - y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1)) - v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) + x = relax.Var("x", relax.TensorType(dtype="float32", ndim=-1)) + y = relax.Var("y", relax.TensorType(dtype="float32", ndim=-1)) + v = relax.Var("v", relax.TensorType((n,), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x, y)): - _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32")) - y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32")) + _ = bb.match_cast(x, relax.TensorType((m,), "float32")) + y1 = bb.match_cast(y, relax.TensorType((n,), "float32")) bb.emit_normalized(relax.VarBinding(v, y1)) bb.emit(v) bb.emit_func_output(relax.ShapeExpr([m, n * 2])) @@ -136,11 +132,11 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): with IRBuilder() as ir_builder: with R.function(): R.func_name("foo") - x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) + x = R.arg("x", relax.TensorType((128, 128), "float32")) with R.dataflow() as df: lv0 = R.emit( R.call_dps_packed( - "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + "extern_func", x, relax.TensorType((128, 128), dtype="float32") ) ) IRBuilder.name("lv0", lv0) @@ -152,13 +148,13 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): func = ir_builder.get() # create with BlockBuilder - x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) + x = relax.Var("x", relax.TensorType((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,)): with bb.dataflow(): lv0 = bb.emit( relax.call_dps_packed( - "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + "extern_func", x, relax.TensorType((128, 128), dtype="float32") ) ) gv = bb.emit_output(lv0) @@ -200,14 +196,14 @@ def subroutine( # create with BlockBuilder bb = relax.BlockBuilder() - A_sub = relax.Var("A", relax.TensorStructInfo((128, 128), "float32")) - B_sub = relax.Var("B", relax.TensorStructInfo((128, 128), "float32")) + A_sub = relax.Var("A", relax.TensorType((128, 128), "float32")) + B_sub = relax.Var("B", relax.TensorType((128, 128), "float32")) with bb.function("subroutine", (A_sub, B_sub)): out = bb.emit(R.add(A_sub, B_sub)) subroutine = bb.emit_func_output(out) - A = relax.Var("A", relax.TensorStructInfo((128, 128), "float32")) - B = relax.Var("B", relax.TensorStructInfo((128, 128), "float32")) + A = relax.Var("A", relax.TensorType((128, 128), "float32")) + B = relax.Var("B", relax.TensorType((128, 128), "float32")) with bb.function("main", (A, B)): out = bb.emit(subroutine(A, B)) bb.emit_func_output(out) @@ -242,11 +238,11 @@ def subroutine( # create with BlockBuilder bb = relax.BlockBuilder() - A = relax.Var("A", relax.TensorStructInfo((128, 128), "float32")) - B = relax.Var("B", relax.TensorStructInfo((128, 128), "float32")) + A = relax.Var("A", relax.TensorType((128, 128), "float32")) + B = relax.Var("B", relax.TensorType((128, 128), "float32")) with bb.function("main", (A, B)): - A_sub = relax.Var("A", relax.TensorStructInfo((128, 128), "float32")) - B_sub = relax.Var("B", relax.TensorStructInfo((128, 128), "float32")) + A_sub = relax.Var("A", relax.TensorType((128, 128), "float32")) + B_sub = relax.Var("B", relax.TensorType((128, 128), "float32")) with bb.function("subroutine", (A_sub, B_sub)): out = bb.emit(R.add(A_sub, B_sub)) subroutine = bb.emit_func_output(out) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 458b9625639d..0f251940ddb8 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -262,7 +262,7 @@ def main(x: R.Tensor((10, 20), "float32")) -> R.Tensor((10, 20), dtype="float32" return out bb = relax.BlockBuilder() - x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32")) + x = relax.Var("x", relax.TensorType([10, 20], "float32")) with bb.function("main", [x], {"global_symbol": "main"}): lv1 = bb.emit_te(topi.add, x, x) out = bb.emit_te(topi.multiply, lv1, lv1) @@ -832,7 +832,7 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): def test_call_packed(): @R.function(pure=False) def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: - z = R.call_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) + z = R.call_packed("vm.builtin.copy", x, ty_args=R.Tensor((32, 32), "float32")) return z x = relax.Var("x", R.Tensor((32, 32), "float32")) @@ -843,7 +843,7 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: relax.ExternFunc("vm.builtin.copy"), (x,), None, - sinfo_args=[R.Tensor((32, 32), "float32")], + ty_args=[R.Tensor((32, 32), "float32")], ) ) bb.emit_func_output(z) @@ -851,7 +851,7 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: _check(foo, bb.get()["foo"]) -def test_call_packed_without_sinfo_args(): +def test_call_packed_without_ty_args(): @R.function(pure=False) def foo(x: R.Object) -> R.Object: z = R.call_packed("test", x) @@ -865,7 +865,7 @@ def foo(x: R.Object) -> R.Object: relax.ExternFunc("test"), (x,), None, - sinfo_args=[], + ty_args=[], ) ) bb.emit_func_output(z) @@ -885,29 +885,29 @@ def foo( w: R.Tensor(ndim=2) = R.multiply(z, z) q: R.Tensor = R.add(w, w) t = R.add(w, z) - sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) + sh: R.Shape = R.call_packed("shape_of", x, ty_args=R.Shape) lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh) - o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) + o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, ty_args=R.Object) return o - def _check_struct_info(binding, expected_sinfo): - tvm.ir.assert_structural_equal(binding.var.struct_info, expected_sinfo) - tvm.ir.assert_structural_equal(binding.value.struct_info, expected_sinfo) + def _check_ty(binding, expected_ty): + tvm.ir.assert_structural_equal(binding.var.ty, expected_ty) + tvm.ir.assert_structural_equal(binding.value.ty, expected_ty) # Cannot use block builder here because we need to check the annotated type, # which may be inconsistent with deduced type. - assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) + assert isinstance(foo.ret_ty, relax.ObjectType) m = relax.get_shape_of(foo.params[0])[1] bindings = foo.body.blocks[0].bindings sh = bindings[4].var - _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) - _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=2)) - _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=-1)) - _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=2)) - _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) - _check_struct_info(bindings[5], relax.TensorStructInfo(sh)) - _check_struct_info(bindings[6], relax.ObjectStructInfo()) + _check_ty(bindings[0], relax.TensorType([32, m], "float32")) + _check_ty(bindings[1], relax.TensorType(dtype="", ndim=2)) + _check_ty(bindings[2], relax.TensorType(dtype="", ndim=-1)) + _check_ty(bindings[3], relax.TensorType(dtype="", ndim=2)) + _check_ty(bindings[4], relax.ShapeType(ndim=-1)) + _check_ty(bindings[5], relax.TensorType(sh)) + _check_ty(bindings[6], relax.ObjectType()) def test_annotate_override(): @@ -918,28 +918,28 @@ def foo(x: R.Tensor): z: R.Object = R.add(x, y) return z - assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) + assert isinstance(foo.ret_ty, relax.ObjectType) y_bind, z_bind = foo.body.blocks[0].bindings - assert isinstance(y_bind.var.struct_info, relax.TensorStructInfo) - assert isinstance(z_bind.var.struct_info, relax.ObjectStructInfo) + assert isinstance(y_bind.var.ty, relax.TensorType) + assert isinstance(z_bind.var.ty, relax.ObjectType) with pytest.raises(tvm.error.DiagnosticError): @R.function def test(x: R.Tensor): - # Error: x is of Tensor StructInfo, which can not annotate to R.Shape. + # Error: x is of Tensor Type, which can not annotate to R.Shape. z: R.Shape = x return z @R.function def bar(x: R.Tensor): - # x is of Tensor StructInfo, the annotation of `z` is ignored. + # x is of Tensor Type, the annotation of `z` is ignored. z: R.Object = x return z - assert isinstance(bar.ret_struct_info, relax.TensorStructInfo) + assert isinstance(bar.ret_ty, relax.TensorType) (z_bind,) = bar.body.blocks[0].bindings - assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo) + assert isinstance(z_bind.var.ty, relax.TensorType) def test_call_dps_packed_empty_shape(): @@ -949,7 +949,7 @@ def foo(x: R.Tensor((), "float32")): return z (z_bind,) = foo.body.blocks[0].bindings - shape_expr = z_bind.value.sinfo_args[0].shape + shape_expr = z_bind.value.ty_args[0].shape assert isinstance(shape_expr, relax.ShapeExpr) assert len(shape_expr.values) == 0 @@ -1067,7 +1067,7 @@ def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")): # caught and raised during parsing. args, inplace_indices=[0, -1], - out_sinfo=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], + out_ty=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], ) return res @@ -1256,7 +1256,7 @@ def func(A: R.Tensor([16, 16]), is_bfloat16: R.Prim("bool")): # If the `R.match_cast` were removed, the function would infer the # return value as `R.Tensor([16,16])`, with an unknown dtype. # With the `R.match_cast` retained, the output dtype is known. - tvm.ir.assert_structural_equal(func.ret_struct_info, R.Tensor([16, 16], "float16")) + tvm.ir.assert_structural_equal(func.ret_ty, R.Tensor([16, 16], "float16")) def test_if_inside_dataflow(): @@ -1302,7 +1302,7 @@ def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")): if_else = func.body.blocks[0].bindings[0].value assert isinstance(if_else.cond, relax.Var) - tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Tensor([], "bool")) + tvm.ir.assert_structural_equal(if_else.cond.ty, R.Tensor([], "bool")) def test_prim_value_as_branch_condition(): @@ -1318,7 +1318,7 @@ def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")): if_else = func.body.blocks[0].bindings[0].value assert isinstance(if_else.cond, relax.Var) - tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim("bool")) + tvm.ir.assert_structural_equal(if_else.cond.ty, R.Prim("bool")) def test_computed_prim_value_as_branch_condition(): @@ -1328,16 +1328,16 @@ def test_computed_prim_value_as_branch_condition(): def func(x: R.Tensor(["N"], "float32")): N = T.int64() if R.prim_value(N % 16 == 0): - out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_pure_packed("fast_vectorized_impl", x, ty_args=[x.ty]) else: - out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_pure_packed("slow_non_vectorized_impl", x, ty_args=[x.ty]) return out - N = func.params[0].struct_info.shape[0] + N = func.params[0].ty.shape[0] if_else = func.body.blocks[0].bindings[0].value assert isinstance(if_else.cond, relax.PrimValue) tvm.ir.assert_structural_equal(N % 16 == 0, if_else.cond.value) - tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim(value=N % 16 == 0)) + tvm.ir.assert_structural_equal(if_else.cond.ty, R.Prim(value=N % 16 == 0)) def test_tir_expr_as_branch_condition(): @@ -1347,18 +1347,18 @@ def test_tir_expr_as_branch_condition(): def sugared(x: R.Tensor(["N"], "float32")): N = T.int64() if N % 16 == 0: - out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_pure_packed("fast_vectorized_impl", x, ty_args=[x.ty]) else: - out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_pure_packed("slow_non_vectorized_impl", x, ty_args=[x.ty]) return out @R.function(private=True) def unsugared(x: R.Tensor(["N"], "float32")): N = T.int64() if R.prim_value(N % 16 == 0): - out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_pure_packed("fast_vectorized_impl", x, ty_args=[x.ty]) else: - out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_pure_packed("slow_non_vectorized_impl", x, ty_args=[x.ty]) return out tvm.ir.assert_structural_equal(unsugared, sugared) @@ -1376,7 +1376,7 @@ def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")): assert_op = func.body.blocks[0].bindings[0].value condition = assert_op.args[0] assert isinstance(condition, relax.Var) - tvm.ir.assert_structural_equal(condition.struct_info, R.Tensor([], "bool")) + tvm.ir.assert_structural_equal(condition.ty, R.Tensor([], "bool")) def test_prim_value_as_assert_condition(): @@ -1391,7 +1391,7 @@ def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")): assert_op = func.body.blocks[0].bindings[0].value condition = assert_op.args[0] assert isinstance(condition, relax.Var) - tvm.ir.assert_structural_equal(condition.struct_info, R.Prim("bool")) + tvm.ir.assert_structural_equal(condition.ty, R.Prim("bool")) def test_computed_prim_value_as_assert_condition(): @@ -1401,15 +1401,15 @@ def test_computed_prim_value_as_assert_condition(): def func(x: R.Tensor(["N"], "float32")): N = T.int64() _ = R.assert_op(R.prim_value(N % 16 == 0)) - out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_packed("fast_vectorized_impl", x, ty_args=[x.ty]) return out - N = func.params[0].struct_info.shape[0] + N = func.params[0].ty.shape[0] assert_op = func.body.blocks[0].bindings[0].value condition = assert_op.args[0] assert isinstance(condition, relax.PrimValue) tvm.ir.assert_structural_equal(N % 16 == 0, condition.value) - tvm.ir.assert_structural_equal(condition.struct_info, R.Prim(value=N % 16 == 0)) + tvm.ir.assert_structural_equal(condition.ty, R.Prim(value=N % 16 == 0)) def test_tir_expr_as_assert_condition(): @@ -1419,14 +1419,14 @@ def test_tir_expr_as_assert_condition(): def sugared(x: R.Tensor(["N"], "float32")): N = T.int64() _ = R.assert_op(N % 16 == 0) - out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_packed("fast_vectorized_impl", x, ty_args=[x.ty]) return out @R.function(pure=False, private=True) def unsugared(x: R.Tensor(["N"], "float32")): N = T.int64() _ = R.assert_op(R.prim_value(N % 16 == 0)) - out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + out = R.call_packed("fast_vectorized_impl", x, ty_args=[x.ty]) return out tvm.ir.assert_structural_equal(unsugared, sugared) @@ -1441,8 +1441,8 @@ def foo(x: R.Tensor): w = z return w - tvm.ir.assert_structural_equal(foo.ret_struct_info, R.Tensor(ndim=2)) - assert foo.ret_struct_info.shape is None + tvm.ir.assert_structural_equal(foo.ret_ty, R.Tensor(ndim=2)) + assert foo.ret_ty.shape is None _check(foo) @@ -1455,7 +1455,7 @@ def foo(x: R.Tensor(["m", "n"])): w = z return w - assert foo.ret_struct_info.shape is not None + assert foo.ret_ty.shape is not None _check(foo) @@ -1468,7 +1468,7 @@ def foo(x: R.Tensor, _: R.Shape(["m", "n"])): w = z return w - assert foo.ret_struct_info.shape is not None + assert foo.ret_ty.shape is not None _check(foo) @@ -1481,7 +1481,7 @@ def foo(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")): w = z return w - assert foo.ret_struct_info.shape is not None + assert foo.ret_ty.shape is not None _check(foo) @@ -1506,7 +1506,7 @@ def main(x: R.Tensor, shape: R.Shape(["m", "n"])): output = Module.subroutine(x, shape) return output - assert Module["main"].ret_struct_info.shape is not None + assert Module["main"].ret_ty.shape is not None _check(Module) @@ -1533,7 +1533,7 @@ def main(x: R.Tensor, relax_m: R.Prim(value="m"), relax_n: R.Prim(value="n")): output = Module.subroutine(x, relax_m, relax_n) return output - assert Module["main"].ret_struct_info.shape is not None + assert Module["main"].ret_ty.shape is not None _check(Module) @@ -1543,7 +1543,7 @@ def foo(x: R.Tuple()): y: R.Tuple() = R.tuple() return y - x = relax.Var("x", relax.TupleStructInfo([])) + x = relax.Var("x", relax.TupleType([])) bb = relax.BlockBuilder() with bb.function("foo", (x,)): y = bb.emit(relax.Tuple([])) @@ -1561,8 +1561,8 @@ def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): return z m = tirx.Var("m", "int64") - x = relax.Var("x", relax.TensorStructInfo([m + 1], "float32")) - y = relax.Var("y", relax.TensorStructInfo([m, 1], "float32")) + x = relax.Var("x", relax.TensorType([m + 1], "float32")) + y = relax.Var("y", relax.TensorType([m, 1], "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x, y)): z = bb.emit(relax.op.add(x, y)) @@ -1583,8 +1583,8 @@ def bar(x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32" return z m = tirx.Var("m", "int64") - x = relax.Var("x", relax.TensorStructInfo([m], "float32")) - y = relax.Var("y", relax.TensorStructInfo([tirx.max(m, 20)], "float32")) + x = relax.Var("x", relax.TensorType([m], "float32")) + y = relax.Var("y", relax.TensorType([tirx.max(m, 20)], "float32")) bb = relax.BlockBuilder() with bb.function("bar", (x, y)): z = bb.emit( @@ -1607,8 +1607,8 @@ def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): return z m = tirx.Var("m", "int64") - x = relax.Var("x", relax.ShapeStructInfo([m])) - y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) + x = relax.Var("x", relax.ShapeType([m])) + y = relax.Var("y", relax.TensorType([m * 2], "float32")) bb = relax.BlockBuilder() with bb.function("baz", (x, y)): z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) @@ -1627,8 +1627,8 @@ def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")): return z m = tirx.Var("m", "int64") - x = relax.Var("x", relax.PrimStructInfo(value=m)) - y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) + x = relax.Var("x", relax.PrimType(value=m)) + y = relax.Var("y", relax.TensorType([m * 2], "float32")) bb = relax.BlockBuilder() with bb.function("baz", (x, y)): z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) @@ -1678,8 +1678,8 @@ def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): m = tirx.Var("m", "int64") n = tirx.Var("n", "int64") - x = relax.Var("x", relax.TensorStructInfo([m, n], "float32")) - y = relax.Var("y", relax.TensorStructInfo([m, n], "float32")) + x = relax.Var("x", relax.TensorType([m, n], "float32")) + y = relax.Var("y", relax.TensorType([m, n], "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x, y)): a0 = bb.emit(relax.op.negative(x)) @@ -1748,7 +1748,7 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")): def test_prim_value(): @R.function(pure=False) def foo(): - gv = R.call_packed("test", 1, sinfo_args=R.Tensor((32, 32), "float32")) + gv = R.call_packed("test", 1, ty_args=R.Tensor((32, 32), "float32")) return gv _check(foo) @@ -1757,7 +1757,7 @@ def foo(): def test_string_imm(): @R.function(pure=False) def foo(): - gv = R.call_packed("test", "hello", sinfo_args=R.Tensor((32, 32), "float32")) + gv = R.call_packed("test", "hello", ty_args=R.Tensor((32, 32), "float32")) return gv _check(foo) @@ -1766,7 +1766,7 @@ def foo(): def test_datatype_imm(): @R.function(pure=False) def foo(): - gv = R.call_packed("test", R.dtype("float32"), sinfo_args=R.Tensor((32, 32), "float32")) + gv = R.call_packed("test", R.dtype("float32"), ty_args=R.Tensor((32, 32), "float32")) return gv _check(foo) @@ -1788,8 +1788,8 @@ def mul(x: R.Tensor((3, 3), dtype="float32")): _check(Foo) # Since the return type of function `mul` is not annotated, # the function `main` regards it as a generic return type. - assert isinstance(Foo["main"].ret_struct_info, relax.ObjectStructInfo) - assert isinstance(Foo["mul"].ret_struct_info, relax.TensorStructInfo) + assert isinstance(Foo["main"].ret_ty, relax.ObjectType) + assert isinstance(Foo["mul"].ret_ty, relax.TensorType) @tvm.script.ir_module class Bar: @@ -1806,8 +1806,8 @@ def mul(x: R.Tensor((3, 3), dtype="float32")) -> None: # Since the return type of function `mul` is not annotated, # the function `main` regards it as a generic return type. _check(Bar) - tvm.ir.assert_structural_equal(Bar["main"].ret_struct_info, relax.TupleStructInfo([])) - tvm.ir.assert_structural_equal(Bar["mul"].ret_struct_info, relax.TupleStructInfo([])) + tvm.ir.assert_structural_equal(Bar["main"].ret_ty, relax.TupleType([])) + tvm.ir.assert_structural_equal(Bar["mul"].ret_ty, relax.TupleType([])) def test_class_normalize(): @@ -1881,7 +1881,7 @@ def main(input: R.Tensor((5, 5))) -> R.Tuple(): _check(Module) -def test_global_var_sinfo(): +def test_global_var_ty(): @I.ir_module(s_tir=True) class Module: @R.function @@ -1889,12 +1889,12 @@ def foo(x: R.Tensor((128, 128), "float32")): gv0 = R.emit_te(topi.add, x, x) return gv0 - target_sinfo = R.Callable( + target_ty = R.Callable( (R.Tensor((128, 128), dtype="float32"),), R.Tensor((128, 128), dtype="float32") ) gv = Module.get_global_var("foo") - tvm.ir.assert_structural_equal(gv.struct_info, target_sinfo) - tvm.ir.assert_structural_equal(Module["foo"].struct_info, target_sinfo) + tvm.ir.assert_structural_equal(gv.ty, target_ty) + tvm.ir.assert_structural_equal(Module["foo"].ty, target_ty) _check(Module) @@ -2068,14 +2068,14 @@ def conditional(x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")) -> R. def test_call_pure_packed(): @R.function def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: - z = R.call_pure_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) + z = R.call_pure_packed("vm.builtin.copy", x, ty_args=R.Tensor((32, 32), "float32")) return z x = relax.Var("x", R.Tensor((32, 32), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x)): z = bb.emit( - R.call_pure_packed("vm.builtin.copy", x, sinfo_args=[R.Tensor((32, 32), "float32")]) + R.call_pure_packed("vm.builtin.copy", x, ty_args=[R.Tensor((32, 32), "float32")]) ) bb.emit_func_output(z) @@ -2085,12 +2085,12 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: def test_call_pure_packed_returning_object(): @R.function def foo() -> R.Object: - z = R.call_pure_packed("dummy_func", sinfo_args=R.Object) + z = R.call_pure_packed("dummy_func", ty_args=R.Object) return z bb = relax.BlockBuilder() with bb.function("foo", params=[]): - z = bb.emit(R.call_pure_packed("dummy_func", sinfo_args=[relax.ObjectStructInfo()])) + z = bb.emit(R.call_pure_packed("dummy_func", ty_args=[relax.ObjectType()])) bb.emit_func_output(z) _check(foo, bb.get()["foo"]) @@ -2236,8 +2236,8 @@ def parsed(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32" bb = relax.BlockBuilder() with bb.function("main", [x], private=True): func = bb.emit(relax.ExternFunc("extern_func")) - y = bb.emit(relax.call_dps_packed(func, x, out_sinfo=R.Tensor((128, 128), "float32"))) - z = bb.emit(relax.call_dps_packed(func, y, out_sinfo=R.Tensor((128, 128), "float32"))) + y = bb.emit(relax.call_dps_packed(func, x, out_ty=R.Tensor((128, 128), "float32"))) + z = bb.emit(relax.call_dps_packed(func, y, out_ty=R.Tensor((128, 128), "float32"))) bb.emit_func_output(z) expected = bb.get()["main"] @@ -2334,7 +2334,7 @@ def test_function_symbolic_variables_are_annotated(): """ @R.function(private=True) - def inferred_sinfo(A: R.Tensor(["extent"])): + def inferred_ty(A: R.Tensor(["extent"])): extent = T.int64() output = R.strided_slice(A, [0], [0], [extent - 1]) return output @@ -2345,7 +2345,7 @@ def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent-1"]): output: R.Tensor([extent - 1]) = R.strided_slice(A, [0], [0], [extent - 1]) return output - tvm.ir.assert_structural_equal(inferred_sinfo, expected) + tvm.ir.assert_structural_equal(inferred_ty, expected) def test_conditional_may_use_symbolic_variables_from_function_scope(): @@ -2362,7 +2362,7 @@ def test_conditional_may_use_symbolic_variables_from_function_scope(): """ @R.function(private=True) - def explicit_sinfo( + def explicit_ty( A: R.Tensor(["N"], "float32"), B: R.Tensor(["N"], "float32"), cond: R.Prim("bool"), @@ -2377,7 +2377,7 @@ def explicit_sinfo( return out @R.function(private=True) - def inferred_sinfo( + def inferred_ty( A: R.Tensor(["N"], "float32"), B: R.Tensor(["N"], "float32"), cond: R.Prim("bool"), @@ -2390,7 +2390,7 @@ def inferred_sinfo( return out - tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + tvm.ir.assert_structural_equal(explicit_ty, inferred_ty) def test_return_from_dataflow_block(): diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 425426a6b1da..012aac8c5567 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -99,14 +99,12 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): ) -def test_extern_func_with_struct_info(): +def test_extern_func_with_ty(): obj = IRModule( { "my_ext": relax.ExternFunc( "my_ext", - relax.FuncStructInfo( - [], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True - ), + relax.FuncType([], relax.TensorType(dtype="float32", ndim=2), purity=True), ), } ) @@ -123,14 +121,12 @@ class Module: ) -def test_extern_func_with_struct_info_roundtrip(): +def test_extern_func_with_ty_roundtrip(): mod = IRModule( { "my_ext": relax.ExternFunc( "my_ext", - relax.FuncStructInfo( - [], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True - ), + relax.FuncType([], relax.TensorType(dtype="float32", ndim=2), purity=True), ), } ) @@ -172,31 +168,31 @@ def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): ) -def test_object_struct_info(): - obj = relax.ObjectStructInfo() +def test_object_ty(): + obj = relax.ObjectType() _assert_print( obj, "R.Object", ) -def test_prim_struct_info(): - obj = relax.PrimStructInfo("float32") +def test_prim_ty(): + obj = relax.PrimType("float32") _assert_print(obj, 'R.Prim("float32")') -def test_shape_struct_info_0(): - obj = relax.ShapeStructInfo(ndim=-1) +def test_shape_ty_0(): + obj = relax.ShapeType(ndim=-1) _assert_print(obj, "R.Shape(ndim=-1)") -def test_shape_struct_info_1(): - obj = relax.ShapeStructInfo([1, 2, 3]) +def test_shape_ty_1(): + obj = relax.ShapeType([1, 2, 3]) _assert_print(obj, "R.Shape([1, 2, 3])") -def test_shape_struct_info_2(): - obj = relax.ShapeStructInfo([1, tirx.Var("a", "int64"), 3]) +def test_shape_ty_2(): + obj = relax.ShapeType([1, tirx.Var("a", "int64"), 3]) _assert_print( obj, """ @@ -205,8 +201,8 @@ def test_shape_struct_info_2(): ) -def test_tensor_struct_info(): - obj = relax.TensorStructInfo( +def test_tensor_ty(): + obj = relax.TensorType( shape=relax.ShapeExpr([1, tirx.Var("a", "int64"), 3]), dtype="float32", ) @@ -219,37 +215,36 @@ def test_tensor_struct_info(): ) -def test_tuple_struct_info_empty(): - obj = relax.TupleStructInfo([]) - _assert_print(obj, "R.Tuple") +def test_tuple_ty_empty(): + obj = relax.TupleType([]) + _assert_print(obj._relax_script(), "R.Tuple") # pylint: disable=protected-access -def test_tuple_struct_info(): - obj = relax.TupleStructInfo( +def test_tuple_ty(): + obj = relax.TupleType( [ - relax.PrimStructInfo("float32"), - relax.ObjectStructInfo(), - relax.ShapeStructInfo([1, tirx.Var("a", "int64"), 3]), + relax.PrimType("float32"), + relax.ObjectType(), + relax.ShapeType([1, tirx.Var("a", "int64"), 3]), ] ) _assert_print( - obj, + obj._relax_script(), # pylint: disable=protected-access """ -a = T.int64() R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3])) """, ) -def test_func_struct_info(): - obj = relax.FuncStructInfo( +def test_func_ty(): + obj = relax.FuncType( params=[ - relax.PrimStructInfo("float32"), - relax.ObjectStructInfo(), - relax.ShapeStructInfo([1, tirx.Var("a", "int64"), 3]), - relax.PrimStructInfo(value=tirx.Var("b", "int64")), + relax.PrimType("float32"), + relax.ObjectType(), + relax.ShapeType([1, tirx.Var("a", "int64"), 3]), + relax.PrimType(value=tirx.Var("b", "int64")), ], - ret=relax.TensorStructInfo( + ret=relax.TensorType( shape=relax.ShapeExpr([1, 2, 3]), dtype="float32", ), @@ -275,7 +270,7 @@ def test_object_type(): def test_dyn_tensor_type(): obj = relax.TensorType() - _assert_print(obj, 'R.Tensor(ndim=-1, dtype="float32")') + _assert_print(obj, 'R.Tensor(dtype="float32")') def test_packed_func_type(): @@ -293,18 +288,18 @@ def test_tuple_type(): def test_func_type(): obj = relax.FuncType( - arg_types=[ + params=[ relax.ObjectType(), relax.ShapeType(ndim=3), ], - ret_type=relax.TensorType( + ret=relax.TensorType( ndim=3, dtype="float32", ), ) _assert_print( obj._relax_script(), # pylint: disable=protected-access - 'R.Callable((R.Object, R.Shape(ndim=3)), R.Tensor(ndim=3, dtype="float32"))', + 'R.Callable((R.Object, R.Shape(ndim=3)), R.Tensor(dtype="float32", ndim=3), True)', ) @@ -324,7 +319,7 @@ def test_data_type_imm(): def test_var(): - obj = relax.Var("a", relax.TensorStructInfo([1, tirx.Var("x", "int64"), 3], "float32")) + obj = relax.Var("a", relax.TensorType([1, tirx.Var("x", "int64"), 3], "float32")) _assert_print( obj, """ @@ -335,7 +330,7 @@ def test_var(): def test_dataflow_var(): - obj = relax.DataflowVar("a", relax.TensorStructInfo([1, tirx.Var("x", "int64"), 3], "float32")) + obj = relax.DataflowVar("a", relax.TensorType([1, tirx.Var("x", "int64"), 3], "float32")) _assert_print( obj, """ @@ -348,9 +343,9 @@ def test_dataflow_var(): def test_tuple(): obj = relax.Tuple( [ - relax.Var("a", relax.TensorStructInfo([1, tirx.Var("x", "int64"), 3], "float32")), - relax.Var("b", relax.TensorStructInfo([1, tirx.Var("y", "int64"), 3], "float32")), - relax.Var("c", relax.TensorStructInfo([1, tirx.Var("z", "int64"), 3], "float32")), + relax.Var("a", relax.TensorType([1, tirx.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorType([1, tirx.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorType([1, tirx.Var("z", "int64"), 3], "float32")), ] ) _assert_print( @@ -371,9 +366,9 @@ def test_tuple_get_item(): obj = relax.TupleGetItem( relax.Tuple( [ - relax.Var("a", relax.TensorStructInfo([1, tirx.Var("x", "int64"), 3], "float32")), - relax.Var("b", relax.TensorStructInfo([1, tirx.Var("y", "int64"), 3], "float32")), - relax.Var("c", relax.TensorStructInfo([1, tirx.Var("z", "int64"), 3], "float32")), + relax.Var("a", relax.TensorType([1, tirx.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorType([1, tirx.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorType([1, tirx.Var("z", "int64"), 3], "float32")), ] ), 0, @@ -399,15 +394,15 @@ def test_shape_expr(): def test_call(): x = tirx.Var("x", "int64") - a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) - o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a, out_sinfo=a.struct_info, tir_vars=[x]) - o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info) + a = relax.Var("a", relax.TensorType([1, x, 3], "float32")) + o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a, out_ty=a.ty, tir_vars=[x]) + o1 = relax.call_dps_packed("my_dps_func", args=a, out_ty=a.ty) _assert_print( o0, """ x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") -R.call_tir(tir_func, (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) +R.call_tir(tir_func, (a,), out_ty=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) """, ) _assert_print( @@ -415,7 +410,7 @@ def test_call(): """ x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") -R.call_dps_packed("my_dps_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32")) +R.call_dps_packed("my_dps_func", (a,), out_ty=R.Tensor((1, x, 3), dtype="float32")) """, ) @@ -436,7 +431,7 @@ def test_call_tir_with_grad(): """ v0: R.Tensor((54, 96), dtype="float32") x = T.int64() -R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": 1.0, "x": x}) +R.call_tir_with_grad(tir_func, (v0,), out_ty=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": 1.0, "x": x}) """, ) @@ -452,7 +447,7 @@ def test_call_tir_inplace(): y, ), inplace_indices=[-1, 0], - out_sinfo=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32), dtype="int32")], + out_ty=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32), dtype="int32")], tir_vars=[t], ) _assert_print( @@ -461,16 +456,16 @@ def test_call_tir_inplace(): x: R.Tensor((32, 32), dtype="int32") y: R.Tensor((32, 32), dtype="int32") t = T.int64() -R.call_tir_inplace(tir_func, (x, y), out_sinfo=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32), dtype="int32")], inplace_indices=[-1, 0], tir_vars=R.shape([t])) +R.call_tir_inplace(tir_func, (x, y), out_ty=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32), dtype="int32")], inplace_indices=[-1, 0], tir_vars=R.shape([t])) """, ) def test_seq_expr(): x = tirx.Var("x", "int64") - a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) - b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) - c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + a = relax.Var("a", relax.TensorType([1, x, 3], "float32")) + b = relax.DataflowVar("b", relax.TensorType([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorType([1, x, 3], "float32")) obj = relax.SeqExpr( blocks=[ @@ -499,9 +494,9 @@ def test_seq_expr(): def test_binding_block(): x = tirx.Var("x", "int64") - a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) - b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) - c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + a = relax.Var("a", relax.TensorType([1, x, 3], "float32")) + b = relax.Var("b", relax.TensorType([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorType([1, x, 3], "float32")) obj = relax.BindingBlock( bindings=[ relax.VarBinding(b, relax.op.sin(a)), @@ -521,9 +516,9 @@ def test_binding_block(): def test_dataflow_block(): x = tirx.Var("x", "int64") - a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) - b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) - c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + a = relax.Var("a", relax.TensorType([1, x, 3], "float32")) + b = relax.DataflowVar("b", relax.TensorType([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorType([1, x, 3], "float32")) obj = relax.DataflowBlock( bindings=[ relax.VarBinding(b, relax.op.sin(a)), @@ -545,12 +540,12 @@ def test_dataflow_block(): def test_match_cast(): x = tirx.Var("x", "int64") - a = relax.Var("a", relax.TensorStructInfo([1, x, 3])) - b = relax.Var("b", relax.TensorStructInfo([1, 5, 3])) + a = relax.Var("a", relax.TensorType([1, x, 3])) + b = relax.Var("b", relax.TensorType([1, 5, 3])) obj = relax.MatchCast( var=b, value=a, - struct_info=b.struct_info, + ty=b.ty, ) _assert_print( obj, @@ -564,8 +559,8 @@ def test_match_cast(): def test_var_binding(): x = tirx.Var("x", "int64") - a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) - b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) + a = relax.Var("a", relax.TensorType([1, x, 3], "float32")) + b = relax.Var("b", relax.TensorType([1, x, 3], "float32")) obj = relax.VarBinding(b, relax.op.sin(a)) _assert_print( obj, @@ -578,9 +573,9 @@ def test_var_binding(): def test_if(): - a = relax.Var("a", relax.TensorStructInfo([], "bool")) - b = relax.Var("b", relax.TensorStructInfo([1, 2, 3], "float32")) - c = relax.Var("c", relax.TensorStructInfo([1, 2, 3], "float32")) + a = relax.Var("a", relax.TensorType([], "bool")) + b = relax.Var("b", relax.TensorType([1, 2, 3], "float32")) + c = relax.Var("c", relax.TensorType([1, 2, 3], "float32")) obj = relax.If( a, relax.SeqExpr([], b), @@ -602,8 +597,8 @@ def test_if(): def test_builtin_keywords(): x = tirx.Var("x", "int64") - a = relax.Var("R", relax.TensorStructInfo([1, x, 3], "float32")) - b = relax.Var("T", relax.TensorStructInfo([1, x, 3], "float32")) + a = relax.Var("R", relax.TensorType([1, x, 3], "float32")) + b = relax.Var("T", relax.TensorType([1, x, 3], "float32")) obj = relax.VarBinding(b, relax.op.sin(a)) _assert_print( obj, @@ -648,7 +643,7 @@ def tir_func(x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128), @R.function def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32"): cls = Module - gv0 = R.call_tir(cls.tir_func, (x,), out_sinfo=R.Tensor((128,), dtype="float32")) + gv0 = R.call_tir(cls.tir_func, (x,), out_ty=R.Tensor((128,), dtype="float32")) return gv0 """, ) @@ -671,7 +666,7 @@ def tir_func(x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128), @R.function def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32"): - gv0 = R.call_tir(Module.tir_func, (x,), out_sinfo=R.Tensor((128,), dtype="float32")) + gv0 = R.call_tir(Module.tir_func, (x,), out_ty=R.Tensor((128,), dtype="float32")) return gv0 """, ) @@ -827,8 +822,8 @@ def test_reused_extern_func(): @R.function def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"): extern_func = R.ExternFunc("extern_func") - y = R.call_dps_packed(extern_func, (x,), out_sinfo=R.Tensor((128, 128), dtype="float32")) - z = R.call_dps_packed(extern_func, (y,), out_sinfo=R.Tensor((128, 128), dtype="float32")) + y = R.call_dps_packed(extern_func, (x,), out_ty=R.Tensor((128, 128), dtype="float32")) + z = R.call_dps_packed(extern_func, (y,), out_ty=R.Tensor((128, 128), dtype="float32")) return z _assert_print( @@ -839,8 +834,8 @@ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype @R.function def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"): extern_func: R.Callable = R.ExternFunc("extern_func") - y = R.call_dps_packed(extern_func, (x,), out_sinfo=R.Tensor((128, 128), dtype="float32")) - z = R.call_dps_packed(extern_func, (y,), out_sinfo=R.Tensor((128, 128), dtype="float32")) + y = R.call_dps_packed(extern_func, (x,), out_ty=R.Tensor((128, 128), dtype="float32")) + z = R.call_dps_packed(extern_func, (y,), out_ty=R.Tensor((128, 128), dtype="float32")) return z """, ) @@ -852,10 +847,10 @@ def test_inline_extern_func(): @R.function def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"): y = R.call_dps_packed( - R.ExternFunc("extern_func"), (x,), out_sinfo=R.Tensor((128, 128), dtype="float32") + R.ExternFunc("extern_func"), (x,), out_ty=R.Tensor((128, 128), dtype="float32") ) z = R.call_dps_packed( - R.ExternFunc("extern_func"), (y,), out_sinfo=R.Tensor((128, 128), dtype="float32") + R.ExternFunc("extern_func"), (y,), out_ty=R.Tensor((128, 128), dtype="float32") ) return z @@ -866,23 +861,23 @@ def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype @R.function def func(x: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"): - y = R.call_dps_packed("extern_func", (x,), out_sinfo=R.Tensor((128, 128), dtype="float32")) - z = R.call_dps_packed("extern_func", (y,), out_sinfo=R.Tensor((128, 128), dtype="float32")) + y = R.call_dps_packed("extern_func", (x,), out_ty=R.Tensor((128, 128), dtype="float32")) + z = R.call_dps_packed("extern_func", (y,), out_ty=R.Tensor((128, 128), dtype="float32")) return z """, ) -def test_hide_inferable_struct_info(): +def test_hide_inferable_ty(): """Redundant type annotations can be omitted - When `show_all_struct_info=False`, TVMScript type annotations that - provide redundant struct info can be omitted. + When `show_all_ty=False`, TVMScript type annotations that + provide redundant type can be omitted. """ @R.function def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")): - # R.match_cast has the struct info as an argument, so it can + # R.match_cast has the type as an argument, so it can # be omitted from the variable annotation. B2 = R.match_cast(B, R.Tensor([10, 20], "float32")) @@ -893,8 +888,8 @@ def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")) # info as the RHS. D = C - # Here, the struct info cannot be omitted. `R.add(D,B)` has - # struct info `R.Tensor(ndim=2)`, but the variable has a shape + # Here, the type cannot be omitted. `R.add(D,B)` has + # type `R.Tensor(ndim=2)`, but the variable has a shape # `R.Tensor([10,20])`. This is compatible, so it is not an # error to have this annotation, but it is not inferrable from # the RHS. Therefore, it must still be printed. @@ -908,7 +903,7 @@ def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, dtype="float32")) return E _assert_print( - func.script(show_all_struct_info=False), + func.script(show_all_ty=False), """ # from tvm.script import relax as R diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_type.py similarity index 65% rename from tests/python/relax/test_struct_info.py rename to tests/python/relax/test_type.py index dd2e415f9a7e..1679048df4b8 100644 --- a/tests/python/relax/test_struct_info.py +++ b/tests/python/relax/test_type.py @@ -40,15 +40,15 @@ def _check_json_roundtrip(x): return xret -def test_object_struct_info(): - s0 = rx.ObjectStructInfo() - s1 = rx.ObjectStructInfo() +def test_object_ty(): + s0 = rx.ObjectType() + s1 = rx.ObjectType() # can turn into str str(s0) _check_equal(s0, s1) - assert isinstance(s0, rx.ObjectStructInfo) + assert isinstance(s0, rx.ObjectType) _check_json_roundtrip(s0) @@ -61,15 +61,15 @@ def test_shape_type(): def test_dyn_tensor_type(): t0 = rx.TensorType() assert t0.ndim == -1 - t1 = rx.TensorType(3, "int32") + t1 = rx.TensorType(ndim=3, dtype="int32") assert t1.ndim == 3 assert t1.dtype == "int32" -def test_prim_struct_info(): - s0 = rx.PrimStructInfo("float32") - s1 = rx.PrimStructInfo("float32") - s2 = rx.PrimStructInfo("int32") +def test_prim_ty(): + s0 = rx.PrimType("float32") + s1 = rx.PrimType("float32") + s2 = rx.PrimType("int32") _check_equal(s0, s1) @@ -79,7 +79,7 @@ def test_prim_struct_info(): assert s0 == s1 assert s0 != s2 - assert isinstance(s0, rx.PrimStructInfo) + assert isinstance(s0, rx.PrimType) _check_json_roundtrip(s0) _check_json_roundtrip(s1) @@ -88,30 +88,30 @@ def test_prim_struct_info(): # wrong API constructors with pytest.raises((RuntimeError, TypeError)): - rx.PrimStructInfo([1]) + rx.PrimType([1]) -def test_prim_struct_info_with_expr(): +def test_prim_ty_with_expr(): n = tirx.Var("n", "int64") - sinfo = rx.PrimStructInfo(value=n + 1) + ty = rx.PrimType(value=n + 1) - _check_equal(sinfo, rx.PrimStructInfo(value=n + 1)) - assert not tvm_ffi.structural_equal(sinfo, rx.PrimStructInfo(dtype=n.dtype)) + _check_equal(ty, rx.PrimType(value=n + 1)) + assert not tvm_ffi.structural_equal(ty, rx.PrimType(dtype=n.dtype)) # can turn into str - str(sinfo) + str(ty) - assert isinstance(sinfo, rx.PrimStructInfo) - _check_json_roundtrip(sinfo) + assert isinstance(ty, rx.PrimType) + _check_json_roundtrip(ty) - assert sinfo.dtype == "int64" + assert ty.dtype == "int64" -def test_shape_struct_info(): +def test_shape_ty(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - s0 = rx.ShapeStructInfo([1, n + 1, m]) - s1 = rx.ShapeStructInfo([1, n + 1, m]) + s0 = rx.ShapeType([1, n + 1, m]) + s1 = rx.ShapeType([1, n + 1, m]) _check_equal(s0, s1) @@ -121,11 +121,11 @@ def test_shape_struct_info(): assert s0.values[2] == m - assert isinstance(s0, rx.ShapeStructInfo) + assert isinstance(s0, rx.ShapeType) _check_json_roundtrip(s0) _check_json_roundtrip(s1) - s2 = rx.ShapeStructInfo(ndim=2) + s2 = rx.ShapeType(ndim=2) assert s2.ndim == 2 assert s2.values is None @@ -137,22 +137,22 @@ def test_shape_struct_info(): # wrong argument type with pytest.raises((RuntimeError, TypeError)): - rx.ShapeStructInfo(1) + rx.ShapeType(1) # cannot pass both ndim and values with pytest.raises(ValueError): - rx.ShapeStructInfo([1, 2], ndim=3) + rx.ShapeType([1, 2], ndim=3) # cannot pass both ndim and values even if they are consistent with pytest.raises(ValueError): - rx.ShapeStructInfo([1, 2], ndim=2) + rx.ShapeType([1, 2], ndim=2) -def test_tensor_struct_info(): +def test_tensor_ty(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - s0 = rx.TensorStructInfo([1, n + 1, m], "float32") - s1 = rx.TensorStructInfo(rx.ShapeExpr([1, n + 1, m]), "float32") + s0 = rx.TensorType([1, n + 1, m], "float32") + s1 = rx.TensorType(rx.ShapeExpr([1, n + 1, m]), "float32") _check_equal(s0, s1) @@ -160,11 +160,11 @@ def test_tensor_struct_info(): assert s0.ndim == 3 assert s1.ndim == 3 - assert isinstance(s0, rx.TensorStructInfo) + assert isinstance(s0, rx.TensorType) _check_json_roundtrip(s0) _check_json_roundtrip(s1) - s2 = rx.TensorStructInfo(ndim=2, dtype="int32") + s2 = rx.TensorType(ndim=2, dtype="int32") assert s2.ndim == 2 assert s2.dtype == "int32" @@ -173,9 +173,9 @@ def test_tensor_struct_info(): assert s0 != s2 # take in opaque var - rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) + rshape = rx.Var("shape", rx.ShapeType(ndim=2)) - s3 = rx.TensorStructInfo(rshape, dtype="int32") + s3 = rx.TensorType(rshape, dtype="int32") assert s3.dtype == "int32" assert s3.shape == rshape assert s3.ndim == 2 @@ -186,28 +186,30 @@ def test_tensor_struct_info(): # cannot pass both ndim and values with pytest.raises(ValueError): - rx.TensorStructInfo([1, 2], ndim=3) + rx.TensorType([1, 2], ndim=3) # cannot pass both ndim and values even if they are consistent with pytest.raises(ValueError): - rx.TensorStructInfo([1, 2], ndim=2) + rx.TensorType([1, 2], ndim=2) -def test_tuple_struct_info(): +def test_tuple_ty(): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - s0 = rx.TensorStructInfo([1, 2, m + n], "float32") - s1 = rx.ObjectStructInfo() + s0 = rx.TensorType([1, 2, m + n], "float32") + s1 = rx.ObjectType() - t0 = rx.TupleStructInfo([s0, s1]) - t1 = rx.TupleStructInfo([s0, rx.ObjectStructInfo()]) - t2 = rx.TupleStructInfo([s0, s0]) + t0 = rx.TupleType([s0, s1]) + t1 = rx.TupleType([s0, rx.ObjectType()]) + t2 = rx.TupleType([s0, s0]) _check_equal(t0, t1) assert t0 == t1 - assert isinstance(t0, rx.TupleStructInfo) + assert rx.TupleType is tvm.ir.TupleType + assert isinstance(t0, tvm.ir.TupleType) + assert t0.__class__ is tvm.ir.TupleType t0 = _check_json_roundtrip(t0) t1 = _check_json_roundtrip(t1) t2 = _check_json_roundtrip(t2) @@ -217,21 +219,21 @@ def test_tuple_struct_info(): # wrong argument type with pytest.raises(TypeError): - rx.TupleStructInfo(1) + rx.TupleType(1) -def test_func_struct_info(): +def test_func_ty(): def fn_info(c): n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") - x = rx.TensorStructInfo([c, n, m], "float32") - y = rx.TensorStructInfo([c, n, 1], "float32") - z = rx.TensorStructInfo([c, n, m], "float32") - return rx.FuncStructInfo([x, y], z) + x = rx.TensorType([c, n, m], "float32") + y = rx.TensorType([c, n, 1], "float32") + z = rx.TensorType([c, n, m], "float32") + return rx.FuncType([x, y], z) f0 = fn_info(1) f1 = fn_info(1) f2 = fn_info(2) - f3 = rx.FuncStructInfo.opaque_func() + f3 = rx.FuncType.opaque_func() _check_equal(f0, f1) @@ -239,13 +241,13 @@ def fn_info(c): assert f0 != f2 assert len(f0.params) == 2 - assert isinstance(f0.ret, rx.TensorStructInfo) + assert isinstance(f0.ret, rx.TensorType) assert f2.derive_func is None assert f3.params is None assert f3.derive_func is None - _check_equal(f3.ret, rx.ObjectStructInfo()) + _check_equal(f3.ret, rx.ObjectType()) - assert isinstance(f0, rx.FuncStructInfo) + assert isinstance(f0, rx.FuncType) f0 = _check_json_roundtrip(f0) f1 = _check_json_roundtrip(f1) f2 = _check_json_roundtrip(f2) diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 8eb3961ef8dd..c2afcaf21b09 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -53,7 +53,7 @@ def before(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")): assert len(after.params) == len(before.params) for before_var, after_var in zip(before.params, after.params): assert before_var != after_var - assert before_var.struct_info.shape[0] != after_var.struct_info.shape[0] + assert before_var.ty.shape[0] != after_var.ty.shape[0] def test_copy_with_new_vars_on_ir_module(): diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index c85d4d8d744e..d555f0d5a9b3 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -50,7 +50,7 @@ class TestVMCompileStage0: @R.function def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): z = R.call_pure_packed( - "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + "test.vm.identity", x, y, ty_args=(R.Tensor(ndim=2, dtype="float32")) ) return y @@ -72,7 +72,7 @@ class mod: @R.function def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): z = R.call_pure_packed( - "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + "test.vm.identity", x, y, ty_args=(R.Tensor(ndim=2, dtype="float32")) ) return y @@ -719,9 +719,7 @@ def tuple_get_item( t = (x, y) a = t[0] b = t[1] - c = R.call_pure_packed( - "test.vm.add", a, b, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) - ) + c = R.call_pure_packed("test.vm.add", a, b, ty_args=(R.Tensor(ndim=2, dtype="float32"))) return c mod = TestVMTupleGetItem @@ -800,7 +798,7 @@ def relax_matmul_packed( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") ) -> R.Object: gv0 = R.call_pure_packed( - "test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + "test.vm.mul", x, w, ty_args=(R.Tensor(ndim=2, dtype="float32")) ) return gv0 @@ -828,17 +826,17 @@ class TestVMRecursion: @R.function def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: cond = R.call_pure_packed( - "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + "test.vm.equal_zero", n, ty_args=(R.Tensor(ndim=1, dtype="float32")) ) if cond: res = R.const(1.0) else: gv0 = R.call_pure_packed( - "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + "test.vm.subtract_one", n, ty_args=(R.Tensor(ndim=1, dtype="float32")) ) tmp = TestVMRecursion.recursion(gv0) res = R.call_pure_packed( - "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + "test.vm.add", tmp, tmp, ty_args=(R.Tensor(ndim=1, dtype="float32")) ) return res @@ -894,7 +892,7 @@ def test_vm_closure(exec_mode): class TestClosure: @R.function def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2, 3), "float32")): - return R.call_pure_packed("test.vm.add", x, env, sinfo_args=(R.Tensor())) + return R.call_pure_packed("test.vm.add", x, env, ty_args=(R.Tensor())) @R.function def main( @@ -903,7 +901,7 @@ def main( ): cls = TestClosure clo = R.make_closure(cls.lifted_func_1, (x,)) - res = R.invoke_pure_closure(clo, (y,), sinfo_args=(R.Tensor())) + res = R.invoke_pure_closure(clo, (y,), ty_args=(R.Tensor())) return res mod = TestClosure @@ -922,7 +920,7 @@ class TestTimeEvaluator: @R.function def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): return R.call_pure_packed( - "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + "test.vm.add", x, y, ty_args=(R.Tensor(ndim=1, dtype="float32")) ) target = tvm.target.Target("llvm", host="llvm") diff --git a/tests/python/relax/test_vm_builtin.py b/tests/python/relax/test_vm_builtin.py index 9fab28b8655d..f818e0ed5d85 100644 --- a/tests/python/relax/test_vm_builtin.py +++ b/tests/python/relax/test_vm_builtin.py @@ -35,7 +35,7 @@ def foo(x: R.Tensor((3, 5), "float32"), y: R.Tensor((3, 1), "float32")): "vm.builtin.multinomial_from_uniform", x, y, - sinfo_args=(R.Tensor((3, 1), dtype="int64")), + ty_args=(R.Tensor((3, 1), dtype="int64")), ) return z diff --git a/tests/python/relax/test_vm_builtin_lower.py b/tests/python/relax/test_vm_builtin_lower.py index 59ac5c3f12d0..8cdfe2addc24 100644 --- a/tests/python/relax/test_vm_builtin_lower.py +++ b/tests/python/relax/test_vm_builtin_lower.py @@ -36,7 +36,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: storage = R.memory.alloc_storage(R.shape([m * n * 4]), 0, "global", "uint8") alloc = R.memory.alloc_tensor(storage, 0, R.shape([m, n]), "float32") _ = R.call_packed( - "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + "test.op.identity", x, alloc, ty_args=(R.Tensor(ndim=2, dtype="float32")) ) gv0 = alloc return gv0 @@ -53,7 +53,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([m, n]), "float32") _ = R.call_packed( - "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + "test.op.identity", x, alloc, ty_args=(R.Tensor(ndim=2, dtype="float32")) ) gv0 = alloc return gv0 @@ -74,7 +74,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32") _ = R.call_packed( - "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + "test.op.identity", x, alloc, ty_args=(R.Tensor(ndim=2, dtype="float32")) ) gv0 = alloc return gv0 @@ -103,7 +103,7 @@ def main(A: R.Tensor([16], "float32"), shape: R.Shape): "vm.builtin.reshape", A, shape, - sinfo_args=R.Tensor(shape, dtype="float32"), + ty_args=R.Tensor(shape, dtype="float32"), ) return reshape @@ -133,13 +133,13 @@ def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")): shape = R.call_packed( "vm.builtin.tensor_to_shape", shape_tensor, - sinfo_args=R.Shape(ndim=2), + ty_args=R.Shape(ndim=2), ) reshape = R.call_packed( "vm.builtin.reshape", A, shape, - sinfo_args=R.Tensor(shape, dtype="float32"), + ty_args=R.Tensor(shape, dtype="float32"), ) return reshape diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 0585a2a844d4..16221e18ed27 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -49,7 +49,7 @@ class TestVMMove: @R.function(pure=False) def foo(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "foo"}) - z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) + z = R.call_packed("vm.builtin.copy", x, ty_args=(R.Tensor((3, 4), dtype="float32"))) return z mod = TestVMMove @@ -70,7 +70,7 @@ def foo(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "foo"}) # Copy x to the first cpu: device_type=1 and device_id=0. z = R.call_packed( - "vm.builtin.to_device", x, 1, 0, sinfo_args=(R.Tensor((3, 4), dtype="float32")) + "vm.builtin.to_device", x, 1, 0, ty_args=(R.Tensor((3, 4), dtype="float32")) ) return z @@ -116,7 +116,7 @@ class TestVMMove: @R.function(pure=False) def foo(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "foo"}) - z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) + z = R.call_packed("vm.builtin.copy", x, ty_args=(R.Tensor((3, 4), dtype="float32"))) return z mod = TestVMMove @@ -140,9 +140,9 @@ class TestVMCompileIf: def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> R.Tensor: R.func_attr({"global_symbol": "ife"}) if cond: - w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + w = R.call_packed("test.vm.add", x, x, ty_args=(R.Tensor)) else: - w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) + w = R.call_packed("test.vm.mul", x, x, ty_args=(R.Tensor)) return w mod = TestVMCompileIf @@ -193,13 +193,13 @@ def main(x: R.Tensor(ndim=2, dtype="float32")): "test.vm.add", relax.const([1, 2]), relax.const([3, 4]), - sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ty_args=(R.Tensor(ndim=2, dtype="float32")), ) b = R.call_packed( "test.vm.add", a, x, - sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ty_args=(R.Tensor(ndim=2, dtype="float32")), ) return b @@ -230,10 +230,10 @@ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(3)], - sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ty_args=[R.Tensor(ndim=1, dtype="int64")], ) _ = R.call_packed( - "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", ty_args=[R.Tuple()] ) _ = R.call_packed( "vm.builtin.match_shape", @@ -245,7 +245,7 @@ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): MS.STORE_TO_HEAP, sindex["m"], "", - sinfo_args=[R.Tuple()], + ty_args=[R.Tuple()], ) # construct shape value for return s = R.call_packed( @@ -258,7 +258,7 @@ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): sindex["n"], MK.USE_IMM, 2, - sinfo_args=[R.Shape(ndim=3)], + ty_args=[R.Shape(ndim=3)], ) return s @@ -345,7 +345,7 @@ class TestVMBuiltinReshape: def main(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "main"}) y = R.call_packed( - "vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32") + "vm.builtin.reshape", x, R.shape([6, 2]), ty_args=R.Tensor((6, 2), "float32") ) return y diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 0eb7f62a3b22..729d4d3d6562 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -38,7 +38,7 @@ class Before: @R.function(pure=False) def foo(x: R.Tensor): R.func_attr({"global_symbol": "foo"}) - z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + z = R.call_packed("test.vm.add", x, x, ty_args=(R.Tensor)) return z @tvm.script.ir_module @@ -107,9 +107,9 @@ class Before: def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor: R.func_attr({"global_symbol": "ife"}) if cond: - w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + w = R.call_packed("test.vm.add", x, x, ty_args=(R.Tensor)) else: - w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) + w = R.call_packed("test.vm.mul", x, x, ty_args=(R.Tensor)) return w @tvm.script.ir_module @@ -195,7 +195,7 @@ class Before: def main(x: R.Tensor): R.func_attr({"global_symbol": "main"}) y = R.const([1, 2]) - z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor)) + z = R.call_packed("test.vm.add", x, y, ty_args=(R.Tensor)) return z @tvm.script.ir_module diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 53450a6fdf67..349a0a6c635c 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -36,13 +36,13 @@ class Module: def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"): cls = Module R.func_attr({"global_symbol": "main"}) - gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),)) + gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), ty_args=(R.Tuple(R.Object, R.Object),)) storage: R.Object = gv[0] alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) _: R.Tuple = cls.add(x, alloc) storage1: R.Object = gv[1] gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, storage1, storage) - gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, gv1, R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),)) + gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, gv1, R.prim_value(0)), ty_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),)) storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("uint8")) alloc3 = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32")) lv4: R.Tensor((16, 16), dtype="float32") = gv2[0] @@ -150,7 +150,7 @@ def main(A: R.Tensor([16], "float16")): C = R.call_pure_packed( "test_vm_cuda_graph.invalid_impl_for_cudagraph", B, - sinfo_args=R.Tensor([16], "float16"), + ty_args=R.Tensor([16], "float16"), ) D = R.add(C, C) return D diff --git a/tests/python/s_tir/dlight/test_benchmark.py b/tests/python/s_tir/dlight/test_benchmark.py index 32ff97335968..7514d24cfcc7 100644 --- a/tests/python/s_tir/dlight/test_benchmark.py +++ b/tests/python/s_tir/dlight/test_benchmark.py @@ -93,12 +93,12 @@ def test(): R.func_attr({"tir_var_upper_bound": {"n": 2048}}) cls = Module with R.dataflow(): - lv1 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) - lv1_1 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) - lv1_2 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) - lv2 = R.call_tir(cls.full2,(), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) - lv2_1 = R.call_tir(cls.full2,(), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) - lv3 = R.call_tir(cls.matmul1, (lv1, lv2), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) + lv1 = R.call_tir(cls.full1,(), out_ty=R.Tensor((1, 32, 1, n), dtype="float16")) + lv1_1 = R.call_tir(cls.full1,(), out_ty=R.Tensor((1, 32, 1, n), dtype="float16")) + lv1_2 = R.call_tir(cls.full1,(), out_ty=R.Tensor((1, 32, 1, n), dtype="float16")) + lv2 = R.call_tir(cls.full2,(), out_ty=R.Tensor((1, 32, n, 128), dtype="float16")) + lv2_1 = R.call_tir(cls.full2,(), out_ty=R.Tensor((1, 32, n, 128), dtype="float16")) + lv3 = R.call_tir(cls.matmul1, (lv1, lv2), out_ty=R.Tensor((1, 32, 1, 128), dtype="float16")) R.output(lv3) return lv3 diff --git a/tests/python/tirx-base/test_tir_specialize.py b/tests/python/tirx-base/test_tir_specialize.py index 125ede32d6c4..cecaf07ab85a 100644 --- a/tests/python/tirx-base/test_tir_specialize.py +++ b/tests/python/tirx-base/test_tir_specialize.py @@ -332,11 +332,11 @@ def expected(A_data: T.handle("float32")): tvm.ir.assert_structural_equal(expected, after) -def test_specialization_updates_struct_info(): - """Update struct info in specialization +def test_specialization_updates_ty(): + """Update type in specialization - A PrimFunc may have a `relax.StructInfo`. If that PrimFunc is - specialized, the struct info should be updated. + A PrimFunc may have a `relax.Type`. If that PrimFunc is + specialized, the type should be updated. """ @T.prim_func(private=True, s_tir=True) @@ -347,20 +347,18 @@ def before(n: T.int32) -> T.int32: def expected() -> T.int32: T.ret(50) - sinfo_before = tvm.relax.FuncStructInfo( - [tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32") - ) - tvm.ir.assert_structural_equal(before.struct_info, sinfo_before) + ty_before = tvm.relax.FuncType([tvm.relax.PrimType("int32")], tvm.relax.PrimType("int32")) + tvm.ir.assert_structural_equal(before.ty, ty_before) - sinfo_expected = tvm.relax.FuncStructInfo([], tvm.relax.PrimStructInfo("int32")) - tvm.ir.assert_structural_equal(expected.struct_info, sinfo_expected) + ty_expected = tvm.relax.FuncType([], tvm.relax.PrimType("int32")) + tvm.ir.assert_structural_equal(expected.ty, ty_expected) n = before.params[0] param_map = {n: 5} after = before.specialize(param_map) tvm.ir.assert_structural_equal(after, expected) - tvm.ir.assert_structural_equal(after.struct_info, sinfo_expected) + tvm.ir.assert_structural_equal(after.ty, ty_expected) if __name__ == "__main__": diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index a26f74b4af4d..5972decaa5fd 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -404,43 +404,43 @@ def func(A: T.Buffer((128, 128)), B: T.Buffer((128, 128))): assert loop_j.thread_binding.var.dtype == "int32" -def test_inferred_sinfo_with_prim_args(): - """A PrimFunc may have inferred StructInfo""" +def test_inferred_ty_with_prim_args(): + """A PrimFunc may have inferred Type""" @T.prim_func(s_tir=True) def func(M: T.int32, N: T.int32) -> T.int32: T.ret(M * N) - expected = tvm.relax.FuncStructInfo( + expected = tvm.relax.FuncType( [ - tvm.relax.PrimStructInfo("int32"), - tvm.relax.PrimStructInfo("int32"), + tvm.relax.PrimType("int32"), + tvm.relax.PrimType("int32"), ], - tvm.relax.PrimStructInfo("int32"), + tvm.relax.PrimType("int32"), purity=True, ) - tvm.ir.assert_structural_equal(func.struct_info, expected) + tvm.ir.assert_structural_equal(func.ty, expected) -def test_inferred_sinfo_with_buffer_args(): +def test_inferred_ty_with_buffer_args(): """PrimFunc buffer arguments are inferred as R.Tensor""" @T.prim_func(s_tir=True) def func(A: T.Buffer([16, 16], "float32"), B: T.Buffer([256], "int32")) -> T.float32: T.ret(T.float32(42.0)) - expected = tvm.relax.FuncStructInfo( + expected = tvm.relax.FuncType( [ - tvm.relax.TensorStructInfo([16, 16], "float32"), - tvm.relax.TensorStructInfo([256], "int32"), + tvm.relax.TensorType([16, 16], "float32"), + tvm.relax.TensorType([256], "int32"), ], - tvm.relax.PrimStructInfo("float32"), + tvm.relax.PrimType("float32"), purity=True, ) - tvm.ir.assert_structural_equal(func.struct_info, expected) + tvm.ir.assert_structural_equal(func.ty, expected) -def test_inferred_sinfo_with_internal_allocation(): +def test_inferred_ty_with_internal_allocation(): """A pure function may still write to internal allocations. Whether a function writes to internal allocations is not a visible @@ -456,17 +456,17 @@ def func(A: T.Buffer([16, 16], "float32")) -> T.float32: T.ret(Sum[()]) - expected = tvm.relax.FuncStructInfo( + expected = tvm.relax.FuncType( [ - tvm.relax.TensorStructInfo([16, 16], "float32"), + tvm.relax.TensorType([16, 16], "float32"), ], - tvm.relax.PrimStructInfo("float32"), + tvm.relax.PrimType("float32"), purity=True, ) - tvm.ir.assert_structural_equal(func.struct_info, expected) + tvm.ir.assert_structural_equal(func.ty, expected) -def test_inferred_sinfo_with_output_buffer(): +def test_inferred_ty_with_output_buffer(): """A pure function may not write to an argument buffer If an argument buffer is written to, the function must be impure. @@ -477,19 +477,19 @@ def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): for i in range(16): B[i] = A[i] - expected = tvm.relax.FuncStructInfo( + expected = tvm.relax.FuncType( [ - tvm.relax.TensorStructInfo([16], "float32"), - tvm.relax.TensorStructInfo([16], "float32"), + tvm.relax.TensorType([16], "float32"), + tvm.relax.TensorType([16], "float32"), ], - tvm.relax.TupleStructInfo([]), + tvm.relax.TupleType([]), purity=False, ) - tvm.ir.assert_structural_equal(func.struct_info, expected) + tvm.ir.assert_structural_equal(func.ty, expected) -def test_inferred_sinfo_with_dynamic_buffer(): - """The inferred StructInfo may contain dynamic shapes""" +def test_inferred_ty_with_dynamic_buffer(): + """The inferred Type may contain dynamic shapes""" @T.prim_func(s_tir=True) def func(a_handle: T.handle, b_handle: T.handle): @@ -502,15 +502,15 @@ def func(a_handle: T.handle, b_handle: T.handle): M = tvm.tirx.Var("M", "int64") N = tvm.tirx.Var("N", "int64") - expected = tvm.relax.FuncStructInfo( + expected = tvm.relax.FuncType( [ - tvm.relax.TensorStructInfo([M, N], "float32"), - tvm.relax.TensorStructInfo([M * N], "float32"), + tvm.relax.TensorType([M, N], "float32"), + tvm.relax.TensorType([M * N], "float32"), ], - tvm.relax.TupleStructInfo([]), + tvm.relax.TupleType([]), purity=False, ) - tvm.ir.assert_structural_equal(func.struct_info, expected) + tvm.ir.assert_structural_equal(func.ty, expected) def test_reinterpret_nop(): diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 9dac2b3f0a97..bdcaf668718e 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3201,11 +3201,11 @@ def func(A: R.Tensor([10, 20], "float32")): func = R.ExternFunc("dummy_func") B: R.Tensor([10, 20], "float32") = R.call_dps_packed( - func, [A], out_sinfo=R.Tensor([10, 20], "float32") + func, [A], out_ty=R.Tensor([10, 20], "float32") ) C: R.Tensor(ndim=2, dtype="float32") = R.call_dps_packed( - func, [B], out_sinfo=R.Tensor([10, 20], "float32") + func, [B], out_ty=R.Tensor([10, 20], "float32") ) return C @@ -3213,15 +3213,15 @@ def func(A: R.Tensor([10, 20], "float32")): return func -def relax_match_cast_struct_info_proxy(): - """StructInfoProxy subclasses may be used as expressions +def relax_match_cast_ty_proxy(): + """TypeProxy subclasses may be used as expressions - This is a regression test. The TVMScript parser allows StructInfo + This is a regression test. The TVMScript parser allows Type to be specified using a default-constructible class (e.g. `R.Tensor` or `R.Shape`) rather than an instance of that class (e.g. `R.Tensor()` or `R.Shape()`). In previous - implementations, this was only handled when the `StructInfo` was - used in an annotation context. However, a `StructInfo` may also + implementations, this was only handled when the `Type` was + used in an annotation context. However, a `Type` may also appear as an argument, which is passed to `R.match_cast`. Use of a default-constructible class must be handled in this context as well. @@ -3239,8 +3239,8 @@ def func(A: R.Object): inner.__name__ = subclass.__name__ return inner - # Not all subclasses of StructInfoProxy are default-constructible. - # This list is a subset of `StructInfoProxy.__subclasses__()`, + # Not all subclasses of TypeProxy are default-constructible. + # This list is a subset of `TypeProxy.__subclasses__()`, # excluding `PrimProxy` and `DTensorProxy`. subclasses = [ tvm.script.parser.relax.entry.ObjectProxy, @@ -3363,7 +3363,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): func_with_loop_jumps, func_with_loop_steps, *op_of_literal(), - *relax_match_cast_struct_info_proxy(), + *relax_match_cast_ty_proxy(), relax_symbolic_size_var, relax_float_symbolic_var, ) @@ -3372,10 +3372,10 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): relax_extern_func, ) -show_all_relax_struct_info = tvm.testing.parameter( +show_all_relax_ty = tvm.testing.parameter( by_dict={ - "show_all_struct_info": True, - "hide_inferable_struct_info": False, + "show_all_ty": True, + "hide_inferable_ty": False, } ) @@ -3395,12 +3395,12 @@ def test_roundtrip(ir_generator): tvm.ir.assert_structural_equal(original, after_roundtrip, True) -def test_relax_roundtrip(relax_ir_generator, show_all_relax_struct_info): +def test_relax_roundtrip(relax_ir_generator, show_all_relax_ty): original = relax_ir_generator() after_roundtrip = tvm.script.from_source( original.script( show_meta=True, - show_all_struct_info=show_all_relax_struct_info, + show_all_ty=show_all_relax_ty, ) ) tvm.ir.assert_structural_equal(original, after_roundtrip, True)