Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...
#include <gemm.hpp>
Static Public Member Functions | |
| template<typename eT > | |
| static arma_hot void | apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0)) |
Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).
Definition at line 27 of file gemm.hpp.
| static arma_hot void gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< eT > & | C, | |
| const Mat< eT > & | A, | |||
| const Mat< eT > & | B, | |||
| const eT | alpha = eT(1), |
|||
| const eT | beta = eT(0) | |||
| ) | [inline, static] |
Definition at line 37 of file gemm.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().
00044 { 00045 arma_extra_debug_sigprint(); 00046 00047 const u32 A_n_rows = A.n_rows; 00048 const u32 A_n_cols = A.n_cols; 00049 00050 const u32 B_n_rows = B.n_rows; 00051 const u32 B_n_cols = B.n_cols; 00052 00053 if( (do_trans_A == false) && (do_trans_B == false) ) 00054 { 00055 arma_aligned podarray<eT> tmp(A_n_cols); 00056 eT* A_rowdata = tmp.memptr(); 00057 00058 for(u32 row_A=0; row_A < A_n_rows; ++row_A) 00059 { 00060 00061 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00062 { 00063 A_rowdata[col_A] = A.at(row_A,col_A); 00064 } 00065 00066 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00067 { 00068 const eT* B_coldata = B.colptr(col_B); 00069 00070 eT acc = eT(0); 00071 for(u32 i=0; i < B_n_rows; ++i) 00072 { 00073 acc += A_rowdata[i] * B_coldata[i]; 00074 } 00075 00076 if( (use_alpha == false) && (use_beta == false) ) 00077 { 00078 C.at(row_A,col_B) = acc; 00079 } 00080 else 00081 if( (use_alpha == true) && (use_beta == false) ) 00082 { 00083 C.at(row_A,col_B) = alpha * acc; 00084 } 00085 else 00086 if( (use_alpha == false) && (use_beta == true) ) 00087 { 00088 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00089 } 00090 else 00091 if( (use_alpha == true) && (use_beta == true) ) 00092 { 00093 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00094 } 00095 00096 } 00097 } 00098 } 00099 else 00100 if( (do_trans_A == true) && (do_trans_B == false) ) 00101 { 00102 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00103 { 00104 // col_A is interpreted as row_A when storing the results in matrix C 00105 00106 const eT* A_coldata = A.colptr(col_A); 00107 00108 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00109 { 00110 const eT* B_coldata = B.colptr(col_B); 00111 00112 eT acc = eT(0); 00113 for(u32 i=0; i < B_n_rows; ++i) 00114 { 00115 acc += A_coldata[i] * B_coldata[i]; 00116 } 00117 00118 if( (use_alpha == false) && (use_beta == false) ) 00119 { 00120 C.at(col_A,col_B) = acc; 00121 } 00122 else 00123 if( (use_alpha == true) && (use_beta == false) ) 00124 { 00125 C.at(col_A,col_B) = alpha * acc; 00126 } 00127 else 00128 if( (use_alpha == false) && (use_beta == true) ) 00129 { 00130 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00131 } 00132 else 00133 if( (use_alpha == true) && (use_beta == true) ) 00134 { 00135 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00136 } 00137 00138 } 00139 } 00140 } 00141 else 00142 if( (do_trans_A == false) && (do_trans_B == true) ) 00143 { 00144 Mat<eT> B_tmp = trans(B); 00145 gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00146 } 00147 else 00148 if( (do_trans_A == true) && (do_trans_B == true) ) 00149 { 00150 // mat B_tmp = trans(B); 00151 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00152 00153 00154 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00155 // transpose operations are not needed 00156 00157 arma_aligned podarray<eT> tmp(B.n_cols); 00158 eT* B_rowdata = tmp.memptr(); 00159 00160 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00161 { 00162 00163 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00164 { 00165 B_rowdata[col_B] = B.at(row_B,col_B); 00166 } 00167 00168 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00169 { 00170 const eT* A_coldata = A.colptr(col_A); 00171 00172 eT acc = eT(0); 00173 for(u32 i=0; i < A_n_rows; ++i) 00174 { 00175 acc += B_rowdata[i] * A_coldata[i]; 00176 } 00177 00178 if( (use_alpha == false) && (use_beta == false) ) 00179 { 00180 C.at(col_A,row_B) = acc; 00181 } 00182 else 00183 if( (use_alpha == true) && (use_beta == false) ) 00184 { 00185 C.at(col_A,row_B) = alpha * acc; 00186 } 00187 else 00188 if( (use_alpha == false) && (use_beta == true) ) 00189 { 00190 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00191 } 00192 else 00193 if( (use_alpha == true) && (use_beta == true) ) 00194 { 00195 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00196 } 00197 00198 } 00199 } 00200 00201 } 00202 }