Skip to content

refactor(turbomind): consolidate CUDA error handling and add manual stacktracing#4565

Open
lzhangzz wants to merge 1 commit intoInternLM:mainfrom
lzhangzz:trace
Open

refactor(turbomind): consolidate CUDA error handling and add manual stacktracing#4565
lzhangzz wants to merge 1 commit intoInternLM:mainfrom
lzhangzz:trace

Conversation

@lzhangzz
Copy link
Copy Markdown
Collaborator

  • Add TM_CUDA_CHECK / TM_CUDRV_CHECK unified error macros
  • Migrate all FT_CHECK, check_cuda_error, sync_check_cuda_error sites
  • Make all kernel invoke* wrappers return cudaError_t with [[nodiscard]]
  • Add RAII Scope class with scope stack, signal handler, and trace output
  • Strip function return types and params from scope stack traces

example

[TM][FATAL][0429.13:00:23.088713][unified_attention_layer.cc:586] CUDA error: invalid argument
*** stacktrace of thread 0x7f807b7fe640 ***
  [ 0] TM_CUDA_CHECK @ unified_attention_layer.cc:586
  [ 1] UnifiedAttentionLayer::core_attention() @ unified_attention_layer.cc:376
  [ 2] UnifiedAttentionLayer::Forward() @ unified_attention_layer.cc:302
  [ 3] layer_0 @ unified_decoder.cc:191
  [ 4] UnifiedDecoder::Forward() @ unified_decoder.cc:139
  [ 5] LanguageModel::Impl::Forward() @ language_model.cc:406
  [ 6] ModelExecutor::Impl::Run() @ model_executor.cc:59
  [ 7] ModelExecutor::Impl::InternalThreadEntry() @ model_executor.cc:35

…acing

- Add TM_CUDA_CHECK / TM_CUDRV_CHECK unified error macros
- Migrate all FT_CHECK, check_cuda_error, sync_check_cuda_error sites
- Make all kernel invoke* wrappers return cudaError_t with [[nodiscard]]
- Add RAII Scope class with scope stack, signal handler, and trace output
- Strip function return types and params from scope stack traces
@lvhan028
Copy link
Copy Markdown
Collaborator

May resolve the conflict

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors TurboMind’s CUDA/CUdriver error handling into unified TM_CUDA_CHECK / TM_CUDRV_CHECK macros, converts many kernel “invoke*” wrappers to return cudaError_t (often [[nodiscard]]), and adds a lightweight manual scope stack (TM_SCOPE / TM_FUNCTION_SCOPE) to print a readable stack trace on fatal errors/signals.

Changes:

  • Introduce unified CUDA runtime + driver error reporting (TM_CUDA_CHECK, TM_CUDRV_CHECK) and migrate call sites away from check_cuda_error / sync_check_cuda_error / FT_CHECK.
  • Update many kernel launcher APIs to return cudaError_t and update callers to check results via TM_CUDA_CHECK(...).
  • Add scope stack tracing infrastructure (core::Scope, Context::scope_trace()), integrate it with fatal logging/signal handling, and add tests for formatting.

Reviewed changes

