Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ if !isfile(joinpath(conda_dir, "condarc-julia.yml"))
touch(joinpath(conda_dir, "conda-meta", "history"))
end
Conda.add_channel("https://software.repos.intel.com/python/conda/", conda_dir)
Conda.add(["dpcpp_linux-64=2025.2.0", "mkl-devel-dpcpp=2025.2.0"], conda_dir)
Conda.add(["dpcpp_linux-64=2026.0.0", "mkl-devel-dpcpp=2026.0.0"], conda_dir)

Conda.list(conda_dir)

Expand Down
4 changes: 2 additions & 2 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ typedef struct omatconvert_descr *omatconvert_descr_t;
struct omatadd_descr;
typedef struct omatadd_descr *omatadd_descr_t;

const int64_t ONEMKL_VERSION_MAJOR = 2025;
const int64_t ONEMKL_VERSION_MINOR = 2;
const int64_t ONEMKL_VERSION_MAJOR = 2026;
const int64_t ONEMKL_VERSION_MINOR = 0;
const int64_t ONEMKL_VERSION_PATCH = 0;
void onemkl_version(int64_t *major, int64_t *minor, int64_t *patch);

Expand Down
33 changes: 28 additions & 5 deletions deps/src/onemkl_dft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ static inline config_param to_param(onemklDftConfigParam p) {
case ONEMKL_DFT_PARAM_NUMBER_OF_TRANSFORMS: return config_param::NUMBER_OF_TRANSFORMS;
case ONEMKL_DFT_PARAM_COMPLEX_STORAGE: return config_param::COMPLEX_STORAGE;
case ONEMKL_DFT_PARAM_PLACEMENT: return config_param::PLACEMENT;
case ONEMKL_DFT_PARAM_INPUT_STRIDES: return config_param::INPUT_STRIDES;
case ONEMKL_DFT_PARAM_OUTPUT_STRIDES: return config_param::OUTPUT_STRIDES;
// oneMKL >= 2026.0 dropped the deprecated INPUT_STRIDES/OUTPUT_STRIDES;
// map the legacy parameters onto their FWD_STRIDES/BWD_STRIDES successors.
case ONEMKL_DFT_PARAM_INPUT_STRIDES: return config_param::FWD_STRIDES;
case ONEMKL_DFT_PARAM_OUTPUT_STRIDES: return config_param::BWD_STRIDES;
case ONEMKL_DFT_PARAM_FWD_DISTANCE: return config_param::FWD_DISTANCE;
case ONEMKL_DFT_PARAM_BWD_DISTANCE: return config_param::BWD_DISTANCE;
case ONEMKL_DFT_PARAM_WORKSPACE: return config_param::WORKSPACE;
Expand Down Expand Up @@ -210,7 +212,28 @@ int onemklDftGetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam para

int onemklDftGetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double *value) {
if (!desc || !value) return -2; if (!desc->ptr) return -3;
try { ONEMKL_DFT_DISPATCH_CFG(desc->ptr, d->get_value(to_param(param), value)); return 0; } catch (...) { return -1; }
try {
config_param p = to_param(param);
// oneMKL >= 2026.0 requires a pointer to the descriptor's real scalar type
// (float for single precision, double for double); use a matching temporary
// and widen to double for the C interface.
if (desc->prec == precision::SINGLE) {
float tmp;
if (desc->dom == domain::REAL)
static_cast< descriptor<precision::SINGLE, domain::REAL>* >(desc->ptr)->get_value(p, &tmp);
else
static_cast< descriptor<precision::SINGLE, domain::COMPLEX>* >(desc->ptr)->get_value(p, &tmp);
*value = static_cast<double>(tmp);
} else {
double tmp;
if (desc->dom == domain::REAL)
static_cast< descriptor<precision::DOUBLE, domain::REAL>* >(desc->ptr)->get_value(p, &tmp);
else
static_cast< descriptor<precision::DOUBLE, domain::COMPLEX>* >(desc->ptr)->get_value(p, &tmp);
*value = tmp;
}
return 0;
} catch (...) { return -1; }
}

int onemklDftGetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *values, int64_t *n) {
Expand Down Expand Up @@ -443,8 +466,8 @@ int onemklDftQueryParamIndices(int64_t *out, int64_t n) {
config_param::NUMBER_OF_TRANSFORMS,
config_param::COMPLEX_STORAGE,
config_param::PLACEMENT,
config_param::INPUT_STRIDES,
config_param::OUTPUT_STRIDES,
config_param::FWD_STRIDES, // was INPUT_STRIDES (removed in oneMKL 2026.0)
config_param::BWD_STRIDES, // was OUTPUT_STRIDES (removed in oneMKL 2026.0)
config_param::FWD_DISTANCE,
config_param::BWD_DISTANCE,
config_param::WORKSPACE,
Expand Down
33 changes: 23 additions & 10 deletions lib/mkl/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,14 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
prec = T<:Float64 || T<:ComplexF64 ? ONEMKL_DFT_PRECISION_DOUBLE : ONEMKL_DFT_PRECISION_SINGLE
dom = complex ? ONEMKL_DFT_DOMAIN_COMPLEX : ONEMKL_DFT_DOMAIN_REAL
desc_ref = Ref{onemklDftDescriptor_t}()
# Create descriptor for the full array dimensions
# Create descriptor for the full array dimensions. `lengths` must stay rooted
# across the ccall: oneMKL copies the dimensions out of `pointer(lengths)`, and
# without GC.@preserve the array can be collected first, leaving the descriptor
# with garbage dimensions (commit then fails with FFT_INVALID_DESCRIPTOR).
lengths = collect(Int64, sz)
st = length(lengths) == 1 ? onemklDftCreate1D(desc_ref, prec, dom, lengths[1]) : onemklDftCreateND(desc_ref, prec, dom, length(lengths), pointer(lengths))
st = GC.@preserve lengths (length(lengths) == 1 ?
onemklDftCreate1D(desc_ref, prec, dom, lengths[1]) :
onemklDftCreateND(desc_ref, prec, dom, length(lengths), pointer(lengths)))
st == 0 || error("onemkl DFT create failed (status $st)")
desc = desc_ref[]
# Do not program descriptor scaling; we'll perform inverse normalization manually.
Expand Down Expand Up @@ -125,8 +130,10 @@ function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,Co
strides[i+1] = prod
prod *= size(X,i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
GC.@preserve strides begin
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
Expand All @@ -144,8 +151,10 @@ function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,C
@inbounds for i in 1:N
strides[i+1]=prod; prod*=size(X,i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
GC.@preserve strides begin
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
Expand All @@ -165,8 +174,10 @@ function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,C
@inbounds for i in 1:N
strides[i+1]=prod; prod*=size(X,i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
GC.@preserve strides begin
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
Expand All @@ -184,8 +195,10 @@ function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,
@inbounds for i in 1:N
strides[i+1]=prod; prod*=size(X,i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
GC.@preserve strides begin
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
Expand Down
2 changes: 1 addition & 1 deletion test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ k = 13

@testset "Version" begin
version_onemkl = oneMKL.version()
@test version_onemkl ≥ v"2025.2.0"
@test version_onemkl ≥ v"2026.0.0"
end

############################################################################################
Expand Down
Loading