All Classes Namespaces Functions Variables Typedefs Enumerator Groups Pages
MatrixSquareRoot.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011 Jitse Niesen <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_MATRIX_SQUARE_ROOT
11 #define EIGEN_MATRIX_SQUARE_ROOT
12 
13 namespace Eigen {
14 
26 template <typename MatrixType>
28 {
29  public:
30 
39  MatrixSquareRootQuasiTriangular(const MatrixType& A)
40  : m_A(A)
41  {
42  eigen_assert(A.rows() == A.cols());
43  }
44 
53  template <typename ResultType> void compute(ResultType &result);
54 
55  private:
56  typedef typename MatrixType::Index Index;
57  typedef typename MatrixType::Scalar Scalar;
58 
59  void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
60  void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
61  void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
62  void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
63  typename MatrixType::Index i, typename MatrixType::Index j);
64  void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
65  typename MatrixType::Index i, typename MatrixType::Index j);
66  void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
67  typename MatrixType::Index i, typename MatrixType::Index j);
68  void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
69  typename MatrixType::Index i, typename MatrixType::Index j);
70 
71  template <typename SmallMatrixType>
72  static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
73  const SmallMatrixType& B, const SmallMatrixType& C);
74 
75  const MatrixType& m_A;
76 };
77 
78 template <typename MatrixType>
79 template <typename ResultType>
81 {
82  result.resize(m_A.rows(), m_A.cols());
83  computeDiagonalPartOfSqrt(result, m_A);
84  computeOffDiagonalPartOfSqrt(result, m_A);
85 }
86 
87 // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
88 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
89 template <typename MatrixType>
91  const MatrixType& T)
92 {
93  using std::sqrt;
94  const Index size = m_A.rows();
95  for (Index i = 0; i < size; i++) {
96  if (i == size - 1 || T.coeff(i+1, i) == 0) {
97  eigen_assert(T(i,i) >= 0);
98  sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i));
99  }
100  else {
101  compute2x2diagonalBlock(sqrtT, T, i);
102  ++i;
103  }
104  }
105 }
106 
107 // pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
108 // post: sqrtT is the square root of T.
109 template <typename MatrixType>
110 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
111  const MatrixType& T)
112 {
113  const Index size = m_A.rows();
114  for (Index j = 1; j < size; j++) {
115  if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block
116  continue;
117  for (Index i = j-1; i >= 0; i--) {
118  if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block
119  continue;
120  bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
121  bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
122  if (iBlockIs2x2 && jBlockIs2x2)
123  compute2x2offDiagonalBlock(sqrtT, T, i, j);
124  else if (iBlockIs2x2 && !jBlockIs2x2)
125  compute2x1offDiagonalBlock(sqrtT, T, i, j);
126  else if (!iBlockIs2x2 && jBlockIs2x2)
127  compute1x2offDiagonalBlock(sqrtT, T, i, j);
128  else if (!iBlockIs2x2 && !jBlockIs2x2)
129  compute1x1offDiagonalBlock(sqrtT, T, i, j);
130  }
131  }
132 }
133 
134 // pre: T.block(i,i,2,2) has complex conjugate eigenvalues
135 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
136 template <typename MatrixType>
137 void MatrixSquareRootQuasiTriangular<MatrixType>
138  ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
139 {
140  // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
141  // in EigenSolver. If we expose it, we could call it directly from here.
142  Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
143  EigenSolver<Matrix<Scalar,2,2> > es(block);
144  sqrtT.template block<2,2>(i,i)
145  = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
146 }
147 
148 // pre: block structure of T is such that (i,j) is a 1x1 block,
149 // all blocks of sqrtT to left of and below (i,j) are correct
150 // post: sqrtT(i,j) has the correct value
151 template <typename MatrixType>
152 void MatrixSquareRootQuasiTriangular<MatrixType>
153  ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
154  typename MatrixType::Index i, typename MatrixType::Index j)
155 {
156  Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
157  sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
158 }
159 
160 // similar to compute1x1offDiagonalBlock()
161 template <typename MatrixType>
162 void MatrixSquareRootQuasiTriangular<MatrixType>
163  ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
164  typename MatrixType::Index i, typename MatrixType::Index j)
165 {
166  Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
167  if (j-i > 1)
168  rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
169  Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
170  A += sqrtT.template block<2,2>(j,j).transpose();
171  sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
172 }
173 
174 // similar to compute1x1offDiagonalBlock()
175 template <typename MatrixType>
176 void MatrixSquareRootQuasiTriangular<MatrixType>
177  ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
178  typename MatrixType::Index i, typename MatrixType::Index j)
179 {
180  Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
181  if (j-i > 2)
182  rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
183  Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
184  A += sqrtT.template block<2,2>(i,i);
185  sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
186 }
187 
188 // similar to compute1x1offDiagonalBlock()
189 template <typename MatrixType>
190 void MatrixSquareRootQuasiTriangular<MatrixType>
191  ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
192  typename MatrixType::Index i, typename MatrixType::Index j)
193 {
194  Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
195  Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
196  Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
197  if (j-i > 2)
198  C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
199  Matrix<Scalar,2,2> X;
200  solveAuxiliaryEquation(X, A, B, C);
201  sqrtT.template block<2,2>(i,j) = X;
202 }
203 
204 // solves the equation A X + X B = C where all matrices are 2-by-2
205 template <typename MatrixType>
206 template <typename SmallMatrixType>
207 void MatrixSquareRootQuasiTriangular<MatrixType>
208  ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
209  const SmallMatrixType& B, const SmallMatrixType& C)
210 {
211  EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
212  EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
213 
214  Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
215  coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
216  coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
217  coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
218  coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
219  coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
220  coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
221  coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
222  coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
223  coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
224  coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
225  coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
226  coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
227 
228  Matrix<Scalar,4,1> rhs;
229  rhs.coeffRef(0) = C.coeff(0,0);
230  rhs.coeffRef(1) = C.coeff(0,1);
231  rhs.coeffRef(2) = C.coeff(1,0);
232  rhs.coeffRef(3) = C.coeff(1,1);
233 
234  Matrix<Scalar,4,1> result;
235  result = coeffMatrix.fullPivLu().solve(rhs);
236 
237  X.coeffRef(0,0) = result.coeff(0);
238  X.coeffRef(0,1) = result.coeff(1);
239  X.coeffRef(1,0) = result.coeff(2);
240  X.coeffRef(1,1) = result.coeff(3);
241 }
242 
243 
255 template <typename MatrixType>
257 {
258  public:
259  MatrixSquareRootTriangular(const MatrixType& A)
260  : m_A(A)
261  {
262  eigen_assert(A.rows() == A.cols());
263  }
264 
274  template <typename ResultType> void compute(ResultType &result);
275 
276  private:
277  const MatrixType& m_A;
278 };
279 
280 template <typename MatrixType>
281 template <typename ResultType>
283 {
284  using std::sqrt;
285 
286  // Compute square root of m_A and store it in upper triangular part of result
287  // This uses that the square root of triangular matrices can be computed directly.
288  result.resize(m_A.rows(), m_A.cols());
289  typedef typename MatrixType::Index Index;
290  for (Index i = 0; i < m_A.rows(); i++) {
291  result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
292  }
293  for (Index j = 1; j < m_A.cols(); j++) {
294  for (Index i = j-1; i >= 0; i--) {
295  typedef typename MatrixType::Scalar Scalar;
296  // if i = j-1, then segment has length 0 so tmp = 0
297  Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
298  // denominator may be zero if original matrix is singular
299  result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
300  }
301  }
302 }
303 
304 
312 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
314 {
315  public:
316 
324  MatrixSquareRoot(const MatrixType& A);
325 
333  template <typename ResultType> void compute(ResultType &result);
334 };
335 
336 
337 // ********** Partial specialization for real matrices **********
338 
339 template <typename MatrixType>
340 class MatrixSquareRoot<MatrixType, 0>
341 {
342  public:
343 
344  MatrixSquareRoot(const MatrixType& A)
345  : m_A(A)
346  {
347  eigen_assert(A.rows() == A.cols());
348  }
349 
350  template <typename ResultType> void compute(ResultType &result)
351  {
352  // Compute Schur decomposition of m_A
353  const RealSchur<MatrixType> schurOfA(m_A);
354  const MatrixType& T = schurOfA.matrixT();
355  const MatrixType& U = schurOfA.matrixU();
356 
357  // Compute square root of T
358  MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
359  MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT);
360 
361  // Compute square root of m_A
362  result = U * sqrtT * U.adjoint();
363  }
364 
365  private:
366  const MatrixType& m_A;
367 };
368 
369 
370 // ********** Partial specialization for complex matrices **********
371 
372 template <typename MatrixType>
373 class MatrixSquareRoot<MatrixType, 1>
374 {
375  public:
376 
377  MatrixSquareRoot(const MatrixType& A)
378  : m_A(A)
379  {
380  eigen_assert(A.rows() == A.cols());
381  }
382 
383  template <typename ResultType> void compute(ResultType &result)
384  {
385  // Compute Schur decomposition of m_A
386  const ComplexSchur<MatrixType> schurOfA(m_A);
387  const MatrixType& T = schurOfA.matrixT();
388  const MatrixType& U = schurOfA.matrixU();
389 
390  // Compute square root of T
391  MatrixType sqrtT;
392  MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
393 
394  // Compute square root of m_A
395  result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
396  }
397 
398  private:
399  const MatrixType& m_A;
400 };
401 
402 
415 template<typename Derived> class MatrixSquareRootReturnValue
416 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
417 {
418  typedef typename Derived::Index Index;
419  public:
425  MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
426 
432  template <typename ResultType>
433  inline void evalTo(ResultType& result) const
434  {
435  const typename Derived::PlainObject srcEvaluated = m_src.eval();
437  me.compute(result);
438  }
439 
440  Index rows() const { return m_src.rows(); }
441  Index cols() const { return m_src.cols(); }
442 
443  protected:
444  const Derived& m_src;
445  private:
447 };
448 
449 namespace internal {
450 template<typename Derived>
451 struct traits<MatrixSquareRootReturnValue<Derived> >
452 {
453  typedef typename Derived::PlainObject ReturnType;
454 };
455 }
456 
457 template <typename Derived>
458 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
459 {
460  eigen_assert(rows() == cols());
461  return MatrixSquareRootReturnValue<Derived>(derived());
462 }
463 
464 } // end namespace Eigen
465 
466 #endif // EIGEN_MATRIX_FUNCTION
void compute(ResultType &result)
Compute the matrix square root.
Definition: MatrixSquareRoot.h:80
MatrixSquareRoot(const MatrixType &A)
Constructor.
Proxy for the matrix square root of some matrix (expression).
Definition: MatrixSquareRoot.h:415
void compute(ResultType &result)
Compute the matrix square root.
void compute(ResultType &result)
Compute the matrix square root.
Definition: MatrixSquareRoot.h:282
Class for computing matrix square roots of upper quasi-triangular matrices.
Definition: MatrixSquareRoot.h:27
void evalTo(ResultType &result) const
Compute the matrix square root.
Definition: MatrixSquareRoot.h:433
Class for computing matrix square roots of upper triangular matrices.
Definition: MatrixSquareRoot.h:256
MatrixSquareRootQuasiTriangular(const MatrixType &A)
Constructor.
Definition: MatrixSquareRoot.h:39
Class for computing matrix square roots of general matrices.
Definition: MatrixSquareRoot.h:313
MatrixSquareRootReturnValue(const Derived &src)
Constructor.
Definition: MatrixSquareRoot.h:425