Copilot reviewed 103 out of 103 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
src/turbomind/utils/memory_utils.h Change invoke wrapper signature to cudaError_t + [[nodiscard]].
src/turbomind/utils/memory_utils.cu Return cudaGetLastError() and migrate memcpy error checks.
src/turbomind/utils/cuda_utils.h Add unified CUDA/CUdriver check macros + reporting API.
src/turbomind/utils/cuda_utils.cc Implement unified fatal CUDA/CUdriver error reporting using Scope+Logger.
src/turbomind/utils/anomaly_handler.cu Replace FT_CHECK/CUDA checks and adjust invoke wrappers to return cudaError_t.
src/turbomind/turbomind.cc Replace FT_CHECK and migrate device property check.
src/turbomind/python/bind.cpp Migrate CUDA error checks in safe_memcpy.
src/turbomind/models/llama/unified_decoder.cc Add function/layer scopes and migrate CUDA error checks to TM_CUDA_CHECK.
src/turbomind/models/llama/unified_attention_layer.cc Add function scopes and migrate CUDA checks/macros.
src/turbomind/models/llama/moe_ffn_layer.cc Migrate MoE kernel calls to TM_CUDA_CHECK and drop sync checks.
src/turbomind/models/llama/mla_utils.cu Make MLA copy launcher return cudaError_t and check it.
src/turbomind/models/llama/llama_utils.cu Migrate sync/check macros to TM_CUDA_CHECK.
src/turbomind/models/llama/llama_kernels.h Convert multiple llama kernel APIs to [[nodiscard]] cudaError_t.
src/turbomind/models/llama/gated_delta_net_kernels.h Convert gated-delta-net kernel APIs to [[nodiscard]] cudaError_t.
src/turbomind/models/llama/gated_delta_net_kernels.cu Return cudaError_t from launchers and standardize error reporting.
src/turbomind/models/llama/context.h Replace check_cuda_error with TM_CUDA_CHECK.
src/turbomind/models/llama/bench_gated_delta_net.cc Add TM_CUDA_CHECK around kernel launchers.
src/turbomind/models/llama/bench_conv1d_silu.cc Add TM_CUDA_CHECK around kernel launchers.
src/turbomind/models/llama/LlamaWeight.cc Replace CUDA error checks with TM_CUDA_CHECK.
src/turbomind/models/llama/LlamaLinear.cu Migrate workspace alloc/memset and MoE dispatch calls to TM_CUDA_CHECK.
src/turbomind/models/llama/LlamaFfnLayer.cc Remove sync checks after kernel calls (relying on checked launchers).
src/turbomind/models/llama/LlamaDenseWeight.cc Migrate memcpy/memset/transpose calls to TM_CUDA_CHECK.
src/turbomind/models/llama/GatedDeltaNetWeight.cc Migrate transpose/copy checks to TM_CUDA_CHECK.
src/turbomind/models/llama/GatedDeltaNetLayer.cc Add function scope and migrate async stream/event + kernel checks.
src/turbomind/models/llama/Barrier.h Replace FT_CHECK with TM_CHECK and adjust includes.
src/turbomind/models/language_model.cc Wrap invoke calls with TM_CUDA_CHECK and add function scopes.
src/turbomind/kernels/unfused_attention_kernels.h Convert selected unfused attention APIs to [[nodiscard]] cudaError_t.
src/turbomind/kernels/unfused_attention_kernels.cu Return cudaGetLastError() from launchers and migrate checks.
src/turbomind/kernels/stop_criteria_kernels.h Convert stop-criteria APIs to [[nodiscard]] cudaError_t.
src/turbomind/kernels/stop_criteria_kernels.cu Return cudaGetLastError() from stop-criteria kernels.
src/turbomind/kernels/sampling_topp_kernels.h Convert TopP APIs to [[nodiscard]] cudaError_t.
src/turbomind/kernels/sampling_topp_kernels.cu Return cudaGetLastError() and migrate CUB calls to TM_CUDA_CHECK.
src/turbomind/kernels/sampling_topk_kernels.h Convert TopK API to [[nodiscard]] cudaError_t.
src/turbomind/kernels/sampling_topk_kernels.cu Return cudaGetLastError() from TopK kernel launcher.
src/turbomind/kernels/sampling_penalty_kernels.h Convert penalty APIs to [[nodiscard]] cudaError_t.
src/turbomind/kernels/sampling_penalty_kernels.cu Return cudaGetLastError() from penalty launchers.
src/turbomind/kernels/sampling_kernels.h Convert sampling API to [[nodiscard]] cudaError_t.
src/turbomind/kernels/sampling_kernels.cu Return cudaGetLastError() from sampling launcher.
src/turbomind/kernels/norm/rms_norm.h Convert RMSNorm APIs to [[nodiscard]] cudaError_t.
src/turbomind/kernels/norm/rms_norm.cu Return cudaError_t from norm launchers and migrate checks.
src/turbomind/kernels/logprob_kernels.h Convert logprob API to [[nodiscard]] cudaError_t.
src/turbomind/kernels/logprob_kernels.cu Return cudaGetLastError() from logprob kernels.
src/turbomind/kernels/gpt_kernels.h Convert transpose/embedding APIs to [[nodiscard]] cudaError_t.
src/turbomind/kernels/gpt_kernels.cu Return cudaGetLastError() from embedding/transpose launchers.
src/turbomind/kernels/gemm/test/testbed_v3.h Add TM_CUDA_CHECK for transpose/MoE helpers in tests.
src/turbomind/kernels/gemm/test/test_moe_utils.cu Add TM_CUDA_CHECK around MoE helper kernel call.
src/turbomind/kernels/gemm/moe_utils_v2.h Convert MoE dispatch/combine/gate helpers to [[nodiscard]] cudaError_t.
src/turbomind/kernels/gemm/moe_utils_v2.cu Return cudaError_t from MoE helper launchers and adjust dispatch macro usage.
src/turbomind/kernels/gemm/kernel_impl_sm90.h Replace manual CUDA status checks with TM_CUDA_CHECK.
src/turbomind/kernels/gemm/cast.cu Return cudaError_t from cast launcher and check it.
src/turbomind/kernels/decoding_kernels.h Convert decoding helper APIs to [[nodiscard]] cudaError_t.
src/turbomind/kernels/decoding_kernels.cu Return cudaError_t from decoding helper launchers.
src/turbomind/kernels/attention/test_attention.cu Wrap attention-kernel helpers with TM_CUDA_CHECK.
src/turbomind/kernels/attention/registry.cu Replace direct CUDA status comparisons with TM_CUDA_CHECK.
src/turbomind/kernels/attention/reference.h Convert rotary embedding reference invoke API to [[nodiscard]] cudaError_t.
src/turbomind/kernels/attention/reference.cu Return cudaGetLastError() and check invoke calls in reference implementation.
src/turbomind/kernels/attention/reduce.h Convert reduce API to [[nodiscard]] cudaError_t.
src/turbomind/kernels/attention/reduce.cu Return cudaGetLastError() from reduce launcher.
src/turbomind/kernels/attention/kv_cache_utils_v2.h Convert KV cache APIs to [[nodiscard]] cudaError_t and return through helpers.
src/turbomind/kernels/attention/kv_cache_utils_v2.cu Return cudaGetLastError() from KV cache launchers and replace FT_CHECK with TM_UNREACHABLE.
src/turbomind/kernels/attention/kernel_impl.h Change kernel interface to return cudaError_t.
src/turbomind/kernels/attention/kernel.h Change base kernel interface to return cudaError_t.
src/turbomind/kernels/attention/decoding_template.h Replace manual abort-on-error with TM_CUDA_CHECK and check reduce call.
src/turbomind/kernels/attention/decoding.cu Check kernel launch return via TM_CUDA_CHECK.
src/turbomind/kernels/attention/cp_utils.cu Remove sync error checking call.
src/turbomind/kernels/attention/attention_template.h Return cudaError_t from attention invoke and check reduce call.
src/turbomind/kernels/attention/attention.cu Check kernel launch return via TM_CUDA_CHECK.
src/turbomind/kernels/activation_kernels.h Convert activation launcher to [[nodiscard]] cudaError_t.
src/turbomind/kernels/activation_kernels.cu Return cudaGetLastError() from activation launchers.
src/turbomind/generation/utils.h Replace FT_CHECK with TM_CHECK.
src/turbomind/generation/stop_criteria.cc Wrap stop-criteria kernel calls with TM_CUDA_CHECK.
src/turbomind/generation/sampling.cc Wrap sampling kernels with TM_CUDA_CHECK.
src/turbomind/generation/logits_processor.cc Wrap penalty kernels with TM_CUDA_CHECK and remove sync checks.
src/turbomind/generation/generation.cc Wrap cast kernel with TM_CUDA_CHECK and remove sync checks.
src/turbomind/engine/model_executor.cc Add function scope and replace device set check with TM_CUDA_CHECK.
src/turbomind/engine/engine.cc Replace CUDA mem info/device set checks with TM_CUDA_CHECK.
src/turbomind/core/test_scope.cc Add Catch2 tests for scope stack + signature stripping.
src/turbomind/core/tensor.cu Replace CUDA error checks with TM_CUDA_CHECK.
src/turbomind/core/tensor.cc Replace CUDA copy/memset checks with TM_CUDA_CHECK.
src/turbomind/core/stream.h Replace stream/event CUDA checks with TM_CUDA_CHECK.
src/turbomind/core/stream.cc Replace stream wait check with TM_CUDA_CHECK.
src/turbomind/core/scope.h Add Scope API + macros for named/function scopes.
src/turbomind/core/scope.cc Implement Scope push/pop and function signature stripping.
src/turbomind/core/logger.h Route fatal logging through traced fatal implementation.
src/turbomind/core/logger.cc Add traced fatal path + print scope trace on fatal signals.
src/turbomind/core/core.h Export Scope in core umbrella header.
src/turbomind/core/context.h Add scope stack API (push_scope, scope_trace, etc.).
src/turbomind/core/context.cc Implement per-thread scope stack and trace formatting.
src/turbomind/core/check.cc Integrate TM_CHECK/TM_CHECK_NOTNULL with scope tracing and fatal logging.
src/turbomind/core/buffer.cc Replace CUDA memcpy/memset checks with TM_CUDA_CHECK.
src/turbomind/core/allocator.cc Replace CUDA allocator checks with TM_CUDA_CHECK.
src/turbomind/core/CMakeLists.txt Build new scope.cc and add test_scope target.
src/turbomind/comm/test_comm.cu Replace CUDA checks with TM_CUDA_CHECK.
src/turbomind/comm/nccl/nccl.cu Replace CUDA checks + RMSNorm invoke check with TM_CUDA_CHECK.
src/turbomind/comm/cuda_ipc/semaphore.h Replace CUDA alloc/copy/sync/free checks with TM_CUDA_CHECK.
src/turbomind/comm/cuda_ipc/fused_allreduce_ex.cu Replace FT_CHECK with TM_CHECK and drop sync-check call.
src/turbomind/comm/cuda_ipc/fused_allreduce.cu Wrap RMSNorm invoke call with TM_CUDA_CHECK.
src/turbomind/comm/cuda_ipc/cuda_ipc_comm.cu Replace checks with TM_CHECK/TM_CUDA_CHECK/TM_CUDRV_CHECK.
src/turbomind/comm/cuda_ipc/broadcast.cu Replace CUDA memcpy checks with TM_CUDA_CHECK.
src/turbomind/comm/cuda_ipc/bootstrap.h Replace FT_CHECK with TM_CHECK and add check include.
src/turbomind/comm/cuda_ipc/allreduce.cu Replace FT_CHECK with TM_CHECK.
src/turbomind/comm/cuda_ipc/allgather.cu Replace CUDA memcpy checks with TM_CUDA_CHECK.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +136 to +140
return {};
}
std::ostringstream oss;
oss << std::hex << std::this_thread::get_id();
std::string s = fmt::format("*** stacktrace of thread 0x{} ***\n", oss.str());
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context::scope_trace() uses std::ostringstream but this translation unit doesn't include <sstream>, which will fail to compile on standard-conforming toolchains. Add #include <sstream> (or avoid ostringstream by formatting the thread id differently).

