47   typename ThreadblockSwizzle_    
    65     typename Mma::IteratorA::TensorRef 
ref_A;
    68     typename Mma::IteratorB::TensorRef 
ref_B;
    70     typename Epilogue::OutputTileIterator::Params 
params_C;
    71     typename Epilogue::OutputTileIterator::TensorRef 
ref_C;
    73     typename Epilogue::OutputTileIterator::Params 
params_D;
    74     typename Epilogue::OutputTileIterator::TensorRef 
ref_D;
    91       typename Mma::IteratorA::TensorRef ref_A_,
    93       typename Mma::IteratorB::TensorRef ref_B_,
    95       typename Epilogue::OutputTileIterator::TensorRef ref_C_,
    97       typename Epilogue::OutputTileIterator::TensorRef ref_D_,
    99       typename OutputOp::Params epilogue_,
   102       problem_size(problem_size_),
   103       grid_tiled_shape(grid_tiled_shape_),
   104       params_A(ref_A_.layout()),
   107       params_B(ref_B_.layout()),
   110       params_C(ref_C_.layout()),
   113       params_D(ref_D_.layout()),
   117       batch_count(batch_count_),
   118       gemm_k_iterations((problem_size.k() + 
Mma::Shape::kK - 1) / 
Mma::Shape::kK) {
   154     for (
int batch_idx = threadblock_swizzle.get_batch_idx(); 
   156       batch_idx += gridDim.z) {
   160         threadblock_tile_offset.
m() * Mma::Shape::kM,
   166         threadblock_tile_offset.
n() * Mma::Shape::kN
   170       int thread_idx = threadIdx.x;
   173       typename Mma::IteratorA iterator_A(
   180       iterator_A.add_pointer_offset(params.
stride_A * batch_idx);
   182       typename Mma::IteratorB iterator_B(
   189       iterator_B.add_pointer_offset(params.
stride_B * batch_idx);
   197       int warp_idx = threadIdx.x / 32;
   198       int lane_idx = threadIdx.x % 32;
   200       Mma mma(shared_storage.
main_loop, thread_idx, warp_idx, lane_idx);
   202       typename Mma::FragmentC accumulators;
   204       accumulators.clear();
   208       mma(params.
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
   220       threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
   224         threadblock_tile_offset.
m() * Mma::Shape::kM,
   225         threadblock_tile_offset.
n() * Mma::Shape::kN
   229       typename Epilogue::OutputTileIterator iterator_C(
   237       iterator_C.add_pointer_offset(params.
stride_C * batch_idx);
   240       typename Epilogue::OutputTileIterator iterator_D(
   248       iterator_D.add_pointer_offset(params.
stride_D * batch_idx);
   257       epilogue(output_op, iterator_D, accumulators, iterator_C);
 CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM. 
Definition: kernel/gemm_batched.h:138
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE Params()
Definition: kernel/gemm_batched.h:85
typename Epilogue::OutputOp OutputOp
Definition: kernel/gemm_batched.h:53
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_batched.h:74
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const 
Obtains a Coord<2> from GemmCoord. 
Definition: include/cutlass/gemm/gemm.h:171
Defines common types used for all GEMM-like operators. 
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_batched.h:68
CUTLASS_HOST_DEVICE Index const & n() const 
Returns the GEMM N coordinate. 
Definition: include/cutlass/gemm/gemm.h:137
int gemm_k_iterations
Definition: kernel/gemm_batched.h:78
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: kernel/gemm_batched.h:71
CUTLASS_HOST_DEVICE GemmBatched()
Definition: kernel/gemm_batched.h:134
Epilogue_ Epilogue
Definition: kernel/gemm_batched.h:52
Shared memory storage structure. 
Definition: kernel/gemm_batched.h:124
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_batched.h:63
Mma::SharedStorage main_loop
Definition: kernel/gemm_batched.h:125
static int const kThreadCount
Definition: kernel/gemm_batched.h:58
Parameters structure. 
Definition: kernel/gemm_batched.h:61
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape) 
Definition: kernel/gemm_batched.h:57
Epilogue::OutputTileIterator::Params params_D
Definition: kernel/gemm_batched.h:73
Epilogue::OutputTileIterator::Params params_C
Definition: kernel/gemm_batched.h:70
OutputOp::Params epilogue
Definition: kernel/gemm_batched.h:76
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int64_t stride_C
Definition: kernel/gemm_batched.h:72
CUTLASS_HOST_DEVICE Coord< 2 > mk() const 
Obtains a Coord<2> from GemmCoord. 
Definition: include/cutlass/gemm/gemm.h:177
Mma_ Mma
Definition: kernel/gemm_batched.h:51
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_batched.h:62
Mma::IteratorA::Params params_A
Definition: kernel/gemm_batched.h:64
Defines a canonical coordinate for rank=2 matrices offering named indices. 
int batch_count
Definition: kernel/gemm_batched.h:77
Mma::IteratorB::Params params_B
Definition: kernel/gemm_batched.h:67
CUTLASS_HOST_DEVICE Coord< 2 > kn() const 
Obtains a Coord<2> from GemmCoord. 
Definition: include/cutlass/gemm/gemm.h:195
int64_t stride_B
Definition: kernel/gemm_batched.h:69
CUTLASS_HOST_DEVICE Index const & m() const 
Returns the GEMM M coordinate. 
Definition: include/cutlass/gemm/gemm.h:129
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_batched.h:65
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size_, cutlass::gemm::GemmCoord const &grid_tiled_shape_, typename Mma::IteratorA::TensorRef ref_A_, int64_t stride_A_, typename Mma::IteratorB::TensorRef ref_B_, int64_t stride_B_, typename Epilogue::OutputTileIterator::TensorRef ref_C_, int64_t stride_C_, typename Epilogue::OutputTileIterator::TensorRef ref_D_, int64_t stride_D_, typename OutputOp::Params epilogue_, int batch_count_)
Definition: kernel/gemm_batched.h:88
int64_t stride_A
Definition: kernel/gemm_batched.h:66
Definition: kernel/gemm_batched.h:49
Epilogue::SharedStorage epilogue
Definition: kernel/gemm_batched.h:126
int64_t stride_D
Definition: kernel/gemm_batched.h:75
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: kernel/gemm_batched.h:54
Basic include for CUTLASS. 
Definition: matrix_coord.h:39