11 #ifndef EIGEN_GENERAL_PRODUCT_H
12 #define EIGEN_GENERAL_PRODUCT_H
35 template<typename Lhs, typename Rhs, int ProductType = internal::product_type<Lhs,Rhs>::value>
45 template<
int Rows,
int Cols,
int Depth>
struct product_type_selector;
47 template<
int Size,
int MaxSize>
struct product_size_category
49 enum { is_large = MaxSize ==
Dynamic ||
50 Size >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD,
51 value = is_large ? Large
57 template<
typename Lhs,
typename Rhs>
struct product_type
59 typedef typename remove_all<Lhs>::type _Lhs;
60 typedef typename remove_all<Rhs>::type _Rhs;
62 MaxRows = _Lhs::MaxRowsAtCompileTime,
63 Rows = _Lhs::RowsAtCompileTime,
64 MaxCols = _Rhs::MaxColsAtCompileTime,
65 Cols = _Rhs::ColsAtCompileTime,
66 MaxDepth = EIGEN_SIZE_MIN_PREFER_FIXED(_Lhs::MaxColsAtCompileTime,
67 _Rhs::MaxRowsAtCompileTime),
68 Depth = EIGEN_SIZE_MIN_PREFER_FIXED(_Lhs::ColsAtCompileTime,
69 _Rhs::RowsAtCompileTime),
70 LargeThreshold = EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
77 rows_select = product_size_category<Rows,MaxRows>::value,
78 cols_select = product_size_category<Cols,MaxCols>::value,
79 depth_select = product_size_category<Depth,MaxDepth>::value
81 typedef product_type_selector<rows_select, cols_select, depth_select> selector;
87 #ifdef EIGEN_DEBUG_PRODUCT
90 EIGEN_DEBUG_VAR(Rows);
91 EIGEN_DEBUG_VAR(Cols);
92 EIGEN_DEBUG_VAR(Depth);
93 EIGEN_DEBUG_VAR(rows_select);
94 EIGEN_DEBUG_VAR(cols_select);
95 EIGEN_DEBUG_VAR(depth_select);
96 EIGEN_DEBUG_VAR(value);
106 template<
int M,
int N>
struct product_type_selector<M,N,1> {
enum { ret = OuterProduct }; };
107 template<
int Depth>
struct product_type_selector<1, 1, Depth> {
enum { ret = InnerProduct }; };
108 template<>
struct product_type_selector<1, 1, 1> {
enum { ret = InnerProduct }; };
109 template<>
struct product_type_selector<Small,1, Small> {
enum { ret = CoeffBasedProductMode }; };
110 template<>
struct product_type_selector<1, Small,Small> {
enum { ret = CoeffBasedProductMode }; };
111 template<>
struct product_type_selector<Small,Small,Small> {
enum { ret = CoeffBasedProductMode }; };
112 template<>
struct product_type_selector<Small, Small, 1> {
enum { ret = LazyCoeffBasedProductMode }; };
113 template<>
struct product_type_selector<Small, Large, 1> {
enum { ret = LazyCoeffBasedProductMode }; };
114 template<>
struct product_type_selector<Large, Small, 1> {
enum { ret = LazyCoeffBasedProductMode }; };
115 template<>
struct product_type_selector<1, Large,Small> {
enum { ret = CoeffBasedProductMode }; };
116 template<>
struct product_type_selector<1, Large,Large> {
enum { ret = GemvProduct }; };
117 template<>
struct product_type_selector<1, Small,Large> {
enum { ret = CoeffBasedProductMode }; };
118 template<>
struct product_type_selector<Large,1, Small> {
enum { ret = CoeffBasedProductMode }; };
119 template<>
struct product_type_selector<Large,1, Large> {
enum { ret = GemvProduct }; };
120 template<>
struct product_type_selector<Small,1, Large> {
enum { ret = CoeffBasedProductMode }; };
121 template<>
struct product_type_selector<Small,Small,Large> {
enum { ret = GemmProduct }; };
122 template<>
struct product_type_selector<Large,Small,Large> {
enum { ret = GemmProduct }; };
123 template<>
struct product_type_selector<Small,Large,Large> {
enum { ret = GemmProduct }; };
124 template<>
struct product_type_selector<Large,Large,Large> {
enum { ret = GemmProduct }; };
125 template<>
struct product_type_selector<Large,Small,Small> {
enum { ret = GemmProduct }; };
126 template<>
struct product_type_selector<Small,Large,Small> {
enum { ret = GemmProduct }; };
127 template<>
struct product_type_selector<Large,Large,Small> {
enum { ret = GemmProduct }; };
148 template<
typename Lhs,
typename Rhs,
int ProductType>
158 template<
typename Lhs,
typename Rhs>
161 typedef typename internal::nested<Lhs, Rhs::ColsAtCompileTime, typename internal::plain_matrix_type<Lhs>::type >::type LhsNested;
162 typedef typename internal::nested<Rhs, Lhs::RowsAtCompileTime, typename internal::plain_matrix_type<Rhs>::type >::type RhsNested;
163 typedef CoeffBasedProduct<LhsNested, RhsNested, EvalBeforeAssigningBit | EvalBeforeNestingBit> Type;
166 template<
typename Lhs,
typename Rhs>
167 struct ProductReturnType<Lhs,Rhs,LazyCoeffBasedProductMode>
169 typedef typename internal::nested<Lhs, Rhs::ColsAtCompileTime, typename internal::plain_matrix_type<Lhs>::type >::type LhsNested;
170 typedef typename internal::nested<Rhs, Lhs::RowsAtCompileTime, typename internal::plain_matrix_type<Rhs>::type >::type RhsNested;
171 typedef CoeffBasedProduct<LhsNested, RhsNested, NestByRefBit> Type;
175 template<
typename Lhs,
typename Rhs>
176 struct LazyProductReturnType :
public ProductReturnType<Lhs,Rhs,LazyCoeffBasedProductMode>
192 template<
typename Lhs,
typename Rhs>
193 struct traits<GeneralProduct<Lhs,Rhs,InnerProduct> >
194 : traits<Matrix<typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType,1,1> >
199 template<
typename Lhs,
typename Rhs>
200 class GeneralProduct<Lhs, Rhs, InnerProduct>
201 : internal::no_assignment_operator,
202 public Matrix<typename internal::scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType,1,1>
204 typedef Matrix<typename internal::scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType,1,1> Base;
206 GeneralProduct(
const Lhs& lhs,
const Rhs& rhs)
208 EIGEN_STATIC_ASSERT((internal::is_same<typename Lhs::RealScalar, typename Rhs::RealScalar>::value),
209 YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
211 Base::coeffRef(0,0) = (lhs.transpose().cwiseProduct(rhs)).sum();
215 operator const typename Base::Scalar()
const {
216 return Base::coeff(0,0);
227 template<
typename ProductType,
typename Dest,
typename Func>
228 EIGEN_DONT_INLINE
void outer_product_selector_run(
const ProductType& prod, Dest& dest,
const Func& func,
const false_type&)
230 typedef typename Dest::Index Index;
233 const Index cols = dest.cols();
234 for (Index j=0; j<cols; ++j)
235 func(dest.col(j), prod.rhs().coeff(0,j) * prod.lhs());
239 template<
typename ProductType,
typename Dest,
typename Func>
240 EIGEN_DONT_INLINE
void outer_product_selector_run(
const ProductType& prod, Dest& dest,
const Func& func,
const true_type&) {
241 typedef typename Dest::Index Index;
244 const Index rows = dest.rows();
245 for (Index i=0; i<rows; ++i)
246 func(dest.row(i), prod.lhs().coeff(i,0) * prod.rhs());
249 template<
typename Lhs,
typename Rhs>
250 struct traits<GeneralProduct<Lhs,Rhs,OuterProduct> >
251 : traits<ProductBase<GeneralProduct<Lhs,Rhs,OuterProduct>, Lhs, Rhs> >
256 template<
typename Lhs,
typename Rhs>
257 class GeneralProduct<Lhs, Rhs, OuterProduct>
258 :
public ProductBase<GeneralProduct<Lhs,Rhs,OuterProduct>, Lhs, Rhs>
260 template<
typename T>
struct is_row_major : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
263 EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
265 GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
267 EIGEN_STATIC_ASSERT((internal::is_same<typename Lhs::RealScalar, typename Rhs::RealScalar>::value),
268 YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
271 struct set {
template<
typename Dst,
typename Src>
void operator()(
const Dst& dst,
const Src& src)
const { dst.const_cast_derived() = src; } };
272 struct add {
template<
typename Dst,
typename Src>
void operator()(
const Dst& dst,
const Src& src)
const { dst.const_cast_derived() += src; } };
273 struct sub {
template<
typename Dst,
typename Src>
void operator()(
const Dst& dst,
const Src& src)
const { dst.const_cast_derived() -= src; } };
276 adds(
const Scalar& s) : m_scale(s) {}
277 template<
typename Dst,
typename Src>
void operator()(
const Dst& dst,
const Src& src)
const {
278 dst.const_cast_derived() += m_scale * src;
282 template<
typename Dest>
283 inline void evalTo(Dest& dest)
const {
284 internal::outer_product_selector_run(*
this, dest, set(), is_row_major<Dest>());
287 template<
typename Dest>
288 inline void addTo(Dest& dest)
const {
289 internal::outer_product_selector_run(*
this, dest, add(), is_row_major<Dest>());
292 template<
typename Dest>
293 inline void subTo(Dest& dest)
const {
294 internal::outer_product_selector_run(*
this, dest, sub(), is_row_major<Dest>());
297 template<
typename Dest>
void scaleAndAddTo(Dest& dest,
const Scalar& alpha)
const
299 internal::outer_product_selector_run(*
this, dest, adds(alpha), is_row_major<Dest>());
316 template<
typename Lhs,
typename Rhs>
317 struct traits<GeneralProduct<Lhs,Rhs,GemvProduct> >
318 : traits<ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs> >
321 template<
int S
ide,
int StorageOrder,
bool BlasCompatible>
322 struct gemv_selector;
326 template<
typename Lhs,
typename Rhs>
327 class GeneralProduct<Lhs, Rhs, GemvProduct>
328 :
public ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs>
331 EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
333 typedef typename Lhs::Scalar LhsScalar;
334 typedef typename Rhs::Scalar RhsScalar;
336 GeneralProduct(const Lhs& a_lhs, const Rhs& a_rhs) : Base(a_lhs,a_rhs)
343 typedef typename internal::conditional<int(Side)==OnTheRight,_LhsNested,_RhsNested>::type MatrixType;
345 template<
typename Dest>
void scaleAndAddTo(Dest& dst,
const Scalar& alpha)
const
347 eigen_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols());
349 bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)>::run(*
this, dst, alpha);
356 template<
int StorageOrder,
bool BlasCompatible>
357 struct gemv_selector<
OnTheLeft,StorageOrder,BlasCompatible>
359 template<
typename ProductType,
typename Dest>
360 static void run(
const ProductType& prod, Dest& dest,
const typename ProductType::Scalar& alpha)
362 Transpose<Dest> destT(dest);
364 gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
365 ::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct>
366 (prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha);
370 template<
typename Scalar,
int Size,
int MaxSize,
bool Cond>
struct gemv_static_vector_if;
372 template<
typename Scalar,
int Size,
int MaxSize>
373 struct gemv_static_vector_if<Scalar,Size,MaxSize,false>
375 EIGEN_STRONG_INLINE Scalar* data() { eigen_internal_assert(
false &&
"should never be called");
return 0; }
378 template<
typename Scalar,
int Size>
379 struct gemv_static_vector_if<Scalar,Size,
Dynamic,true>
381 EIGEN_STRONG_INLINE Scalar* data() {
return 0; }
384 template<
typename Scalar,
int Size,
int MaxSize>
385 struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
387 #if EIGEN_ALIGN_STATICALLY
388 internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize),0> m_data;
389 EIGEN_STRONG_INLINE Scalar* data() {
return m_data.array; }
394 ForceAlignment = internal::packet_traits<Scalar>::Vectorizable,
395 PacketSize = internal::packet_traits<Scalar>::size
397 internal::plain_array<Scalar,EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize)+(ForceAlignment?PacketSize:0),0> m_data;
398 EIGEN_STRONG_INLINE Scalar* data() {
399 return ForceAlignment
400 ?
reinterpret_cast<Scalar*
>((
reinterpret_cast<size_t>(m_data.array) & ~(
size_t(15))) + 16)
408 template<
typename ProductType,
typename Dest>
409 static inline void run(
const ProductType& prod, Dest& dest,
const typename ProductType::Scalar& alpha)
411 typedef typename ProductType::Index Index;
412 typedef typename ProductType::LhsScalar LhsScalar;
413 typedef typename ProductType::RhsScalar RhsScalar;
414 typedef typename ProductType::Scalar ResScalar;
415 typedef typename ProductType::RealScalar RealScalar;
416 typedef typename ProductType::ActualLhsType ActualLhsType;
417 typedef typename ProductType::ActualRhsType ActualRhsType;
418 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
419 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
420 typedef Map<Matrix<ResScalar,Dynamic,1>,
Aligned> MappedDest;
422 ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
423 ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
425 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
426 * RhsBlasTraits::extractScalarFactor(prod.rhs());
431 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
432 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
433 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
436 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
438 bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
439 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
441 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
443 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
444 evalToDest ? dest.data() : static_dest.data());
448 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
449 int size = dest.size();
450 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
452 if(!alphaIsCompatible)
454 MappedDest(actualDestPtr, dest.size()).setZero();
455 compatibleAlpha = RhsScalar(1);
458 MappedDest(actualDestPtr, dest.size()) = dest;
461 general_matrix_vector_product
462 <Index,LhsScalar,
ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
463 actualLhs.rows(), actualLhs.cols(),
464 actualLhs.data(), actualLhs.outerStride(),
465 actualRhs.data(), actualRhs.innerStride(),
471 if(!alphaIsCompatible)
472 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
474 dest = MappedDest(actualDestPtr, dest.size());
481 template<
typename ProductType,
typename Dest>
482 static void run(
const ProductType& prod, Dest& dest,
const typename ProductType::Scalar& alpha)
484 typedef typename ProductType::LhsScalar LhsScalar;
485 typedef typename ProductType::RhsScalar RhsScalar;
486 typedef typename ProductType::Scalar ResScalar;
487 typedef typename ProductType::Index Index;
488 typedef typename ProductType::ActualLhsType ActualLhsType;
489 typedef typename ProductType::ActualRhsType ActualRhsType;
490 typedef typename ProductType::_ActualRhsType _ActualRhsType;
491 typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
492 typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
494 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
495 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
497 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
498 * RhsBlasTraits::extractScalarFactor(prod.rhs());
503 DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
506 gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
508 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
509 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
513 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
514 int size = actualRhs.size();
515 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
517 Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
520 general_matrix_vector_product
521 <Index,LhsScalar,
RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
522 actualLhs.rows(), actualLhs.cols(),
523 actualLhs.data(), actualLhs.outerStride(),
525 dest.data(), dest.innerStride(),
530 template<>
struct gemv_selector<
OnTheRight,ColMajor,false>
532 template<
typename ProductType,
typename Dest>
533 static void run(
const ProductType& prod, Dest& dest,
const typename ProductType::Scalar& alpha)
535 typedef typename Dest::Index Index;
537 const Index size = prod.rhs().rows();
538 for(Index k=0; k<size; ++k)
539 dest += (alpha*prod.rhs().coeff(k)) * prod.lhs().col(k);
543 template<>
struct gemv_selector<
OnTheRight,RowMajor,false>
545 template<
typename ProductType,
typename Dest>
546 static void run(
const ProductType& prod, Dest& dest,
const typename ProductType::Scalar& alpha)
548 typedef typename Dest::Index Index;
550 const Index rows = prod.rows();
551 for(Index i=0; i<rows; ++i)
552 dest.coeffRef(i) += alpha * (prod.lhs().row(i).cwiseProduct(prod.rhs().transpose())).sum();
568 template<
typename Derived>
569 template<
typename OtherDerived>
570 inline const typename ProductReturnType<Derived, OtherDerived>::Type
578 ProductIsValid = Derived::ColsAtCompileTime==
Dynamic
579 || OtherDerived::RowsAtCompileTime==
Dynamic
580 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
581 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
582 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
587 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
588 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
589 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
590 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
591 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
592 #ifdef EIGEN_DEBUG_PRODUCT
593 internal::product_type<Derived,OtherDerived>::debug();
609 template<
typename Derived>
610 template<
typename OtherDerived>
615 ProductIsValid = Derived::ColsAtCompileTime==
Dynamic
616 || OtherDerived::RowsAtCompileTime==
Dynamic
617 || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
618 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
619 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
624 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
625 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
626 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
627 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
628 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
635 #endif // EIGEN_PRODUCT_H
Expression of the product of two general matrices or vectors.
Definition: GeneralProduct.h:36
const int Dynamic
Definition: Constants.h:21
Definition: Constants.h:264
Definition: Constants.h:194
Definition: Constants.h:279
const ScalarMultipleReturnType operator*(const Scalar &scalar) const
Definition: MatrixBase.h:50
const LazyProductReturnType< Derived, OtherDerived >::Type lazyProduct(const MatrixBase< OtherDerived > &other) const
Definition: GeneralProduct.h:612
Definition: Constants.h:266
const unsigned int RowMajorBit
Definition: Constants.h:53
Helper class to get the correct and optimized returned type of operator*.
Definition: GeneralProduct.h:149
Definition: Constants.h:277
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:48