Copilot uses AI. Check for mistakes.
Comment on lines +301 to 305
auto trace = Context::scope_trace();
if (!trace.empty()) {
fmt::print(stderr, "{}", trace);
}
AsyncLogWorker::Instance().OnSignal();
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OnFatalSignal is a POSIX signal handler, but it calls Context::scope_trace() and fmt::print, both of which may allocate, lock, or touch non-async-signal-safe state. This can deadlock or crash during a SIGSEGV/SIGABRT path. Consider capturing/printing the trace in the normal fatal path (before raising signals), or in the handler only using async-signal-safe APIs (e.g., write(2, ...)) with a preallocated buffer.

Suggested change
auto trace = Context::scope_trace();
if (!trace.empty()) {
fmt::print(stderr, "{}", trace);
}
AsyncLogWorker::Instance().OnSignal();
#ifndef _WIN32
static constexpr char kFatalSignalMessage[] = "Fatal signal received\n";
(void)::write(STDERR_FILENO, kFatalSignalMessage, sizeof(kFatalSignalMessage) - 1);
#endif

Copilot uses AI. Check for mistakes.
Comment on lines 229 to +233
const dim3 block(std::min(batch_size * end_ids_size, 1024));
const dim3 grid((batch_size * end_ids_size + block.x - 1) / block.x);
batchApplyMinLengthPenalty<<<block, grid, 0, stream>>>(
logits, min_lengths, sequnece_lengths, vocab_size_padded, batch_size, end_ids, end_ids_size);
return cudaGetLastError();
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel launch configuration is swapped: batchApplyMinLengthPenalty<<<block, grid>>> uses block as gridDim and grid as blockDim. This can exceed the max threads-per-block and/or run the wrong number of threads. Swap to <<<grid, block>>> (and consider an early return when batch_size*end_ids_size == 0).

