Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions lib/cunumeric_jl_wrapper/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ DEFINE_CODE_TO_CXX(UINT8, uint8_t)
DEFINE_CODE_TO_CXX(UINT16, uint16_t)
DEFINE_CODE_TO_CXX(UINT32, uint32_t)
DEFINE_CODE_TO_CXX(UINT64, uint64_t)
#ifdef HAVE_CUDA
#if LEGATE_DEFINED(LEGATE_USE_CUDA)
DEFINE_CODE_TO_CXX(FLOAT16, __half)
#else
// Dummy type for FLOAT16 when CUDA is not available
// This allows compilation but we throw an error if actually used
struct __half_dummy {};
DEFINE_CODE_TO_CXX(FLOAT16, __half_dummy)
struct __dummy {};
DEFINE_CODE_TO_CXX(FLOAT16, __dummy)
#endif
DEFINE_CODE_TO_CXX(FLOAT32, float)
DEFINE_CODE_TO_CXX(FLOAT64, double)
DEFINE_CODE_TO_CXX(COMPLEX64, std::complex<float>)
DEFINE_CODE_TO_CXX(COMPLEX128, std::complex<double>)
#undef DEFINE_CODE_TO_CXX

using HalfType = typename code_to_cxx<legate::Type::Code::FLOAT16>::type;

} // namespace legate_util

