33 #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
34 #define EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
46 template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder>
47 struct triangular_matrix_vector_product_trmv :
48 triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
50 #define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \
51 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
52 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
53 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
54 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
55 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
56 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
59 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
60 struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
61 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
62 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
63 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
64 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
68 EIGEN_MKL_TRMV_SPECIALIZE(
double)
69 EIGEN_MKL_TRMV_SPECIALIZE(
float)
70 EIGEN_MKL_TRMV_SPECIALIZE(dcomplex)
71 EIGEN_MKL_TRMV_SPECIALIZE(scomplex)
74 #define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
75 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
76 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
78 IsLower = (Mode&Lower) == Lower, \
79 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
80 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
81 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
82 LowUp = IsLower ? Lower : Upper \
84 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
85 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
87 if (ConjLhs || IsZeroDiag) { \
88 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
89 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
92 Index size = (std::min)(_rows,_cols); \
93 Index rows = IsLower ? _rows : size; \
94 Index cols = IsLower ? size : _cols; \
96 typedef VectorX##EIGPREFIX VectorRhs; \
100 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
102 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
107 char trans, uplo, diag; \
108 MKL_INT m, n, lda, incx, incy; \
110 MKLTYPE alpha_, beta_; \
111 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
112 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
122 uplo = IsLower ? 'L' : 'U'; \
123 diag = IsUnitDiag ? 'U' : 'N'; \
126 MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
129 MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
131 if (size<(std::max)(rows,cols)) { \
132 typedef Matrix<EIGTYPE, Dynamic, Dynamic> MatrixLhs; \
133 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
136 y = _res + size*resIncr; \
144 a = _lhs + size*lda; \
148 MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
153 EIGEN_MKL_TRMV_CM(
double,
double, d, d)
154 EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z)
155 EIGEN_MKL_TRMV_CM(
float,
float, f, s)
156 EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c)
159 #define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
160 template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
161 struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
163 IsLower = (Mode&Lower) == Lower, \
164 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
165 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
166 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
167 LowUp = IsLower ? Lower : Upper \
169 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
170 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
173 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
174 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
177 Index size = (std::min)(_rows,_cols); \
178 Index rows = IsLower ? _rows : size; \
179 Index cols = IsLower ? size : _cols; \
181 typedef VectorX##EIGPREFIX VectorRhs; \
185 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
187 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
192 char trans, uplo, diag; \
193 MKL_INT m, n, lda, incx, incy; \
195 MKLTYPE alpha_, beta_; \
196 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
197 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
206 trans = ConjLhs ? 'C' : 'T'; \
207 uplo = IsLower ? 'U' : 'L'; \
208 diag = IsUnitDiag ? 'U' : 'N'; \
211 MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
214 MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
216 if (size<(std::max)(rows,cols)) { \
217 typedef Matrix<EIGTYPE, Dynamic, Dynamic> MatrixLhs; \
218 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
221 y = _res + size*resIncr; \
222 a = _lhs + size*lda; \
233 MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
238 EIGEN_MKL_TRMV_RM(
double,
double, d, d)
239 EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z)
240 EIGEN_MKL_TRMV_RM(
float,
float, f, s)
241 EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c)
247 #endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H