Copilot uses AI. Check for mistakes.
Comment on lines 394 to +397
dim3 block(min(256, size));
dim3 grid(ceil(size / 256.));
plusScalar<<<block, grid, 0, stream>>>(buf, val, size);
return cudaGetLastError();
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel launch configuration is swapped: plusScalar<<<block, grid>>> passes the intended threads-per-block as gridDim and the intended blocks as blockDim. Swap to <<<grid, block>>> (and consider returning cudaSuccess when size == 0 to avoid launching with gridDim=0).

Copilot uses AI. Check for mistakes.
Comment on lines +584 to +587
// TM_CHECK(0);
// TM_CHECK_NOTNULL(nullptr);
TM_CUDA_CHECK(cudaErrorInvalidValue);
// TM_LOG_FATAL("End of the Road");
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hard-coded TM_CUDA_CHECK(cudaErrorInvalidValue); will always fatal at runtime and looks like leftover debug instrumentation. It should be removed (or gated behind a dedicated debug flag/test-only path) before merging.

Copilot uses AI. Check for mistakes.
Comment on lines 286 to 291
g_pinf_val_,
g_ninf_val_,
g_nan_val_);

sync_check_cuda_error();
}
return cudaGetLastError();
}
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invokeCountAndFixAnomaly calls cudaGetLastError() even when g_level is false / no kernel is launched. cudaGetLastError() clears the thread’s last CUDA error and can mask unrelated failures. Only query the last error when a kernel was launched (or use cudaPeekAtLastError()), and return cudaSuccess otherwise.

Copilot uses AI. Check for mistakes.
Comment on lines 306 to 310
batch_size,
fallback_);

sync_check_cuda_error();
}
return cudaGetLastError();
}
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invokeFixLogitsAnomaly calls cudaGetLastError() even when g_level is false / no kernel is launched. This clears the last CUDA error and can hide failures from earlier CUDA calls. Only query/return the last error when the kernel is actually launched (or use cudaPeekAtLastError()), and return cudaSuccess otherwise.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants