GNU Octave  4.0.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
kron.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2002-2015 John W. Eaton
4 
5 This file is part of Octave.
6 
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 // Author: Paul Kienzle <[email protected]>
24 
25 #ifdef HAVE_CONFIG_H
26 #include <config.h>
27 #endif
28 
29 #include "dMatrix.h"
30 #include "fMatrix.h"
31 #include "CMatrix.h"
32 #include "fCMatrix.h"
33 
34 #include "dSparse.h"
35 #include "CSparse.h"
36 
37 #include "dDiagMatrix.h"
38 #include "fDiagMatrix.h"
39 #include "CDiagMatrix.h"
40 #include "fCDiagMatrix.h"
41 
42 #include "PermMatrix.h"
43 
44 #include "mx-inlines.cc"
45 #include "quit.h"
46 
47 #include "defun.h"
48 #include "error.h"
49 #include "oct-obj.h"
50 
51 template <class R, class T>
52 static MArray<T>
53 kron (const MArray<R>& a, const MArray<T>& b)
54 {
55  assert (a.ndims () == 2);
56  assert (b.ndims () == 2);
57 
58  octave_idx_type nra = a.rows ();
59  octave_idx_type nrb = b.rows ();
60  octave_idx_type nca = a.cols ();
61  octave_idx_type ncb = b.cols ();
62 
63  MArray<T> c (dim_vector (nra*nrb, nca*ncb));
64  T *cv = c.fortran_vec ();
65 
66  for (octave_idx_type ja = 0; ja < nca; ja++)
67  for (octave_idx_type jb = 0; jb < ncb; jb++)
68  for (octave_idx_type ia = 0; ia < nra; ia++)
69  {
70  octave_quit ();
71  mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
72  cv += nrb;
73  }
74 
75  return c;
76 }
77 
78 template <class R, class T>
79 static MArray<T>
80 kron (const MDiagArray2<R>& a, const MArray<T>& b)
81 {
82  assert (b.ndims () == 2);
83 
84  octave_idx_type nra = a.rows ();
85  octave_idx_type nrb = b.rows ();
86  octave_idx_type dla = a.diag_length ();
87  octave_idx_type nca = a.cols ();
88  octave_idx_type ncb = b.cols ();
89 
90  MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
91 
92  for (octave_idx_type ja = 0; ja < dla; ja++)
93  for (octave_idx_type jb = 0; jb < ncb; jb++)
94  {
95  octave_quit ();
96  mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja),
97  b.data () + nrb*jb);
98  }
99 
100  return c;
101 }
102 
103 template <class T>
104 static MSparse<T>
105 kron (const MSparse<T>& A, const MSparse<T>& B)
106 {
107  octave_idx_type idx = 0;
108  MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
109  A.nnz () * B.nnz ());
110 
111  C.cidx (0) = 0;
112 
113  for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
114  for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
115  {
116  octave_quit ();
117  for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
118  {
119  octave_idx_type Ci = A.ridx (Ai) * B.rows ();
120  const T v = A.data (Ai);
121 
122  for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
123  {
124  C.data (idx) = v * B.data (Bi);
125  C.ridx (idx++) = Ci + B.ridx (Bi);
126  }
127  }
128  C.cidx (Aj * B.columns () + Bj + 1) = idx;
129  }
130 
131  return C;
132 }
133 
134 static PermMatrix
135 kron (const PermMatrix& a, const PermMatrix& b)
136 {
137  octave_idx_type na = a.rows ();
138  octave_idx_type nb = b.rows ();
139  const Array<octave_idx_type>& pa = a.col_perm_vec ();
140  const Array<octave_idx_type>& pb = b.col_perm_vec ();
141  Array<octave_idx_type> res_perm (dim_vector (na * nb, 1));
142  octave_idx_type rescol = 0;
143  for (octave_idx_type i = 0; i < na; i++)
144  {
145  octave_idx_type a_add = pa(i) * nb;
146  for (octave_idx_type j = 0; j < nb; j++)
147  res_perm.xelem (rescol++) = a_add + pb(j);
148  }
149 
150  return PermMatrix (res_perm, true);
151 }
152 
153 template <class MTA, class MTB>
155 do_kron (const octave_value& a, const octave_value& b)
156 {
157  MTA am = octave_value_extract<MTA> (a);
158  MTB bm = octave_value_extract<MTB> (b);
159  return octave_value (kron (am, bm));
160 }
161 
164 {
165  octave_value retval;
166  if (a.is_perm_matrix () && b.is_perm_matrix ())
167  retval = do_kron<PermMatrix, PermMatrix> (a, b);
168  else if (a.is_sparse_type () || b.is_sparse_type ())
169  {
170  if (a.is_complex_type () || b.is_complex_type ())
171  retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
172  else
173  retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
174  }
175  else if (a.is_diag_matrix ())
176  {
177  if (b.is_diag_matrix () && a.rows () == a.columns ()
178  && b.rows () == b.columns ())
179  {
180  // We have two diagonal matrices, the product of those will be
181  // another diagonal matrix. To do that efficiently, extract
182  // the diagonals as vectors and compute the product. That
183  // will be another vector, which we then use to construct a
184  // diagonal matrix object. Note that this will fail if our
185  // digaonal matrix object is modified to allow the nonzero
186  // values to be stored off of the principal diagonal (i.e., if
187  // diag ([1,2], 3) is modified to return a diagonal matrix
188  // object instead of a full matrix object).
189 
190  octave_value tmp = dispatch_kron (a.diag (), b.diag ());
191  retval = tmp.diag ();
192  }
193  else if (a.is_single_type () || b.is_single_type ())
194  {
195  if (a.is_complex_type ())
196  retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
197  else if (b.is_complex_type ())
198  retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
199  else
200  retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
201  }
202  else
203  {
204  if (a.is_complex_type ())
205  retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
206  else if (b.is_complex_type ())
207  retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
208  else
209  retval = do_kron<DiagMatrix, Matrix> (a, b);
210  }
211  }
212  else if (a.is_single_type () || b.is_single_type ())
213  {
214  if (a.is_complex_type ())
215  retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
216  else if (b.is_complex_type ())
217  retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
218  else
219  retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
220  }
221  else
222  {
223  if (a.is_complex_type ())
224  retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
225  else if (b.is_complex_type ())
226  retval = do_kron<Matrix, ComplexMatrix> (a, b);
227  else
228  retval = do_kron<Matrix, Matrix> (a, b);
229  }
230  return retval;
231 }
232 
233 
234 DEFUN (kron, args, , "-*- texinfo -*-\n\
235 @deftypefn {Built-in Function} {} kron (@var{A}, @var{B})\n\
236 @deftypefnx {Built-in Function} {} kron (@var{A1}, @var{A2}, @dots{})\n\
237 Form the Kronecker product of two or more matrices.\n\
238 \n\
239 This is defined block by block as\n\
240 \n\
241 @example\n\
242 x = [ a(i,j)*b ]\n\
243 @end example\n\
244 \n\
245 For example:\n\
246 \n\
247 @example\n\
248 @group\n\
249 kron (1:4, ones (3, 1))\n\
250  @result{} 1 2 3 4\n\
251  1 2 3 4\n\
252  1 2 3 4\n\
253 @end group\n\
254 @end example\n\
255 \n\
256 If there are more than two input arguments @var{A1}, @var{A2}, @dots{},\n\
257 @var{An} the Kronecker product is computed as\n\
258 \n\
259 @example\n\
260 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})\n\
261 @end example\n\
262 \n\
263 @noindent\n\
264 Since the Kronecker product is associative, this is well-defined.\n\
265 @end deftypefn")
266 {
267  octave_value retval;
268 
269  int nargin = args.length ();
270 
271  if (nargin >= 2)
272  {
273  octave_value a = args(0);
274  octave_value b = args(1);
275  retval = dispatch_kron (a, b);
276  for (octave_idx_type i = 2; i < nargin; i++)
277  retval = dispatch_kron (retval, args(i));
278  }
279  else
280  print_usage ();
281 
282  return retval;
283 }
284 
285 
286 /*
287 %!test
288 %! x = ones (2);
289 %! assert (kron (x, x), ones (4));
290 
291 %!shared x, y, z, p1, p2, d1, d2
292 %! x = [1, 2];
293 %! y = [-1, -2];
294 %! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
295 %! p1 = eye (3)([2, 3, 1], :); ## Permutation matrix
296 %! p2 = [0 1 0; 0 0 1; 1 0 0]; ## Non-permutation equivalent
297 %! d1 = diag ([1 2 3]); ## Diag type matrix
298 %! d2 = [1 0 0; 0 2 0; 0 0 3]; ## Non-diag equivalent
299 %!assert (kron (1:4, ones (3, 1)), z)
300 %!assert (kron (single (1:4), ones (3, 1)), single (z))
301 %!assert (kron (sparse (1:4), ones (3, 1)), sparse (z))
302 %!assert (kron (complex (1:4), ones (3, 1)), z)
303 %!assert (kron (complex (single(1:4)), ones (3, 1)), single(z))
304 %!assert (kron (x, y, z), kron (kron (x, y), z))
305 %!assert (kron (x, y, z), kron (x, kron (y, z)))
306 %!assert (kron (p1, p1), kron (p2, p2))
307 %!assert (kron (p1, p2), kron (p2, p1))
308 %!assert (kron (d1, d1), kron (d2, d2))
309 %!assert (kron (d1, d2), kron (d2, d1))
310 
311 
312 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
313 
314 %% Test for two diag matrices. See the comments above in
315 %% dispatch_kron for this case.
316 %%
317 %!test
318 %! expected = zeros (16, 16);
319 %! expected (1, 11) = 3;
320 %! expected (2, 12) = 4;
321 %! expected (5, 15) = 6;
322 %! expected (6, 16) = 8;
323 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected)
324 */
octave_value dispatch_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:163
T * data(void)
Definition: Sparse.h:509
octave_idx_type rows(void) const
Definition: Sparse.h:263
#define C(a, b)
Definition: Faddeeva.cc:255
octave_idx_type rows(void) const
Definition: ov.h:473
octave_idx_type rows(void) const
Definition: PermMatrix.h:55
int ndims(void) const
Definition: Array.h:487
OCTINTERP_API void print_usage(void)
Definition: defun.cc:51
octave_value diag(octave_idx_type k=0) const
Definition: ov.h:1126
F77_RET_T const octave_idx_type Complex * A
Definition: CmplxGEPBAL.cc:39
#define DEFUN(name, args_name, nargout_name, doc)
Definition: defun.h:44
bool is_perm_matrix(void) const
Definition: ov.h:559
octave_idx_type rows(void) const
Definition: DiagArray2.h:86
octave_idx_type * cidx(void)
Definition: Sparse.h:531
octave_idx_type columns(void) const
Definition: Sparse.h:265
Definition: MArray.h:36
octave_idx_type rows(void) const
Definition: Array.h:313
void mx_inline_mul(size_t n, R *r, const X *x, const Y *y)
Definition: mx-inlines.cc:84
octave_idx_type nnz(void) const
Definition: Sparse.h:248
T dgelem(octave_idx_type i) const
Definition: DiagArray2.h:121
octave_idx_type columns(void) const
Definition: ov.h:475
bool is_sparse_type(void) const
Definition: ov.h:666
const Array< octave_idx_type > & col_perm_vec(void) const
Definition: PermMatrix.h:72
const T * data(void) const
Definition: Array.h:479
bool is_complex_type(void) const
Definition: ov.h:654
octave_idx_type length(void) const
Definition: ov.cc:1525
static MArray< T > kron(const MArray< R > &a, const MArray< T > &b)
Definition: kron.cc:53
T & xelem(octave_idx_type n)
Definition: Array.h:353
octave_idx_type cols(void) const
Definition: DiagArray2.h:87
octave_idx_type * ridx(void)
Definition: Sparse.h:518
F77_RET_T const octave_idx_type Complex const octave_idx_type Complex * B
Definition: CmplxGEPBAL.cc:39
octave_value do_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:155
octave_idx_type diag_length(void) const
Definition: DiagArray2.h:90
const T * fortran_vec(void) const
Definition: Array.h:481
bool is_single_type(void) const
Definition: ov.h:611
octave_idx_type cols(void) const
Definition: Array.h:321
bool is_diag_matrix(void) const
Definition: ov.h:556
return octave_value(v1.char_array_value().concat(v2.char_array_value(), ra_idx),((a1.is_sq_string()||a2.is_sq_string())? '\'': '"'))