GNU Octave  3.8.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
kron.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2002-2013 John W. Eaton
4 
5 This file is part of Octave.
6 
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 // Author: Paul Kienzle <[email protected]>
24 
25 #ifdef HAVE_CONFIG_H
26 #include <config.h>
27 #endif
28 
29 #include "dMatrix.h"
30 #include "fMatrix.h"
31 #include "CMatrix.h"
32 #include "fCMatrix.h"
33 
34 #include "dSparse.h"
35 #include "CSparse.h"
36 
37 #include "dDiagMatrix.h"
38 #include "fDiagMatrix.h"
39 #include "CDiagMatrix.h"
40 #include "fCDiagMatrix.h"
41 
42 #include "PermMatrix.h"
43 
44 #include "mx-inlines.cc"
45 #include "quit.h"
46 
47 #include "defun.h"
48 #include "error.h"
49 #include "oct-obj.h"
50 
51 template <class R, class T>
52 static MArray<T>
53 kron (const MArray<R>& a, const MArray<T>& b)
54 {
55  assert (a.ndims () == 2);
56  assert (b.ndims () == 2);
57 
58  octave_idx_type nra = a.rows (), nrb = b.rows ();
59  octave_idx_type nca = a.cols (), ncb = b.cols ();
60 
61  MArray<T> c (dim_vector (nra*nrb, nca*ncb));
62  T *cv = c.fortran_vec ();
63 
64  for (octave_idx_type ja = 0; ja < nca; ja++)
65  for (octave_idx_type jb = 0; jb < ncb; jb++)
66  for (octave_idx_type ia = 0; ia < nra; ia++)
67  {
68  octave_quit ();
69  mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
70  cv += nrb;
71  }
72 
73  return c;
74 }
75 
76 template <class R, class T>
77 static MArray<T>
78 kron (const MDiagArray2<R>& a, const MArray<T>& b)
79 {
80  assert (b.ndims () == 2);
81 
82  octave_idx_type nra = a.rows (), nrb = b.rows (), dla = a.diag_length ();
83  octave_idx_type nca = a.cols (), ncb = b.cols ();
84 
85  MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
86 
87  for (octave_idx_type ja = 0; ja < dla; ja++)
88  for (octave_idx_type jb = 0; jb < ncb; jb++)
89  {
90  octave_quit ();
91  mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja),
92  b.data () + nrb*jb);
93  }
94 
95  return c;
96 }
97 
98 template <class T>
99 static MSparse<T>
100 kron (const MSparse<T>& A, const MSparse<T>& B)
101 {
102  octave_idx_type idx = 0;
103  MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
104  A.nnz () * B.nnz ());
105 
106  C.cidx (0) = 0;
107 
108  for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
109  for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
110  {
111  octave_quit ();
112  for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
113  {
114  octave_idx_type Ci = A.ridx (Ai) * B.rows ();
115  const T v = A.data (Ai);
116 
117  for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
118  {
119  C.data (idx) = v * B.data (Bi);
120  C.ridx (idx++) = Ci + B.ridx (Bi);
121  }
122  }
123  C.cidx (Aj * B.columns () + Bj + 1) = idx;
124  }
125 
126  return C;
127 }
128 
129 static PermMatrix
130 kron (const PermMatrix& a, const PermMatrix& b)
131 {
132  octave_idx_type na = a.rows (), nb = b.rows ();
133  const octave_idx_type *pa = a.data (), *pb = b.data ();
134  PermMatrix c(na*nb); // Row permutation.
135  octave_idx_type *pc = c.fortran_vec ();
136 
137  bool cola = a.is_col_perm (), colb = b.is_col_perm ();
138  if (cola && colb)
139  {
140  for (octave_idx_type i = 0; i < na; i++)
141  for (octave_idx_type j = 0; j < nb; j++)
142  pc[pa[i]*nb+pb[j]] = i*nb+j;
143  }
144  else if (cola)
145  {
146  for (octave_idx_type i = 0; i < na; i++)
147  for (octave_idx_type j = 0; j < nb; j++)
148  pc[pa[i]*nb+j] = i*nb+pb[j];
149  }
150  else if (colb)
151  {
152  for (octave_idx_type i = 0; i < na; i++)
153  for (octave_idx_type j = 0; j < nb; j++)
154  pc[i*nb+pb[j]] = pa[i]*nb+j;
155  }
156  else
157  {
158  for (octave_idx_type i = 0; i < na; i++)
159  for (octave_idx_type j = 0; j < nb; j++)
160  pc[i*nb+j] = pa[i]*nb+pb[j];
161  }
162 
163  return c;
164 }
165 
166 template <class MTA, class MTB>
168 do_kron (const octave_value& a, const octave_value& b)
169 {
170  MTA am = octave_value_extract<MTA> (a);
171  MTB bm = octave_value_extract<MTB> (b);
172  return octave_value (kron (am, bm));
173 }
174 
177 {
178  octave_value retval;
179  if (a.is_perm_matrix () && b.is_perm_matrix ())
180  retval = do_kron<PermMatrix, PermMatrix> (a, b);
181  else if (a.is_sparse_type () || b.is_sparse_type ())
182  {
183  if (a.is_complex_type () || b.is_complex_type ())
184  retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
185  else
186  retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
187  }
188  else if (a.is_diag_matrix ())
189  {
190  if (b.is_diag_matrix () && a.rows () == a.columns ()
191  && b.rows () == b.columns ())
192  {
193  // We have two diagonal matrices, the product of those will be
194  // another diagonal matrix. To do that efficiently, extract
195  // the diagonals as vectors and compute the product. That
196  // will be another vector, which we then use to construct a
197  // diagonal matrix object. Note that this will fail if our
198  // digaonal matrix object is modified to allow the non-zero
199  // values to be stored off of the principal diagonal (i.e., if
200  // diag ([1,2], 3) is modified to return a diagonal matrix
201  // object instead of a full matrix object).
202 
203  octave_value tmp = dispatch_kron (a.diag (), b.diag ());
204  retval = tmp.diag ();
205  }
206  else if (a.is_single_type () || b.is_single_type ())
207  {
208  if (a.is_complex_type ())
209  retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
210  else if (b.is_complex_type ())
211  retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
212  else
213  retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
214  }
215  else
216  {
217  if (a.is_complex_type ())
218  retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
219  else if (b.is_complex_type ())
220  retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
221  else
222  retval = do_kron<DiagMatrix, Matrix> (a, b);
223  }
224  }
225  else if (a.is_single_type () || b.is_single_type ())
226  {
227  if (a.is_complex_type ())
228  retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
229  else if (b.is_complex_type ())
230  retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
231  else
232  retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
233  }
234  else
235  {
236  if (a.is_complex_type ())
237  retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
238  else if (b.is_complex_type ())
239  retval = do_kron<Matrix, ComplexMatrix> (a, b);
240  else
241  retval = do_kron<Matrix, Matrix> (a, b);
242  }
243  return retval;
244 }
245 
246 
247 DEFUN (kron, args, , "-*- texinfo -*-\n\
248 @deftypefn {Built-in Function} {} kron (@var{A}, @var{B})\n\
249 @deftypefnx {Built-in Function} {} kron (@var{A1}, @var{A2}, @dots{})\n\
250 Form the Kronecker product of two or more matrices, defined block by \n\
251 block as\n\
252 \n\
253 @example\n\
254 x = [ a(i,j)*b ]\n\
255 @end example\n\
256 \n\
257 For example:\n\
258 \n\
259 @example\n\
260 @group\n\
261 kron (1:4, ones (3, 1))\n\
262  @result{} 1 2 3 4\n\
263  1 2 3 4\n\
264  1 2 3 4\n\
265 @end group\n\
266 @end example\n\
267 \n\
268 If there are more than two input arguments @var{A1}, @var{A2}, @dots{}, \n\
269 @var{An} the Kronecker product is computed as\n\
270 \n\
271 @example\n\
272 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})\n\
273 @end example\n\
274 \n\
275 @noindent\n\
276 Since the Kronecker product is associative, this is well-defined.\n\
277 @end deftypefn")
278 {
279  octave_value retval;
280 
281  int nargin = args.length ();
282 
283  if (nargin >= 2)
284  {
285  octave_value a = args(0), b = args(1);
286  retval = dispatch_kron (a, b);
287  for (octave_idx_type i = 2; i < nargin; i++)
288  retval = dispatch_kron (retval, args(i));
289  }
290  else
291  print_usage ();
292 
293  return retval;
294 }
295 
296 
297 /*
298 %!test
299 %! x = ones (2);
300 %! assert (kron (x, x), ones (4));
301 
302 %!shared x, y, z
303 %! x = [1, 2];
304 %! y = [-1, -2];
305 %! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
306 %!assert (kron (1:4, ones (3, 1)), z)
307 %!assert (kron (x, y, z), kron (kron (x, y), z))
308 %!assert (kron (x, y, z), kron (x, kron (y, z)))
309 
310 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
311 
312 %% Test for two diag matrices. See the comments above in
313 %% dispatch_kron for this case.
314 %%
315 %!test
316 %! expected = zeros (16, 16);
317 %! expected (1, 11) = 3;
318 %! expected (2, 12) = 4;
319 %! expected (5, 15) = 6;
320 %! expected (6, 16) = 8;
321 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected)
322 */