10 #ifndef EIGEN_SOLVETRIANGULAR_H
11 #define EIGEN_SOLVETRIANGULAR_H
19 template<
typename LhsScalar,
typename RhsScalar,
typename Index,
int S
ide,
int Mode,
bool Conjugate,
int StorageOrder>
20 struct triangular_solve_vector;
22 template <
typename Scalar,
typename Index,
int S
ide,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherStorageOrder>
23 struct triangular_solve_matrix;
26 template<
typename Lhs,
typename Rhs,
int S
ide>
31 RhsIsVectorAtCompileTime = (Side==
OnTheLeft ? Rhs::ColsAtCompileTime : Rhs::RowsAtCompileTime)==1
35 Unrolling = (RhsIsVectorAtCompileTime && Rhs::SizeAtCompileTime !=
Dynamic && Rhs::SizeAtCompileTime <= 8)
36 ? CompleteUnrolling : NoUnrolling,
37 RhsVectors = RhsIsVectorAtCompileTime ? 1 :
Dynamic
41 template<
typename Lhs,
typename Rhs,
44 int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling,
45 int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors
47 struct triangular_solver_selector;
49 template<
typename Lhs,
typename Rhs,
int S
ide,
int Mode>
50 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
52 typedef typename Lhs::Scalar LhsScalar;
53 typedef typename Rhs::Scalar RhsScalar;
54 typedef blas_traits<Lhs> LhsProductTraits;
55 typedef typename LhsProductTraits::ExtractType ActualLhsType;
56 typedef Map<Matrix<RhsScalar,Dynamic,1>,
Aligned> MappedRhs;
57 static void run(
const Lhs& lhs, Rhs& rhs)
59 ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
63 bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1;
65 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhs,rhs.size(),
66 (useRhsDirectly ? rhs.data() : 0));
69 MappedRhs(actualRhs,rhs.size()) = rhs;
71 triangular_solve_vector<LhsScalar, RhsScalar,
typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate,
73 ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
76 rhs = MappedRhs(actualRhs, rhs.size());
81 template<
typename Lhs,
typename Rhs,
int S
ide,
int Mode>
82 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,
Dynamic>
84 typedef typename Rhs::Scalar Scalar;
85 typedef typename Rhs::Index Index;
86 typedef blas_traits<Lhs> LhsProductTraits;
87 typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
89 static void run(
const Lhs& lhs, Rhs& rhs)
91 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs);
93 const Index size = lhs.rows();
94 const Index othersize = Side==
OnTheLeft? rhs.cols() : rhs.rows();
97 Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType;
99 BlockingType blocking(rhs.rows(), rhs.cols(), size);
101 triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) &
RowMajorBit) ?
RowMajor : ColMajor,
103 ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking);
111 template<
typename Lhs,
typename Rhs,
int Mode,
int Index,
int Size,
112 bool Stop = Index==Size>
113 struct triangular_solver_unroller;
115 template<
typename Lhs,
typename Rhs,
int Mode,
int Index,
int Size>
116 struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
119 RowIndex = IsLower ? Index : Size - Index - 1,
120 S = IsLower ? 0 : RowIndex+1
122 static void run(
const Lhs& lhs, Rhs& rhs)
125 rhs.coeffRef(RowIndex) -= lhs.row(RowIndex).template segment<Index>(S).transpose()
126 .cwiseProduct(rhs.template segment<Index>(S)).sum();
129 rhs.coeffRef(RowIndex) /= lhs.coeff(RowIndex,RowIndex);
131 triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs);
135 template<
typename Lhs,
typename Rhs,
int Mode,
int Index,
int Size>
136 struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
137 static void run(
const Lhs&, Rhs&) {}
140 template<
typename Lhs,
typename Rhs,
int Mode>
141 struct triangular_solver_selector<Lhs,Rhs,
OnTheLeft,Mode,CompleteUnrolling,1> {
142 static void run(
const Lhs& lhs, Rhs& rhs)
143 { triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
146 template<
typename Lhs,
typename Rhs,
int Mode>
147 struct triangular_solver_selector<Lhs,Rhs,
OnTheRight,Mode,CompleteUnrolling,1> {
148 static void run(
const Lhs& lhs, Rhs& rhs)
150 Transpose<const Lhs> trLhs(lhs);
151 Transpose<Rhs> trRhs(rhs);
153 triangular_solver_unroller<Transpose<const Lhs>,Transpose<Rhs>,
155 0,Rhs::SizeAtCompileTime>::run(trLhs,trRhs);
172 template<
typename MatrixType,
unsigned int Mode>
173 template<
int S
ide,
typename OtherDerived>
176 OtherDerived& other = _other.const_cast_derived();
177 eigen_assert( cols() == rows() && ((Side==
OnTheLeft && cols() == other.rows()) || (Side==
OnTheRight && cols() == other.cols())) );
180 enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit && OtherDerived::IsVectorAtCompileTime };
181 typedef typename internal::conditional<copy,
183 OtherCopy otherCopy(other);
185 internal::triangular_solver_selector<MatrixType, typename internal::remove_reference<OtherCopy>::type,
186 Side, Mode>::run(nestedExpression(), otherCopy);
213 template<
typename Derived,
unsigned int Mode>
214 template<
int S
ide,
typename Other>
215 const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other>
218 return internal::triangular_solve_retval<Side,TriangularView,Other>(*
this, other.derived());
224 template<
int S
ide,
typename TriangularType,
typename Rhs>
225 struct traits<triangular_solve_retval<Side, TriangularType, Rhs> >
230 template<
int S
ide,
typename TriangularType,
typename Rhs>
struct triangular_solve_retval
231 :
public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> >
233 typedef typename remove_all<typename Rhs::Nested>::type RhsNestedCleaned;
234 typedef ReturnByValue<triangular_solve_retval> Base;
235 typedef typename Base::Index Index;
237 triangular_solve_retval(
const TriangularType& tri,
const Rhs& rhs)
238 : m_triangularMatrix(tri), m_rhs(rhs)
241 inline Index rows()
const {
return m_rhs.rows(); }
242 inline Index cols()
const {
return m_rhs.cols(); }
244 template<
typename Dest>
inline void evalTo(Dest& dst)
const
246 if(!(is_same<RhsNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_rhs)))
248 m_triangularMatrix.template solveInPlace<Side>(dst);
252 const TriangularType& m_triangularMatrix;
253 typename Rhs::Nested m_rhs;
260 #endif // EIGEN_SOLVETRIANGULAR_H
Definition: Constants.h:167
void solveInPlace(const MatrixBase< OtherDerived > &other) const
Definition: SolveTriangular.h:174
const int Dynamic
Definition: Constants.h:21
Definition: Constants.h:264
Definition: Constants.h:173
Definition: Constants.h:169
Definition: Constants.h:194
Definition: Constants.h:279
Definition: Constants.h:171
Base class for triangular part in a matrix.
Definition: TriangularMatrix.h:158
Definition: Constants.h:266
const unsigned int RowMajorBit
Definition: Constants.h:53
The matrix class, also used for vectors and row-vectors.
Definition: Matrix.h:127
Definition: Constants.h:277
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:48