63   typename ConvertOp = NumericConverter<ElementC, ScalarType>,
    64   typename InnerProductOp = multiply_add<ComputeType>
    75   ComputeType initial_accum) {
    78     LayoutA::kRank == 2 &&
    79     LayoutB::kRank == 2 &&
    80     LayoutC::kRank == 2, 
"Tensors must be of rank 2");
    83   int const M = problem_size.
m();
    84   int const N = problem_size.
n();
    85   int const K = problem_size.
k();
    88   int const Mblock = 16;
    89   int const Nblock = 16;
    92   InnerProductOp inner_product_op;
    94   for (
int row_block = 0; row_block < M; row_block += Mblock) {
    95     for (
int col_block = 0; col_block < N; col_block += Nblock) {
    97       ComputeType accum[Mblock][Nblock];
    99       for (
int j = 0; j < Nblock; j++) {
   100         for (
int i = 0; i < Mblock; i++) {
   101           accum[i][j] = initial_accum;
   105       for (
int k_block = 0; k_block < K; ++k_block) {
   106         for (
int j = 0; j < Nblock; j++) {
   107           for (
int i = 0; i < Mblock; i++) {
   108             int row = row_block + i;
   109             int col = col_block + j;
   111             if (row < M && col < N) {
   115               ComputeType a_ik = ComputeType(a);
   116               ComputeType b_kj = ComputeType(b);
   126               accum[i][j] = inner_product_op(a_ik, b_kj,  accum[i][j]);
   132       for (
int j = 0; j < Nblock; j++) {
   133         for (
int i = 0; i < Mblock; i++) {
   134           int row = row_block + i;
   135           int col = col_block + j;
   139           if (row < M && col < N) {
   141             tensor_c.
at(coord) = convert_op(
   142               alpha * ScalarType(accum[i][j]) + 
   143               beta * ScalarType(tensor_c.
at(coord)));
   176   GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, ScalarType(0));
 Definition: aligned_buffer.h:35
ComplexTransform
Enumeraed type describing a transformation on a complex value. 
Definition: complex.h:43
A Coord is a coordinate of arbitrary rank into a tensor or matrix. 
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators. 
CUTLASS_HOST_DEVICE Index const & n() const 
Returns the GEMM N coordinate. 
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data. 
CUTLASS_HOST_DEVICE Index const & k() const 
Returns the GEMM K coordinate. 
Definition: include/cutlass/gemm/gemm.h:145
void GemmComplex(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, ComplexTransform transform_a, TensorRef< ElementB, LayoutB > tensor_b, ComplexTransform transform_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum)
Definition: tools/util/include/cutlass/util/reference/host/gemm_complex.h:66
Boost-like numeric conversion operator for CUTLASS numeric types. 
CUTLASS_HOST_DEVICE complex< T > conj(complex< T > const &z)
Returns the complex conjugate. 
Definition: complex.h:356
Top-level include for all CUTLASS numeric types. 
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const 
Returns a reference to the element at a given Coord. 
Definition: tensor_ref.h:307
CUTLASS_HOST_DEVICE Index const & m() const 
Returns the GEMM M coordinate. 
Definition: include/cutlass/gemm/gemm.h:129
Defines properties of matrices used to denote layout and operands to GEMM kernels. 
Definition: matrix_coord.h:39
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...