33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
41 template <
typename Scalar,
typename Index,
42 int Mode,
bool LhsIsTriangular,
43 int LhsStorageOrder,
bool ConjugateLhs,
44 int RhsStorageOrder,
bool ConjugateRhs,
46 struct product_triangular_matrix_matrix_trmm :
47 product_triangular_matrix_matrix<Scalar,Index,Mode,
48 LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
53 #define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
54 template <typename Index, int Mode, \
55 int LhsStorageOrder, bool ConjugateLhs, \
56 int RhsStorageOrder, bool ConjugateRhs> \
57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
58 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \
59 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
60 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
61 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
62 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
63 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
64 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
68 EIGEN_MKL_TRMM_SPECIALIZE(
double,
true)
69 EIGEN_MKL_TRMM_SPECIALIZE(
double, false)
70 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true)
71 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false)
72 EIGEN_MKL_TRMM_SPECIALIZE(
float, true)
73 EIGEN_MKL_TRMM_SPECIALIZE(
float, false)
74 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
75 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
78 #define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
79 template <typename Index, int Mode, \
80 int LhsStorageOrder, bool ConjugateLhs, \
81 int RhsStorageOrder, bool ConjugateRhs> \
82 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
83 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
86 IsLower = (Mode&Lower) == Lower, \
87 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
88 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
89 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
90 LowUp = IsLower ? Lower : Upper, \
91 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
95 Index _rows, Index _cols, Index _depth, \
96 const EIGTYPE* _lhs, Index lhsStride, \
97 const EIGTYPE* _rhs, Index rhsStride, \
98 EIGTYPE* res, Index resStride, \
99 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
101 Index diagSize = (std::min)(_rows,_depth); \
102 Index rows = IsLower ? _rows : diagSize; \
103 Index depth = IsLower ? diagSize : _depth; \
104 Index cols = _cols; \
106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
110 if (rows != depth) { \
112 int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
114 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
116 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
117 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
118 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
122 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
123 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
124 MKL_INT aStride = aa_tmp.outerStride(); \
125 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
126 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
127 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
133 char side = 'L', transa, uplo, diag = 'N'; \
136 MKL_INT m, n, lda, ldb; \
140 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
143 m = (MKL_INT)diagSize; \
147 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
150 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
151 MatrixX##EIGPREFIX b_tmp; \
153 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
155 ldb = b_tmp.outerStride(); \
158 uplo = IsLower ? 'L' : 'U'; \
159 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
161 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
164 if ((conjA!=0) || (SetDiag==0)) { \
165 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
167 a_tmp.diagonal().setZero(); \
168 else if (IsUnitDiag) \
169 a_tmp.diagonal().setOnes();\
171 lda = a_tmp.outerStride(); \
178 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
181 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
182 res_tmp=res_tmp+b_tmp; \
186 EIGEN_MKL_TRMM_L(
double,
double, d, d)
187 EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z)
188 EIGEN_MKL_TRMM_L(
float,
float, f, s)
189 EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c)
192 #define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
193 template <typename Index, int Mode, \
194 int LhsStorageOrder, bool ConjugateLhs, \
195 int RhsStorageOrder, bool ConjugateRhs> \
196 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
197 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
200 IsLower = (Mode&Lower) == Lower, \
201 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
202 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
203 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
204 LowUp = IsLower ? Lower : Upper, \
205 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
209 Index _rows, Index _cols, Index _depth, \
210 const EIGTYPE* _lhs, Index lhsStride, \
211 const EIGTYPE* _rhs, Index rhsStride, \
212 EIGTYPE* res, Index resStride, \
213 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
215 Index diagSize = (std::min)(_cols,_depth); \
216 Index rows = _rows; \
217 Index depth = IsLower ? _depth : diagSize; \
218 Index cols = IsLower ? diagSize : _cols; \
220 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
221 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
224 if (cols != depth) { \
226 int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
228 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
230 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
231 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
232 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
236 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
237 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
238 MKL_INT aStride = aa_tmp.outerStride(); \
239 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
240 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
241 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
247 char side = 'R', transa, uplo, diag = 'N'; \
250 MKL_INT m, n, lda, ldb; \
254 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
258 n = (MKL_INT)diagSize; \
261 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
264 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
265 MatrixX##EIGPREFIX b_tmp; \
267 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
269 ldb = b_tmp.outerStride(); \
272 uplo = IsLower ? 'L' : 'U'; \
273 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
275 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
278 if ((conjA!=0) || (SetDiag==0)) { \
279 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
281 a_tmp.diagonal().setZero(); \
282 else if (IsUnitDiag) \
283 a_tmp.diagonal().setOnes();\
285 lda = a_tmp.outerStride(); \
292 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
295 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
296 res_tmp=res_tmp+b_tmp; \
300 EIGEN_MKL_TRMM_R(
double,
double, d, d)
301 EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z)
302 EIGEN_MKL_TRMM_R(
float,
float, f, s)
303 EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c)
309 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H