Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,15 @@ template <typename _KtTag, bool __is_ascending, ::std::uint8_t __radix_bits, ::s
// __work_group_size, _InRngPack, _OutRngPack>
template <typename _KtTag, bool __is_ascending, ::std::uint8_t __radix_bits, ::std::uint16_t __data_per_work_item,
::std::uint16_t __work_group_size, typename _InRngPack, typename _OutRngPack>
struct __radix_sort_onesweep_kernel
struct __radix_sort_onesweep_kernel;

template <bool __is_ascending, ::std::uint8_t __radix_bits, ::std::uint16_t __data_per_work_item,
::std::uint16_t __work_group_size, typename _InRngPack, typename _OutRngPack>
struct __radix_sort_onesweep_kernel<__esimd_tag, __is_ascending, __radix_bits, __data_per_work_item, __work_group_size, _InRngPack, _OutRngPack>
{
using _LocOffsetT = ::std::uint16_t;
using _GlobOffsetT = ::std::uint32_t;
using _AtomicIdT = ::std::uint32_t;

using _KeyT = typename _InRngPack::_KeyT;
using _ValT = typename _InRngPack::_ValT;
Expand Down Expand Up @@ -457,7 +462,7 @@ struct __radix_sort_onesweep_kernel
_OutRngPack __out_pack;

__radix_sort_onesweep_kernel(::std::uint32_t __n, ::std::uint32_t __stage, _GlobOffsetT* __p_global_hist,
_GlobOffsetT* __p_group_hists, const _InRngPack& __in_pack,
_GlobOffsetT* __p_group_hists, _AtomicIdT* /*__p_atomic_id*/, const _InRngPack& __in_pack,
const _OutRngPack& __out_pack)
: __n(__n), __stage(__stage), __p_global_hist(__p_global_hist), __p_group_hists(__p_group_hists),
__in_pack(__in_pack), __out_pack(__out_pack)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,23 @@ class __onesweep_memory_holder
// Memory to store intermediate results of sorting
_KeyT* __m_keys_ptr = nullptr;
_ValT* __m_vals_ptr = nullptr;
std::uint32_t* __m_atomic_id_pointer = nullptr;

::std::size_t __m_raw_mem_bytes = 0;
::std::size_t __m_keys_bytes = 0;
::std::size_t __m_vals_bytes = 0;
::std::size_t __m_global_hist_bytes = 0;
::std::size_t __m_group_hist_bytes = 0;

::std::size_t __m_atomic_id_bytes = 4;

sycl::queue __m_q;

void
__calculate_raw_memory_amount() noexcept
{
// Extra bytes are added for potentiall padding
__m_raw_mem_bytes = __m_keys_bytes + __m_global_hist_bytes + __m_group_hist_bytes + sizeof(_KeyT);
__m_raw_mem_bytes = __m_keys_bytes + __m_global_hist_bytes + __m_group_hist_bytes + __m_atomic_id_bytes + sizeof(std::uint32_t) + sizeof(_KeyT);
if constexpr (__has_values)
{
__m_raw_mem_bytes += (__m_vals_bytes + sizeof(_ValT));
Expand Down Expand Up @@ -135,6 +138,9 @@ class __onesweep_memory_holder
__aligned_ptr = ::std::align(::std::alignment_of_v<_ValT>, __m_vals_bytes, __base_ptr, __remainder);
__m_vals_ptr = reinterpret_cast<_ValT*>(__aligned_ptr);
}
std::size_t __atomic_id_offset = __m_raw_mem_bytes - __m_atomic_id_bytes;
__atomic_id_offset -= (__atomic_id_offset % alignof(std::uint32_t));
__m_atomic_id_pointer = reinterpret_cast<std::uint32_t*>(__m_raw_mem_ptr + __atomic_id_offset);
}

public:
Expand Down Expand Up @@ -181,6 +187,11 @@ class __onesweep_memory_holder
{
return __m_group_hist_ptr;
}
std::uint32_t*
__atomic_id_pointer() const noexcept
{
return __m_atomic_id_pointer;
}

void
__allocate()
Expand Down Expand Up @@ -249,7 +260,7 @@ __onesweep_impl(_KtTag __kt_tag, sycl::queue __q, _RngPack1&& __input_pack, _Rng
// TODO: consider adding a more versatile API, e.g. passing special kernel_config parameters for histogram computation
// ESIMD work-group size: 64 XVEs ~ 2048 SIMD lanes
// SYCL work-group size: Programming model enables 1024, so 128 required for PVC-1550 full concurrency. 10x HW oversubscription
constexpr ::std::uint32_t __hist_work_group_count = std::is_same_v<_KtTag, __sycl_tag> ? 128 * 10 : 32;
constexpr ::std::uint32_t __hist_work_group_count = std::is_same_v<_KtTag, __sycl_tag> ? 128 * 10 : 64;
constexpr ::std::uint32_t __hist_work_group_size = std::is_same_v<_KtTag, __sycl_tag> ? 1024 : 64;
__event_chain = __radix_sort_histogram_submitter<__is_ascending, __radix_bits, __hist_work_group_count,
__hist_work_group_size, _RadixSortHistogram>()(
Expand All @@ -260,7 +271,7 @@ __onesweep_impl(_KtTag __kt_tag, sycl::queue __q, _RngPack1&& __input_pack, _Rng

__event_chain = __radix_sort_onesweep_submitter<__is_ascending, __radix_bits, __data_per_work_item,
__work_group_size, _RadixSortSweepInitial>()(
__kt_tag, __q, __input_pack, __virt_pack1, __mem_holder.__global_hist_ptr(), __mem_holder.__group_hist_ptr(),
__kt_tag, __q, __input_pack, __virt_pack1, __mem_holder.__global_hist_ptr(), __mem_holder.__group_hist_ptr(), __mem_holder.__atomic_id_pointer(),
__sweep_work_group_count, __n, 0, __event_chain);

for (::std::uint32_t __stage = 1; __stage < __stage_count; __stage++)
Expand All @@ -273,14 +284,14 @@ __onesweep_impl(_KtTag __kt_tag, sycl::queue __q, _RngPack1&& __input_pack, _Rng
{
__event_chain = __radix_sort_onesweep_submitter<__is_ascending, __radix_bits, __data_per_work_item,
__work_group_size, _RadixSortSweepOdd>()(
__kt_tag, __q, __virt_pack1, __virt_pack2, __p_global_hist, __p_group_hists, __sweep_work_group_count, __n,
__kt_tag, __q, __virt_pack1, __virt_pack2, __p_global_hist, __p_group_hists, __mem_holder.__atomic_id_pointer(), __sweep_work_group_count, __n,
__stage, __event_chain);
}
else
{
__event_chain = __radix_sort_onesweep_submitter<__is_ascending, __radix_bits, __data_per_work_item,
__work_group_size, _RadixSortSweepEven>()(
__kt_tag, __q, __virt_pack2, __virt_pack1, __p_global_hist, __p_group_hists, __sweep_work_group_count, __n,
__kt_tag, __q, __virt_pack2, __virt_pack1, __p_global_hist, __p_group_hists, __mem_holder.__atomic_id_pointer(), __sweep_work_group_count, __n,
__stage, __event_chain);
}
}
Expand Down Expand Up @@ -389,23 +400,23 @@ __radix_sort(_KtTag __kt_tag, sycl::queue __q, _RngPack1&& __pack_in, _RngPack2&
else
{
constexpr ::std::uint32_t __one_wg_cap = __data_per_workitem * __workgroup_size;
if (__n <= __one_wg_cap)
// TODO: this is temporary in the prototype until we have a SYCL one wg version to plugin.
if constexpr (std::is_same_v<_KtTag, __esimd_tag>)
{
// TODO: support different RadixBits values (only 7, 8, 9 are currently supported)
// TODO: support more granular DataPerWorkItem and WorkGroupSize

return __one_wg<_KernelName, __is_ascending, __radix_bits, __data_per_workitem, __workgroup_size>(
__kt_tag, __q, ::std::forward<_RngPack1>(__pack_in), ::std::forward<_RngPack2>(__pack_out), __n);
}
else
{
// TODO: avoid kernel duplication (generate the output storage with the same type as input storage and use swap)
// TODO: support different RadixBits
// TODO: support more granular DataPerWorkItem and WorkGroupSize
return __onesweep<_KernelName, __is_ascending, __radix_bits, __data_per_workitem, __workgroup_size,
__in_place>(__kt_tag, __q, ::std::forward<_RngPack1>(__pack_in),
::std::forward<_RngPack2>(__pack_out), __n);
if (__n <= __one_wg_cap)
{
// TODO: support different RadixBits values (only 7, 8, 9 are currently supported)
// TODO: support more granular DataPerWorkItem and WorkGroupSize

return __one_wg<_KernelName, __is_ascending, __radix_bits, __data_per_workitem, __workgroup_size>(
__kt_tag, __q, ::std::forward<_RngPack1>(__pack_in), ::std::forward<_RngPack2>(__pack_out), __n);
}
}
// TODO: avoid kernel duplication (generate the output storage with the same type as input storage and use swap)
// TODO: support different RadixBits
// TODO: support more granular DataPerWorkItem and WorkGroupSize
return __onesweep<_KernelName, __is_ascending, __radix_bits, __data_per_workitem, __workgroup_size, __in_place>(
__kt_tag, __q, ::std::forward<_RngPack1>(__pack_in), ::std::forward<_RngPack2>(__pack_out), __n);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,80 @@ template <bool __is_ascending, ::std::uint8_t __radix_bits, ::std::uint16_t __da
struct __radix_sort_onesweep_submitter<__is_ascending, __radix_bits, __data_per_work_item, __work_group_size,
oneapi::dpl::__par_backend_hetero::__internal::__optional_kernel_name<_Name...>>
{
template <typename _KtTag, typename _InRngPack, typename _OutRngPack, typename _GlobalHistT>
private:
// ESIMD kernel dispatch
template <typename _InRngPack, typename _OutRngPack, typename _GlobalHistT, typename _AtomicIdT>
sycl::event
__submit_esimd(sycl::queue& __q, _InRngPack&& __in_pack, _OutRngPack&& __out_pack, _GlobalHistT* __p_global_hist,
_GlobalHistT* __p_group_hists, _AtomicIdT* __p_atomic_id, ::std::uint32_t __sweep_work_group_count,
::std::size_t __n, ::std::uint32_t __stage, const sycl::event& __e) const
{
sycl::nd_range<1> __nd_range(__sweep_work_group_count * __work_group_size, __work_group_size);
return __q.submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __in_pack.__keys_rng(), __out_pack.__keys_rng());
if constexpr (::std::decay_t<_InRngPack>::__has_values)
{
oneapi::dpl::__ranges::__require_access(__cgh, __in_pack.__vals_rng(), __out_pack.__vals_rng());
}
__cgh.depends_on(__e);
__radix_sort_onesweep_kernel<__esimd_tag, __is_ascending, __radix_bits, __data_per_work_item,
__work_group_size, ::std::decay_t<_InRngPack>, ::std::decay_t<_OutRngPack>>
__kernel(__n, __stage, __p_global_hist, __p_group_hists, __p_atomic_id,
::std::forward<_InRngPack>(__in_pack), ::std::forward<_OutRngPack>(__out_pack));
__cgh.parallel_for<_Name...>(__nd_range, __kernel);
});
}

// SYCL kernel dispatch
template <typename _InRngPack, typename _OutRngPack, typename _GlobalHistT, typename _AtomicIdT>
sycl::event
operator()(_KtTag, sycl::queue& __q, _InRngPack&& __in_pack, _OutRngPack&& __out_pack, _GlobalHistT* __p_global_hist,
_GlobalHistT* __p_group_hists, ::std::uint32_t __sweep_work_group_count, ::std::size_t __n,
::std::uint32_t __stage, const sycl::event& __e) const
__submit_sycl(sycl::queue& __q, _InRngPack&& __in_pack, _OutRngPack&& __out_pack, _GlobalHistT* __p_global_hist,
_GlobalHistT* __p_group_hists, _AtomicIdT* __p_atomic_id, ::std::uint32_t __sweep_work_group_count,
::std::size_t __n, ::std::uint32_t __stage, const sycl::event& __e) const
{
using _KernelType =
__radix_sort_onesweep_kernel<__sycl_tag, __is_ascending, __radix_bits, __data_per_work_item,
__work_group_size, ::std::decay_t<_InRngPack>, ::std::decay_t<_OutRngPack>>;
constexpr ::std::uint32_t __slm_size_bytes = _KernelType::__calc_slm_alloc();
constexpr ::std::uint32_t __slm_size_elements = __slm_size_bytes / sizeof(::std::uint32_t);

sycl::nd_range<1> __nd_range(__sweep_work_group_count * __work_group_size, __work_group_size);
return __q.submit([&](sycl::handler& __cgh) {
sycl::local_accessor<unsigned char, 1> __slm_accessor(__slm_size_bytes, __cgh);
oneapi::dpl::__ranges::__require_access(__cgh, __in_pack.__keys_rng(), __out_pack.__keys_rng());
if constexpr (::std::decay_t<_InRngPack>::__has_values)
{
oneapi::dpl::__ranges::__require_access(__cgh, __in_pack.__vals_rng(), __out_pack.__vals_rng());
}
__cgh.depends_on(__e);
__radix_sort_onesweep_kernel<_KtTag, __is_ascending, __radix_bits, __data_per_work_item, __work_group_size,
::std::decay_t<_InRngPack>, ::std::decay_t<_OutRngPack>>
__kernel(__n, __stage, __p_global_hist, __p_group_hists, ::std::forward<_InRngPack>(__in_pack),
::std::forward<_OutRngPack>(__out_pack));
_KernelType __kernel(__n, __stage, __p_global_hist, __p_group_hists, __p_atomic_id,
::std::forward<_InRngPack>(__in_pack), ::std::forward<_OutRngPack>(__out_pack),
__slm_accessor);
__cgh.parallel_for<_Name...>(__nd_range, __kernel);
});
}

public:
template <typename _KtTag, typename _InRngPack, typename _OutRngPack, typename _GlobalHistT, typename _AtomicIdT>
sycl::event
operator()(_KtTag, sycl::queue& __q, _InRngPack&& __in_pack, _OutRngPack&& __out_pack,
_GlobalHistT* __p_global_hist, _GlobalHistT* __p_group_hists, _AtomicIdT* __p_atomic_id,
::std::uint32_t __sweep_work_group_count, ::std::size_t __n, ::std::uint32_t __stage,
const sycl::event& __e) const
{
if constexpr (std::is_same_v<_KtTag, __sycl_tag>)
{
return __submit_sycl(__q, ::std::forward<_InRngPack>(__in_pack), ::std::forward<_OutRngPack>(__out_pack),
__p_global_hist, __p_group_hists, __p_atomic_id, __sweep_work_group_count, __n,
__stage, __e);
}
else
{
return __submit_esimd(__q, ::std::forward<_InRngPack>(__in_pack), ::std::forward<_OutRngPack>(__out_pack),
__p_global_hist, __p_group_hists, __p_atomic_id, __sweep_work_group_count, __n,
__stage, __e);
}
}
};

template <typename _KernelName>
Expand Down
55 changes: 53 additions & 2 deletions include/oneapi/dpl/experimental/kt/internal/radix_sort_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ constexpr void
__check_onesweep_params()
{
static_assert(__radix_bits == 8);
static_assert(__data_per_workitem % 32 == 0);
static_assert(__workgroup_size == 32 || __workgroup_size == 64);
//static_assert(__data_per_workitem % 32 == 0);
//static_assert(__workgroup_size == 32 || __workgroup_size == 64);
}

//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -125,6 +125,57 @@ __order_preserving_cast_scalar(_Float __src)
return __uint64_src ^ __mask;
}

template <std::uint16_t _N, typename _KeyT>
struct __keys_pack
{
_KeyT __keys[_N];
};

template <std::uint16_t _N, typename _KeyT, typename _ValT>
struct __pairs_pack
{
_KeyT __keys[_N];
_ValT __vals[_N];
};

template <std::uint16_t _N, typename _T1, typename _T2 = void>
auto
__make_key_value_pack()
{
if constexpr (::std::is_void_v<_T2>)
{
return __keys_pack<_N, _T1>{};
}
else
{
return __pairs_pack<_N, _T1, _T2>{};
}
}

template <std::uint32_t __segment_width, std::uint32_t __num_segments, std::uint32_t __sub_group_size,
typename _ScanBuffer>
void
__sub_group_cross_segment_exclusive_scan(sycl::sub_group& __sub_group, _ScanBuffer* __scan_elements)
{
// TODO: make it work if this static assert is not true
static_assert(__segment_width == __sub_group_size);
using _ElemT = std::remove_reference_t<decltype(__scan_elements[0])>;
_ElemT __carry = 0;
auto __sub_group_local_id = __sub_group.get_local_linear_id();

_ONEDPL_PRAGMA_UNROLL
for (std::uint32_t __i = 0; __i < __num_segments; ++__i)
{
auto __element = __scan_elements[__i * __segment_width + __sub_group_local_id];
auto __element_right_shift = sycl::shift_group_right(__sub_group, __element, 1);
if (__sub_group_local_id == 0)
__element_right_shift = 0;
__scan_elements[__i * __segment_width + __sub_group_local_id] = __element_right_shift + __carry;

__carry += sycl::group_broadcast(__sub_group, __element, __sub_group_size - 1);
}
}

} // namespace oneapi::dpl::experimental::kt::gpu::__impl

#endif // _ONEDPL_KT_SYCL_RADIX_SORT_UTILS_H
Loading