// Unary op codes
Expand Down
6 changes: 4 additions & 2 deletions lib/cunumeric_jl_wrapper/src/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,23 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) {
using jlcxx::ParameterList;
using jlcxx::Parametric;
using jlcxx::TypeVar;
using legate_util::HalfType;

// Map C++ complex types to Julia complex types
mod.map_type<std::complex<double>>("ComplexF64");
mod.map_type<std::complex<float>>("ComplexF32");
mod.map_type<HalfType>("Float16");

// These are the types/dims used to generate templated functions
// i.e. only these types/dims can be used from Julia side
using fp_types = ParameterList<double, float>;
using fp_types = ParameterList<double, float, HalfType>;
using int_types = ParameterList<int8_t, int16_t, int32_t, int64_t>;
using uint_types = ParameterList<uint8_t, uint16_t, uint32_t, uint64_t>;

using all_types =
ParameterList<double, float, int8_t, int16_t, int32_t, int64_t, uint8_t,
uint16_t, uint32_t, uint64_t, bool, std::complex<double>,
std::complex<float>>;
std::complex<float>, HalfType>;
using allowed_dims = ParameterList<
std::integral_constant<int_t, 1>, std::integral_constant<int_t, 2>,
std::integral_constant<int_t, 3>, std::integral_constant<int_t, 4>>;
Expand Down
17 changes: 5 additions & 12 deletions src/cuNumeric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,15 @@ end
const DEFAULT_FLOAT = Float32
const DEFAULT_INT = Int32

const SUPPORTED_INT_TYPES = Union{Int32,Int64}
const SUPPORTED_FLOAT_TYPES = Union{Float32,Float64}
const SUPPORTED_INT_TYPES = Union{Int8,Int16,Int32,Int64,UInt8,UInt16,UInt32,UInt64}
const SUPPORTED_FLOAT_TYPES = Union{Float32,Float64} # Float16 not supported yet
const SUPPORTED_COMPLEX_TYPES = Union{ComplexF32,ComplexF64}

const SUPPORTED_NUMERIC_TYPES = Union{
SUPPORTED_INT_TYPES,SUPPORTED_FLOAT_TYPES,SUPPORTED_COMPLEX_TYPES
}
# const SUPPORTED_TYPES = Union{SUPPORTED_INT_TYPES,SUPPORTED_FLOAT_TYPES,Bool} #* TODO Test UInt, Complex

const SUPPORTED_TYPES = Union{
Bool,
Int8,Int16,Int32,Int64,
UInt8,UInt16,UInt32,UInt64,
Float16,Float32,Float64,
ComplexF32,ComplexF64,
String,
}
const SUPPORTED_ARRAY_TYPES = Union{Bool,SUPPORTED_NUMERIC_TYPES}
const SUPPORTED_TYPES = Union{SUPPORTED_ARRAY_TYPES,String}

# const MAX_DIM = 6 # idk what we compiled?

Expand Down
30 changes: 23 additions & 7 deletions src/ndarray/binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,36 @@ LinearAlgebra.mul!(out, a, b)
function LinearAlgebra.mul!(
out::NDArray{T,2}, rhs1::NDArray{A,2}, rhs2::NDArray{B,2}
) where {T<:SUPPORTED_NUMERIC_TYPES,A,B}
#! This will probably need more checks once we support Complex number
size(rhs1, 2) == size(rhs2, 1) ||
throw(DimensionMismatch("Matrix dimensions incompatible: $(size(rhs1)) × $(size(rhs2))"))
(size(out, 1) == size(rhs1, 1) && size(out, 2) == size(rhs2, 2)) || throw(
DimensionMismatch(
"mul! output is $(size(out)), but inputs would produce $(size(rhs1,1))×$(size(rhs2,2))"
),
)
T_OUT = __my_promote_type(A, B)
((T_OUT <: AbstractFloat) && (T <: Integer)) && throw(
ArgumentError(
"mul! output has integer type $(T), but inputs promote to floating point type: $(T_OUT)"
),
)

T_REQUIRED = __my_promote_type(A, B)
if promote_type(T_REQUIRED, T) != T
if (T_REQUIRED <: Complex && !(T <: Complex))
throw(
ArgumentError(
"Implicit promotion: mul! output has real type $(T), but inputs promote to complex type: $(T_REQUIRED)"
),
)
elseif (T_REQUIRED <: AbstractFloat && T <: Integer)
throw(
ArgumentError(
"Implicit promotion: mul! output has integer type $(T), but inputs promote to floating point type: $(T_REQUIRED)"
),
)
end
# General case (e.g. Float64 result into Float32)
throw(
ArgumentError(
"mul! output type $(T) cannot hold the promoted input type $(T_REQUIRED). Implicit promotion to wider type or complex result is disallowed."
),
)
end
return nda_three_dot_arg(checked_promote_arr(rhs1, T), checked_promote_arr(rhs2, T), out)
end

Expand Down
51 changes: 27 additions & 24 deletions src/ndarray/detail/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,43 @@ The NDArray type represents a multi-dimensional array in cuNumeric.
It is a wrapper around a Legate array and provides various methods for array manipulation and operations.
Finalizer calls `nda_destroy_array` to clean up the underlying Legate array when the NDArray is garbage collected.
"""
mutable struct NDArray{T,N,PADDED} <: AbstractNDArray{T,N}
mutable struct NDArray{T,N,PADDED,P} <: AbstractNDArray{T,N}
ptr::NDArray_t
nbytes::Int64
padding::Union{Nothing,NTuple{N,Int}}
parent::P

function NDArray(ptr::NDArray_t, ::Type{T}, ::Val{N}) where {T,N}
nbytes = cuNumeric.nda_nbytes(ptr)
cuNumeric.register_alloc!(nbytes)
handle = new{T,N,false}(ptr, nbytes, nothing)
handle = new{T,N,false,Nothing}(ptr, nbytes, nothing, nothing)
finalizer(handle) do h
cuNumeric.nda_destroy_array(h.ptr)
cuNumeric.register_free!(h.nbytes)
end
return handle
end
end

# Dynamic fallback, not great but required if we cannot infer things
NDArray(ptr::NDArray_t; T=get_julia_type(ptr), N::Integer=get_n_dim(ptr)) = NDArray(ptr, T, Val(N))

# struct WrappedNDArray{T,N} <: AbstractNDArray{T,N}
# ndarr::NDArray{T,N}
# jlarr::Array{T,N}

# function WrappedNDArray(ndarray::NDArray{T,N}, jlarray::Array{T,N}) where {T,N}
# ndarr = ndarray
# jlarr = jlarray
# end
# Explicit parent inner constructor
function NDArray(ptr::NDArray_t, ::Type{T}, ::Val{N}, parent::P) where {T,N,P}
nbytes = cuNumeric.nda_nbytes(ptr)
cuNumeric.register_alloc!(nbytes)
handle = new{T,N,false,P}(ptr, nbytes, nothing, parent)
finalizer(handle) do h
cuNumeric.nda_destroy_array(h.ptr)
cuNumeric.register_free!(h.nbytes)
end
return handle
end
end
# this here is to avoid if else patterns
@inline _NDArray(ptr, T, v, ::Nothing) = NDArray(ptr, T, v)
@inline _NDArray(ptr, T, v, parent) = NDArray(ptr, T, v, parent)

# function WrappedNDArray(ndarray::NDArray{T,N}) where {T,N}
# ndarr = ndarray
# jlarr = nothing
# end
# end
# Dynamic fallback
function NDArray(ptr::NDArray_t; T=get_julia_type(ptr), N::Integer=get_n_dim(ptr), parent=nothing)
return _NDArray(ptr, T, Val(N), parent)
end

#! JUST USE FULL TO MAKE a 0D?
# $ cuNumeric.nda_full_array(UInt64[], 2.0f0)
Expand Down Expand Up @@ -314,7 +317,7 @@ function nda_attach_external(arr::AbstractArray{T,N}) where {T,N}
# Use the CxxWrap method for type-safe interaction
# This returns a raw pointer compatible with the NDArray constructor
nda_ptr = cuNumeric.nda_store_to_ndarray(st.handle)
return NDArray(nda_ptr; T=T, N=N)
return NDArray(nda_ptr, T, Val(N), arr)
end

# return underlying logical store to the NDArray obj
Expand Down Expand Up @@ -448,8 +451,8 @@ Emits warnings when array sizes or element types differ.
- Iterates over elements using `CartesianIndices` to compare element-wise difference.
"""
function compare(
julia_array::AbstractArray{T,N}, arr::NDArray{T,N}, atol::Real, rtol::Real
) where {T,N}
julia_array::AbstractArray{T1,N}, arr::NDArray{T2,N}, atol::Real, rtol::Real
) where {T1,T2,N}
if (shape(arr) != Base.size(julia_array))
@warn "NDArray has shape $(shape(arr)) and Julia array has shape $(Base.size(julia_array))!\n"
return false
Expand All @@ -468,8 +471,8 @@ function compare(
end

function compare(
arr::NDArray{T,N}, julia_array::AbstractArray{T,N}, atol::Real, rtol::Real
) where {T,N}
arr::NDArray{T2,N}, julia_array::AbstractArray{T1,N}, atol::Real, rtol::Real
) where {T1,T2,N}
return compare(julia_array, arr, atol, rtol)
end

