diff --git a/ffcx/codegeneration/utils.py b/ffcx/codegeneration/utils.py index ec7d72358..509c2cd5b 100644 --- a/ffcx/codegeneration/utils.py +++ b/ffcx/codegeneration/utils.py @@ -159,3 +159,56 @@ def codegen(context, builder, signature, args): sig = numba.types.voidptr(arr) return sig, codegen + + def _create_voidptr_to_dtype_ptr_caster( + target_dtype: numba.types.Type, + ) -> numba.extending.intrinsic: + """Factory that creates a Numba intrinsic casting void* to CPointer(target_dtype). + + The produced intrinsic accepts either `CPointer(void)` or `voidptr` as input + (matching how UFCx kernels receive `custom_data`) and returns a + `CPointer(target_dtype)` for convenient indexed access in Numba cfuncs. + + The voidptr cast is needed when: + - UFCx kernels pass custom_data as void* (the last parameter in tabulate_tensor) + - Users want to access structured runtime data (e.g., custom_data with element + tables) inside Numba-compiled kernels + - Type-safe indexed access is required (e.g., custom_data[0], custom_data[1]) + + Args: + target_dtype: A Numba scalar type (e.g. `numba.types.float64`). + + Returns: + A Numba intrinsic function that performs the cast. + """ + + @numba.extending.intrinsic + def voidptr_to_dtype_ptr(typingctx, src): + # Accept void pointers in various Numba representations: + # - CPointer(void): from UFCx cfunc signatures (shows as 'none*') + # - voidptr: from numba.cfunc("...(voidptr)") signatures + is_cpointer_void = ( + isinstance(src, numba.types.CPointer) and src.dtype == numba.types.void + ) + is_voidptr = src == numba.types.voidptr + + # Raise a clear error if the source type is not a void pointer + if not is_cpointer_void and not is_voidptr: + msg = ( + "voidptr_to_dtype_ptr expects a void pointer (CPointer(void) or voidptr), " + f"got {src}. Ensure you are passing a void* " + "(e.g., custom_data from UFCx kernel signature)." + ) + raise numba.core.errors.TypingError(msg) + + result_type = numba.types.CPointer(target_dtype) + sig = result_type(src) + + def codegen(context, builder, signature, args): + [src_val] = args + dst_type = context.get_value_type(result_type) + return builder.bitcast(src_val, dst_type) + + return sig, codegen + + return voidptr_to_dtype_ptr diff --git a/test/test_numba_custom_data.py b/test/test_numba_custom_data.py new file mode 100644 index 000000000..3b99df8c3 --- /dev/null +++ b/test/test_numba_custom_data.py @@ -0,0 +1,60 @@ +# Test that the Numba voidptr -> typed pointer caster factory works in ffcx utils +import ctypes + +import numpy as np +import pytest + +import ffcx.codegeneration.utils as codegen_utils + +# Skip the tests if Numba is not available in the environment. +numba = pytest.importorskip("numba") + +float64_ptr_caster = codegen_utils._create_voidptr_to_dtype_ptr_caster(numba.types.float64) +int32_ptr_caster = codegen_utils._create_voidptr_to_dtype_ptr_caster(numba.types.int32) + + +def test_numba_voidptr_struct_like_mixed_types(): + """Test reading a struct-like mixed-type buffer: float64 + int32. + + We create a NumPy structured array with fields ('scale', float64) and + ('id', int32) with padding to align to 16 bytes. The kernel casts the + void* to float64* and int32* and reads the corresponding offsets. + """ + sig = codegen_utils.numba_ufcx_kernel_signature(np.float64, np.float64) + + @numba.cfunc(sig, nopython=True) + def tabulate(b_, w_, c_, coords_, local_index, orientation, custom_data): + b = numba.carray(b_, (1,), dtype=np.float64) + fptr = float64_ptr_caster(custom_data) + iptr = int32_ptr_caster(custom_data) + scale = fptr[0] + # int32 index for offset 8 bytes == 8/4 == 2 + id_val = iptr[2] + b[0] = scale + id_val + + b = np.zeros(1, dtype=np.float64) + w = np.zeros(1, dtype=np.float64) + c = np.zeros(1, dtype=np.float64) + coords = np.zeros(9, dtype=np.float64) + local_index = np.array([0], dtype=np.int32) + orientation = np.array([0], dtype=np.uint8) + + # structured dtype with C-compatible alignment + dtype = np.dtype([("scale", np.float64), ("id", np.int32)], align=True) + arr = np.zeros(1, dtype=dtype) + arr["scale"][0] = 1.25 + arr["id"][0] = 5 + + ptr = arr.ctypes.data + + tabulate.ctypes( + b.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), + w.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), + c.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), + coords.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), + local_index.ctypes.data_as(ctypes.POINTER(ctypes.c_int)), + orientation.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + ctypes.c_void_p(ptr), + ) + + assert b[0] == pytest.approx(6.25)