This kernel comes from someone on /r/CUDA who writes:
xHello everyone,
I'm working on a project where I need to calculate the pairwise distance matrix between two 2D matrices on the GPU. I've written some basic CUDA C++ code to achieve this, but I've noticed that its performance is currently slower than what I can get using PyTorch's cdist function.
As I'm relatively new to C++ and CUDA development, I'm trying to understand the best practices and common pitfalls for GPU performance optimization. I'm looking for advice on how I can make my custom CUDA implementation faster.
Any insights or suggestions would be greatly appreciated!
Thank you in advance.
code: https://gist.github.com/goktugyildirim4d/f7a370f494612d11ad51dbc0ae467285
For completeness, the code in that gist is given below, verbatim:
x
__global__ void pairwise_distance_kernel(const float* d_descriptors_live,
const float* d_descriptors_in_view,
float* d_distance_matrix,
int num_live,
int num_in_view,
int descriptor_dim) {
// Global thread index for the output matrix
int row = blockIdx.y * blockDim.y + threadIdx.y; // Corresponds to descriptorsLive
int col = blockIdx.x * blockDim.x + threadIdx.x; // Corresponds to descriptorsInView
if (row < num_live && col < num_in_view) {
float current_dist_sq = 0.0f;
// Calculate squared Euclidean distance
for (int i = 0; i < descriptor_dim; ++i) {
float diff = d_descriptors_live[row * descriptor_dim + i] -
d_descriptors_in_view[col * descriptor_dim + i];
current_dist_sq += diff * diff;
}
// Store the L2 distance (with square root)
d_distance_matrix[row * num_in_view + col] = sqrtf(current_dist_sq);
}
}
// YENİ EKLENDİ: Pairwise L2 Distance Host wrapper
void pairwiseDistanceWrapper(const float* d_descriptors_live, const float* d_descriptors_in_view, float* d_distance_matrix,
int num_live, int num_in_view, int descriptor_dim) {
// Use a 2D grid for the kernel launch
int blockDimX = 32; // For columns (num_in_view)
int blockDimY = 32; // For rows (num_live)
dim3 blockSize(blockDimX, blockDimY);
dim3 gridSize( (num_in_view + blockDimX - 1) / blockDimX,
(num_live + blockDimY - 1) / blockDimY );
pairwise_distance_kernel<<<gridSize, blockSize>>>(
d_descriptors_live, d_descriptors_in_view, d_distance_matrix,
num_live, num_in_view, descriptor_dim
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel execution error from pairwiseDistanceWrapper: %s\\n", cudaGetErrorString(err));
throw std::runtime_error("CUDA kernel error: " + std::string(cudaGetErrorString(err)));
}
}
After looking at this code a little bit: here are a few things I notice:
The core calculation is
which is very similar to dense matrix multiplication
in the sense that both calculations depend on exactly the same inputs, with only slightly different arithmetic.
As a result, we can apply many of the same optimizations that make sense for matrix multiplication.
The surrounding code that calls this kernel is missing, so we don't know the actual values of the parameters num_live
, num_in_view
, descriptor_dim
. I'll arbitrarily assume the following values for the sake of the timings reported in the remainder of this document: num_live
= 2048, num_in_view
= 1024, descriptor_dim
= 16.
When descriptor_dim
is small, the kernel is memory-bound (most of the time is spent writing the matrix entries out to global memory). When descriptor_dim
is large, then the kernel becomes compute-bound.
With that in mind, when I run the original code with the assumed problem size mentioned above I get a baseline runtime of 201us on a GTX Titan V.
When I look at the kernel code, the first thing that jumps out at me is that the data loads are not coalesced (i.e. adjacent threads within a warp are not accessing adjacent memory locations):
xxxxxxxxxx
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
...
for (int d = 0; d < descriptor_dim; ++d) {
float diff = d_descriptors_live[row * descriptor_dim + d] -
d_descriptors_in_view[col * descriptor_dim + d];
...
}
Depending on the value of descriptor_dim
, each thread within a warp could potentially be accessing its own cache line! For small values of descriptor_dim
this kernel is memory-bound, which means that its performance is dictated by
This is a common issue and there are a few ways to address it.
Loads are said to be "coalesced" when the stride associated with thread index is 1. However, in this kernel it's the descriptor index that has stride 1. So, one way to coalesce these loads is to "transpose" the layout of the input arrays:
x
// not coalesced (threadIdx.x has stride = descriptor_dim)
d_descriptors_in_view[col * descriptor_dim + d]
// coalesced (threadIdx.x has unit stride)
d_descriptors_in_view[d * num_in_view + col]
This approach is very simple and brings our runtime down to 52.2us, about a 4x speedup. The downside is that changing the data layout of the input arrays affects how other parts of the code access that data as well.
If you don't want to change the data layout, another option is to load the input arrays into __shared__
memory buffer and perform the transposition inside the kernel.
xxxxxxxxxx
template < int BX, int BY, int descriptor_dim >
__global__ void pairwise_distance_kernel(const float* d_descriptors_live,
const float* d_descriptors_in_view,
float* d_distance_matrix, int num_live,
int num_in_view) {
// Global thread index for the output matrix
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
// load values into shared memory first
__shared__ float shr_live[BY * descriptor_dim];
__shared__ float shr_in_view[BX * descriptor_dim];
int tid = threadIdx.y * blockDim.x + threadIdx.x;
for (int i = tid; i < descriptor_dim * blockDim.y; i += blockDim.x * blockDim.y) {
shr_live[i] = d_descriptors_live[blockIdx.y * blockDim.y * descriptor_dim + i];
}
for (int i = tid; i < descriptor_dim * blockDim.x; i += blockDim.x * blockDim.y) {
shr_in_view[i] = d_descriptors_in_view[blockIdx.x * blockDim.x * descriptor_dim + i];
}
// wait for shared memory to be ready before accessing it
__syncthreads();
if (row < num_live && col < num_in_view) {
float current_dist_sq = 0.0f;
// Calculate squared Euclidean distance
for (int i = 0; i < descriptor_dim; ++i) {
float diff = shr_live[threadIdx.y * descriptor_dim + i] -
shr_in_view[threadIdx.x * descriptor_dim + i];
current_dist_sq += diff * diff;
}
// Store the L2 distance (with square root)
d_distance_matrix[row * num_in_view + col] = sqrtf(current_dist_sq);
}
}
In this version, we start by allocating shared memory buffers for local "tiles" of values to process. Adjacent threads within a warp access global memory with unit stride when copying the data to the shared buffer, so this transfer is coalesced.
xxxxxxxxxx
int tid = threadIdx.y * blockDim.x + threadIdx.x;
for (int i = tid; i < descriptor_dim * blockDim.y; i += blockDim.x * blockDim.y) {
shr_live[i] = d_descriptors_live[blockIdx.y * blockDim.y * descriptor_dim + i];
}
The __syncthreads()
prevents threads from accessing shared memory locations before the values have been written.
This version of the kernel takes 70.6us to finish, a significant speedup over the original, but slower than explicitly transposing the data layout. Part of this is related to the fact that this version fixed the issue with strided access to global memory, but now we have strided access to shared memory! Depending on the value of descriptor_dim
, this can cause bank conflicts which result in reduced throughput to shared memory.
We can avoid the bank conflicts by transposing the layout of the shared memory buffer as follows:
for (int i = tid; i < descriptor_dim * blockDim.x; i += blockDim.x * blockDim.y) {
int r = i % descriptor_dim;
int c = i / descriptor_dim;
int shr_id = r * blockDim.x + c;
shr_in_view[shr_id] = d_descriptors_in_view[blockIdx.x * blockDim.x * descriptor_dim + i];
}
...
for (int i = 0; i < descriptor_dim; ++i) {
float diff = shr_live[threadIdx.y * descriptor_dim + i] -
shr_in_view[i * blockDim.x + threadIdx.x];
}
This small change brings the runtime down to 45.1us, the best total speedup so far.
An question for the reader: why did I transpose the layout of one of the shared memory buffers but not the other?
It's important for memory-bound kernels to saturate the memory bus as much as practical to maximize throughput. One way to do this is to increase the amount of data processed by an individual thread, so that it can issue multiple memory transactions before stalling.
With this in mind, here's a version of the kernel where each thread is responsible for writing a small rectangular chunk of the output array as opposed to just 1 entry:
template < int n >
struct alignas(n * sizeof(float)) vec {
float values[n];
__host__ __device__ float & operator[](int i) { return values[i]; }
__host__ __device__ const float & operator[](int i) const { return values[i]; }
};
template < int BX, int BY, int rows_per_thread, int cols_per_thread, int descriptor_dim >
__global__ void pairwise_distance_kernel(const float* d_descriptors_live,
const float* d_descriptors_in_view,
float* d_distance_matrix, int num_live,
int num_in_view) {
// Global thread index for the output matrix
int base_row = rows_per_thread * blockDim.y * blockIdx.y;
int base_col = cols_per_thread * blockDim.x * blockIdx.x;
// load values into shared memory first
__shared__ vec<rows_per_thread> shr_live[BY * descriptor_dim];
__shared__ vec<cols_per_thread> shr_in_view[BX * descriptor_dim];
int tid = threadIdx.y * blockDim.x + threadIdx.x;
float * shr_live_float = (float *)shr_live;
for (int i = tid; i < BY * rows_per_thread * descriptor_dim; i += blockDim.x * blockDim.y) {
int descrip = i % descriptor_dim;
int local_row = i / descriptor_dim;
int shr_id = descrip * (BY * rows_per_thread) + local_row;
int global_id = descrip * num_live + base_row + local_row;
shr_live_float[shr_id] = d_descriptors_live[global_id];
}
float * shr_in_view_float = (float *)shr_in_view;
for (int i = tid; i < BX * cols_per_thread * descriptor_dim; i += blockDim.x * blockDim.y) {
int descrip = i % descriptor_dim;
int local_col = i / descriptor_dim;
int shr_id = descrip * (BX * cols_per_thread) + local_col;
int global_id = descrip * num_in_view + base_col + local_col;
shr_in_view_float[shr_id] = d_descriptors_in_view[global_id];
}
// wait for shared memory to be ready before accessing it
__syncthreads();
vec<cols_per_thread> output[rows_per_thread] = {};
for (int i = 0; i < descriptor_dim; ++i) {
vec<rows_per_thread> live = shr_live[i * BY + threadIdx.y];
vec<cols_per_thread> in_view = shr_in_view[i * BX + threadIdx.x];
for (int r = 0; r < rows_per_thread; r++) {
for (int c = 0; c < cols_per_thread; c++) {
float diff = live[r] - in_view[c];
output[r][c] += diff * diff;
}
}
}
for (int r = 0; r < rows_per_thread; r++) {
for (int c = 0; c < cols_per_thread; c++) {
output[r][c] = sqrt(output[r][c]);
}
}
for (int r = 0; r < rows_per_thread; r++) {
int row = base_row + rows_per_thread * threadIdx.y + r;
int col = cols_per_thread * (blockDim.x * blockIdx.x + threadIdx.x);
*((vec<cols_per_thread> *)(d_distance_matrix + row * num_in_view + col)) = output[r];
}
}
When the compiler sees accesses to 8-byte, 16-byte or 32-byte aligned types it can emit optimized vector load instructions that generate multiple memory transactions. More information about vectorized loads can be found here. The vec
class is a C++ type that satisfies those alignment requirements so that the compiler can issue vectorized loads (although other types like float4
, double2
with appropriate alignment also work).
An additional benefit of this implementation is that there some of the intermediate values in the summation process are loaded into registers and reused multiple times. Accessing shared memory is relatively fast, but registers are even faster.
This version of the kernel runs in 34.7us, an 6.5x speedup over the original implementation.
Here, we looked at a couple simple optimizations that resulted in a significant speedup:
Coalescing loads
by transposing the underlying data layout
by keeping existing data layout and transposing inside the kernel
Increased work per thread
reduced number of accesses to shared memory
additional instruction level parallelism increases potential bytes in flight
If we run the problem again with descriptor_dim = 1
to ensure that this is a completely memory-bound kernel, we find that the runtime drops to 17.4us, which is close to the optimistic theoretical estimate of
for my GPU. So, there is still potential for further optimization, but these changes have brought us "close" to the theoretical best-case performance, at least for the case where descriptor_dim
is significantly smaller than the output matrix dimensions.
There are also many other detailed analyses of how to optimize dense matrix-matrix multiplication on GPUs which could directly apply to this kernel as well!
Code for the kernel implementations is available here:
https://github.com/samuelpmish/cuda_kernel_optimization_examples/tree/main/pairwise_distance_kernel