52   typename ElementOutput_,                             
    54   typename ElementAccumulator_ = ElementOutput_,       
    55   typename ElementCompute_ = ElementOutput_,           
    96     ): alpha(alpha), beta(beta), alpha_ptr(
nullptr), beta_ptr(
nullptr) {
   104     ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
   154     ComputeFragment converted_accumulator = accumulator_converter(accumulator);
   166     intermediate = mul_add_source(beta_, converted_source);                             
   167     intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);    
   172     intermediate = max_accumulator(intermediate, -kClamp);
   173     intermediate = min_accumulator(intermediate, kClamp - 
ElementCompute(1));
   178     return destination_converter(intermediate);
   186 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)   194   typename ElementOutput_,                             
   205   static int const kCount = Count;
   294     ComputeFragment converted_accumulator = accumulator_converter(accumulator);
   303     intermediate = mul_add_source(beta_, converted_source);                             
   304     intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);    
   310     for (
int i = 0; i < 
kCount; ++i) {
   311       scaled_accumulator[i] = 
static_cast<int>(intermediate[i]);
   317     return destination_converter(scaled_accumulator);
   321 #endif // Conditional guards to enable partial specialization for packed integers Fused multiply-add. 
Definition: functional.h:92
ElementCompute_ ElementCompute
Definition: linear_combination_clamp.h:63
Definition: aligned_buffer.h:35
ElementCompute beta
scales source tensor 
Definition: linear_combination_clamp.h:77
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
Definition: linear_combination_clamp.h:101
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta)
Definition: linear_combination_clamp.h:93
Definition: linear_combination_clamp.h:58
Definition: functional.h:298
Definition: functional.h:235
static int const kCount
Definition: linear_combination_clamp.h:65
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const 
Computes linear scaling: D = alpha * accumulator + beta * source. 
Definition: linear_combination_clamp.h:144
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_clamp.h:86
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Boost-like numeric conversion operator for CUTLASS numeric types. 
Defines the size of an element in bits. 
Definition: numeric_types.h:42
CUTLASS_HOST_DEVICE LinearCombinationClamp(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory. 
Definition: linear_combination_clamp.h:122
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types. 
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_clamp.h:69
ElementOutput_ ElementOutput
Definition: linear_combination_clamp.h:61
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_clamp.h:67
ElementAccumulator_ ElementAccumulator
Definition: linear_combination_clamp.h:62
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory 
Definition: linear_combination_clamp.h:79
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue. 
Definition: linear_combination_clamp.h:136
FloatRoundStyle
Definition: numeric_conversion.h:43
Conversion operator for Array. 
Definition: numeric_conversion.h:294
Host-constructable parameters structure. 
Definition: linear_combination_clamp.h:74
static FloatRoundStyle const kRound
Definition: linear_combination_clamp.h:71
CUTLASS_HOST_DEVICE bool is_source_needed() const 
Returns true if source is needed. 
Definition: linear_combination_clamp.h:130
Basic include for CUTLASS. 
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory 
Definition: linear_combination_clamp.h:78
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
ElementCompute alpha
scales accumulators 
Definition: linear_combination_clamp.h:76
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_clamp.h:68