Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/backend/metal/runtime/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,26 @@
#include "../../../support/bytes_io.h"
#include "metal_common.h"

#if (defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000) || \
(defined(__IPHONE_OS_VERSION_MAX_ALLOWED) && __IPHONE_OS_VERSION_MAX_ALLOWED >= 260000)
#define TVM_METAL_HAS_MSL_4_0 1
#endif

namespace tvm {
namespace runtime {

/*! \brief Maximum number of GPU supported in MetalModule. */
static constexpr const int kMetalMaxNumDevice = 32;

static bool MetalDeviceSupportsMetal4(id<MTLDevice> device) {
#if defined(TVM_METAL_HAS_MSL_4_0)
if (@available(macOS 26.0, iOS 26.0, *)) {
return [device supportsFamily:MTLGPUFamilyMetal4];
}
#endif
return false;
}

// Module to support thread-safe multi-GPU execution.
// The runtime will contain a per-device module table
// The modules will be lazily loaded
Expand Down Expand Up @@ -123,17 +137,22 @@ int GetPropertyMask() const final {
const ffi::Bytes& source = (*kernel).second;

if (fmt_ == "metal") {
MTLCompileOptions* opts = [MTLCompileOptions alloc];
opts.languageVersion = MTLLanguageVersion2_3;
MTLCompileOptions* opts = [[MTLCompileOptions alloc] init];
MTLLanguageVersion language_version = MTLLanguageVersion2_3;
#if defined(TVM_METAL_HAS_MSL_4_0)
if (MetalDeviceSupportsMetal4(w->devices[device_id])) {
language_version = MTLLanguageVersion4_0;
}
#endif
opts.languageVersion = language_version;
opts.fastMathEnabled = YES;
// opts = nil;
// Per-kernel payload is bytes; treat as UTF-8 MSL source.
std::string source_str(source.data(), source.size());
lib = [w->devices[device_id]
newLibraryWithSource:[NSString stringWithUTF8String:source_str.c_str()]
options:opts
error:&err_msg];
[opts dealloc];
[opts release];
if (lib == nil) {
LOG(FATAL) << "Fail to compile metal source:"
<< [[err_msg localizedDescription] UTF8String];
Expand Down
Loading