10 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11 #define EIGEN_GENERAL_MATRIX_VECTOR_H
30 template<
typename Index,
typename LhsScalar,
bool ConjugateLhs,
typename RhsScalar,
bool ConjugateRhs,
int Version>
31 struct general_matrix_vector_product<Index,LhsScalar,
ColMajor,ConjugateLhs,RhsScalar,ConjugateRhs,Version>
33 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
36 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
37 && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
38 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
39 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
40 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
43 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
44 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
45 typedef typename packet_traits<ResScalar>::type _ResPacket;
47 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
48 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
49 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
51 EIGEN_DONT_INLINE
static void run(
52 Index rows, Index cols,
53 const LhsScalar* lhs, Index lhsStride,
54 const RhsScalar* rhs, Index rhsIncr,
55 ResScalar* res, Index resIncr, RhsScalar alpha);
58 template<
typename Index,
typename LhsScalar,
bool ConjugateLhs,
typename RhsScalar,
bool ConjugateRhs,
int Version>
59 EIGEN_DONT_INLINE
void general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjugateLhs,RhsScalar,ConjugateRhs,Version>::run(
60 Index rows, Index cols,
61 const LhsScalar* lhs, Index lhsStride,
62 const RhsScalar* rhs, Index rhsIncr,
63 ResScalar* res, Index resIncr, RhsScalar alpha)
65 EIGEN_UNUSED_VARIABLE(resIncr)
66 eigen_internal_assert(resIncr==1);
67 #ifdef _EIGEN_ACCUMULATE_PACKETS
68 #error _EIGEN_ACCUMULATE_PACKETS has already been defined
70 #define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) \
72 padd(pload<ResPacket>(&res[j]), \
74 padd(pcj.pmul(EIGEN_CAT(ploa , A0)<LhsPacket>(&lhs0[j]), ptmp0), \
75 pcj.pmul(EIGEN_CAT(ploa , A13)<LhsPacket>(&lhs1[j]), ptmp1)), \
76 padd(pcj.pmul(EIGEN_CAT(ploa , A2)<LhsPacket>(&lhs2[j]), ptmp2), \
77 pcj.pmul(EIGEN_CAT(ploa , A13)<LhsPacket>(&lhs3[j]), ptmp3)) )))
79 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
80 conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
82 alpha = numext::conj(alpha);
84 enum { AllAligned = 0, EvenAligned, FirstAligned, NoneAligned };
85 const Index columnsAtOnce = 4;
86 const Index peels = 2;
87 const Index LhsPacketAlignedMask = LhsPacketSize-1;
88 const Index ResPacketAlignedMask = ResPacketSize-1;
90 const Index size = rows;
94 Index alignedStart = internal::first_aligned(res,size);
95 Index alignedSize = ResPacketSize>1 ? alignedStart + ((size-alignedStart) & ~ResPacketAlignedMask) : 0;
96 const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
98 const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
99 Index alignmentPattern = alignmentStep==0 ? AllAligned
100 : alignmentStep==(LhsPacketSize/2) ? EvenAligned
104 const Index lhsAlignmentOffset = internal::first_aligned(lhs,size);
107 Index skipColumns = 0;
109 if( (
size_t(lhs)%
sizeof(LhsScalar)) || (
size_t(res)%
sizeof(ResScalar)) )
114 else if (LhsPacketSize>1)
116 eigen_internal_assert(
size_t(lhs+lhsAlignmentOffset)%
sizeof(LhsPacket)==0 || size<LhsPacketSize);
118 while (skipColumns<LhsPacketSize &&
119 alignedStart != ((lhsAlignmentOffset + alignmentStep*skipColumns)%LhsPacketSize))
121 if (skipColumns==LhsPacketSize)
124 alignmentPattern = NoneAligned;
129 skipColumns = (std::min)(skipColumns,cols);
133 eigen_internal_assert( (alignmentPattern==NoneAligned)
134 || (skipColumns + columnsAtOnce >= cols)
135 || LhsPacketSize > size
136 || (
size_t(lhs+alignedStart+lhsStride*skipColumns)%
sizeof(LhsPacket))==0);
138 else if(Vectorizable)
142 alignmentPattern = AllAligned;
145 Index offset1 = (FirstAligned && alignmentStep==1?3:1);
146 Index offset3 = (FirstAligned && alignmentStep==1?1:3);
148 Index columnBound = ((cols-skipColumns)/columnsAtOnce)*columnsAtOnce + skipColumns;
149 for (Index i=skipColumns; i<columnBound; i+=columnsAtOnce)
151 RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs[i*rhsIncr]),
152 ptmp1 = pset1<RhsPacket>(alpha*rhs[(i+offset1)*rhsIncr]),
153 ptmp2 = pset1<RhsPacket>(alpha*rhs[(i+2)*rhsIncr]),
154 ptmp3 = pset1<RhsPacket>(alpha*rhs[(i+offset3)*rhsIncr]);
157 const LhsScalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride,
158 *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride;
164 for (Index j=0; j<alignedStart; ++j)
166 res[j] = cj.pmadd(lhs0[j], pfirst(ptmp0), res[j]);
167 res[j] = cj.pmadd(lhs1[j], pfirst(ptmp1), res[j]);
168 res[j] = cj.pmadd(lhs2[j], pfirst(ptmp2), res[j]);
169 res[j] = cj.pmadd(lhs3[j], pfirst(ptmp3), res[j]);
172 if (alignedSize>alignedStart)
174 switch(alignmentPattern)
177 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
178 _EIGEN_ACCUMULATE_PACKETS(d,d,d);
181 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
182 _EIGEN_ACCUMULATE_PACKETS(d,du,d);
186 Index j = alignedStart;
189 LhsPacket A00, A01, A02, A03, A10, A11, A12, A13;
192 A01 = pload<LhsPacket>(&lhs1[alignedStart-1]);
193 A02 = pload<LhsPacket>(&lhs2[alignedStart-2]);
194 A03 = pload<LhsPacket>(&lhs3[alignedStart-3]);
196 for (; j<peeledSize; j+=peels*ResPacketSize)
198 A11 = pload<LhsPacket>(&lhs1[j-1+LhsPacketSize]); palign<1>(A01,A11);
199 A12 = pload<LhsPacket>(&lhs2[j-2+LhsPacketSize]); palign<2>(A02,A12);
200 A13 = pload<LhsPacket>(&lhs3[j-3+LhsPacketSize]); palign<3>(A03,A13);
202 A00 = pload<LhsPacket>(&lhs0[j]);
203 A10 = pload<LhsPacket>(&lhs0[j+LhsPacketSize]);
204 T0 = pcj.pmadd(A00, ptmp0, pload<ResPacket>(&res[j]));
205 T1 = pcj.pmadd(A10, ptmp0, pload<ResPacket>(&res[j+ResPacketSize]));
207 T0 = pcj.pmadd(A01, ptmp1, T0);
208 A01 = pload<LhsPacket>(&lhs1[j-1+2*LhsPacketSize]); palign<1>(A11,A01);
209 T0 = pcj.pmadd(A02, ptmp2, T0);
210 A02 = pload<LhsPacket>(&lhs2[j-2+2*LhsPacketSize]); palign<2>(A12,A02);
211 T0 = pcj.pmadd(A03, ptmp3, T0);
213 A03 = pload<LhsPacket>(&lhs3[j-3+2*LhsPacketSize]); palign<3>(A13,A03);
214 T1 = pcj.pmadd(A11, ptmp1, T1);
215 T1 = pcj.pmadd(A12, ptmp2, T1);
216 T1 = pcj.pmadd(A13, ptmp3, T1);
217 pstore(&res[j+ResPacketSize],T1);
220 for (; j<alignedSize; j+=ResPacketSize)
221 _EIGEN_ACCUMULATE_PACKETS(d,du,du);
225 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
226 _EIGEN_ACCUMULATE_PACKETS(du,du,du);
233 for (Index j=alignedSize; j<size; ++j)
235 res[j] = cj.pmadd(lhs0[j], pfirst(ptmp0), res[j]);
236 res[j] = cj.pmadd(lhs1[j], pfirst(ptmp1), res[j]);
237 res[j] = cj.pmadd(lhs2[j], pfirst(ptmp2), res[j]);
238 res[j] = cj.pmadd(lhs3[j], pfirst(ptmp3), res[j]);
244 Index start = columnBound;
247 for (Index k=start; k<end; ++k)
249 RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs[k*rhsIncr]);
250 const LhsScalar* lhs0 = lhs + k*lhsStride;
256 for (Index j=0; j<alignedStart; ++j)
257 res[j] += cj.pmul(lhs0[j], pfirst(ptmp0));
259 if ((
size_t(lhs0+alignedStart)%
sizeof(LhsPacket))==0)
260 for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
261 pstore(&res[i], pcj.pmadd(pload<LhsPacket>(&lhs0[i]), ptmp0, pload<ResPacket>(&res[i])));
263 for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
264 pstore(&res[i], pcj.pmadd(ploadu<LhsPacket>(&lhs0[i]), ptmp0, pload<ResPacket>(&res[i])));
268 for (Index i=alignedSize; i<size; ++i)
269 res[i] += cj.pmul(lhs0[i], pfirst(ptmp0));
279 }
while(Vectorizable);
280 #undef _EIGEN_ACCUMULATE_PACKETS
293 template<
typename Index,
typename LhsScalar,
bool ConjugateLhs,
typename RhsScalar,
bool ConjugateRhs,
int Version>
294 struct general_matrix_vector_product<Index,LhsScalar,
RowMajor,ConjugateLhs,RhsScalar,ConjugateRhs,Version>
296 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
299 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
300 && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
301 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
302 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
303 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
306 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
307 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
308 typedef typename packet_traits<ResScalar>::type _ResPacket;
310 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
311 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
312 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
314 EIGEN_DONT_INLINE
static void run(
315 Index rows, Index cols,
316 const LhsScalar* lhs, Index lhsStride,
317 const RhsScalar* rhs, Index rhsIncr,
318 ResScalar* res, Index resIncr,
322 template<
typename Index,
typename LhsScalar,
bool ConjugateLhs,
typename RhsScalar,
bool ConjugateRhs,
int Version>
323 EIGEN_DONT_INLINE
void general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjugateLhs,RhsScalar,ConjugateRhs,Version>::run(
324 Index rows, Index cols,
325 const LhsScalar* lhs, Index lhsStride,
326 const RhsScalar* rhs, Index rhsIncr,
327 ResScalar* res, Index resIncr,
330 EIGEN_UNUSED_VARIABLE(rhsIncr);
331 eigen_internal_assert(rhsIncr==1);
332 #ifdef _EIGEN_ACCUMULATE_PACKETS
333 #error _EIGEN_ACCUMULATE_PACKETS has already been defined
336 #define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) {\
337 RhsPacket b = pload<RhsPacket>(&rhs[j]); \
338 ptmp0 = pcj.pmadd(EIGEN_CAT(ploa,A0) <LhsPacket>(&lhs0[j]), b, ptmp0); \
339 ptmp1 = pcj.pmadd(EIGEN_CAT(ploa,A13)<LhsPacket>(&lhs1[j]), b, ptmp1); \
340 ptmp2 = pcj.pmadd(EIGEN_CAT(ploa,A2) <LhsPacket>(&lhs2[j]), b, ptmp2); \
341 ptmp3 = pcj.pmadd(EIGEN_CAT(ploa,A13)<LhsPacket>(&lhs3[j]), b, ptmp3); }
343 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
344 conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
346 enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 };
347 const Index rowsAtOnce = 4;
348 const Index peels = 2;
349 const Index RhsPacketAlignedMask = RhsPacketSize-1;
350 const Index LhsPacketAlignedMask = LhsPacketSize-1;
352 const Index depth = cols;
357 Index alignedStart = internal::first_aligned(rhs, depth);
358 Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0;
359 const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
361 const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
362 Index alignmentPattern = alignmentStep==0 ? AllAligned
363 : alignmentStep==(LhsPacketSize/2) ? EvenAligned
367 const Index lhsAlignmentOffset = internal::first_aligned(lhs,depth);
372 if( (
sizeof(LhsScalar)!=
sizeof(RhsScalar)) || (
size_t(lhs)%
sizeof(LhsScalar)) || (
size_t(rhs)%
sizeof(RhsScalar)) )
377 else if (LhsPacketSize>1)
379 eigen_internal_assert(
size_t(lhs+lhsAlignmentOffset)%
sizeof(LhsPacket)==0 || depth<LhsPacketSize);
381 while (skipRows<LhsPacketSize &&
382 alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize))
384 if (skipRows==LhsPacketSize)
387 alignmentPattern = NoneAligned;
392 skipRows = (std::min)(skipRows,Index(rows));
395 eigen_internal_assert( alignmentPattern==NoneAligned
397 || (skipRows + rowsAtOnce >= rows)
398 || LhsPacketSize > depth
399 || (
size_t(lhs+alignedStart+lhsStride*skipRows)%
sizeof(LhsPacket))==0);
401 else if(Vectorizable)
405 alignmentPattern = AllAligned;
408 Index offset1 = (FirstAligned && alignmentStep==1?3:1);
409 Index offset3 = (FirstAligned && alignmentStep==1?1:3);
411 Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows;
412 for (Index i=skipRows; i<rowBound; i+=rowsAtOnce)
414 EIGEN_ALIGN16 ResScalar tmp0 = ResScalar(0);
415 ResScalar tmp1 = ResScalar(0), tmp2 = ResScalar(0), tmp3 = ResScalar(0);
418 const LhsScalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride,
419 *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride;
424 ResPacket ptmp0 = pset1<ResPacket>(ResScalar(0)), ptmp1 = pset1<ResPacket>(ResScalar(0)),
425 ptmp2 = pset1<ResPacket>(ResScalar(0)), ptmp3 = pset1<ResPacket>(ResScalar(0));
429 for (Index j=0; j<alignedStart; ++j)
431 RhsScalar b = rhs[j];
432 tmp0 += cj.pmul(lhs0[j],b); tmp1 += cj.pmul(lhs1[j],b);
433 tmp2 += cj.pmul(lhs2[j],b); tmp3 += cj.pmul(lhs3[j],b);
436 if (alignedSize>alignedStart)
438 switch(alignmentPattern)
441 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
442 _EIGEN_ACCUMULATE_PACKETS(d,d,d);
445 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
446 _EIGEN_ACCUMULATE_PACKETS(d,du,d);
450 Index j = alignedStart;
459 LhsPacket A01, A02, A03, A11, A12, A13;
460 A01 = pload<LhsPacket>(&lhs1[alignedStart-1]);
461 A02 = pload<LhsPacket>(&lhs2[alignedStart-2]);
462 A03 = pload<LhsPacket>(&lhs3[alignedStart-3]);
464 for (; j<peeledSize; j+=peels*RhsPacketSize)
466 RhsPacket b = pload<RhsPacket>(&rhs[j]);
467 A11 = pload<LhsPacket>(&lhs1[j-1+LhsPacketSize]); palign<1>(A01,A11);
468 A12 = pload<LhsPacket>(&lhs2[j-2+LhsPacketSize]); palign<2>(A02,A12);
469 A13 = pload<LhsPacket>(&lhs3[j-3+LhsPacketSize]); palign<3>(A03,A13);
471 ptmp0 = pcj.pmadd(pload<LhsPacket>(&lhs0[j]), b, ptmp0);
472 ptmp1 = pcj.pmadd(A01, b, ptmp1);
473 A01 = pload<LhsPacket>(&lhs1[j-1+2*LhsPacketSize]); palign<1>(A11,A01);
474 ptmp2 = pcj.pmadd(A02, b, ptmp2);
475 A02 = pload<LhsPacket>(&lhs2[j-2+2*LhsPacketSize]); palign<2>(A12,A02);
476 ptmp3 = pcj.pmadd(A03, b, ptmp3);
477 A03 = pload<LhsPacket>(&lhs3[j-3+2*LhsPacketSize]); palign<3>(A13,A03);
479 b = pload<RhsPacket>(&rhs[j+RhsPacketSize]);
480 ptmp0 = pcj.pmadd(pload<LhsPacket>(&lhs0[j+LhsPacketSize]), b, ptmp0);
481 ptmp1 = pcj.pmadd(A11, b, ptmp1);
482 ptmp2 = pcj.pmadd(A12, b, ptmp2);
483 ptmp3 = pcj.pmadd(A13, b, ptmp3);
486 for (; j<alignedSize; j+=RhsPacketSize)
487 _EIGEN_ACCUMULATE_PACKETS(d,du,du);
491 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
492 _EIGEN_ACCUMULATE_PACKETS(du,du,du);
495 tmp0 += predux(ptmp0);
496 tmp1 += predux(ptmp1);
497 tmp2 += predux(ptmp2);
498 tmp3 += predux(ptmp3);
504 for (Index j=alignedSize; j<depth; ++j)
506 RhsScalar b = rhs[j];
507 tmp0 += cj.pmul(lhs0[j],b); tmp1 += cj.pmul(lhs1[j],b);
508 tmp2 += cj.pmul(lhs2[j],b); tmp3 += cj.pmul(lhs3[j],b);
510 res[i*resIncr] += alpha*tmp0;
511 res[(i+offset1)*resIncr] += alpha*tmp1;
512 res[(i+2)*resIncr] += alpha*tmp2;
513 res[(i+offset3)*resIncr] += alpha*tmp3;
518 Index start = rowBound;
521 for (Index i=start; i<end; ++i)
523 EIGEN_ALIGN16 ResScalar tmp0 = ResScalar(0);
524 ResPacket ptmp0 = pset1<ResPacket>(tmp0);
525 const LhsScalar* lhs0 = lhs + i*lhsStride;
528 for (Index j=0; j<alignedStart; ++j)
529 tmp0 += cj.pmul(lhs0[j], rhs[j]);
531 if (alignedSize>alignedStart)
534 if ((
size_t(lhs0+alignedStart)%
sizeof(LhsPacket))==0)
535 for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
536 ptmp0 = pcj.pmadd(pload<LhsPacket>(&lhs0[j]), pload<RhsPacket>(&rhs[j]), ptmp0);
538 for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
539 ptmp0 = pcj.pmadd(ploadu<LhsPacket>(&lhs0[j]), pload<RhsPacket>(&rhs[j]), ptmp0);
540 tmp0 += predux(ptmp0);
545 for (Index j=alignedSize; j<depth; ++j)
546 tmp0 += cj.pmul(lhs0[j], rhs[j]);
547 res[i*resIncr] += alpha*tmp0;
557 }
while(Vectorizable);
559 #undef _EIGEN_ACCUMULATE_PACKETS
566 #endif // EIGEN_GENERAL_MATRIX_VECTOR_H
Definition: Constants.h:264
Definition: Constants.h:266