53 namespace threadblock {
    68   using Shape = 
typename ThreadMap::Shape;
    81   static int const kThreads = ThreadMap::kThreads;
    84   static_assert( ThreadMap::Iterations::kRow > 0,
"ThreadMap::Iterations::kRow must be > 0");
    85   static_assert( ThreadMap::Iterations::kGroup > 0,
"ThreadMap::Iterations::kGroup must be > 0");
    86   static_assert( ThreadMap::Iterations::kCluster > 0,
"ThreadMap::Iterations::kCluster must be > 0");
    87   static_assert( ThreadMap::Iterations::kColumn > 0,
"ThreadMap::Iterations::kColumn must be > 0");
    92     ThreadMap::Iterations::kColumn * 
    93     ThreadMap::Iterations::kRow * 
    94     ThreadMap::Iterations::kGroup * 
    95     ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
   130       increment_row = stride * ThreadMap::Delta::kRow;
   132       increment_group = stride * ThreadMap::Delta::kGroup
   133         - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
   135       increment_cluster = stride * ThreadMap::Delta::kCluster
   136         - stride * ThreadMap::Delta::kGroup * (ThreadMap::Iterations::kGroup - 1)
   137         - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
   139       advance_row = stride * ThreadMap::Shape::kRow;
   141       advance_group = stride * (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
   145         ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;;
   149         ThreadMap::Shape::kGroup * 
   150         ThreadMap::Shape::kRow * 
   151         ThreadMap::Shape::kCluster * 
   152         ThreadMap::Shape::kTile;
   172     static int const kCount = ThreadMap::Iterations::kColumn;
   175     bool predicates[kCount];
   188       for (
int i = 0; i < kCount; ++i) {
   189         predicates[i] = 
false;
   196       for (
int i = 0; i < kCount; ++i) {
   197         predicates[i] = 
true;
   212   uint8_t *byte_pointer_;
   221   Index thread_start_row_;
   249     TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
   251     extent_row_ = extent.
row();
   252     thread_start_row_ = thread_offset.
row();
   256     for (
int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
   259         + ThreadMap::Delta::kColumn * c) < extent.
column());
   263     byte_pointer_ = 
reinterpret_cast<uint8_t *
>(pointer) + 
   268     state_[0] = state_[1] = state_[2] = 0;
   281     uint8_t *byte_pointer = byte_pointer_;
   285     for (
int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
   288       for (
int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
   291         for (
int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
   294             (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
   296           int row_offset = row * ThreadMap::Delta::kRow 
   297             + group * ThreadMap::Delta::kGroup 
   298             + cluster * ThreadMap::Delta::kCluster;
   300           bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
   305           for (
int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
   307             bool guard = row_guard && mask_.
predicates[column];
   310               frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = 
   315           if (row + 1 < ThreadMap::Iterations::kRow) {
   320         if (group + 1 < ThreadMap::Iterations::kGroup) {
   325       if (cluster + 1 < ThreadMap::Iterations::kCluster) {
   334     uint8_t *byte_pointer = byte_pointer_;
   338     for (
int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
   341       for (
int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
   344         for (
int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
   347             (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
   349           int row_offset = row * ThreadMap::Delta::kRow 
   350             + group * ThreadMap::Delta::kGroup 
   351             + cluster * ThreadMap::Delta::kCluster;
   353           bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
   358           for (
int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
   360             bool guard = row_guard && mask_.
predicates[column];
   365                 frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
   369           if (row + 1 < ThreadMap::Iterations::kRow) {
   374         if (group + 1 < ThreadMap::Iterations::kGroup) {
   379       if (cluster + 1 < ThreadMap::Iterations::kCluster) {
   391     thread_start_row_ += ThreadMap::Shape::kRow;
   393     if (state_[0] == ThreadMap::Count::kRow) {
   399       thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * 
   400         ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
   402       if (state_[1] == ThreadMap::Count::kGroup) {
   408         thread_start_row_ += ThreadMap::Count::kGroup * 
   409           ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
   411         if (state_[2] == ThreadMap::Count::kCluster) {
   456   using Element = Element_;
   466   static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
   467   static int const kThreads = ThreadMap::kThreads;
   468   static int const kIterations = ThreadMap::Iterations::kCount;
   471   using Fragment = Array<Element, ThreadMap::kElementsPerAccess>;
   503           stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess *
   523     static int const kCount = (ThreadMap::Iterations::kContiguous < 8)
   525                                   : ThreadMap::Iterations::kContiguous;
   528     bool predicates[kCount];
   541       for (
int i = 0; i < kCount; ++i) {
   542         predicates[i] = 
false;
   549       for (
int i = 0; i < kCount; ++i) {
   550         predicates[i] = 
true;
   565   uint8_t *byte_pointer_;
   575   Index thread_start_col_;
   578   int iteration_contiguous_;
   580   int iteration_strided_;
   604     TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) +
   606                                  threadblock_offset.
strided() / InterleavedK);
   608     extent_col_ = extent.
strided() / InterleavedK;
   609     thread_start_col_ = thread_offset.
strided();
   613     for (
int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
   615           ((thread_offset.
contiguous() + ThreadMap::Delta::kContiguous * c) <
   620     byte_pointer_ = 
reinterpret_cast<uint8_t *
>(pointer) + 
   625     iteration_contiguous_ = iteration_strided_ = 0;
   637     uint8_t *byte_pointer = byte_pointer_;
   641     int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;
   643     bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
   645     bool guard = col_guard && mask_.
predicates[iteration_contiguous_];
   648       *frag_ptr = *memory_pointer;
   655     uint8_t *byte_pointer = byte_pointer_;
   659     int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;
   661     bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
   663     bool guard = col_guard && mask_.
predicates[iteration_contiguous_];
   666       *memory_pointer = *frag_ptr;
   673     iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous;
   674     iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous;
   681     ++iteration_contiguous_;
   684     if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) {
   686       iteration_contiguous_ = 0;
   687       ++iteration_strided_;
   690       if (iteration_strided_ == ThreadMap::Iterations::kStrided) {
   691         iteration_strided_ = 0;
 bool predicates[kCount]
Predicate state. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:175
int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:62
static int const kElementsPerAccess
Definition: epilogue/threadblock/predicated_tile_iterator.h:80
CUTLASS_DEVICE void enable()
Definition: epilogue/threadblock/predicated_tile_iterator.h:194
CUTLASS_HOST_DEVICE Index const & column() const 
Returns the column of the coordinate. 
Definition: matrix_coord.h:85
Index advance_row
amount to add to move to the next 'row' position 
Definition: epilogue/threadblock/predicated_tile_iterator.h:116
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:279
Element_ Element
Definition: epilogue/threadblock/predicated_tile_iterator.h:70
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space. 
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data. 
AlignedArray< Element, ThreadMap::kElementsPerAccess > AccessType
Memory access size. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:98
CUTLASS_HOST_DEVICE Status initialize(Index stride_)
Definition: epilogue/threadblock/predicated_tile_iterator.h:496
CUTLASS_HOST_DEVICE void clear()
CUTLASS_HOST_DEVICE enables all accesses guarded by mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:539
Templates implementing how threads are mapped to a given tile. 
CUTLASS_DEVICE void get_mask(Mask &mask)
Sets the mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:432
Array< Element, ThreadMap::kElementsPerAccess > Fragment
Fragment object. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:471
ThreadMap_ ThreadMap
Definition: epilogue/threadblock/predicated_tile_iterator.h:454
Aligned array type. 
Definition: array.h:511
Mask object. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:170
CUTLASS_HOST_DEVICE Index const & row() const 
Returns the row of the coordinate. 
Definition: matrix_coord.h:77
bool predicates[kCount]
Predicate state. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:528
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: epilogue/threadblock/predicated_tile_iterator.h:74
CUTLASS_HOST_DEVICE Mask()
Efficiently disables all accesses guarded by mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:534
CUTLASS_HOST_DEVICE Stride stride() const 
Returns the stride of the layout. 
Definition: layout/matrix.h:418
CUTLASS_HOST_DEVICE Stride stride() const 
Returns the stride of the layout. 
Definition: layout/matrix.h:112
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:95
Definition: epilogue/threadblock/predicated_tile_iterator.h:480
CUTLASS_DEVICE void store(Fragment const &frag)
Stores a fragment to memory. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:333
typename Layout::LongIndex LongIndex
Definition: epilogue/threadblock/predicated_tile_iterator.h:77
TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef
TensorRef to constant data. 
Definition: tensor_ref.h:179
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:636
CUTLASS_DEVICE InterleavedPredicatedTileIterator(Params const ¶ms, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset)
Constructor. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:596
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
typename Layout::Index Index
Definition: epilogue/threadblock/predicated_tile_iterator.h:462
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:59
Index advance_cluster
amount to add to move to the next 'cluster' position 
Definition: epilogue/threadblock/predicated_tile_iterator.h:118
Defines a Shape template for matrix tiles. 
CUTLASS_DEVICE void store(Fragment const &frag)
Stores a fragment to memory. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:654
Defines the size of an element in bits. 
Definition: numeric_types.h:42
Index advance_row
amount to add to move to the next 'row' position 
Definition: epilogue/threadblock/predicated_tile_iterator.h:488
CUTLASS_DEVICE void clear_mask()
Efficiently enables all accesses guarded by mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:699
ThreadMap_ ThreadMap
Definition: epilogue/threadblock/predicated_tile_iterator.h:67
Mask object. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:522
CUTLASS_HOST_DEVICE Params()
Definition: epilogue/threadblock/predicated_tile_iterator.h:510
CUTLASS_DEVICE PredicatedTileIterator(Params const ¶ms, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset=TensorCoord())
Constructor. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:240
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types. 
CUTLASS_HOST_DEVICE Index const & contiguous() const 
Returns the contiguous dimension. 
Definition: pitch_linear.h:89
Definition: epilogue/threadblock/predicated_tile_iterator.h:452
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: epilogue/threadblock/predicated_tile_iterator.h:460
CUTLASS_HOST_DEVICE void set_iteration_index(int iteration)
Overrides the internal iteration index. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:672
Index stride
stride in bytes between rows 
Definition: epilogue/threadblock/predicated_tile_iterator.h:110
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Advances to the next position to load or store. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:387
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Definition: epilogue/threadblock/predicated_tile_iterator.h:163
Index stride
stride in bytes between columns 
Definition: epilogue/threadblock/predicated_tile_iterator.h:486
Index advance_column
amount to add to move to the next 'column' position 
Definition: epilogue/threadblock/predicated_tile_iterator.h:489
Definition: epilogue/threadblock/predicated_tile_iterator.h:104
static int const kIterations
Definition: epilogue/threadblock/predicated_tile_iterator.h:82
Index advance_tile
amount to add to move to the next 'tile' 
Definition: epilogue/threadblock/predicated_tile_iterator.h:119
Metaprogram for determining the mapping of output elements to threads for epilogue tiles...
CUTLASS_HOST_DEVICE InterleavedPredicatedTileIterator & operator++()
Advances to the next position to load or store. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:679
CUTLASS_DEVICE void clear_mask()
Efficiently enables all accesses guarded by mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:422
Index increment_group
increment quantity (in bytes) to advance when moving to the next group 
Definition: epilogue/threadblock/predicated_tile_iterator.h:113
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
typename Layout::Index Index
Definition: epilogue/threadblock/predicated_tile_iterator.h:76
Definition: epilogue/threadblock/predicated_tile_iterator.h:65
typename Layout::LongIndex LongIndex
Definition: epilogue/threadblock/predicated_tile_iterator.h:463
CUTLASS_DEVICE void set_mask(Mask const &mask)
Definition: epilogue/threadblock/predicated_tile_iterator.h:714
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Definition: epilogue/threadblock/predicated_tile_iterator.h:515
CUTLASS_DEVICE void set_mask(Mask const &mask)
Definition: epilogue/threadblock/predicated_tile_iterator.h:437
CUTLASS_DEVICE void enable()
Definition: epilogue/threadblock/predicated_tile_iterator.h:547
Index advance_group
amount to add to move to the next 'group' position 
Definition: epilogue/threadblock/predicated_tile_iterator.h:117
Defines layout functions used by TensorRef and derived classes. 
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:630
CUTLASS_HOST_DEVICE Status initialize(Index stride_)
Definition: epilogue/threadblock/predicated_tile_iterator.h:126
Operation was successful. 
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:273
typename ThreadMap::Shape Shape
Definition: epilogue/threadblock/predicated_tile_iterator.h:68
Definition: layout/matrix.h:343
MatrixCoord TensorCoord
Definition: epilogue/threadblock/predicated_tile_iterator.h:78
CUTLASS_HOST_DEVICE Params()
Definition: epilogue/threadblock/predicated_tile_iterator.h:158
CUTLASS_DEVICE void enable_mask()
Sets the mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:427
Index increment_row
increment quantity (in bytes) to advance when moving between rows 
Definition: epilogue/threadblock/predicated_tile_iterator.h:112
Index increment_cluster
increment quantity (in bytes) to advance when moving to the next cluster 
Definition: epilogue/threadblock/predicated_tile_iterator.h:114
CUTLASS_DEVICE void get_mask(Mask &mask)
Sets the mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:709
Basic include for CUTLASS. 
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE void clear()
CUTLASS_HOST_DEVICE enables all accesses guarded by mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:186
CUTLASS_HOST_DEVICE Index const & strided() const 
Returns the column of the coordinate. 
Definition: pitch_linear.h:97
Status
Status code returned by CUTLASS operations. 
Definition: cutlass.h:39
CUTLASS_DEVICE void enable_mask()
Sets the mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:704
CUTLASS_HOST_DEVICE Mask()
Efficiently disables all accesses guarded by mask. 
Definition: epilogue/threadblock/predicated_tile_iterator.h:181
static int const kThreads
Definition: epilogue/threadblock/predicated_tile_iterator.h:81