10 #ifndef EIGEN_MATRIX_SQUARE_ROOT
11 #define EIGEN_MATRIX_SQUARE_ROOT
26 template <
typename MatrixType>
42 eigen_assert(A.rows() == A.cols());
53 template <
typename ResultType>
void compute(ResultType &result);
56 typedef typename MatrixType::Index Index;
57 typedef typename MatrixType::Scalar Scalar;
59 void computeDiagonalPartOfSqrt(MatrixType& sqrtT,
const MatrixType& T);
60 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
const MatrixType& T);
61 void compute2x2diagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
typename MatrixType::Index i);
62 void compute1x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
63 typename MatrixType::Index i,
typename MatrixType::Index j);
64 void compute1x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
65 typename MatrixType::Index i,
typename MatrixType::Index j);
66 void compute2x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
67 typename MatrixType::Index i,
typename MatrixType::Index j);
68 void compute2x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
69 typename MatrixType::Index i,
typename MatrixType::Index j);
71 template <
typename SmallMatrixType>
72 static void solveAuxiliaryEquation(SmallMatrixType& X,
const SmallMatrixType& A,
73 const SmallMatrixType& B,
const SmallMatrixType& C);
75 const MatrixType& m_A;
78 template <
typename MatrixType>
79 template <
typename ResultType>
82 result.resize(m_A.rows(), m_A.cols());
83 computeDiagonalPartOfSqrt(result, m_A);
84 computeOffDiagonalPartOfSqrt(result, m_A);
89 template <
typename MatrixType>
94 const Index size = m_A.rows();
95 for (Index i = 0; i < size; i++) {
96 if (i == size - 1 || T.coeff(i+1, i) == 0) {
97 eigen_assert(T(i,i) >= 0);
98 sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i));
101 compute2x2diagonalBlock(sqrtT, T, i);
109 template <
typename MatrixType>
110 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
113 const Index size = m_A.rows();
114 for (Index j = 1; j < size; j++) {
115 if (T.coeff(j, j-1) != 0)
117 for (Index i = j-1; i >= 0; i--) {
118 if (i > 0 && T.coeff(i, i-1) != 0)
120 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
121 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
122 if (iBlockIs2x2 && jBlockIs2x2)
123 compute2x2offDiagonalBlock(sqrtT, T, i, j);
124 else if (iBlockIs2x2 && !jBlockIs2x2)
125 compute2x1offDiagonalBlock(sqrtT, T, i, j);
126 else if (!iBlockIs2x2 && jBlockIs2x2)
127 compute1x2offDiagonalBlock(sqrtT, T, i, j);
128 else if (!iBlockIs2x2 && !jBlockIs2x2)
129 compute1x1offDiagonalBlock(sqrtT, T, i, j);
136 template <
typename MatrixType>
137 void MatrixSquareRootQuasiTriangular<MatrixType>
138 ::compute2x2diagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
typename MatrixType::Index i)
142 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
143 EigenSolver<Matrix<Scalar,2,2> > es(block);
144 sqrtT.template block<2,2>(i,i)
145 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
151 template <
typename MatrixType>
152 void MatrixSquareRootQuasiTriangular<MatrixType>
153 ::compute1x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
154 typename MatrixType::Index i,
typename MatrixType::Index j)
156 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
157 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
161 template <
typename MatrixType>
162 void MatrixSquareRootQuasiTriangular<MatrixType>
163 ::compute1x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
164 typename MatrixType::Index i,
typename MatrixType::Index j)
166 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
168 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
169 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
170 A += sqrtT.template block<2,2>(j,j).transpose();
171 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
175 template <
typename MatrixType>
176 void MatrixSquareRootQuasiTriangular<MatrixType>
177 ::compute2x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
178 typename MatrixType::Index i,
typename MatrixType::Index j)
180 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
182 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
183 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
184 A += sqrtT.template block<2,2>(i,i);
185 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
189 template <
typename MatrixType>
190 void MatrixSquareRootQuasiTriangular<MatrixType>
191 ::compute2x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
192 typename MatrixType::Index i,
typename MatrixType::Index j)
194 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
195 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
196 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
198 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
199 Matrix<Scalar,2,2> X;
200 solveAuxiliaryEquation(X, A, B, C);
201 sqrtT.template block<2,2>(i,j) = X;
205 template <
typename MatrixType>
206 template <
typename SmallMatrixType>
207 void MatrixSquareRootQuasiTriangular<MatrixType>
208 ::solveAuxiliaryEquation(SmallMatrixType& X,
const SmallMatrixType& A,
209 const SmallMatrixType& B,
const SmallMatrixType& C)
211 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
212 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
214 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
215 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
216 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
217 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
218 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
219 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
220 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
221 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
222 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
223 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
224 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
225 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
226 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
228 Matrix<Scalar,4,1> rhs;
229 rhs.coeffRef(0) = C.coeff(0,0);
230 rhs.coeffRef(1) = C.coeff(0,1);
231 rhs.coeffRef(2) = C.coeff(1,0);
232 rhs.coeffRef(3) = C.coeff(1,1);
234 Matrix<Scalar,4,1> result;
235 result = coeffMatrix.fullPivLu().solve(rhs);
237 X.coeffRef(0,0) = result.coeff(0);
238 X.coeffRef(0,1) = result.coeff(1);
239 X.coeffRef(1,0) = result.coeff(2);
240 X.coeffRef(1,1) = result.coeff(3);
255 template <
typename MatrixType>
262 eigen_assert(A.rows() == A.cols());
274 template <
typename ResultType>
void compute(ResultType &result);
277 const MatrixType& m_A;
280 template <
typename MatrixType>
281 template <
typename ResultType>
288 result.resize(m_A.rows(), m_A.cols());
289 typedef typename MatrixType::Index Index;
290 for (Index i = 0; i < m_A.rows(); i++) {
291 result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
293 for (Index j = 1; j < m_A.cols(); j++) {
294 for (Index i = j-1; i >= 0; i--) {
295 typedef typename MatrixType::Scalar Scalar;
297 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
299 result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
312 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
333 template <
typename ResultType>
void compute(ResultType &result);
339 template <
typename MatrixType>
347 eigen_assert(A.rows() == A.cols());
350 template <
typename ResultType>
void compute(ResultType &result)
353 const RealSchur<MatrixType> schurOfA(m_A);
354 const MatrixType& T = schurOfA.matrixT();
355 const MatrixType& U = schurOfA.matrixU();
358 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
359 MatrixSquareRootQuasiTriangular<MatrixType>(T).
compute(sqrtT);
362 result = U * sqrtT * U.adjoint();
366 const MatrixType& m_A;
372 template <
typename MatrixType>
373 class MatrixSquareRoot<MatrixType, 1>
380 eigen_assert(A.rows() == A.cols());
383 template <
typename ResultType>
void compute(ResultType &result)
386 const ComplexSchur<MatrixType> schurOfA(m_A);
387 const MatrixType& T = schurOfA.matrixT();
388 const MatrixType& U = schurOfA.matrixU();
392 MatrixSquareRootTriangular<MatrixType>(T).
compute(sqrtT);
395 result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
399 const MatrixType& m_A;
416 :
public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
418 typedef typename Derived::Index Index;
432 template <
typename ResultType>
433 inline void evalTo(ResultType& result)
const
435 const typename Derived::PlainObject srcEvaluated = m_src.eval();
440 Index rows()
const {
return m_src.rows(); }
441 Index cols()
const {
return m_src.cols(); }
444 const Derived& m_src;
450 template<
typename Derived>
451 struct traits<MatrixSquareRootReturnValue<Derived> >
453 typedef typename Derived::PlainObject ReturnType;
457 template <
typename Derived>
458 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt()
const
460 eigen_assert(rows() == cols());
461 return MatrixSquareRootReturnValue<Derived>(derived());
466 #endif // EIGEN_MATRIX_FUNCTION
void compute(ResultType &result)
Compute the matrix square root.
Definition: MatrixSquareRoot.h:80
MatrixSquareRoot(const MatrixType &A)
Constructor.
Proxy for the matrix square root of some matrix (expression).
Definition: MatrixSquareRoot.h:415
void compute(ResultType &result)
Compute the matrix square root.
void compute(ResultType &result)
Compute the matrix square root.
Definition: MatrixSquareRoot.h:282
Class for computing matrix square roots of upper quasi-triangular matrices.
Definition: MatrixSquareRoot.h:27
void evalTo(ResultType &result) const
Compute the matrix square root.
Definition: MatrixSquareRoot.h:433
Class for computing matrix square roots of upper triangular matrices.
Definition: MatrixSquareRoot.h:256
MatrixSquareRootQuasiTriangular(const MatrixType &A)
Constructor.
Definition: MatrixSquareRoot.h:39
Class for computing matrix square roots of general matrices.
Definition: MatrixSquareRoot.h:313
MatrixSquareRootReturnValue(const Derived &src)
Constructor.
Definition: MatrixSquareRoot.h:425