Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/mkl/fft.jl b/lib/mkl/fft.jl
index 5f5614b..9d38e11 100644
--- a/lib/mkl/fft.jl
+++ b/lib/mkl/fft.jl
@@ -22,34 +22,34 @@ using ..Support
# Allow implicit conversion of SYCL queue object to raw handle when storing/passing
Base.convert(::Type{syclQueue_t}, q::SYCL.syclQueue) = Base.unsafe_convert(syclQueue_t, q)
-abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
+abstract type MKLFFTPlan{T, K, inplace} <: AbstractFFTs.Plan{T} end
-Base.eltype(::MKLFFTPlan{T}) where T = T
-is_inplace(::MKLFFTPlan{<:Any,<:Any,inplace}) where inplace = inplace
+Base.eltype(::MKLFFTPlan{T}) where {T} = T
+is_inplace(::MKLFFTPlan{<:Any, <:Any, inplace}) where {inplace} = inplace
# Forward / inverse flags
const MKLFFT_FORWARD = true
const MKLFFT_INVERSE = false
-mutable struct cMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct cMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
handle::onemklDftDescriptor_t
queue::syclQueue_t
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
realdomain::Bool
- region::NTuple{R,Int}
+ region::NTuple{R, Int}
buffer::B
pinv::Any
end
# Real transforms use separate struct (mirroring AMDGPU style) for buffer staging
-mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct rMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
handle::onemklDftDescriptor_t
queue::syclQueue_t
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
xtype::Symbol
- region::NTuple{R,Int}
+ region::NTuple{R, Int}
buffer::B
pinv::Any
end
@@ -57,40 +57,44 @@ end
# Inverse plan constructors (derive from existing plan)
function normalization_factor(sz, region)
# AbstractFFTs expects inverse to scale by 1/prod(lengths along region)
- prod(ntuple(i-> sz[region[i]], length(region)))
+ return prod(ntuple(i -> sz[region[i]], length(region)))
end
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:brfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :brfft, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:rfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :rfft, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function Base.show(io::IO, p::MKLFFTPlan{T,K,inplace}) where {T,K,inplace}
+function Base.show(io::IO, p::MKLFFTPlan{T, K, inplace}) where {T, K, inplace}
print(io, inplace ? "oneMKL FFT in-place " : "oneMKL FFT ", K ? "forward" : "inverse", " plan for ")
- if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
- print(io, " oneArray of ", T)
+ if isempty(p.sz)
+ print(io, "0-dimensional")
+ else
+ print(io, join(p.sz, "×"))
+ end
+ return print(io, " oneArray of ", T)
end
# Plan constructors
-function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
- prec = T<:Float64 || T<:ComplexF64 ? ONEMKL_DFT_PRECISION_DOUBLE : ONEMKL_DFT_PRECISION_SINGLE
+function _create_descriptor(sz::NTuple{N, Int}, T::Type, complex::Bool) where {N}
+ prec = T <: Float64 || T <: ComplexF64 ? ONEMKL_DFT_PRECISION_DOUBLE : ONEMKL_DFT_PRECISION_SINGLE
dom = complex ? ONEMKL_DFT_DOMAIN_COMPLEX : ONEMKL_DFT_DOMAIN_REAL
desc_ref = Ref{onemklDftDescriptor_t}()
# Create descriptor for the full array dimensions
@@ -109,8 +113,8 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
end
# Complex plans
-function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -119,20 +123,20 @@ function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,Co
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
if N > 1
# Column-major strides: stride along dimension i is product of sizes of previous dims
- strides = Vector{Int64}(undef, N+1); strides[1]=0
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0
prod = 1
@inbounds for i in 1:N
- strides[i+1] = prod
- prod *= size(X,i)
+ strides[i + 1] = prod
+ prod *= size(X, i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
- return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
-function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -140,87 +144,87 @@ function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,C
desc, q = _create_descriptor(size(X), T, true)
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
- return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
# In-place (provide separate methods)
-function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
end
- desc,q = _create_descriptor(size(X),T,true)
+ desc, q = _create_descriptor(size(X), T, true)
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_INPLACE)
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
- cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_FORWARD, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
-function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
end
- desc,q = _create_descriptor(size(X),T,true)
+ desc, q = _create_descriptor(size(X), T, true)
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_INPLACE)
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_FWD_STRIDES, pointer(strides), length(strides))
onemklDftSetValueInt64Array(desc, ONEMKL_DFT_PARAM_BWD_STRIDES, pointer(strides), length(strides))
end
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
- cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_INVERSE, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
# Real input methods - convert to complex like FFTW does
-function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_fft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
CT = Complex{T}
# Create a complex plan by converting the real array to complex
X_complex = oneAPI.oneArray{CT}(undef, size(X))
- plan_fft(X_complex, region)
+ return plan_fft(X_complex, region)
end
-function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_bfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
CT = Complex{T}
# Create a complex plan by converting the real array to complex
X_complex = oneAPI.oneArray{CT}(undef, size(X))
- plan_bfft(X_complex, region)
+ return plan_bfft(X_complex, region)
end
-function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_fft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
error("In-place FFT not supported for real input arrays. Use plan_fft instead.")
end
-function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_bfft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
error("In-place FFT not supported for real input arrays. Use plan_bfft instead.")
end
# Real forward (out-of-place) - supports multi-dimensional transforms
-function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_rfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
# Convert region to tuple if it's a range
if isa(region, AbstractUnitRange)
region = tuple(region...)
end
- R = length(region); reg = NTuple{R,Int}(region)
+ R = length(region); reg = NTuple{R, Int}(region)
# For single dimension transforms, use the optimized oneMKL real FFT
if R == 1 && reg[1] == 1
@@ -234,12 +238,12 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
end
# Single-dimension real FFT using oneMKL (optimized path)
-function _plan_rfft_1d(X::oneAPI.oneArray{T,N}, reg::NTuple{1,Int}) where {T<:Union{Float32,Float64},N}
+function _plan_rfft_1d(X::oneAPI.oneArray{T, N}, reg::NTuple{1, Int}) where {T <: Union{Float32, Float64}, N}
# Create 1D descriptor for the transform dimension
- desc,q = _create_descriptor((size(X, reg[1]),), T, false)
+ desc, q = _create_descriptor((size(X, reg[1]),), T, false)
xdims = size(X)
# output along first dim becomes N/2+1
- ydims = Base.setindex(xdims, div(xdims[1],2)+1, 1)
+ ydims = Base.setindex(xdims, div(xdims[1], 2) + 1, 1)
buffer = oneAPI.oneArray{Complex{T}}(undef, ydims)
onemklDftSetValueConfigValue(desc, ONEMKL_DFT_PARAM_PLACEMENT, ONEMKL_DFT_VALUE_NOT_INPLACE)
@@ -255,18 +259,18 @@ function _plan_rfft_1d(X::oneAPI.oneArray{T,N}, reg::NTuple{1,Int}) where {T<:Un
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
R = length(reg)
- rMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:rfft,reg,buffer,nothing)
+ return rMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :rfft, reg, buffer, nothing)
end
# Multi-dimensional real FFT using complex FFT approach
-struct ComplexBasedRealFFTPlan{T,N,R} <: MKLFFTPlan{T,MKLFFT_FORWARD,false}
- complex_plan::cMKLFFTPlan{Complex{T},MKLFFT_FORWARD,false,N,R,Nothing}
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
- region::NTuple{R,Int}
+struct ComplexBasedRealFFTPlan{T, N, R} <: MKLFFTPlan{T, MKLFFT_FORWARD, false}
+ complex_plan::cMKLFFTPlan{Complex{T}, MKLFFT_FORWARD, false, N, R, Nothing}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
+ region::NTuple{R, Int}
end
-function _plan_rfft_nd(X::oneAPI.oneArray{T,N}, reg::NTuple{R,Int}) where {T<:Union{Float32,Float64},N,R}
+function _plan_rfft_nd(X::oneAPI.oneArray{T, N}, reg::NTuple{R, Int}) where {T <: Union{Float32, Float64}, N, R}
# Create complex version for planning
X_complex = oneAPI.oneArray{Complex{T}}(undef, size(X))
complex_plan = plan_fft(X_complex, reg)
@@ -281,18 +285,22 @@ function _plan_rfft_nd(X::oneAPI.oneArray{T,N}, reg::NTuple{R,Int}) where {T<:Un
end
end
- ComplexBasedRealFFTPlan{T,N,R}(complex_plan, xdims, ydims, reg)
+ return ComplexBasedRealFFTPlan{T, N, R}(complex_plan, xdims, ydims, reg)
end
# Show method for complex-based plan
function Base.show(io::IO, p::ComplexBasedRealFFTPlan{T}) where {T}
print(io, "oneMKL FFT forward plan for ")
- if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
- print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
+ if isempty(p.sz)
+ print(io, "0-dimensional")
+ else
+ print(io, join(p.sz, "×"))
+ end
+ return print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
end
# Execution for complex-based real FFT plan
-function Base.:*(p::ComplexBasedRealFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where {T,N,R}
+function Base.:*(p::ComplexBasedRealFFTPlan{T, N, R}, X::oneAPI.oneArray{T}) where {T, N, R}
# Convert to complex
X_complex = Complex{T}.(X)
@@ -316,14 +324,13 @@ function Base.:*(p::ComplexBasedRealFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where
end
-
# Real inverse (complex->real) requires complex input shape - supports multi-dimensional transforms
-function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union{ComplexF32,ComplexF64},N}
+function plan_brfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T <: Union{ComplexF32, ComplexF64}, N}
# Convert region to tuple if it's a range
if isa(region, AbstractUnitRange)
region = tuple(region...)
end
- R = length(region); reg = NTuple{R,Int}(region)
+ R = length(region); reg = NTuple{R, Int}(region)
# For single dimension transforms along first dim, use optimized oneMKL path
if R == 1 && reg[1] == 1
@@ -335,13 +342,13 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
end
# Single-dimension real inverse FFT using oneMKL (optimized path)
-function _plan_brfft_1d(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{1,Int}) where {T<:Union{ComplexF32,ComplexF64},N}
+function _plan_brfft_1d(X::oneAPI.oneArray{T, N}, d::Integer, reg::NTuple{1, Int}) where {T <: Union{ComplexF32, ComplexF64}, N}
# Extract underlying real type R from Complex{R}
@assert T <: Complex
RT = T.parameters[1]
# Create 1D descriptor for the transform dimension
- desc,q = _create_descriptor((d,), RT, false)
+ desc, q = _create_descriptor((d,), RT, false)
xdims = size(X)
ydims = Base.setindex(xdims, d, 1)
buffer = oneAPI.oneArray{T}(undef, xdims) # copy for safety
@@ -355,19 +362,19 @@ function _plan_brfft_1d(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{1,Int})
stc = onemklDftCommit(desc, q); stc == 0 || error("commit failed ($stc)")
R = length(reg)
- rMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:brfft,reg,buffer,nothing)
+ return rMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :brfft, reg, buffer, nothing)
end
# Multi-dimensional real inverse FFT using complex FFT approach
-struct ComplexBasedRealIFFTPlan{T,N,R} <: MKLFFTPlan{T,MKLFFT_INVERSE,false}
- complex_plan::cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
- region::NTuple{R,Int}
+struct ComplexBasedRealIFFTPlan{T, N, R} <: MKLFFTPlan{T, MKLFFT_INVERSE, false}
+ complex_plan::cMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, Nothing}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
+ region::NTuple{R, Int}
d::Int # Original size of the reduced dimension
end
-function _plan_brfft_nd(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{R,Int}) where {T<:Union{ComplexF32,ComplexF64},N,R}
+function _plan_brfft_nd(X::oneAPI.oneArray{T, N}, d::Integer, reg::NTuple{R, Int}) where {T <: Union{ComplexF32, ComplexF64}, N, R}
# Calculate the full complex array size (before real FFT reduction)
xdims = size(X)
full_complex_dims = ntuple(N) do i
@@ -382,18 +389,22 @@ function _plan_brfft_nd(X::oneAPI.oneArray{T,N}, d::Integer, reg::NTuple{R,Int})
X_complex_full = oneAPI.oneArray{T}(undef, full_complex_dims)
complex_plan = plan_bfft(X_complex_full, reg)
- ComplexBasedRealIFFTPlan{T,N,R}(complex_plan, xdims, full_complex_dims, reg, d)
+ return ComplexBasedRealIFFTPlan{T, N, R}(complex_plan, xdims, full_complex_dims, reg, d)
end
# Show method for complex-based inverse plan
function Base.show(io::IO, p::ComplexBasedRealIFFTPlan{T}) where {T}
print(io, "oneMKL FFT inverse plan for ")
- if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
- print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
+ if isempty(p.sz)
+ print(io, "0-dimensional")
+ else
+ print(io, join(p.sz, "×"))
+ end
+ return print(io, " oneArray of ", T, " (multi-dimensional via complex FFT)")
end
# Execution for complex-based real inverse FFT plan
-function Base.:*(p::ComplexBasedRealIFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) where {T,N,R}
+function Base.:*(p::ComplexBasedRealIFFTPlan{T, N, R}, X::oneAPI.oneArray{T}) where {T, N, R}
# Reconstruct full complex array by exploiting conjugate symmetry
# This is a simplified approach - for full accuracy, we'd need to properly
# reconstruct the conjugate symmetric part
@@ -435,7 +446,7 @@ function Base.:*(p::ComplexBasedRealIFFTPlan{T,N,R}, X::oneAPI.oneArray{T}) wher
end
# Inverse plan for complex-based real FFT plans
-function plan_inv(p::ComplexBasedRealFFTPlan{T,N,R}) where {T,N,R}
+function plan_inv(p::ComplexBasedRealFFTPlan{T, N, R}) where {T, N, R}
# For real FFT inverse, we need plan_brfft functionality
# The first dimension in the region should be the one that was reduced
first_dim = minimum(p.region)
@@ -443,18 +454,17 @@ function plan_inv(p::ComplexBasedRealFFTPlan{T,N,R}) where {T,N,R}
# Create inverse plan using our new multi-dimensional brfft
brfft_plan = _plan_brfft_nd(oneAPI.oneArray{Complex{T}}(undef, p.osz), d, p.region)
- ScaledPlan(brfft_plan, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(brfft_plan, 1 / normalization_factor(p.sz, p.region))
end
# Inverse plan for complex-based real inverse FFT plans
-function plan_inv(p::ComplexBasedRealIFFTPlan{T,N,R}) where {T,N,R}
+function plan_inv(p::ComplexBasedRealIFFTPlan{T, N, R}) where {T, N, R}
# Create forward plan
forward_plan = _plan_rfft_nd(oneAPI.oneArray{real(T)}(undef, p.osz), p.region)
- ScaledPlan(forward_plan, 1/normalization_factor(p.osz, p.region))
+ return ScaledPlan(forward_plan, 1 / normalization_factor(p.osz, p.region))
end
-
# Convenience no-region methods use all dimensions in order
plan_fft(X::oneAPI.oneArray) = plan_fft(X, ntuple(identity, ndims(X)))
plan_bfft(X::oneAPI.oneArray) = plan_bfft(X, ntuple(identity, ndims(X)))
@@ -467,111 +477,119 @@ plan_brfft(X::oneAPI.oneArray, d::Integer) = plan_brfft(X, d, ntuple(identity, n
const plan_ifft = plan_bfft
const plan_ifft! = plan_bfft!
# plan_irfft should be normalized, unlike plan_brfft
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T,N} = begin
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T, N} = begin
p = plan_brfft(X, d, region)
- ScaledPlan(p, 1/normalization_factor(p.sz, p.region))
+ ScaledPlan(p, 1 / normalization_factor(p.sz, p.region))
end
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer) where {T,N} = plan_irfft(X, d, (1,))
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer) where {T, N} = plan_irfft(X, d, (1,))
# Inversion
Base.inv(p::MKLFFTPlan) = plan_inv(p)
# High-level wrappers operating like CPU FFTW versions.
-function fft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
- (plan_fft(X) * X)
+function fft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
+ return (plan_fft(X) * X)
end
-function ifft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
p = plan_bfft(X)
# Apply normalization for ifft (unlike bfft which is unnormalized)
scaling = one(T) / normalization_factor(size(X), ntuple(identity, ndims(X)))
- scaling * (p * X)
+ return scaling * (p * X)
end
-function fft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function fft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
(plan_fft!(X) * X; X)
end
-function ifft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
p = plan_bfft!(X)
# Apply normalization for ifft! (unlike bfft! which is unnormalized)
scaling = one(T) / normalization_factor(size(X), ntuple(identity, ndims(X)))
p * X
X .*= scaling
- X
+ return X
end
-function rfft(X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- (plan_rfft(X) * X)
+function rfft(X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ return (plan_rfft(X) * X)
end
-function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T<:Union{ComplexF32,ComplexF64}}
+function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T <: Union{ComplexF32, ComplexF64}}
# Use the normalized plan_irfft instead of unnormalized plan_brfft
- (plan_irfft(X, d) * X)
+ return (plan_irfft(X, d) * X)
end
# Execution helpers
-_rawptr(a::oneAPI.oneArray{T}) where T = reinterpret(Ptr{Cvoid}, pointer(a))
+_rawptr(a::oneAPI.oneArray{T}) where {T} = reinterpret(Ptr{Cvoid}, pointer(a))
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_FORWARD,true}, X::oneAPI.oneArray{T}) where T
- st = onemklDftComputeForward(p.handle, _rawptr(X)); st==0 || error("forward FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_FORWARD, true}, X::oneAPI.oneArray{T}) where {T}
+ st = onemklDftComputeForward(p.handle, _rawptr(X)); st == 0 || error("forward FFT failed ($st)")
+ return X
end
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_INVERSE,true}, X::oneAPI.oneArray{T}) where T
- st = onemklDftComputeBackward(p.handle, _rawptr(X)); st==0 || error("inverse FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_INVERSE, true}, X::oneAPI.oneArray{T}) where {T}
+ st = onemklDftComputeBackward(p.handle, _rawptr(X)); st == 0 || error("inverse FFT failed ($st)")
+ return X
end
-function _exec!(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T,K}
- st = (K==MKLFFT_FORWARD ? onemklDftComputeForwardOutOfPlace : onemklDftComputeBackwardOutOfPlace)(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("FFT failed ($st)"); Y
+function _exec!(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T, K}
+ st = (K == MKLFFT_FORWARD ? onemklDftComputeForwardOutOfPlace : onemklDftComputeBackwardOutOfPlace)(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("FFT failed ($st)")
+ return Y
end
# Real forward
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where T
- st = onemklDftComputeForwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("rfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where {T}
+ st = onemklDftComputeForwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("rfft failed ($st)")
+ return Y
end
# Real inverse (complex -> real)
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R,T<:Complex{R}}
- st = onemklDftComputeBackwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("brfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R, T <: Complex{R}}
+ st = onemklDftComputeBackwardOutOfPlace(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("brfft failed ($st)")
+ return Y
end
# Public API similar to AMDGPU
-function Base.:*(p::cMKLFFTPlan{T,K,true}, X::oneAPI.oneArray{T}) where {T,K}
- _exec!(p,X)
+function Base.:*(p::cMKLFFTPlan{T, K, true}, X::oneAPI.oneArray{T}) where {T, K}
+ return _exec!(p, X)
end
-function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
- Y = oneAPI.oneArray{T}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+ Y = oneAPI.oneArray{T}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+ return _exec!(p, X, Y)
end
# Real forward
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- Y = oneAPI.oneArray{Complex{T}}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ Y = oneAPI.oneArray{Complex{T}}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ return _exec!(p, X, Y)
end
# Real inverse
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
- Y = oneAPI.oneArray{R}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+ Y = oneAPI.oneArray{R}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+ return _exec!(p, X, Y)
end
# Support for applying complex plans to real arrays (convert real to complex first)
-function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{R}) where {T,K,R<:Union{Float32,Float64}}
+function Base.:*(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{R}) where {T, K, R <: Union{Float32, Float64}}
# Only allow if T is the complex version of R
if T != Complex{R}
error("Type mismatch: plan expects $(T) but got $(R)")
end
# Convert real input to complex
X_complex = complex.(X)
- p * X_complex
+ return p * X_complex
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{R}) where {T,K,R<:Union{Float32,Float64}}
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{R}) where {T, K, R <: Union{Float32, Float64}}
# Only allow if T is the complex version of R
if T != Complex{R}
error("Type mismatch: plan expects $(T) but got $(R)")
end
# Convert real input to complex
X_complex = complex.(X)
- _exec!(p, X_complex, Y)
+ return _exec!(p, X_complex, Y)
end
end # module FFT
diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl
index 06d8bee..0ea694b 100644
--- a/lib/support/liboneapi_support.jl
+++ b/lib/support/liboneapi_support.jl
@@ -7111,122 +7111,160 @@ mutable struct onemklDftDescriptor_st end
const onemklDftDescriptor_t = Ptr{onemklDftDescriptor_st}
function onemklDftCreate1D(desc, precision, domain, length)
- @ccall liboneapi_support.onemklDftCreate1D(desc::Ptr{onemklDftDescriptor_t},
- precision::onemklDftPrecision,
- domain::onemklDftDomain, length::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftCreate1D(
+ desc::Ptr{onemklDftDescriptor_t},
+ precision::onemklDftPrecision,
+ domain::onemklDftDomain, length::Int64
+ )::Cint
end
function onemklDftCreateND(desc, precision, domain, dim, lengths)
- @ccall liboneapi_support.onemklDftCreateND(desc::Ptr{onemklDftDescriptor_t},
- precision::onemklDftPrecision,
- domain::onemklDftDomain, dim::Int64,
- lengths::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftCreateND(
+ desc::Ptr{onemklDftDescriptor_t},
+ precision::onemklDftPrecision,
+ domain::onemklDftDomain, dim::Int64,
+ lengths::Ptr{Int64}
+ )::Cint
end
function onemklDftDestroy(desc)
- @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
+ return @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
end
function onemklDftCommit(desc, queue)
- @ccall liboneapi_support.onemklDftCommit(desc::onemklDftDescriptor_t,
- queue::syclQueue_t)::Cint
+ return @ccall liboneapi_support.onemklDftCommit(
+ desc::onemklDftDescriptor_t,
+ queue::syclQueue_t
+ )::Cint
end
function onemklDftSetValueInt64(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueInt64(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueInt64(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Int64
+ )::Cint
end
function onemklDftSetValueDouble(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueDouble(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Cdouble)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueDouble(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Cdouble
+ )::Cint
end
function onemklDftSetValueInt64Array(desc, param, values, n)
- @ccall liboneapi_support.onemklDftSetValueInt64Array(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- values::Ptr{Int64}, n::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueInt64Array(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ values::Ptr{Int64}, n::Int64
+ )::Cint
end
function onemklDftSetValueConfigValue(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueConfigValue(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::onemklDftConfigValue)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueConfigValue(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::onemklDftConfigValue
+ )::Cint
end
function onemklDftGetValueInt64(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueInt64(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueInt64(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{Int64}
+ )::Cint
end
function onemklDftGetValueDouble(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueDouble(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{Cdouble})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueDouble(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{Cdouble}
+ )::Cint
end
function onemklDftGetValueInt64Array(desc, param, values, n)
- @ccall liboneapi_support.onemklDftGetValueInt64Array(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- values::Ptr{Int64},
- n::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueInt64Array(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ values::Ptr{Int64},
+ n::Ptr{Int64}
+ )::Cint
end
function onemklDftGetValueConfigValue(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueConfigValue(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{onemklDftConfigValue})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueConfigValue(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{onemklDftConfigValue}
+ )::Cint
end
function onemklDftComputeForward(desc, inout)
- @ccall liboneapi_support.onemklDftComputeForward(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForward(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardOutOfPlace(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackward(desc, inout)
- @ccall liboneapi_support.onemklDftComputeBackward(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackward(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardOutOfPlace(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardBuffer(desc, inout)
- @ccall liboneapi_support.onemklDftComputeForwardBuffer(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardBuffer(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardOutOfPlaceBuffer(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardBuffer(desc, inout)
- @ccall liboneapi_support.onemklDftComputeBackwardBuffer(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardBuffer(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardOutOfPlaceBuffer(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftQueryParamIndices(out, n)
- @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
end
const ONEMKL_DFT_STATUS_SUCCESS = 0
diff --git a/res/wrap.jl b/res/wrap.jl
index 1d48315..2e9b29f 100644
--- a/res/wrap.jl
+++ b/res/wrap.jl
@@ -112,14 +112,14 @@ using oneAPI_Level_Zero_Headers_jll
function main()
wrap("ze", oneAPI_Level_Zero_Headers_jll.ze_api)
- wrap(
- "support",
- joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
- joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
- joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
- dependents=false,
- include_dirs=[dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
- )
+ return wrap(
+ "support",
+ joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
+ joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
+ joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
+ dependents = false,
+ include_dirs = [dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
+ )
end
isinteractive() || main()
diff --git a/test/fft.jl b/test/fft.jl
index 1b148df..ef81c21 100644
--- a/test/fft.jl
+++ b/test/fft.jl
@@ -7,39 +7,39 @@ using Random
Random.seed!(1234)
# Helper to move data to GPU
-gpu(A::AbstractArray{T}) where T = oneAPI.oneArray{T}(A)
+gpu(A::AbstractArray{T}) where {T} = oneAPI.oneArray{T}(A)
struct _Plan end
struct _FFT end
-const MYRTOL = 1e-5
-const MYATOL = 1e-8
+const MYRTOL = 1.0e-5
+const MYATOL = 1.0e-8
-function cmp(a,b; rtol=MYRTOL, atol=MYATOL)
- @test isapprox(Array(a), Array(b); rtol=rtol, atol=atol)
+function cmp(a, b; rtol = MYRTOL, atol = MYATOL)
+ return @test isapprox(Array(a), Array(b); rtol = rtol, atol = atol)
end
-function test_plan(::_Plan, plan, X::AbstractArray{T,N}) where {T,N}
+function test_plan(::_Plan, plan, X::AbstractArray{T, N}) where {T, N}
p = plan(X)
Y = p * X
return Y
end
-function test_plan(::_FFT, f, X::AbstractArray{T,N}) where {T,N}
+function test_plan(::_FFT, f, X::AbstractArray{T, N}) where {T, N}
Y = if f === AbstractFFTs.irfft || f === AbstractFFTs.brfft
- f(X, size(X, ndims(X))*2 - 2)
+ f(X, size(X, ndims(X)) * 2 - 2)
else
f(X)
end
return Y
end
-function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan=nothing)
+function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan = nothing)
X = rand(T, dim)
dX = gpu(X)
Y = test_plan(t, plan, X)
dY = test_plan(t, plan, dX)
cmp(dY, Y)
- if iplan !== nothing
+ return if iplan !== nothing
iX = test_plan(t, iplan, Y)
idX = test_plan(t, iplan, dY)
cmp(idX, iX)
@@ -47,36 +47,36 @@ function test_plan(t, plan::Function, dim::Tuple, T::Type, iplan=nothing)
end
@testset "FFT" begin
-@testset "$(length(dim))D" for dim in [(8,), (8,32), (8,32,64)]
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_ifft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_bfft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_ifft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_bfft)
- test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float32)
- test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF32, AbstractFFTs.plan_bfft!)
- # Not part of FFTW
- # test_plan(AbstractFFTs.plan_rfft!, Float32)
- test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.ifft)
- test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.bfft)
- if length(dim) == 1 # irfft/brfft only for 1D
- test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.irfft)
- test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.brfft)
- end
- if (ComplexF64 in eltypes) && (Float64 in eltypes)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_ifft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_bfft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_ifft)
- test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_bfft)
- test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float64)
- test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF64, AbstractFFTs.plan_bfft!)
+ @testset "$(length(dim))D" for dim in [(8,), (8, 32), (8, 32, 64)]
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_ifft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF32, AbstractFFTs.plan_bfft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_ifft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float32, AbstractFFTs.plan_bfft)
+ test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float32)
+ test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF32, AbstractFFTs.plan_bfft!)
# Not part of FFTW
- # test_plan(AbstractFFTs.plan_rfft!, Float64)
- test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.ifft)
- test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.bfft)
+ # test_plan(AbstractFFTs.plan_rfft!, Float32)
+ test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.ifft)
+ test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF32, AbstractFFTs.bfft)
if length(dim) == 1 # irfft/brfft only for 1D
- test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.irfft)
- test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.brfft)
+ test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.irfft)
+ test_plan(_FFT(), AbstractFFTs.rfft, dim, Float32, AbstractFFTs.brfft)
+ end
+ if (ComplexF64 in eltypes) && (Float64 in eltypes)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_ifft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, ComplexF64, AbstractFFTs.plan_bfft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_ifft)
+ test_plan(_Plan(), AbstractFFTs.plan_fft, dim, Float64, AbstractFFTs.plan_bfft)
+ test_plan(_Plan(), AbstractFFTs.plan_rfft, dim, Float64)
+ test_plan(_Plan(), AbstractFFTs.plan_fft!, dim, ComplexF64, AbstractFFTs.plan_bfft!)
+ # Not part of FFTW
+ # test_plan(AbstractFFTs.plan_rfft!, Float64)
+ test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.ifft)
+ test_plan(_FFT(), AbstractFFTs.fft, dim, ComplexF64, AbstractFFTs.bfft)
+ if length(dim) == 1 # irfft/brfft only for 1D
+ test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.irfft)
+ test_plan(_FFT(), AbstractFFTs.rfft, dim, Float64, AbstractFFTs.brfft)
+ end
end
end
end
-end |
| } | ||
| *out = desc; | ||
| return 0; | ||
| } catch (...) { |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #515 +/- ##
==========================================
- Coverage 81.73% 79.70% -2.04%
==========================================
Files 44 45 +1
Lines 2540 2818 +278
==========================================
+ Hits 2076 2246 +170
- Misses 464 572 +108 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
lib/mkl/fft.jl
Outdated
| ccall_create1d(desc_ref, prec::Int32, dom::Int32, length::Int64) = ccall((:onemklDftCreate1D, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64), desc_ref, prec, dom, length) | ||
| ccall_creatend(desc_ref, prec::Int32, dom::Int32, dim::Int64, lengths::Ptr{Int64}) = ccall((:onemklDftCreateND, lib), Cint, (Ref{Ptr{Cvoid}}, Cint, Cint, Int64, Ptr{Int64}), desc_ref, prec, dom, dim, lengths) | ||
| ccall_destroy(desc) = ccall((:onemklDftDestroy, lib), Cint, (Ptr{Cvoid},), desc) | ||
| ccall_commit(desc, q) = ccall((:onemklDftCommit, lib), Cint, (Ptr{Cvoid}, syclQueue_t), desc, q) | ||
| ccall_fwd(desc, ptr) = ccall((:onemklDftComputeForward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr) | ||
| ccall_fwd_oop(desc, pin, pout) = ccall((:onemklDftComputeForwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout) | ||
| ccall_bwd(desc, ptr) = ccall((:onemklDftComputeBackward, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), desc, ptr) | ||
| ccall_bwd_oop(desc, pin, pout) = ccall((:onemklDftComputeBackwardOutOfPlace, lib), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), desc, pin, pout) | ||
| ccall_set_double(desc, param::Int32, value::Float64) = ccall((:onemklDftSetValueDouble, lib), Cint, (Ptr{Cvoid}, Cint, Float64), desc, param, value) | ||
| ccall_set_int(desc, param::Int32, value::Int64) = ccall((:onemklDftSetValueInt64, lib), Cint, (Ptr{Cvoid}, Cint, Int64), desc, param, value) | ||
| ccall_set_int64_array(desc, param::Int32, values::Vector{Int64}) = ccall((:onemklDftSetValueInt64Array, lib), Cint, (Ptr{Cvoid}, Cint, Ptr{Int64}, Int64), desc, param, pointer(values), length(values)) | ||
| ccall_set_cfg(desc, param::Int32, value::Int32) = ccall((:onemklDftSetValueConfigValue, lib), Cint, (Ptr{Cvoid}, Cint, Cint), desc, param, value) |
There was a problem hiding this comment.
@michel2323 Please use the wrappers generated by Clang.jl.
There was a problem hiding this comment.
This is generated with Clang. I'm not sure I understand.
Line 119 in 7822c44
lib/mkl/fft.jl
Outdated
| R = length(region); reg = NTuple{R,Int}(region) | ||
| # Only support single dimension transforms for now | ||
| if R != 1 | ||
| error("Multi-dimensional real FFT not yet supported") |
There was a problem hiding this comment.
@michel2323 Do we know if it is feature not yet implemented by Intel or something to improve on our side?
There was a problem hiding this comment.
I get wrong values. Maybe I did something wrong, or the Intel library is wrong.
amontoison
left a comment
There was a problem hiding this comment.
@michel2323 Please address the two comments and it should be good for me.
It passes some tests. GPT 5 helped quite a bit here.