39 #if defined(CUTLASS_ARCH_WMMA_ENABLED)    55 namespace threadblock {
    73     typename InstructionShape_,
    86 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
    87                       layout::ColumnMajor, ElementB_, layout::RowMajor,
    88                       ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
    91   using WarpShape = WarpShape_;
    92   using InstructionShape = InstructionShape_;
    93   using ElementA = ElementA_;
    94   using LayoutA = layout::ColumnMajor;
    95   using ElementB = ElementB_;
    96   using LayoutB = layout::RowMajor;
    97   using ElementC = ElementC_;
    98   using LayoutC = LayoutC_;
    99   using OperatorClass = arch::OpClassWmmaTensorOp;
   102   using WarpCount = GemmShape<
   103     Shape::kM / WarpShape::kM,
   104     Shape::kN / WarpShape::kN,
   105     Shape::kK / WarpShape::kK
   110     !(Shape::kM % WarpShape::kM) &&
   111     !(Shape::kN % WarpShape::kN),
   112     "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."   119   static int const kThreads = WarpCount::kCount * kWarpSize;
   122   static int const kAccessSizeInBits = 128;
   125   using Operator = Operator_;
   131   using SmemLayoutA = LayoutA;
   132   using SmemLayoutB = LayoutB;
   143   using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
   144     layout::PitchLinearShape<Shape::kM, Shape::kK>,
   150   using SmemIteratorA = transform::threadblock::RegularTileIterator<
   151     MatrixShape<Shape::kM, Shape::kK>, 
   159   using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
   160     layout::PitchLinearShape<Shape::kN, Shape::kK>,
   166   using SmemIteratorB = transform::threadblock::RegularTileIterator<
   167     MatrixShape<Shape::kK, Shape::kN>, 
   193   using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
   205   using MmaPolicy = MmaPolicy<
   207     MatrixShape<kPaddingA, 0>,
   208     MatrixShape<0, kPaddingB>,
   230     typename InstructionShape_,
   243 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
   244                       layout::RowMajor, ElementB_, layout::ColumnMajor,
   245                       ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
   247   using Shape = Shape_;
   248   using WarpShape = WarpShape_;
   249   using InstructionShape = InstructionShape_;
   250   using ElementA = ElementA_;
   251   using LayoutA = layout::RowMajor;
   252   using ElementB = ElementB_;
   253   using LayoutB = layout::ColumnMajor;
   254   using ElementC = ElementC_;
   255   using LayoutC = LayoutC_;
   256   using OperatorClass = arch::OpClassWmmaTensorOp;
   259   using WarpCount = GemmShape<
   260     Shape::kM / WarpShape::kM,
   261     Shape::kN / WarpShape::kN,
   262     Shape::kK / WarpShape::kK
   267     !(Shape::kM % WarpShape::kM) &&
   268     !(Shape::kN % WarpShape::kN),
   269     "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."   276   static int const kThreads = WarpCount::kCount * kWarpSize;
   280   static int const kAccessSizeInBits = 128;
   283   using Operator = Operator_;
   286   static int const kWarpThreadArrangementContiguousA =
   289   static int const kWarpThreadArrangementStridedA =
   290       kWarpSize / kWarpThreadArrangementContiguousA;
   292   static int const kWarpThreadArrangementContiguousB =
   295   static int const kWarpThreadArrangementStridedB =
   296       kWarpSize / kWarpThreadArrangementContiguousB;
   303   using SmemLayoutA = LayoutA;
   304   using SmemLayoutB = LayoutB;
   313   using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
   314     layout::PitchLinearShape<Shape::kK, Shape::kM>,
   320   using SmemIteratorA = transform::threadblock::RegularTileIterator<
   321     MatrixShape<Shape::kM, Shape::kK>, 
   329   using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
   330     layout::PitchLinearShape<Shape::kK, Shape::kN>,
   336   using SmemIteratorB = transform::threadblock::RegularTileIterator<
   337     MatrixShape<Shape::kK, Shape::kN>, 
   363   using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
   375   using MmaPolicy = MmaPolicy<
   377     MatrixShape<0, kPaddingA>,
   378     MatrixShape<kPaddingB, 0>,
   401     typename InstructionShape_,
   414 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
   415                       layout::RowMajor, ElementB_, layout::RowMajor, ElementC_,
   416                       LayoutC_, arch::OpClassWmmaTensorOp, Stages, Operator_> {
   417   using Shape = Shape_;
   418   using WarpShape = WarpShape_;
   419   using InstructionShape = InstructionShape_;
   420   using ElementA = ElementA_;
   421   using LayoutA = layout::RowMajor;
   422   using ElementB = ElementB_;
   423   using LayoutB = layout::RowMajor;
   424   using ElementC = ElementC_;
   425   using LayoutC = LayoutC_;
   426   using OperatorClass = arch::OpClassWmmaTensorOp;
   429   using WarpCount = GemmShape<
   430     Shape::kM / WarpShape::kM,
   431     Shape::kN / WarpShape::kN,
   432     Shape::kK / WarpShape::kK
   437     !(Shape::kM % WarpShape::kM) &&
   438     !(Shape::kN % WarpShape::kN),
   439     "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."   446   static int const kThreads = WarpCount::kCount * kWarpSize;
   449   static int const kAccessSizeInBits = 128;
   452   using Operator = Operator_;
   455   static int const kWarpThreadArrangementContiguousA =
   458   static int const kWarpThreadArrangementStridedA =
   459       kWarpSize / kWarpThreadArrangementContiguousA;
   466   using SmemLayoutA = LayoutA;
   467   using SmemLayoutB = LayoutB;
   478   using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
   479     layout::PitchLinearShape<Shape::kK, Shape::kM>,
   486   using SmemIteratorA = transform::threadblock::RegularTileIterator<
   487     MatrixShape<Shape::kM, Shape::kK>, 
   495   using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
   496     layout::PitchLinearShape<Shape::kN, Shape::kK>,
   502   using SmemIteratorB = transform::threadblock::RegularTileIterator<
   503     MatrixShape<Shape::kK, Shape::kN>, 
   529   using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
   541   using MmaPolicy = MmaPolicy<
   543     MatrixShape<0, kPaddingA>,
   544     MatrixShape<0, kPaddingB>,
   565     typename InstructionShape_,
   578 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
   579                       layout::ColumnMajor, ElementB_, layout::ColumnMajor,
   580                       ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
   582   using Shape = Shape_;
   583   using WarpShape = WarpShape_;
   584   using InstructionShape = InstructionShape_;
   585   using ElementA = ElementA_;
   586   using LayoutA = layout::ColumnMajor;
   587   using ElementB = ElementB_;
   588   using LayoutB = layout::ColumnMajor;
   589   using ElementC = ElementC_;
   590   using LayoutC = LayoutC_;
   591   using OperatorClass = arch::OpClassWmmaTensorOp;
   595       GemmShape<Shape::kM / WarpShape::kM, Shape::kN / WarpShape::kN,
   596                 Shape::kK / WarpShape::kK>;
   600       !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
   601       "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
   607   static int const kThreads = WarpCount::kCount * kWarpSize;
   610   static int const kAccessSizeInBits = 128;
   613   using Operator = Operator_; 
   616   static int const kWarpThreadArrangementContiguousB =
   619   static int const kWarpThreadArrangementStridedB =
   620       kWarpSize / kWarpThreadArrangementContiguousB;
   627   using SmemLayoutA = LayoutA;
   628   using SmemLayoutB = LayoutB;
   639   using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
   640     layout::PitchLinearShape<Shape::kM, Shape::kK>,
   646   using SmemIteratorA = transform::threadblock::RegularTileIterator<
   647       MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 1,
   651   using IteratorThreadMapB =  transform::PitchLinearStripminedThreadMap<
   652     layout::PitchLinearShape<Shape::kK, Shape::kN>,
   658   using SmemIteratorB = transform::threadblock::RegularTileIterator<
   659       MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 0,
   681   using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
   693   using MmaPolicy = MmaPolicy<
   695     MatrixShape<kPaddingA, 0>,
   696     MatrixShape<kPaddingB, 0>,
   705 #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) Describes the size of a matrix tile. 
Definition: matrix_shape.h:42
Templates implementing loading of tiles from pitch-linear rank=2 tensors. 
Definition: aligned_buffer.h:35
static int const value
Definition: numeric_types.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Defines a Shape template for matrix tiles. 
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
static int const value
Definition: gemm/warp/mma.h:44
Top-level include for all CUTLASS numeric types. 
Policy. 
Definition: mma_tensor_op_policy.h:48
Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores...
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations. 
Basic include for CUTLASS. 
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.