Description
I am seeing incorrect dgemm_ results when one application thread is inside dsyevd_ and another application thread calls dgemm_, using the same loaded OpenBLAS library in the same process.
This originally surfaced in jax-ml/jax#38056, where JAX async dispatch could leave a CPU LAPACK eigensolver call in flight while a later NumPy/OpenBLAS matmul started. The C reproducer below removes JAX, NumPy, and SciPy. It links directly against OpenBLAS and calls the Fortran BLAS/LAPACK symbols.
Environment
OpenBLAS config: OpenBLAS 0.3.33 NO_AFFINITY USE_OPENMP VORTEX MAX_THREADS=128
Platform: macOS / Apple ARM64
Install source: conda-forge
Summary
The reproducer does the following:
background pthread: dsyevd_ on its own matrix/work arrays
main thread: repeatedly calls dgemm_ while dsyevd_ is in flight
The dgemm_ inputs are all ones, so every output entry should equal K.
The user arrays for the eigensolver and GEMM are separate. The only intended shared object is the OpenBLAS library/runtime itself.
I used OMP_NUM_THREADS to control the thread count because this OpenBLAS build reports USE_OPENMP.
Observed behavior
This fails for me:
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 64 64 128 50
Output from one run:
OpenBLAS config: OpenBLAS 0.3.33 NO_AFFINITY USE_OPENMP VORTEX MAX_THREADS=128
OpenBLAS num_threads: 8
use_mutex=0 eig_n=64 M=64 N=64 K=128 iterations=50 GEMM_work=524288
FAIL iter=1 gemm_iter=1 max_abs_err=96.4034 bad=128 first_bad_index=1056 first_bad_value=31.7662 expected=128
Observed failure after 10 GEMM calls.
With a mutex around both OpenBLAS calls, the same case passes:
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 1 64 64 64 128 50
Output:
OpenBLAS config: OpenBLAS 0.3.33 NO_AFFINITY USE_OPENMP VORTEX MAX_THREADS=128
OpenBLAS num_threads: 8
use_mutex=1 eig_n=64 M=64 N=64 K=128 iterations=50 GEMM_work=524288
No incorrect result observed after 50 GEMM calls.
Starting the process with one OpenMP thread also passes:
OMP_NUM_THREADS=1 ./openblas_dsyevd_dgemm_repro 0 64 64 64 128 50
Output:
OpenBLAS config: OpenBLAS 0.3.33 NO_AFFINITY USE_OPENMP VORTEX MAX_THREADS=128
OpenBLAS num_threads: 1
use_mutex=0 eig_n=64 M=64 N=64 K=128 iterations=50 GEMM_work=524288
No incorrect result observed after 314 GEMM calls.
I also see the same threshold behavior for several OpenMP thread counts. With M=N=64, K=128 fails for OMP_NUM_THREADS=2,3,4,8, while the adjacent K=127 case passes for all tested thread counts. With OMP_NUM_THREADS=1, both cases pass.
Threshold-like behavior
The failing case above has:
eig_n = 64
M * N * K = 64 * 64 * 128 = 524288 = 2^19
I also see failures for other factorizations of the same GEMM work:
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 64 64 128 50
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 64 128 64 50
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 128 64 64 50
Outputs from one run:
use_mutex=0 eig_n=64 M=64 N=64 K=128 iterations=50 GEMM_work=524288
FAIL iter=1 gemm_iter=1 max_abs_err=96.4034 bad=128 first_bad_index=1056 first_bad_value=31.7662 expected=128
Observed failure after 10 GEMM calls.
use_mutex=0 eig_n=64 M=64 N=128 K=64 iterations=50 GEMM_work=524288
FAIL iter=16 gemm_iter=0 max_abs_err=64.1916 bad=192 first_bad_index=7424 first_bad_value=47.9644 expected=64
Observed failure after 104 GEMM calls.
use_mutex=0 eig_n=64 M=128 N=64 K=64 iterations=50 GEMM_work=524288
FAIL iter=1 gemm_iter=1 max_abs_err=64.1916 bad=1152 first_bad_index=64 first_bad_value=-0.137265 expected=64
Observed failure after 10 GEMM calls.
Adjacent cases just below the apparent thresholds pass in my tests:
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 63 64 64 128 50
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 63 64 128 50
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 64 63 128 50
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 64 64 127 50
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 64 127 64 50
OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 127 64 64 50
Outputs from one run:
use_mutex=0 eig_n=63 M=64 N=64 K=128 iterations=50 GEMM_work=524288
No incorrect result observed after 380 GEMM calls.
use_mutex=0 eig_n=64 M=63 N=64 K=128 iterations=50 GEMM_work=516096
No incorrect result observed after 465 GEMM calls.
use_mutex=0 eig_n=64 M=64 N=63 K=128 iterations=50 GEMM_work=516096
No incorrect result observed after 516 GEMM calls.
use_mutex=0 eig_n=64 M=64 N=64 K=127 iterations=50 GEMM_work=520192
No incorrect result observed after 539 GEMM calls.
use_mutex=0 eig_n=64 M=64 N=127 K=64 iterations=50 GEMM_work=520192
No incorrect result observed after 368 GEMM calls.
use_mutex=0 eig_n=64 M=127 N=64 K=64 iterations=50 GEMM_work=520192
No incorrect result observed after 495 GEMM calls.
On this machine/build, the issue appears to show up when both of these are true:
eig_n >= 64
M * N * K >= 2^19
under concurrent dsyevd_ / dgemm_ entry with multiple OpenMP threads.
Expected behavior
I would not expect dgemm_ to silently return incorrect values when it is operating on independent user arrays.
I realize concurrent calls into a multithreaded BLAS library from an already-multithreaded application might be discouraged or unsupported. Still, the failure mode here is silent incorrect output from dgemm_, not just oversubscription or poor performance. If this usage pattern is unsupported for USE_OPENMP builds, it would be useful to know what the intended safe usage pattern is.
Reproducer
// openblas_dsyevd_dgemm_repro.c
//
// Reproducer for incorrect dgemm_ results when dgemm_ is called while
// another application thread is inside dsyevd_.
//
// thread 1: dsyevd_
// thread 0: dgemm_ while dsyevd_ is in flight
//
// GEMM computes ones(M,K) @ ones(K,N), so every C entry should equal K.
//
// Usage:
// ./openblas_dsyevd_dgemm_repro [use_mutex] [eig_n] [M] [N] [K] [iterations]
//
// Default:
// use_mutex = 0
// eig_n = 64
// M = 64
// N = 64
// K = 128
// iterations = 50
//
// For the default GEMM, M*N*K = 64*64*128 = 524288 = 2^19.
//
// Example failing case:
// OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0
//
// Mutex control:
// OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 1
//
// Single-thread OpenMP control:
// OMP_NUM_THREADS=1 ./openblas_dsyevd_dgemm_repro 0
//
// Adjacent below-threshold case:
// OMP_NUM_THREADS=8 ./openblas_dsyevd_dgemm_repro 0 64 64 64 127 50
#define _GNU_SOURCE
#include <math.h>
#include <pthread.h>
#include <stdatomic.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
extern char *openblas_get_config(void);
extern int openblas_get_num_threads(void);
extern void dgemm_(
const char *TRANSA, const char *TRANSB,
const int *M, const int *N, const int *K,
const double *ALPHA,
const double *A, const int *LDA,
const double *B, const int *LDB,
const double *BETA,
double *C, const int *LDC
);
extern void dsyevd_(
const char *JOBZ, const char *UPLO,
const int *N,
double *A, const int *LDA,
double *W,
double *WORK, const int *LWORK,
int *IWORK, const int *LIWORK,
int *INFO
);
static pthread_mutex_t blas_mutex = PTHREAD_MUTEX_INITIALIZER;
static atomic_int eig_entered = 0;
static atomic_int eig_done = 0;
static int use_mutex = 0;
static void *xalloc(size_t nbytes) {
void *p = NULL;
if (posix_memalign(&p, 64, nbytes) != 0 || p == NULL) {
fprintf(stderr, "allocation failed for %zu bytes\n", nbytes);
exit(1);
}
return p;
}
static void fill(double *x, size_t n, double v) {
for (size_t i = 0; i < n; ++i) {
x[i] = v;
}
}
static void make_symmetric(double *A, int n) {
for (int j = 0; j < n; ++j) {
for (int i = 0; i < n; ++i) {
A[i + (size_t)j * n] =
(i == j) ? (double)n + 1.0
: 1.0 / (1.0 + fabs((double)i - (double)j));
}
}
}
typedef struct {
int n;
double *A0;
double *A;
double *W;
double *WORK;
int *IWORK;
int LWORK;
int LIWORK;
int INFO;
} EigCtx;
static void *eig_thread(void *arg) {
EigCtx *e = (EigCtx *)arg;
char jobz = 'V';
char uplo = 'U';
int lda = e->n;
memcpy(e->A, e->A0, sizeof(double) * (size_t)e->n * (size_t)e->n);
e->INFO = 0;
if (use_mutex) {
pthread_mutex_lock(&blas_mutex);
}
atomic_store_explicit(&eig_entered, 1, memory_order_release);
dsyevd_(
&jobz, &uplo,
&e->n,
e->A, &lda,
e->W,
e->WORK, &e->LWORK,
e->IWORK, &e->LIWORK,
&e->INFO
);
if (use_mutex) {
pthread_mutex_unlock(&blas_mutex);
}
atomic_store_explicit(&eig_done, 1, memory_order_release);
return NULL;
}
static void call_dgemm(double *A, double *B, double *C, int M, int N, int K) {
char trans = 'N';
double alpha = 1.0;
double beta = 0.0;
if (use_mutex) {
pthread_mutex_lock(&blas_mutex);
}
dgemm_(
&trans, &trans,
&M, &N, &K,
&alpha,
A, &M,
B, &K,
&beta,
C, &M
);
if (use_mutex) {
pthread_mutex_unlock(&blas_mutex);
}
}
static int check_gemm(double *C, int M, int N, int K, int iter, int gemm_iter) {
const double expected = (double)K;
const size_t total = (size_t)M * (size_t)N;
double max_abs_err = 0.0;
size_t bad = 0;
size_t first_bad = 0;
double first_value = 0.0;
for (size_t i = 0; i < total; ++i) {
double err = fabs(C[i] - expected);
if (err > max_abs_err) {
max_abs_err = err;
}
if (err > 1e-9) {
if (bad == 0) {
first_bad = i;
first_value = C[i];
}
++bad;
}
}
if (bad) {
fprintf(stderr,
"FAIL iter=%d gemm_iter=%d max_abs_err=%g bad=%zu "
"first_bad_index=%zu first_bad_value=%g expected=%g\n",
iter, gemm_iter, max_abs_err, bad,
first_bad, first_value, expected);
return 1;
}
return 0;
}
int main(int argc, char **argv) {
use_mutex = (argc > 1) ? atoi(argv[1]) : 0;
int eig_n = (argc > 2) ? atoi(argv[2]) : 64;
int M = (argc > 3) ? atoi(argv[3]) : 64;
int N = (argc > 4) ? atoi(argv[4]) : 64;
int K = (argc > 5) ? atoi(argv[5]) : 128;
int iterations = (argc > 6) ? atoi(argv[6]) : 50;
if (eig_n < 1 || M < 1 || N < 1 || K < 1 || iterations < 1) {
fprintf(stderr,
"Usage: %s [use_mutex] [eig_n] [M] [N] [K] [iterations]\n",
argv[0]);
return 2;
}
fprintf(stderr, "OpenBLAS config: %s\n", openblas_get_config());
fprintf(stderr, "OpenBLAS num_threads: %d\n", openblas_get_num_threads());
fprintf(stderr,
"use_mutex=%d eig_n=%d M=%d N=%d K=%d iterations=%d "
"GEMM_work=%zu\n",
use_mutex, eig_n, M, N, K, iterations,
(size_t)M * (size_t)N * (size_t)K);
EigCtx e;
e.n = eig_n;
e.INFO = 0;
e.A0 = xalloc(sizeof(double) * (size_t)eig_n * (size_t)eig_n);
e.A = xalloc(sizeof(double) * (size_t)eig_n * (size_t)eig_n);
e.W = xalloc(sizeof(double) * (size_t)eig_n);
make_symmetric(e.A0, eig_n);
char jobz = 'V';
char uplo = 'U';
int lda = eig_n;
int lwork_query = -1;
int liwork_query = -1;
double work_query = 0.0;
int iwork_query = 0;
memcpy(e.A, e.A0, sizeof(double) * (size_t)eig_n * (size_t)eig_n);
dsyevd_(
&jobz, &uplo,
&eig_n,
e.A, &lda,
e.W,
&work_query, &lwork_query,
&iwork_query, &liwork_query,
&e.INFO
);
if (e.INFO != 0) {
fprintf(stderr, "dsyevd workspace query failed: info=%d\n", e.INFO);
return 1;
}
e.LWORK = (int)work_query;
e.LIWORK = iwork_query;
e.WORK = xalloc(sizeof(double) * (size_t)e.LWORK);
e.IWORK = xalloc(sizeof(int) * (size_t)e.LIWORK);
double *GA = xalloc(sizeof(double) * (size_t)M * (size_t)K);
double *GB = xalloc(sizeof(double) * (size_t)K * (size_t)N);
double *GC = xalloc(sizeof(double) * (size_t)M * (size_t)N);
fill(GA, (size_t)M * (size_t)K, 1.0);
fill(GB, (size_t)K * (size_t)N, 1.0);
long total_gemm = 0;
for (int iter = 0; iter < iterations; ++iter) {
pthread_t th;
atomic_store_explicit(&eig_entered, 0, memory_order_release);
atomic_store_explicit(&eig_done, 0, memory_order_release);
if (pthread_create(&th, NULL, eig_thread, &e) != 0) {
fprintf(stderr, "pthread_create failed\n");
return 1;
}
while (atomic_load_explicit(&eig_entered, memory_order_acquire) == 0) {
}
int gemm_iter = 0;
while (atomic_load_explicit(&eig_done, memory_order_acquire) == 0) {
memset(GC, 0, sizeof(double) * (size_t)M * (size_t)N);
call_dgemm(GA, GB, GC, M, N, K);
++total_gemm;
if (check_gemm(GC, M, N, K, iter, gemm_iter)) {
pthread_join(th, NULL);
fprintf(stderr, "Observed failure after %ld GEMM calls.\n", total_gemm);
free(e.A0);
free(e.A);
free(e.W);
free(e.WORK);
free(e.IWORK);
free(GA);
free(GB);
free(GC);
return 1;
}
++gemm_iter;
}
pthread_join(th, NULL);
if (e.INFO != 0) {
fprintf(stderr, "dsyevd failed: info=%d\n", e.INFO);
return 1;
}
}
fprintf(stderr, "No incorrect result observed after %ld GEMM calls.\n", total_gemm);
free(e.A0);
free(e.A);
free(e.W);
free(e.WORK);
free(e.IWORK);
free(GA);
free(GB);
free(GC);
return 0;
}
Compile command I used:
clang -O2 -std=c11 -pthread openblas_dsyevd_dgemm_repro.c \
"$CONDA_PREFIX/lib/libopenblas.0.dylib" \
-Wl,-rpath,"$CONDA_PREFIX/lib" \
-lm \
-o openblas_dsyevd_dgemm_repro
Description
I am seeing incorrect
dgemm_results when one application thread is insidedsyevd_and another application thread callsdgemm_, using the same loaded OpenBLAS library in the same process.This originally surfaced in jax-ml/jax#38056, where JAX async dispatch could leave a CPU LAPACK eigensolver call in flight while a later NumPy/OpenBLAS matmul started. The C reproducer below removes JAX, NumPy, and SciPy. It links directly against OpenBLAS and calls the Fortran BLAS/LAPACK symbols.
Environment
Summary
The reproducer does the following:
The
dgemm_inputs are all ones, so every output entry should equalK.The user arrays for the eigensolver and GEMM are separate. The only intended shared object is the OpenBLAS library/runtime itself.
I used
OMP_NUM_THREADSto control the thread count because this OpenBLAS build reportsUSE_OPENMP.Observed behavior
This fails for me:
Output from one run:
With a mutex around both OpenBLAS calls, the same case passes:
Output:
Starting the process with one OpenMP thread also passes:
Output:
I also see the same threshold behavior for several OpenMP thread counts. With
M=N=64,K=128fails forOMP_NUM_THREADS=2,3,4,8, while the adjacentK=127case passes for all tested thread counts. WithOMP_NUM_THREADS=1, both cases pass.Threshold-like behavior
The failing case above has:
I also see failures for other factorizations of the same GEMM work:
Outputs from one run:
Adjacent cases just below the apparent thresholds pass in my tests:
Outputs from one run:
On this machine/build, the issue appears to show up when both of these are true:
under concurrent
dsyevd_/dgemm_entry with multiple OpenMP threads.Expected behavior
I would not expect
dgemm_to silently return incorrect values when it is operating on independent user arrays.I realize concurrent calls into a multithreaded BLAS library from an already-multithreaded application might be discouraged or unsupported. Still, the failure mode here is silent incorrect output from
dgemm_, not just oversubscription or poor performance. If this usage pattern is unsupported forUSE_OPENMPbuilds, it would be useful to know what the intended safe usage pattern is.Reproducer
Compile command I used: