49   typename ThreadblockSwizzle_,   
    69     typename Mma::IteratorA::TensorRef 
ref_A;
    71     typename Mma::IteratorB::TensorRef 
ref_B;
    72     typename Epilogue::OutputTileIterator::Params 
params_C;
    73     typename Epilogue::OutputTileIterator::TensorRef 
ref_C;
    74     typename Epilogue::OutputTileIterator::Params 
params_D;
    75     typename Epilogue::OutputTileIterator::TensorRef 
ref_D;
    92       typename Mma::IteratorA::TensorRef ref_A,
    93       typename Mma::IteratorB::TensorRef ref_B,
    94       typename Epilogue::OutputTileIterator::TensorRef ref_C,
    95       typename Epilogue::OutputTileIterator::TensorRef ref_D,
    96       typename OutputOp::Params output_op = 
typename OutputOp::Params(),
    97       int *semaphore = 
nullptr    99       problem_size(problem_size),
   100       grid_tiled_shape(grid_tiled_shape),
   101       params_A(ref_A.layout()),
   103       params_B(ref_B.layout()),
   105       params_C(ref_C.layout()),
   107       params_D(ref_D.layout()),
   109       output_op(output_op),
   110       semaphore(semaphore) {
   112       int total_gemm_k_iterations = (problem_size.
k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
   113       int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.
k() - 1) / grid_tiled_shape.
k();
   115       gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
   135       typename Mma::IteratorA::TensorRef ref_A,
   136       typename Mma::IteratorB::TensorRef ref_B,
   137       typename Epilogue::OutputTileIterator::TensorRef ref_C,
   138       typename Epilogue::OutputTileIterator::TensorRef ref_D) {
   140     static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
   141     static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
   142     static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
   160     if ((problem_size.
m() % kAlignmentA) || (problem_size.
k() % kAlignmentA) ||
   161       (problem_size.
n() % kAlignmentB) || (problem_size.
k() % kAlignmentB) ||
   162       (problem_size.
m() % kAlignmentC) || (problem_size.
n() % kAlignmentC)) {
   188       threadblock_tile_offset.
m() * Mma::Shape::kM,
   194       threadblock_tile_offset.
n() * Mma::Shape::kN
   198     int problem_size_k = 
min(
   200       (threadblock_tile_offset.
k() + 1) * params.
gemm_k_size);
   203     int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
   206     int thread_idx = threadIdx.x;
   209     typename Mma::IteratorA iterator_A(
   216     typename Mma::IteratorB iterator_B(
   223     int warp_idx = threadIdx.x / 32;
   224     int lane_idx = threadIdx.x % 32;
   231     Mma mma(shared_storage.
main_loop, thread_idx, warp_idx, lane_idx);
   233     typename Mma::FragmentC accumulators;
   235     accumulators.clear();
   237     if (!kSplitKSerial || gemm_k_iterations > 0) {
   239       mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
   252     threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
   256       threadblock_tile_offset.
m() * Mma::Shape::kM,
   257       threadblock_tile_offset.
n() * Mma::Shape::kN
   260     int block_idx = threadblock_tile_offset.
m() + threadblock_tile_offset.
n() * params.
grid_tiled_shape.
m();
   272       output_op.set_k_partition(threadblock_tile_offset.
k());
   276     typename Epilogue::OutputTileIterator iterator_C(
   285     typename Epilogue::OutputTileIterator iterator_D(
   303       if (threadblock_tile_offset.
k()) {
   304         iterator_C = iterator_D;
   307       semaphore.
wait(threadblock_tile_offset.
k());
   313     epilogue(output_op, iterator_D, accumulators, iterator_C); 
   329         lock = threadblock_tile_offset.
k() + 1;
 Epilogue::OutputTileIterator::TensorRef ref_C
Definition: include/cutlass/gemm/kernel/gemm.h:73
Definition: aligned_buffer.h:35
Epilogue::SharedStorage epilogue
Definition: include/cutlass/gemm/kernel/gemm.h:122
Epilogue::OutputTileIterator::Params params_D
Definition: include/cutlass/gemm/kernel/gemm.h:74
Mma::IteratorA::Params params_A
Definition: include/cutlass/gemm/kernel/gemm.h:68
Epilogue_ Epilogue
Definition: include/cutlass/gemm/kernel/gemm.h:55
Mma::IteratorB::Params params_B
Definition: include/cutlass/gemm/kernel/gemm.h:70
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op=typename OutputOp::Params(), int *semaphore=nullptr)
Definition: include/cutlass/gemm/kernel/gemm.h:89
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const 
Obtains a Coord<2> from GemmCoord. 
Definition: include/cutlass/gemm/gemm.h:171
Epilogue::OutputTileIterator::Params params_C
Definition: include/cutlass/gemm/kernel/gemm.h:72
static int const kThreadCount
Definition: include/cutlass/gemm/kernel/gemm.h:62
Defines common types used for all GEMM-like operators. 
CUTLASS_DEVICE void fetch()
Permit fetching the synchronization mechanism early. 
Definition: semaphore.h:68
CUTLASS_HOST_DEVICE Index const & n() const 
Returns the GEMM N coordinate. 
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: include/cutlass/gemm/kernel/gemm.h:67
int gemm_k_iterations
Definition: include/cutlass/gemm/kernel/gemm.h:78
Mma::IteratorB::TensorRef ref_B
Definition: include/cutlass/gemm/kernel/gemm.h:71
static Status can_implement(cutlass::gemm::GemmCoord const &problem_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D)
Determines whether kernel satisfies alignment. 
Definition: include/cutlass/gemm/kernel/gemm.h:133
CUTLASS_HOST_DEVICE Gemm()
Definition: include/cutlass/gemm/kernel/gemm.h:130
CUTLASS_HOST_DEVICE Index const & k() const 
Returns the GEMM K coordinate. 
Definition: include/cutlass/gemm/gemm.h:145
static bool const kSplitKSerial
Definition: include/cutlass/gemm/kernel/gemm.h:58
typename Epilogue::OutputOp OutputOp
Definition: include/cutlass/gemm/kernel/gemm.h:56
Parameters structure. 
Definition: include/cutlass/gemm/kernel/gemm.h:65
OutputOp::Params output_op
Definition: include/cutlass/gemm/kernel/gemm.h:76
operands fail alignment requirements. 
Shared memory storage structure. 
Definition: include/cutlass/gemm/kernel/gemm.h:120
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int gemm_k_size
Definition: include/cutlass/gemm/kernel/gemm.h:79
int * semaphore
Definition: include/cutlass/gemm/kernel/gemm.h:77
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM. 
Definition: include/cutlass/gemm/kernel/gemm.h:172
CTA-wide semaphore for inter-CTA synchronization. 
Definition: semaphore.h:48
Implementation of a CTA-wide semaphore for inter-CTA synchronization. 
Defines a canonical coordinate for rank=2 matrices offering named indices. 
CUTLASS_DEVICE void release(int status=0)
Updates the lock with the given result. 
Definition: semaphore.h:98
cutlass::gemm::GemmCoord problem_size
Definition: include/cutlass/gemm/kernel/gemm.h:66
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: include/cutlass/gemm/kernel/gemm.h:57
Definition: include/cutlass/gemm/kernel/gemm.h:52
Mma::IteratorA::TensorRef ref_A
Definition: include/cutlass/gemm/kernel/gemm.h:69
bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)
Definition: tensor_ref.h:382
CUTLASS_DEVICE void wait(int status=0)
Waits until the semaphore is equal to the given value. 
Definition: semaphore.h:81
Operation was successful. 
CUTLASS_HOST_DEVICE Index const & m() const 
Returns the GEMM M coordinate. 
Definition: include/cutlass/gemm/gemm.h:129
Mma_ Mma
Definition: include/cutlass/gemm/kernel/gemm.h:54
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape) 
Definition: include/cutlass/gemm/kernel/gemm.h:61
Basic include for CUTLASS. 
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE Params()
Definition: include/cutlass/gemm/kernel/gemm.h:86
Status
Status code returned by CUTLASS operations. 
Definition: cutlass.h:39
Mma::SharedStorage main_loop
Definition: include/cutlass/gemm/kernel/gemm.h:121
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: include/cutlass/gemm/kernel/gemm.h:75