35 #if defined(CUTLASS_ARCH_WMMA_ENABLED)    81   typename Enable = 
bool    83 class MmaTensorOpWmma {
    89   using ElementA = ElementA_;
    92   using LayoutA = LayoutA_;
    95   using ElementB = ElementB_;
    98   using LayoutB = LayoutB_;
   101   using ElementC = ElementC_;
   104   using LayoutC = LayoutC_;
   107   using Policy = Policy_;
   110   using OperatorClass = arch::OpClassTensorOp;
   113   static int const kThreadCount = 32;
   116   static int const kPartitionsK = PartitionsK_;
   119   static int const kPartitionsN = PartitionsN_;
   124   using IteratorA = MmaTensorOpWmmaMultiplicandTileIterator<
   125      MatrixShape<Shape::kM, Shape::kK>, 
Operand::kA, ElementA, LayoutA,
   126      Policy::OpDelta::kRow, kThreadCount, Policy>;
   129   using FragmentA = 
typename IteratorA::Fragment;
   132   using IteratorB = MmaTensorOpWmmaMultiplicandTileIterator<
   133      MatrixShape<Shape::kK, Shape::kN>, 
Operand::kB, ElementB, LayoutB,
   134      Policy::OpDelta::kRow, kThreadCount, Policy>;
   137   using FragmentB = 
typename IteratorB::Fragment;
   140   using IteratorC = MmaTensorOpWmmaAccumulatorTileIterator<
   141      MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
   142     typename Policy::OpDelta, Policy>;
   145   using FragmentC = 
typename IteratorC::Fragment;
   150     !(Shape::kM % Policy::Operator::Shape::kM) && 
   151     !(Shape::kN % Policy::Operator::Shape::kN),
   152     "Shape of warp-level Wmma must be divisible by operator shape (wmma native size)");
   155   using WmmaIterations = MatrixShape<
   156     Shape::kM / Policy::Operator::Shape::kM,
   157     (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?
   158      Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :
   165   typename Policy::Operator wmma;
   184     int const &partitionN_idx = 0)
 const {
   187     for (
int n = 0; n < WmmaIterations::kColumn; ++n) {
   189       for (
int m = 0; m < WmmaIterations::kRow; ++m) {
   192         wmma(D[m * WmmaIterations::kColumn + n], A[m], B[n], C[m * WmmaIterations::kColumn + n]);
   205 #endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
Architecture-specific operators on memory added for SM75. 
Defines common types used for all GEMM-like operators. 
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. 
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations. 
Defines a Shape template for matrix tiles. 
Top-level include for all CUTLASS numeric types. 
Matrix multiply for SM75. 
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations. 
Basic include for CUTLASS. 
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.