[CMSIS-NN] Fix stateful execution and batch-major striding for CMSIS-NN LSTM#3564
[CMSIS-NN] Fix stateful execution and batch-major striding for CMSIS-NN LSTM#3564veblush wants to merge 1 commit into
Conversation
This PR fixes two critical issues in `arm_lstm_unidirectional_s8` and `s16` that prevent state persistence in streaming models and cause out-of-bounds reads during non-time-major inference. These issues are closely related to in tensorflow/tflite-micro#3564. Problem: - State Wiping: By default, `arm_lstm_unidirectional_*` unconditionally sets `hidden_in` to `NULL` and memsets `cell_state` to 0. This discards the `HiddenStateTensor` and `CellStateTensor` that TFLM relies on to persist state across `Invoke()` calls for streaming models. - Striding Bug: In the `time_major` = `false` block of `arm_lstm_unidirectional_*`, CMSIS-NN attempts to jump between batches by passing `batch_offset` = `params->time_steps` to `arm_nn_lstm_step_*`. However, `arm_nn_lstm_step_*` forwards this `batch_offset` to `arm_nn_vec_mat_mul_result_acc_s8_s16` for both the `data_in` and `hidden_in` pointers. Since the `hidden_state` buffer is contiguous (stride 1) and not strided like `data_in`, passing `batch_offset` = `params->time_steps` causes out-of-bounds reads on the hidden_in buffer at `timestep` t=0. Solution: - Adding a `hidden_state` pointer to `cmsis_nn_lstm_context`. - Forwarding this `hidden_state` as `hidden_in` when present, skipping the `cell_state` wiping if so. - Explicitly iterating over the `batch_size` in the `time_major` = `false` case when computing step sizes, which forces `batch_offset` = 1 and avoids the buggy out-of-bounds stride entirely while writing to the final memory buffer sequentially.
| if (params.time_steps > 0) { | ||
| std::copy_n(step_hidden_in, params.batch_size * params.hidden_size, | ||
| hidden_state); | ||
| } |
There was a problem hiding this comment.
not sure why this is here. When using the greedy memory planner, the hidden_state may be overwritten by subsequent operator's output(s). See next comment for more info.
| if (params.time_steps > 0) { | ||
| std::copy_n(step_hidden_in, params.hidden_size, | ||
| hidden_state + b * params.hidden_size); | ||
| } |
There was a problem hiding this comment.
Same as the above comment with this additional info: I have not been able to produce a Colab where the. converter will produce a stateful, fused LSTM operation with quantization. The converter (and the Colab session) crash every time. The only time I can make a stateful LSTM in Colab, always produces an unfused LSTM.
| // Update hidden state for next step | ||
| std::copy_n(hidden_out, params.batch_size * params.hidden_size, | ||
| hidden_state); |
There was a problem hiding this comment.
Don't understand why this is inside the step loop. Why not just update the hidden state input pointer as was done in the s8 code?
| // Update hidden state for next step | ||
| std::copy_n(hidden_out, params.hidden_size, current_hidden); |
There was a problem hiding this comment.
Don't understand why this is inside the step loop. Why not just update the hidden state input pointer as was done in the s8 code?
|
Could we add test case where its failing before/working after the fix? |
Problem
The current CMSIS-NN LSTM wrapper uses arm_lstm_unidirectional_s8 and arm_lstm_unidirectional_s16. These CMSIS-NN functions are designed for stateless sequence evaluation: they explicitly wipe the cell state at t=0 and ignore any initial hidden state, returning only the sequence outputs.
This breaks TFLM's streaming/embedded ML workloads which rely on stateful LSTMs where the CellStateTensor and HiddenStateTensor persist as variable tensors across Invoke() calls.
Furthermore, CMSIS-NN's internal implementation for batch-major tensors (time_major=false with batch_size > 1) incorrectly jumps memory by time_steps, causing an out-of-bounds read on the contiguous hidden_state buffer.
Solution
BUG=N/A