Expand Down
31 changes: 19 additions & 12 deletions src/ndarray/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,17 @@ function (::Type{<:Array})(arr::NDArray{B}) where {B}
end

# conversion from Base Julia array to NDArray
function (::Type{<:NDArray{A}})(arr::Array{B}) where {A,B}
dims = Base.size(arr)
out = cuNumeric.zeros(A, dims)
attached = cuNumeric.nda_attach_external(arr)
copyto!(out, attached) # copy elems of attached to resulting out
return out
function (::Type{<:NDArray{T}})(arr::Array{T,N}) where {T,N}
return cuNumeric.nda_attach_external(arr)
end

function (::Type{<:NDArray})(arr::Array{B}) where {B}
dims = Base.size(arr)
out = cuNumeric.zeros(B, dims)
attached = cuNumeric.nda_attach_external(arr)
copyto!(out, attached)
return out
function (::Type{<:NDArray{A}})(arr::Array{B,N}) where {A,B,N}
# If types differ, we cast in Julia first (creating a temp) then attach
return cuNumeric.nda_attach_external(A.(arr))
end

function (::Type{<:NDArray})(arr::Array{T,N}) where {T,N}
return cuNumeric.nda_attach_external(arr)
end

# Base.convert(::Type{<:NDArray{T}}, a::A) where {T, A} = NDArray(T(a))::NDArray{T}
Expand Down Expand Up @@ -339,6 +336,16 @@ function Base.setindex!(arr::NDArray{T,N}, value::T, idxs::Vararg{Int,N}) where
_setindex!(Val{N}(), arr, value, idxs...)
end

