Skip to content

Commit 58d3593

Browse files
authored
bugfix: fix cu118 cub usage (#410)
Related issue: sgl-project/sglang#771 This PR fixes the usage of `FlagHeads` cub API in sampling kernels. As [documented](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockDiscontinuity.html), the default FlagHeads api will always flag the first element, which is not expected when first element is not `true`. > For thread0, item input[0] is always flagged. This PR sets the `tile_predecessor_item` argument (to 0) which will be compared against input[0]. CUDA 12+ don't have this issue because we are using the new `SubtractLeft` API instead of `FlagHeads`.
1 parent aaa929a commit 58d3593

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

include/flashinfer/sampling.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
118118
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
119119
#else
120120
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
121-
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp());
121+
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
122122
#endif
123123
__syncthreads();
124124

0 commit comments

Comments
 (0)