From ad7adec20ca40bfd31c634fb57876b28e036cc93 Mon Sep 17 00:00:00 2001 From: Lisandro Dalcin Date: Tue, 1 Aug 2023 17:52:45 +0300 Subject: [PATCH] cffi: typedef shmem_{ctx|team}_t as an opaque struct type This allows for OpenSHMEM implementations declaring ctx/team types as either a pointer types or an integral types. Declaring `shmem_{ctx|team}_t` as an opaque struct type has an annoying side-effect: now ctx/team handles can no longer be compared for equality. Therefore, add a couple auxiliary functions `eq_{ctx|team}` to compare handles for equality. Additionally, in case we need them in the future, add functions to convert to/from integer values. --- src/ffibuilder.py | 10 ++++++++-- src/libshmem.c | 1 + src/libshmem/hdltypes.h | 7 +++++++ src/shmem4py/shmem.py | 22 +++++++++++----------- 4 files changed, 27 insertions(+), 13 deletions(-) create mode 100644 src/libshmem/hdltypes.h diff --git a/src/ffibuilder.py b/src/ffibuilder.py index e37c544..8a0fe77 100644 --- a/src/ffibuilder.py +++ b/src/ffibuilder.py @@ -11,8 +11,8 @@ def build_api( module="api", shmem_h="shmem.h", - shmem_ctx_t='...*', - shmem_team_t='...*', + shmem_ctx_t='struct{...;}', + shmem_team_t='struct{...;}', ): from apicodegen import generate ffi = cffi.FFI() @@ -23,6 +23,12 @@ def build_api( ffi.cdef(code) for code in generate(): ffi.cdef(code) + for hdl in ('ctx', 'team'): + ffi.cdef(f""" + bool eq_{hdl}(shmem_{hdl}_t, shmem_{hdl}_t); + uintptr_t {hdl}2id(shmem_{hdl}_t); + shmem_{hdl}_t id2{hdl}(uintptr_t); + """) ffi.cdef(""" int shmem_alltoallsmem_x( shmem_team_t team, diff --git a/src/libshmem.c b/src/libshmem.c index 2dbacea..1dc1d7c 100644 --- a/src/libshmem.c +++ b/src/libshmem.c @@ -50,6 +50,7 @@ /* --- */ +#include "libshmem/hdltypes.h" #include "libshmem/fallback.h" #include "libshmem/initfini.h" #include "libshmem/memalloc.h" diff --git a/src/libshmem/hdltypes.h b/src/libshmem/hdltypes.h new file mode 100644 index 0000000..7a73d58 --- /dev/null +++ b/src/libshmem/hdltypes.h @@ -0,0 +1,7 @@ +#define eq_ctx(a, b) ((a) == (b)) +#define ctx2id(c) ((uintptr_t)(c)) +#define id2ctx(i) ((shmem_ctx_t)(i)) + +#define eq_team(a, b) ((a) == (b)) +#define team2id(t) ((uintptr_t)(t)) +#define id2team(i) ((shmem_team_t)i) diff --git a/src/shmem4py/shmem.py b/src/shmem4py/shmem.py index 3d373b7..67087c9 100644 --- a/src/shmem4py/shmem.py +++ b/src/shmem4py/shmem.py @@ -281,15 +281,15 @@ def __new__( def __eq__(self, other: Any) -> bool: if not isinstance(other, Ctx): return NotImplemented - return self.ob_ctx == other.ob_ctx + return lib.eq_ctx(self.ob_ctx, other.ob_ctx) def __ne__(self, other: Any) -> bool: if not isinstance(other, Ctx): return NotImplemented - return self.ob_ctx != other.ob_ctx + return not lib.eq_ctx(self.ob_ctx, other.ob_ctx) def __bool__(self) -> bool: - return self.ob_ctx != lib.SHMEM_CTX_INVALID + return not lib.eq_ctx(self.ob_ctx, lib.SHMEM_CTX_INVALID) def __enter__(self) -> Ctx: return self @@ -331,9 +331,9 @@ def destroy(self) -> None: return ctx = self.ob_ctx self.ob_ctx = lib.SHMEM_CTX_INVALID - if ctx == lib.SHMEM_CTX_DEFAULT: + if lib.eq_ctx(ctx, lib.SHMEM_CTX_DEFAULT): return - if ctx == lib.SHMEM_CTX_INVALID: + if lib.eq_ctx(ctx, lib.SHMEM_CTX_INVALID): return lib.shmem_ctx_destroy(ctx) @@ -400,15 +400,15 @@ def __new__( def __eq__(self, other: Any) -> bool: if not isinstance(other, Team): return NotImplemented - return self.ob_team == other.ob_team + return lib.eq_team(self.ob_team, other.ob_team) def __ne__(self, other: Any) -> bool: if not isinstance(other, Team): return NotImplemented - return self.ob_team != other.ob_team + return not lib.eq_team(self.ob_team, other.ob_team) def __bool__(self) -> bool: - return self.ob_team != lib.SHMEM_TEAM_INVALID + return not lib.eq_team(self.ob_team, lib.SHMEM_TEAM_INVALID) def __enter__(self) -> Team: return self @@ -426,11 +426,11 @@ def destroy(self) -> None: return team = self.ob_team self.ob_team = lib.SHMEM_TEAM_INVALID - if team == lib.SHMEM_TEAM_WORLD: + if lib.eq_team(team, lib.SHMEM_TEAM_WORLD): return - if team == lib.SHMEM_TEAM_SHARED: + if lib.eq_team(team, lib.SHMEM_TEAM_SHARED): return - if team == lib.SHMEM_TEAM_INVALID: + if lib.eq_team(team, lib.SHMEM_TEAM_INVALID): return lib.shmem_team_destroy(team)