function Base.setindex!(arr::NDArray{Complex{T},N}, value::T, idxs::Vararg{Int,N}) where {T,N}
assertscalar("setindex!")
_setindex!(Val{N}(), arr, Complex{T}(value), idxs...)
end

function Base.setindex!(arr::NDArray{T,N}, value, idxs::Vararg{Int,N}) where {T,N}
assertscalar("setindex!")
_setindex!(Val{N}(), arr, convert(T, value), idxs...)
end

function _setindex!(::Val{0}, arr::NDArray{T,0}, value::T) where {T<:SUPPORTED_NUMERIC_TYPES}
acc = NDArrayAccessor{T,1}()
write(acc, arr.ptr, StdVector(UInt64[0]), value)
Expand Down
49 changes: 47 additions & 2 deletions src/ndarray/unary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,51 @@ function Base.:(-)(input::NDArray{T}) where {T}
return nda_unary_op!(out, cuNumeric.NEGATIVE, input)
end

function Base.real(input::NDArray{T}) where {T<:Complex}
T_OUT = Base.promote_op(real, T)
out = cuNumeric.zeros(T_OUT, size(input))
return nda_unary_op!(out, cuNumeric.REAL, input)
end
Base.real(input::NDArray{<:Real}) = input

function Base.imag(input::NDArray{T}) where {T<:Complex}
T_OUT = Base.promote_op(imag, T)
out = cuNumeric.zeros(T_OUT, size(input))
return nda_unary_op!(out, cuNumeric.IMAG, input)
end
Base.imag(input::NDArray{T}) where {T<:Real} = cuNumeric.zeros(T, size(input))

function Base.conj(input::NDArray{T}) where {T<:Complex}
out = cuNumeric.zeros(T, size(input))
return nda_unary_op!(out, cuNumeric.CONJ, input)
end
Base.conj(input::NDArray{<:Real}) = input

# Broadcoast support for complex ops
@inline function __broadcast(f::typeof(Base.real), out::NDArray, input::NDArray{<:Complex})
return nda_unary_op!(out, cuNumeric.REAL, input)
end
@inline function __broadcast(f::typeof(Base.imag), out::NDArray, input::NDArray{<:Complex})
return nda_unary_op!(out, cuNumeric.IMAG, input)
end
@inline function __broadcast(f::typeof(Base.conj), out::NDArray, input::NDArray{<:Complex})
return nda_unary_op!(out, cuNumeric.CONJ, input)
end

# Fallbacks for Real types
@inline function __broadcast(f::typeof(Base.real), out::NDArray, input::NDArray{<:Real})
# real(real_array) is just the array
return nda_unary_op!(out, cuNumeric.IDENTITY, input)
end
@inline function __broadcast(f::typeof(Base.imag), out::NDArray, input::NDArray{<:Real})
# imag(real_array) is all zeros
return nda_binary_op!(out, cuNumeric.SUBTRACT, input, input)
end
@inline function __broadcast(f::typeof(Base.conj), out::NDArray, input::NDArray{<:Real})
# conj(real_array) is just the array
return nda_unary_op!(out, cuNumeric.IDENTITY, input)
end

function Base.:(-)(input::NDArray{Bool})
return -(checked_promote_arr(input, DEFAULT_INT))
end
Expand Down Expand Up @@ -157,8 +202,8 @@ end
for (julia_fn, op_code) in unary_op_map_no_args
@eval begin
@inline function __broadcast(
f::typeof($julia_fn), out::NDArray{T}, input::NDArray{T}
) where {T}
f::typeof($julia_fn), out::NDArray{A}, input::NDArray{B}
) where {A,B}
return nda_unary_op!(out, $(op_code), input)
end
end
Expand Down
Loading
Loading