diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 1a17ba6da1..3166bf255a 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -1008,6 +1008,12 @@ class Device(Platform): "warp" in NVidia GPUs and a "wavefront" in AMD GPUs. """ + thread_group_slots = None + """ + Number of thread groups issue/execution slots per compute engine (e.g. + SM for Nvidia GPUs, CU for AMD GPUs). + """ + def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp', max_threads_per_block=1024, max_threads_dimx=1024, max_threads_dimy=1024, max_threads_dimz=64, @@ -1091,6 +1097,7 @@ def march(self): class NvidiaDevice(Device): thread_group_size = 32 + thread_group_slots = 4 max_mem_trans_nbytes = 128 @@ -1182,6 +1189,7 @@ class Blackwell(Hopper): class AmdDevice(Device): thread_group_size = 64 + thread_group_slots = 4 max_mem_trans_nbytes = 256 diff --git a/devito/types/array.py b/devito/types/array.py index bd7747d40d..13d3b105a2 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -653,7 +653,13 @@ def indices(self): @property def dtype(self): - return self.function.dtype + try: + return self.function.c0.dtype + except AttributeError: + # Vector-component access over a scalar symbol, e.g. a float4 register. + if self.function.is_Symbol: + return dtypes_vector_mapper.get_base_dtype(self.function.dtype) + raise @cacheit def sort_key(self, order=None): diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 49f134148f..48709d7cd1 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -20,7 +20,7 @@ SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions, retrieve_indexed, uxreplace ) -from devito.tools import CustomDtype, as_tuple +from devito.tools import CustomDtype, as_tuple, dtypes_vector_mapper from devito.types import ( Array, Bundle, ComponentAccess, FIndexed, LocalObject, Object, StencilDimension ) @@ -575,6 +575,13 @@ def test_component_access(): assert cf2 == cf1 +def test_component_access_symbol_printing(): + acc = dSymbol(name='acc', dtype=dtypes_vector_mapper[(np.float32, 4)]) + expr = ComponentAccess(acc, 0) + + assert ccode(sympy.Float('1.25')*expr, dtype=expr.dtype) == '1.250F*acc.x' + + def test_vector_access(): grid = Grid(shape=(3, 3, 3))