1 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_PROD_HPP
2 #define VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_PROD_HPP
22 template <
typename StringType>
24 bool row_major_A,
bool row_major_B,
bool row_major_C,
25 bool transpose_A,
bool transpose_B)
28 source.append(
"__kernel void prod_");
38 source.append(
"( \n");
39 source.append(
" "); source.append(numeric_string); source.append(
" alpha, \n");
40 source.append(
" __global const "); source.append(numeric_string); source.append(
" * A, \n");
41 source.append(
" unsigned int A_row_start, \n");
42 source.append(
" unsigned int A_col_start, \n");
43 source.append(
" unsigned int A_row_inc, \n");
44 source.append(
" unsigned int A_col_inc, \n");
45 source.append(
" unsigned int A_row_size, \n");
46 source.append(
" unsigned int A_col_size, \n");
47 source.append(
" unsigned int A_internal_rows, \n");
48 source.append(
" unsigned int A_internal_cols, \n");
50 source.append(
" __global const "); source.append(numeric_string); source.append(
" * B, \n");
51 source.append(
" unsigned int B_row_start, \n");
52 source.append(
" unsigned int B_col_start, \n");
53 source.append(
" unsigned int B_row_inc, \n");
54 source.append(
" unsigned int B_col_inc, \n");
55 source.append(
" unsigned int B_row_size, \n");
56 source.append(
" unsigned int B_col_size, \n");
57 source.append(
" unsigned int B_internal_rows, \n");
58 source.append(
" unsigned int B_internal_cols, \n");
60 source.append(
" "); source.append(numeric_string); source.append(
" beta, \n");
61 source.append(
" __global "); source.append(numeric_string); source.append(
" * C, \n");
62 source.append(
" unsigned int C_row_start, \n");
63 source.append(
" unsigned int C_col_start, \n");
64 source.append(
" unsigned int C_row_inc, \n");
65 source.append(
" unsigned int C_col_inc, \n");
66 source.append(
" unsigned int C_row_size, \n");
67 source.append(
" unsigned int C_col_size, \n");
68 source.append(
" unsigned int C_internal_rows, \n");
69 source.append(
" unsigned int C_internal_cols) \n");
70 source.append(
"{ \n");
72 source.append(
" __local "); source.append(numeric_string); source.append(
" bufA[272]; \n");
73 source.append(
" __local "); source.append(numeric_string); source.append(
" bufB[272]; \n");
75 source.append(
" size_t block_size = 16; \n");
77 source.append(
" size_t row_block_id = get_group_id(0); \n");
78 source.append(
" size_t col_block_id = get_group_id(1); \n");
79 source.append(
" size_t row_thread_id = get_local_id(0); \n");
80 source.append(
" size_t col_thread_id = get_local_id(1); \n");
83 if (row_major_A && transpose_A)
85 source.append(
" size_t aBegin = (row_block_id * block_size * A_col_inc + A_col_start) + A_row_start * A_internal_cols; \n");
86 source.append(
" size_t aStep = block_size * A_row_inc * A_internal_cols; \n");
88 else if (row_major_A && !transpose_A)
90 source.append(
" size_t aBegin = (row_block_id * block_size * A_row_inc + A_row_start) * A_internal_cols + A_col_start; \n");
91 source.append(
" size_t aStep = block_size * A_col_inc; \n");
93 else if (!row_major_A && transpose_A)
95 source.append(
" size_t aBegin = (row_block_id * block_size * A_col_inc + A_col_start) * A_internal_rows + A_row_start; \n");
96 source.append(
" size_t aStep = block_size * A_row_inc; \n");
98 else if (!row_major_A && !transpose_A)
100 source.append(
" size_t aBegin = (row_block_id * block_size * A_row_inc + A_row_start) + A_col_start * A_internal_rows; \n");
101 source.append(
" size_t aStep = block_size * A_col_inc * A_internal_rows; \n");
105 if (row_major_B && transpose_B)
107 source.append(
" size_t bBegin = (col_block_id * block_size * B_row_inc + B_row_start) * B_internal_cols + B_col_start; \n");
108 source.append(
" size_t bStep = block_size * B_col_inc; \n");
110 else if (row_major_B && !transpose_B)
112 source.append(
" size_t bBegin = (col_block_id * block_size * B_col_inc + B_col_start) + B_row_start * B_internal_cols; \n");
113 source.append(
" size_t bStep = block_size * B_internal_cols * B_row_inc; \n");
115 else if (!row_major_B && transpose_B)
117 source.append(
" size_t bBegin = (col_block_id * block_size * B_row_inc + B_row_start) + B_col_start * B_internal_rows; \n");
118 source.append(
" size_t bStep = block_size * B_internal_rows * B_col_inc; \n");
120 else if (!row_major_B && !transpose_B)
122 source.append(
" size_t bBegin = (col_block_id * block_size * B_col_inc + B_col_start) * B_internal_rows + B_row_start; \n");
123 source.append(
" size_t bStep = block_size * B_row_inc; \n");
128 source.append(
" size_t block_num = (A_row_size + block_size - 1) / block_size; \n");
130 source.append(
" size_t block_num = (A_col_size + block_size - 1) / block_size; \n");
132 source.append(
" "); source.append(numeric_string); source.append(
" Csub = 0; \n");
136 source.append(
" size_t aOffset = row_thread_id * A_col_inc + col_thread_id * A_row_inc * A_internal_cols; \n");
138 source.append(
" size_t aOffset = row_thread_id * A_row_inc + col_thread_id * A_col_inc * A_internal_rows; \n");
141 source.append(
" size_t bOffset = row_thread_id * B_col_inc + col_thread_id * B_row_inc * B_internal_cols; \n");
143 source.append(
" size_t bOffset = row_thread_id * B_row_inc + col_thread_id * B_col_inc * B_internal_rows; \n");
145 source.append(
" size_t row_thread_id_times_block_size = row_thread_id * (block_size + 1); \n");
146 source.append(
" size_t col_thread_id_times_block_size = col_thread_id * (block_size + 1); \n");
148 source.append(
" for (size_t block = 0; \n");
149 source.append(
" block < block_num; \n");
150 source.append(
" ++block) \n");
151 source.append(
" { \n");
155 if (transpose_A && row_major_A)
156 source.append(
" bufA[row_thread_id_times_block_size + col_thread_id] = ((block * block_size + col_thread_id < A_row_size) && (row_block_id * block_size + row_thread_id < A_col_size)) ? A[aBegin + aOffset] : 0; \n");
157 else if (transpose_A && !row_major_A)
158 source.append(
" bufA[col_thread_id_times_block_size + row_thread_id] = ((block * block_size + row_thread_id < A_row_size) && (row_block_id * block_size + col_thread_id < A_col_size)) ? A[aBegin + aOffset] : 0; \n");
159 else if (!transpose_A && row_major_A)
160 source.append(
" bufA[col_thread_id_times_block_size + row_thread_id] = ((block * block_size + row_thread_id < A_col_size) && (row_block_id * block_size + col_thread_id < A_row_size)) ? A[aBegin + aOffset] : 0; \n");
161 else if (!transpose_A && !row_major_A)
162 source.append(
" bufA[row_thread_id_times_block_size + col_thread_id] = ((block * block_size + col_thread_id < A_col_size) && (row_block_id * block_size + row_thread_id < A_row_size)) ? A[aBegin + aOffset] : 0; \n");
165 if (transpose_B && row_major_B)
166 source.append(
" bufB[col_thread_id_times_block_size + row_thread_id] = ((block * block_size + row_thread_id < B_col_size) && (col_block_id * block_size + col_thread_id < B_row_size)) ? B[bBegin + bOffset] : 0; \n");
167 else if (transpose_B && !row_major_B)
168 source.append(
" bufB[row_thread_id_times_block_size + col_thread_id] = ((block * block_size + col_thread_id < B_col_size) && (col_block_id * block_size + row_thread_id < B_row_size)) ? B[bBegin + bOffset] : 0; \n");
169 else if (!transpose_B && row_major_B)
170 source.append(
" bufB[row_thread_id_times_block_size + col_thread_id] = ((block * block_size + col_thread_id < B_row_size) && (col_block_id * block_size + row_thread_id < B_col_size)) ? B[bBegin + bOffset] : 0; \n");
171 else if (!transpose_B && !row_major_B)
172 source.append(
" bufB[col_thread_id_times_block_size + row_thread_id] = ((block * block_size + row_thread_id < B_row_size) && (col_block_id * block_size + col_thread_id < B_col_size)) ? B[bBegin + bOffset] : 0; \n");
175 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
178 source.append(
" __local "); source.append(numeric_string); source.append(
" * bufAptr = bufA + row_thread_id_times_block_size; \n");
179 source.append(
" __local "); source.append(numeric_string); source.append(
" * bufBptr = bufB + col_thread_id_times_block_size; \n");
181 for (
size_t unroll = 0; unroll < 16; ++unroll) {
182 source.append(
" Csub += (*bufAptr) * (*bufBptr); ++bufAptr; ++bufBptr; \n");
185 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
186 source.append(
" aBegin += aStep; \n");
187 source.append(
" bBegin += bStep; \n");
188 source.append(
" } \n");
193 source.append(
" if (get_global_id(0) < A_col_size && ");
197 source.append(
" if (get_global_id(0) < A_row_size && ");
202 source.append(
"get_global_id(1) < B_row_size) \n");
206 source.append(
"get_global_id(1) < B_col_size) \n");
211 source.append(
" C[(get_global_id(0) * C_row_inc + C_row_start) * C_internal_cols + get_global_id(1) * C_col_inc + C_col_start] = (beta == 0) ? alpha * Csub : alpha * Csub + beta * C[(get_global_id(0) * C_row_inc + C_row_start) * C_internal_cols + get_global_id(1) * C_col_inc + C_col_start]; \n");
215 source.append(
" C[get_global_id(0) * C_row_inc + C_row_start + (get_global_id(1) * C_col_inc + C_col_start) * C_internal_rows] = (beta == 0) ? alpha * Csub : alpha * Csub + beta * C[get_global_id(0) * C_row_inc + C_row_start + (get_global_id(1) * C_col_inc + C_col_start) * C_internal_rows]; \n");
217 source.append(
"} \n");
220 template <
typename StringType>
222 bool row_major_A,
bool row_major_B,
bool row_major_C,
223 bool transpose_A,
bool transpose_B)
229 source.append(
"__kernel void prod16_");
239 source.append(
"( "); source.append(numeric_string); source.append(
" alpha, \n");
240 source.append(
" __global const "); source.append(numeric_string); source.append(
" * A, \n");
241 source.append(
" unsigned int A_row_start, \n");
242 source.append(
" unsigned int A_col_start, \n");
243 source.append(
" unsigned int A_row_inc, \n");
244 source.append(
" unsigned int A_col_inc, \n");
245 source.append(
" unsigned int A_row_size, \n");
246 source.append(
" unsigned int A_col_size, \n");
247 source.append(
" unsigned int A_internal_rows, \n");
248 source.append(
" unsigned int A_internal_cols, \n");
249 source.append(
" __global const "); source.append(numeric_string); source.append(
" * B, \n");
250 source.append(
" unsigned int B_row_start, \n");
251 source.append(
" unsigned int B_col_start, \n");
252 source.append(
" unsigned int B_row_inc, \n");
253 source.append(
" unsigned int B_col_inc, \n");
254 source.append(
" unsigned int B_row_size, \n");
255 source.append(
" unsigned int B_col_size, \n");
256 source.append(
" unsigned int B_internal_rows, \n");
257 source.append(
" unsigned int B_internal_cols, \n");
258 source.append(
" "); source.append(numeric_string); source.append(
" beta, \n");
259 source.append(
" __global "); source.append(numeric_string); source.append(
" * C, \n");
260 source.append(
" unsigned int C_row_start, \n");
261 source.append(
" unsigned int C_col_start, \n");
262 source.append(
" unsigned int C_row_inc, \n");
263 source.append(
" unsigned int C_col_inc, \n");
264 source.append(
" unsigned int C_row_size, \n");
265 source.append(
" unsigned int C_col_size, \n");
266 source.append(
" unsigned int C_internal_rows, \n");
267 source.append(
" unsigned int C_internal_cols) \n");
268 source.append(
"{ \n");
270 source.append(
" size_t row_block_id = get_group_id(1); \n");
271 source.append(
" size_t col_block_id = get_group_id(0); \n");
272 source.append(
" size_t row_thread_id = get_local_id(1); \n");
273 source.append(
" size_t col_thread_id = get_local_id(0); \n");
275 source.append(
" __local "); source.append(numeric_string); source.append(
" As[256]; \n");
277 source.append(
" "); source.append(numeric_string); source.append(
" cv[16] = {");
280 source.append(
"0}; \n");
283 if (row_major_A && transpose_A)
285 source.append(
" size_t aBegin = (row_block_id * 16 * A_col_inc + A_col_start) + A_row_start * A_internal_cols; \n");
286 source.append(
" size_t aStep = 16 * A_internal_cols * A_row_inc; \n");
287 source.append(
" size_t aEnd = aBegin + A_internal_cols * A_row_inc * A_row_size; \n");
289 else if (row_major_A && !transpose_A)
291 source.append(
" size_t aBegin = (row_block_id * 16 * A_row_inc + A_row_start) * A_internal_cols + A_col_start; \n");
292 source.append(
" size_t aStep = 16 * A_col_inc; \n");
293 source.append(
" size_t aEnd = aBegin + A_col_inc * A_col_size; \n");
295 else if (!row_major_A && transpose_A)
297 source.append(
" size_t aBegin = (row_block_id * 16 * A_col_inc + A_col_start) * A_internal_rows + A_row_start; \n");
298 source.append(
" size_t aStep = 16 * A_row_inc; \n");
299 source.append(
" size_t aEnd = aBegin + A_row_inc * A_row_size; \n");
301 else if (!row_major_A && !transpose_A)
303 source.append(
" size_t aBegin = (row_block_id * 16 * A_row_inc + A_row_start) + A_col_start * A_internal_rows; \n");
304 source.append(
" size_t aStep = 16 * A_internal_rows * A_col_inc; \n");
305 source.append(
" size_t aEnd = aBegin + A_internal_rows * A_col_inc * A_col_size; \n");
309 if (row_major_B && transpose_B)
311 source.append(
" size_t bBegin = (col_block_id * 64 * B_row_inc + B_row_start) * B_internal_cols + B_col_start; \n");
312 source.append(
" size_t bStep = 16 * B_col_inc; \n");
314 else if (row_major_B && !transpose_B)
316 source.append(
" size_t bBegin = (col_block_id * 64 * B_col_inc + B_col_start) + B_row_start * B_internal_cols; \n");
317 source.append(
" size_t bStep = 16 * B_row_inc * B_internal_cols; \n");
319 else if (!row_major_B && transpose_B)
321 source.append(
" size_t bBegin = (col_block_id * 64 * B_row_inc + B_row_start) + B_col_start * B_internal_rows; \n");
322 source.append(
" size_t bStep = 16 * B_col_inc * B_internal_rows; \n");
324 else if (!row_major_B && !transpose_B)
326 source.append(
" size_t bBegin = (col_block_id * 64 * B_col_inc + B_col_start) * B_internal_rows + B_row_start; \n");
327 source.append(
" size_t bStep = 16 * B_row_inc; \n");
330 source.append(
" for(size_t a = aBegin, b = bBegin; a < aEnd; a += aStep, b += bStep) { \n");
333 source.append(
" for(size_t i = 0; i < 4; i++) \n");
334 if (row_major_A && transpose_A)
335 source.append(
" As[ (i*4 + row_thread_id) + 16 * col_thread_id] = (A[a + A_col_inc * (i * 4 + row_thread_id) + A_internal_cols * A_row_inc * col_thread_id]);");
336 else if (row_major_A && !transpose_A)
337 source.append(
" As[ (i*4 + row_thread_id) + 16 * col_thread_id] = (A[a + A_internal_cols * A_row_inc * (i * 4 + row_thread_id) + A_col_inc * col_thread_id]);");
338 else if (!row_major_A && transpose_A)
339 source.append(
" As[ (i*4 + row_thread_id) + 16 * col_thread_id] = (A[a + A_internal_rows * A_col_inc * (i * 4 + row_thread_id) + A_row_inc * col_thread_id]);");
340 else if (!row_major_A && !transpose_A)
341 source.append(
" As[ (i*4 + row_thread_id) + 16 * col_thread_id] = (A[a + A_row_inc * (i * 4 + row_thread_id) + A_internal_rows * A_col_inc * col_thread_id]);");
343 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
346 source.append(
" __local "); source.append(numeric_string); source.append(
" *ap = As; \n");
347 if (row_major_B && transpose_B)
349 source.append(
" __global const "); source.append(numeric_string); source.append(
" *bp = B + (b + (16 * row_thread_id + col_thread_id) * B_row_inc * B_internal_cols); \n");
351 else if (row_major_B && !transpose_B)
353 source.append(
" __global const "); source.append(numeric_string); source.append(
" *bp = B + (b + (16 * row_thread_id + col_thread_id) * B_col_inc); \n");
355 else if (!row_major_B && transpose_B)
357 source.append(
" __global const "); source.append(numeric_string); source.append(
" *bp = B + (b + (16 * row_thread_id + col_thread_id) * B_row_inc); \n");
359 else if (!row_major_B && !transpose_B)
361 source.append(
" __global const "); source.append(numeric_string); source.append(
" *bp = B + (b + (16 * row_thread_id + col_thread_id) * B_col_inc * B_internal_rows); \n");
365 source.append(
" for(size_t i = 0; i < 16; i++) { \n");
366 if (row_major_B && transpose_B)
368 source.append(
" "); source.append(numeric_string); source.append(
" bv = bp[i * B_col_inc]; \n");
370 else if (row_major_B && !transpose_B)
372 source.append(
" "); source.append(numeric_string); source.append(
" bv = bp[i * B_row_inc * B_internal_cols]; \n");
374 else if (!row_major_B && transpose_B)
376 source.append(
" "); source.append(numeric_string); source.append(
" bv = bp[i * B_col_inc * B_internal_rows]; \n");
378 else if (!row_major_B && !transpose_B)
380 source.append(
" "); source.append(numeric_string); source.append(
" bv = bp[i * B_row_inc]; \n");
383 source.append(
" for(size_t k = 0; k < 16; k++) \n");
384 source.append(
" cv[k] += ap[k] * bv; \n");
386 source.append(
" ap += 16; \n");
387 source.append(
" } \n");
389 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
390 source.append(
" } \n");
395 source.append(
" int c = C_internal_cols * (C_row_inc * 16 * row_block_id + C_row_start) + 64 * C_col_inc * col_block_id + C_col_start \n");
396 source.append(
" + C_col_inc * (16 * row_thread_id + col_thread_id); \n");
400 source.append(
" int c = C_row_inc * 16 * row_block_id + C_row_start + (64 * C_col_inc * col_block_id + C_col_start) * C_internal_rows \n");
401 source.append(
" + C_internal_rows * C_col_inc * (16 * row_thread_id + col_thread_id); \n");
404 source.append(
" for(size_t i = 0; i < 16; i++) { \n");
408 source.append(
" C[c] = (beta == 0) ? alpha * cv[i] : alpha * cv[i] + beta * C[c]; \n");
409 source.append(
" c += C_internal_cols * C_row_inc; \n");
413 source.append(
" C[c] = (beta == 0) ? alpha * cv[i] : alpha * cv[i] + beta * C[c]; \n");
414 source.append(
" c += C_row_inc; \n");
417 source.append(
" } \n");
418 source.append(
"} \n");
430 template <
class NumericT,
typename F_A,
typename F_B,
typename F_C>
447 static std::map<cl_context, bool> init_done;
451 source.reserve(8192);
453 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
456 if (numeric_string ==
"float" || numeric_string ==
"double")
471 #ifdef VIENNACL_BUILD_INFO
472 std::cout <<
"Creating program " << prog_name << std::endl;
474 ctx.add_program(source, prog_name);
475 init_done[ctx.handle().get()] =
true;
std::size_t vcl_size_t
Definition: forwards.h:58
Helper class for checking whether a matrix has a row-major layout.
Definition: forwards.h:399
Manages an OpenCL context and provides the respective convenience functions for creating buffers...
Definition: context.hpp:51
Provides OpenCL-related utilities.
void generate_matrix_prod_blas3(StringType &source, std::string const &numeric_string, bool row_major_A, bool row_major_B, bool row_major_C, bool transpose_A, bool transpose_B)
Definition: matrix_prod.hpp:23
Main kernel class for the generation of matrix-matrix product kernels C = A * B.
Definition: matrix_prod.hpp:431
const OCL_TYPE & get() const
Definition: handle.hpp:189
const viennacl::ocl::handle< cl_context > & handle() const
Returns the context handle.
Definition: context.hpp:476
Main namespace in ViennaCL. Holds all the basic types such as vector, matrix, etc. and defines operations upon them.
Definition: cpu_ram.hpp:29
void generate_matrix_prod16_blas3(StringType &source, std::string const &numeric_string, bool row_major_A, bool row_major_B, bool row_major_C, bool transpose_A, bool transpose_B)
Definition: matrix_prod.hpp:221
static void apply(viennacl::ocl::context const &)
Definition: utils.hpp:40
static void init(viennacl::ocl::context &ctx)
Definition: matrix_prod.hpp:438
Representation of an OpenCL kernel in ViennaCL.
std::string type_to_string(viennacl::row_major)
Definition: matrix.hpp:868
Helper class for converting a type to its string representation.
Definition: utils.hpp:57
static std::string program_name()
Definition: matrix_prod.hpp:433
Runtime generation of OpenCL kernels for matrix operations.