Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
3.6k views
in Technique[技术] by (71.8m points)

CUDA lambda expressions lose restrict information

So here is a very naive matrix multiplication implementation, where I trid to use a lambda expression to define the kernel, so that I can save the work of passing arguments:

#define cuda_malloc(size) ({ void *_x; cudaMalloc(&_x, size); _x; })

template <typename F> __global__ void exec_kern(F f) { f(); }

void matmul(int *host_a, int *host_b, int *host_c) {
  int *__restrict__ a = (int *)cuda_malloc(2048 * 2048 * sizeof(int));
  int *__restrict__ b = (int *)cuda_malloc(2048 * 2048 * sizeof(int));
  int *__restrict__ c = (int *)cuda_malloc(2048 * 2048 * sizeof(int));
  cudaMemcpy(a, host_a, 2048 * 2048 * sizeof(int), cudaMemcpyHostToDevice);
  cudaMemcpy(b, host_b, 2048 * 2048 * sizeof(int), cudaMemcpyHostToDevice);
  {
    auto _kern = [=] __device__ {
      int i0 = blockIdx.x;
      int i1 = threadIdx.x;
      for (int i2 = 0; i2 <= 255; i2 += 1) {
        for (int i3 = 0; i3 <= 7; i3 += 1) {
          for (int i4 = 0; i4 <= 7; i4 += 1) {
            for (int i5 = 0; i5 <= 7; i5 += 1) {
              c[((8 * i0) + i3) * 2048 + ((8 * i1) + i4)] =
                  ((a[((8 * i0) + i3) * 2048 + ((8 * i2) + i5)] *
                    b[((8 * i2) + i5) * 2048 + ((8 * i1) + i4)]) +
                   c[((8 * i0) + i3) * 2048 + ((8 * i1) + i4)]);
            }
          }
        }
      }
    };
    exec_kern<<<dim3(256, 1, 1), dim3(256, 1, 1)>>>(_kern);
  }
  cudaMemcpy(host_c, c, 2048 * 2048 * sizeof(int), cudaMemcpyDeviceToHost);
  cudaFree(a);
  cudaFree(b);
  cudaFree(c);
}

I annotated the global memory pointers a, b and c with __restrict__, in hope that the compiler can optimize the generated code with the information.

I tested the code on CUDA 10.2. It takes about 0.7s to run. However, when I turn to the normal ways of kernel definition like this:

__global__ void kern(int *__restrict__ a, int *__restrict__ b, int *__restrict__ c) {
  int i0 = blockIdx.x;
  int i1 = threadIdx.x;
  for (int i2 = 0; i2 <= 255; i2 += 1) {
    for (int i3 = 0; i3 <= 7; i3 += 1) {
      for (int i4 = 0; i4 <= 7; i4 += 1) {
        for (int i5 = 0; i5 <= 7; i5 += 1) {
          c[((8 * i0) + i3) * 2048 + ((8 * i1) + i4)] =
              ((a[((8 * i0) + i3) * 2048 + ((8 * i2) + i5)] *
                b[((8 * i2) + i5) * 2048 + ((8 * i1) + i4)]) +
               c[((8 * i0) + i3) * 2048 + ((8 * i1) + i4)]);
        }
      }
    }
  }
}
...
kern<<<dim3(256, 1, 1), dim3(256, 1, 1)>>>(a, b, c);

it only takes about 0.4s. Further, I tried to manually annotate the lambda-captured pointers as __restrict__ like this:

auto _kern = [=] __device__ {
  int *__restrict__ _a = a;
  int *__restrict__ _b = b;
  int *__restrict__ _c = c;
  int i0 = blockIdx.x;
  int i1 = threadIdx.x;
  for (int i2 = 0; i2 <= 255; i2 += 1) {
    for (int i3 = 0; i3 <= 7; i3 += 1) {
      for (int i4 = 0; i4 <= 7; i4 += 1) {
        for (int i5 = 0; i5 <= 7; i5 += 1) {
          _c[((8 * i0) + i3) * 2048 + ((8 * i1) + i4)] =
              ((_a[((8 * i0) + i3) * 2048 + ((8 * i2) + i5)] *
                _b[((8 * i2) + i5) * 2048 + ((8 * i1) + i4)]) +
               _c[((8 * i0) + i3) * 2048 + ((8 * i1) + i4)]);
        }
      }
    }
  }
};

and it also takes about 0.4s. This should be able to prove my guess: the slowness of the first version is due to the loss of the __restrict__ information of the lambda-captured pointers.

So my question is: is there any way to keep such information, without needing to manually annotate the lambda-captured pointers?


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
等待大神解答

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

1.4m articles

1.4m replys

5 comments

57.0k users

...