69   ComputeType initial_accum) {
    72     LayoutA::kRank == 2 &&
    73     LayoutB::kRank == 2 &&
    74     LayoutC::kRank == 2, 
"Tensors must be of rank 2");
    78   int const M = problem_size.
m();
    79   int const N = problem_size.
n();
    80   int const K = problem_size.
k();
    83   int const Mblock = 16;
    84   int const Nblock = 16;
    87   InnerProductOp inner_product_op;
    89   for (
int row_block = 0; row_block < M; row_block += Mblock) {
    90     for (
int col_block = 0; col_block < N; col_block += Nblock) {
    92       ComputeType accum[Mblock][Nblock];
    94       for (
int j = 0; j < Nblock; j++) {
    95         for (
int i = 0; i < Mblock; i++) {
    96           accum[i][j] = initial_accum;
   100       for (
int k_block = 0; k_block < K; ++k_block) {
   101         for (
int j = 0; j < Nblock; j++) {
   102           for (
int i = 0; i < Mblock; i++) {
   103             int row = row_block + i;
   104             int col = col_block + j;
   106             if (row < M && col < N) {
   110               accum[i][j] = inner_product_op(ComputeType(a), ComputeType(b),  accum[i][j]);
   116       for (
int j = 0; j < Nblock; j++) {
   117         for (
int i = 0; i < Mblock; i++) {
   118           int row = row_block + i;
   119           int col = col_block + j;
   123           if (row < M && col < N) {
   124             tensor_d.
at(coord) = convert_op(
   125               alpha * ScalarType(accum[i][j]) +
   126               beta * ScalarType(tensor_c.
at(coord)));
   146   typename ComputeType,
   157   ComputeType initial_accum) {
   158   compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   159                ScalarType, ComputeType, InnerProductOp, ConvertOp>(
   160       problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
   174   typename ComputeType,
   175   typename InnerProductOp = cutlass::arch::OpMultiplyAdd
   182 template <
typename ElementA, 
typename LayoutA, 
typename ElementB,
   183           typename LayoutB, 
typename ElementC, 
typename LayoutC,
   184           typename ScalarType, 
typename ComputeType>
   185 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
   186             ComputeType, arch::OpMultiplyAdd> {
   192                   ComputeType initial_accum = ComputeType(0)) {
   194         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   195         "Tensors must be of rank 2");
   197     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   199         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
   207                   ComputeType initial_accum = ComputeType(0)) {
   209         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   210         "Tensors must be of rank 2");
   212     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   214         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
   221 template <
typename ElementA, 
typename LayoutA, 
typename ElementB,
   222           typename LayoutB, 
typename ElementC, 
typename LayoutC,
   223           typename ScalarType, 
typename ComputeType>
   224 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
   225             ComputeType, arch::OpMultiplyAddSaturate> {
   231                   ComputeType initial_accum = ComputeType(0)) {
   233         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   234         "Tensors must be of rank 2");
   236     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   237                  ScalarType, ComputeType, multiply_add<ComputeType>,
   239         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
   247                   ComputeType initial_accum = ComputeType(0)) {
   249         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   250         "Tensors must be of rank 2");
   252     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   253                  ScalarType, ComputeType, multiply_add<ComputeType>,
   255         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
   262 template <
typename ElementA, 
typename LayoutA, 
typename ElementB,
   263           typename LayoutB, 
typename ElementC, 
typename LayoutC,
   264           typename ScalarType, 
typename ComputeType>
   265 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
   266             ComputeType, arch::OpXorPopc> {
   272                   ComputeType initial_accum = ComputeType(0)) {
   274         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   275         "Tensors must be of rank 2");
   277     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   279         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
   287                   ComputeType initial_accum = ComputeType(0)) {
   289         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   290         "Tensors must be of rank 2");
   292     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   294         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
   309   typename TensorRefCollectionA,
   310   typename TensorRefCollectionB,
   311   typename TensorRefCollectionC,
   313   typename AccumulatorType
   319   TensorRefCollectionA 
const& tensor_a,
   320   TensorRefCollectionB 
const& tensor_b,
   322   TensorRefCollectionC &tensor_c,
   323   AccumulatorType initial_accum) {
   325   typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin();
   326   typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin();
   327   typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin();
   331     ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) {
   333     Gemm<
typename TensorRefCollectionA::Element,
   334          typename TensorRefCollectionA::Layout,
   335          typename TensorRefCollectionB::Element,
   336          typename TensorRefCollectionB::Layout,
   337          typename TensorRefCollectionC::Element,
   338          typename TensorRefCollectionC::Layout,
   339          typename TensorRefCollectionC::Element,
   340          typename TensorRefCollectionC::Element>
   343     gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it,
   354   typename TensorRefCollectionA,
   355   typename TensorRefCollectionB,
   356   typename TensorRefCollectionC,
   358   typename AccumulatorType
   364   TensorRefCollectionA 
const& tensor_a,
   365   TensorRefCollectionB 
const& tensor_b,
   367   TensorRefCollectionC &tensor_c) {
   369   BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
 Fused multiply-add. 
Definition: functional.h:92
void BatchedGemm(gemm::GemmCoord problem_size, int batch_count, ScalarType alpha, TensorRefCollectionA const &tensor_a, TensorRefCollectionB const &tensor_b, ScalarType beta, TensorRefCollectionC &tensor_c, AccumulatorType initial_accum)
Computes a batch of GEMMs over a set of matrices of common dimension. 
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:315
Definition: aligned_buffer.h:35
Definition: numeric_conversion.h:254
A Coord is a coordinate of arbitrary rank into a tensor or matrix. 
Definition: include/cutlass/gemm/gemm.h:94
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:177
Defines common types used for all GEMM-like operators. 
CUTLASS_HOST_DEVICE Index const & n() const 
Returns the GEMM N coordinate. 
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data. 
CUTLASS_HOST_DEVICE Index const & k() const 
Returns the GEMM K coordinate. 
Definition: include/cutlass/gemm/gemm.h:145
Templates exposing architecture support for multiply-add operations. 
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:242
Boost-like numeric conversion operator for CUTLASS numeric types. 
Top-level include for all CUTLASS numeric types. 
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:202
Definition: numeric_conversion.h:59
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const 
Returns a reference to the element at a given Coord. 
Definition: tensor_ref.h:307
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:282
void compute_gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum)
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:61
Fused multiply-add. 
Definition: functional.h:101
CUTLASS_HOST_DEVICE Index const & m() const 
Returns the GEMM M coordinate. 
Definition: include/cutlass/gemm/gemm.h:129
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:188
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:227
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:268
Defines properties of matrices used to denote layout and operands to GEMM kernels. 
Definition: matrix_coord.h:39
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...