Source file
src/math/big/int.go
1
2
3
4
5
6
7 package big
8
9 import (
10 "fmt"
11 "io"
12 "math/rand"
13 "strings"
14 )
15
16
17
18 type Int struct {
19 neg bool
20 abs nat
21 }
22
23 var intOne = &Int{false, natOne}
24
25
26
27
28
29
30
31 func (x *Int) Sign() int {
32 if len(x.abs) == 0 {
33 return 0
34 }
35 if x.neg {
36 return -1
37 }
38 return 1
39 }
40
41
42 func (z *Int) SetInt64(x int64) *Int {
43 neg := false
44 if x < 0 {
45 neg = true
46 x = -x
47 }
48 z.abs = z.abs.setUint64(uint64(x))
49 z.neg = neg
50 return z
51 }
52
53
54 func (z *Int) SetUint64(x uint64) *Int {
55 z.abs = z.abs.setUint64(x)
56 z.neg = false
57 return z
58 }
59
60
61 func NewInt(x int64) *Int {
62 return new(Int).SetInt64(x)
63 }
64
65
66 func (z *Int) Set(x *Int) *Int {
67 if z != x {
68 z.abs = z.abs.set(x.abs)
69 z.neg = x.neg
70 }
71 return z
72 }
73
74
75
76
77
78
79 func (x *Int) Bits() []Word {
80 return x.abs
81 }
82
83
84
85
86
87
88 func (z *Int) SetBits(abs []Word) *Int {
89 z.abs = nat(abs).norm()
90 z.neg = false
91 return z
92 }
93
94
95 func (z *Int) Abs(x *Int) *Int {
96 z.Set(x)
97 z.neg = false
98 return z
99 }
100
101
102 func (z *Int) Neg(x *Int) *Int {
103 z.Set(x)
104 z.neg = len(z.abs) > 0 && !z.neg
105 return z
106 }
107
108
109 func (z *Int) Add(x, y *Int) *Int {
110 neg := x.neg
111 if x.neg == y.neg {
112
113
114 z.abs = z.abs.add(x.abs, y.abs)
115 } else {
116
117
118 if x.abs.cmp(y.abs) >= 0 {
119 z.abs = z.abs.sub(x.abs, y.abs)
120 } else {
121 neg = !neg
122 z.abs = z.abs.sub(y.abs, x.abs)
123 }
124 }
125 z.neg = len(z.abs) > 0 && neg
126 return z
127 }
128
129
130 func (z *Int) Sub(x, y *Int) *Int {
131 neg := x.neg
132 if x.neg != y.neg {
133
134
135 z.abs = z.abs.add(x.abs, y.abs)
136 } else {
137
138
139 if x.abs.cmp(y.abs) >= 0 {
140 z.abs = z.abs.sub(x.abs, y.abs)
141 } else {
142 neg = !neg
143 z.abs = z.abs.sub(y.abs, x.abs)
144 }
145 }
146 z.neg = len(z.abs) > 0 && neg
147 return z
148 }
149
150
151 func (z *Int) Mul(x, y *Int) *Int {
152
153
154
155
156 if x == y {
157 z.abs = z.abs.sqr(x.abs)
158 z.neg = false
159 return z
160 }
161 z.abs = z.abs.mul(x.abs, y.abs)
162 z.neg = len(z.abs) > 0 && x.neg != y.neg
163 return z
164 }
165
166
167
168
169 func (z *Int) MulRange(a, b int64) *Int {
170 switch {
171 case a > b:
172 return z.SetInt64(1)
173 case a <= 0 && b >= 0:
174 return z.SetInt64(0)
175 }
176
177
178 neg := false
179 if a < 0 {
180 neg = (b-a)&1 == 0
181 a, b = -b, -a
182 }
183
184 z.abs = z.abs.mulRange(uint64(a), uint64(b))
185 z.neg = neg
186 return z
187 }
188
189
190 func (z *Int) Binomial(n, k int64) *Int {
191
192 if n/2 < k && k <= n {
193 k = n - k
194 }
195 var a, b Int
196 a.MulRange(n-k+1, n)
197 b.MulRange(1, k)
198 return z.Quo(&a, &b)
199 }
200
201
202
203
204 func (z *Int) Quo(x, y *Int) *Int {
205 z.abs, _ = z.abs.div(nil, x.abs, y.abs)
206 z.neg = len(z.abs) > 0 && x.neg != y.neg
207 return z
208 }
209
210
211
212
213 func (z *Int) Rem(x, y *Int) *Int {
214 _, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
215 z.neg = len(z.abs) > 0 && x.neg
216 return z
217 }
218
219
220
221
222
223
224
225
226
227
228
229
230
231 func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) {
232 z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs)
233 z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg
234 return z, r
235 }
236
237
238
239
240 func (z *Int) Div(x, y *Int) *Int {
241 y_neg := y.neg
242 var r Int
243 z.QuoRem(x, y, &r)
244 if r.neg {
245 if y_neg {
246 z.Add(z, intOne)
247 } else {
248 z.Sub(z, intOne)
249 }
250 }
251 return z
252 }
253
254
255
256
257 func (z *Int) Mod(x, y *Int) *Int {
258 y0 := y
259 if z == y || alias(z.abs, y.abs) {
260 y0 = new(Int).Set(y)
261 }
262 var q Int
263 q.QuoRem(x, y, z)
264 if z.neg {
265 if y0.neg {
266 z.Sub(z, y0)
267 } else {
268 z.Add(z, y0)
269 }
270 }
271 return z
272 }
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289 func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) {
290 y0 := y
291 if z == y || alias(z.abs, y.abs) {
292 y0 = new(Int).Set(y)
293 }
294 z.QuoRem(x, y, m)
295 if m.neg {
296 if y0.neg {
297 z.Add(z, intOne)
298 m.Sub(m, y0)
299 } else {
300 z.Sub(z, intOne)
301 m.Add(m, y0)
302 }
303 }
304 return z, m
305 }
306
307
308
309
310
311
312
313 func (x *Int) Cmp(y *Int) (r int) {
314
315
316
317
318 switch {
319 case x.neg == y.neg:
320 r = x.abs.cmp(y.abs)
321 if x.neg {
322 r = -r
323 }
324 case x.neg:
325 r = -1
326 default:
327 r = 1
328 }
329 return
330 }
331
332
333
334
335
336
337
338 func (x *Int) CmpAbs(y *Int) int {
339 return x.abs.cmp(y.abs)
340 }
341
342
343 func low32(x nat) uint32 {
344 if len(x) == 0 {
345 return 0
346 }
347 return uint32(x[0])
348 }
349
350
351 func low64(x nat) uint64 {
352 if len(x) == 0 {
353 return 0
354 }
355 v := uint64(x[0])
356 if _W == 32 && len(x) > 1 {
357 return uint64(x[1])<<32 | v
358 }
359 return v
360 }
361
362
363
364 func (x *Int) Int64() int64 {
365 v := int64(low64(x.abs))
366 if x.neg {
367 v = -v
368 }
369 return v
370 }
371
372
373
374 func (x *Int) Uint64() uint64 {
375 return low64(x.abs)
376 }
377
378
379 func (x *Int) IsInt64() bool {
380 if len(x.abs) <= 64/_W {
381 w := int64(low64(x.abs))
382 return w >= 0 || x.neg && w == -w
383 }
384 return false
385 }
386
387
388 func (x *Int) IsUint64() bool {
389 return !x.neg && len(x.abs) <= 64/_W
390 }
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407 func (z *Int) SetString(s string, base int) (*Int, bool) {
408 return z.setFromScanner(strings.NewReader(s), base)
409 }
410
411
412
413 func (z *Int) setFromScanner(r io.ByteScanner, base int) (*Int, bool) {
414 if _, _, err := z.scan(r, base); err != nil {
415 return nil, false
416 }
417
418 if _, err := r.ReadByte(); err != io.EOF {
419 return nil, false
420 }
421 return z, true
422 }
423
424
425
426 func (z *Int) SetBytes(buf []byte) *Int {
427 z.abs = z.abs.setBytes(buf)
428 z.neg = false
429 return z
430 }
431
432
433 func (x *Int) Bytes() []byte {
434 buf := make([]byte, len(x.abs)*_S)
435 return buf[x.abs.bytes(buf):]
436 }
437
438
439
440 func (x *Int) BitLen() int {
441 return x.abs.bitLen()
442 }
443
444
445
446
447
448
449 func (z *Int) Exp(x, y, m *Int) *Int {
450
451 var yWords nat
452 if !y.neg {
453 yWords = y.abs
454 }
455
456
457 var mWords nat
458 if m != nil {
459 mWords = m.abs
460 }
461
462 z.abs = z.abs.expNN(x.abs, yWords, mWords)
463 z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1
464 if z.neg && len(mWords) > 0 {
465
466 z.abs = z.abs.sub(mWords, z.abs)
467 z.neg = false
468 }
469
470 return z
471 }
472
473
474
475
476
477 func (z *Int) GCD(x, y, a, b *Int) *Int {
478 if a.Sign() <= 0 || b.Sign() <= 0 {
479 z.SetInt64(0)
480 if x != nil {
481 x.SetInt64(0)
482 }
483 if y != nil {
484 y.SetInt64(0)
485 }
486 return z
487 }
488 if x == nil && y == nil {
489 return z.lehmerGCD(a, b)
490 }
491
492 A := new(Int).Set(a)
493 B := new(Int).Set(b)
494
495 X := new(Int)
496 lastX := new(Int).SetInt64(1)
497
498 q := new(Int)
499 temp := new(Int)
500
501 r := new(Int)
502 for len(B.abs) > 0 {
503 q, r = q.QuoRem(A, B, r)
504
505 A, B, r = B, r, A
506
507 temp.Set(X)
508 X.Mul(X, q)
509 X.Sub(lastX, X)
510 lastX.Set(temp)
511 }
512
513 if x != nil {
514 *x = *lastX
515 }
516
517 if y != nil {
518
519 y.Mul(a, lastX)
520 y.Sub(A, y)
521 y.Div(y, b)
522 }
523
524 *z = *A
525 return z
526 }
527
528
529
530
531
532
533
534
535 func (z *Int) lehmerGCD(a, b *Int) *Int {
536
537 if a.abs.cmp(b.abs) < 0 {
538 a, b = b, a
539 }
540
541
542 B := new(Int).Set(b)
543 A := z.Set(a)
544
545
546 t := new(Int)
547 r := new(Int)
548 s := new(Int)
549 w := new(Int)
550
551
552 for len(B.abs) > 1 {
553
554 var a1, a2, u0, u1, u2, v0, v1, v2 Word
555
556 m := len(B.abs)
557 n := len(A.abs)
558
559
560 h := nlz(A.abs[n-1])
561 a1 = (A.abs[n-1] << h) | (A.abs[n-2] >> (_W - h))
562
563 switch {
564 case n == m:
565 a2 = (B.abs[n-1] << h) | (B.abs[n-2] >> (_W - h))
566 case n == m+1:
567 a2 = (B.abs[n-2] >> (_W - h))
568 default:
569 a2 = 0
570 }
571
572
573
574
575
576
577 even := false
578
579 u0, u1, u2 = 0, 1, 0
580 v0, v1, v2 = 0, 0, 1
581
582
583
584
585
586 for a2 >= v2 && a1-a2 >= v1+v2 {
587 q := a1 / a2
588 a1, a2 = a2, a1-q*a2
589 u0, u1, u2 = u1, u2, u1+q*u2
590 v0, v1, v2 = v1, v2, v1+q*v2
591 even = !even
592 }
593
594
595 if v0 != 0 {
596
597
598
599
600 t.abs = t.abs.setWord(u0)
601 s.abs = s.abs.setWord(v0)
602 t.neg = !even
603 s.neg = even
604
605 t.Mul(A, t)
606 s.Mul(B, s)
607
608 r.abs = r.abs.setWord(u1)
609 w.abs = w.abs.setWord(v1)
610 r.neg = even
611 w.neg = !even
612
613 r.Mul(A, r)
614 w.Mul(B, w)
615
616 A.Add(t, s)
617 B.Add(r, w)
618
619 } else {
620
621
622 t.Rem(A, B)
623 A, B, t = B, t, A
624 }
625 }
626
627 if len(B.abs) > 0 {
628
629 if len(A.abs) > 1 {
630
631 t.Rem(A, B)
632 A, B, t = B, t, A
633 }
634 if len(B.abs) > 0 {
635
636 a1, a2 := A.abs[0], B.abs[0]
637 for a2 != 0 {
638 a1, a2 = a2, a1%a2
639 }
640 A.abs[0] = a1
641 }
642 }
643 *z = *A
644 return z
645 }
646
647
648
649
650
651 func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
652 z.neg = false
653 if n.neg || len(n.abs) == 0 {
654 z.abs = nil
655 return z
656 }
657 z.abs = z.abs.random(rnd, n.abs, n.abs.bitLen())
658 return z
659 }
660
661
662
663 func (z *Int) ModInverse(g, n *Int) *Int {
664 if g.neg {
665
666 var g2 Int
667 g = g2.Mod(g, n)
668 }
669 var d Int
670 d.GCD(z, nil, g, n)
671
672
673
674 if z.neg {
675 z.Add(z, n)
676 }
677 return z
678 }
679
680
681
682 func Jacobi(x, y *Int) int {
683 if len(y.abs) == 0 || y.abs[0]&1 == 0 {
684 panic(fmt.Sprintf("big: invalid 2nd argument to Int.Jacobi: need odd integer but got %s", y))
685 }
686
687
688
689
690
691 var a, b, c Int
692 a.Set(x)
693 b.Set(y)
694 j := 1
695
696 if b.neg {
697 if a.neg {
698 j = -1
699 }
700 b.neg = false
701 }
702
703 for {
704 if b.Cmp(intOne) == 0 {
705 return j
706 }
707 if len(a.abs) == 0 {
708 return 0
709 }
710 a.Mod(&a, &b)
711 if len(a.abs) == 0 {
712 return 0
713 }
714
715
716
717 s := a.abs.trailingZeroBits()
718 if s&1 != 0 {
719 bmod8 := b.abs[0] & 7
720 if bmod8 == 3 || bmod8 == 5 {
721 j = -j
722 }
723 }
724 c.Rsh(&a, s)
725
726
727 if b.abs[0]&3 == 3 && c.abs[0]&3 == 3 {
728 j = -j
729 }
730 a.Set(&b)
731 b.Set(&c)
732 }
733 }
734
735
736
737
738
739
740
741 func (z *Int) modSqrt3Mod4Prime(x, p *Int) *Int {
742 e := new(Int).Add(p, intOne)
743 e.Rsh(e, 2)
744 z.Exp(x, e, p)
745 return z
746 }
747
748
749
750 func (z *Int) modSqrtTonelliShanks(x, p *Int) *Int {
751
752 var s Int
753 s.Sub(p, intOne)
754 e := s.abs.trailingZeroBits()
755 s.Rsh(&s, e)
756
757
758 var n Int
759 n.SetInt64(2)
760 for Jacobi(&n, p) != -1 {
761 n.Add(&n, intOne)
762 }
763
764
765
766
767
768 var y, b, g, t Int
769 y.Add(&s, intOne)
770 y.Rsh(&y, 1)
771 y.Exp(x, &y, p)
772 b.Exp(x, &s, p)
773 g.Exp(&n, &s, p)
774 r := e
775 for {
776
777 var m uint
778 t.Set(&b)
779 for t.Cmp(intOne) != 0 {
780 t.Mul(&t, &t).Mod(&t, p)
781 m++
782 }
783
784 if m == 0 {
785 return z.Set(&y)
786 }
787
788 t.SetInt64(0).SetBit(&t, int(r-m-1), 1).Exp(&g, &t, p)
789
790 g.Mul(&t, &t).Mod(&g, p)
791 y.Mul(&y, &t).Mod(&y, p)
792 b.Mul(&b, &g).Mod(&b, p)
793 r = m
794 }
795 }
796
797
798
799
800
801 func (z *Int) ModSqrt(x, p *Int) *Int {
802 switch Jacobi(x, p) {
803 case -1:
804 return nil
805 case 0:
806 return z.SetInt64(0)
807 case 1:
808 break
809 }
810 if x.neg || x.Cmp(p) >= 0 {
811 x = new(Int).Mod(x, p)
812 }
813
814
815 if len(p.abs) > 0 && p.abs[0]%4 == 3 {
816 return z.modSqrt3Mod4Prime(x, p)
817 }
818
819 return z.modSqrtTonelliShanks(x, p)
820 }
821
822
823 func (z *Int) Lsh(x *Int, n uint) *Int {
824 z.abs = z.abs.shl(x.abs, n)
825 z.neg = x.neg
826 return z
827 }
828
829
830 func (z *Int) Rsh(x *Int, n uint) *Int {
831 if x.neg {
832
833 t := z.abs.sub(x.abs, natOne)
834 t = t.shr(t, n)
835 z.abs = t.add(t, natOne)
836 z.neg = true
837 return z
838 }
839
840 z.abs = z.abs.shr(x.abs, n)
841 z.neg = false
842 return z
843 }
844
845
846
847 func (x *Int) Bit(i int) uint {
848 if i == 0 {
849
850 if len(x.abs) > 0 {
851 return uint(x.abs[0] & 1)
852 }
853 return 0
854 }
855 if i < 0 {
856 panic("negative bit index")
857 }
858 if x.neg {
859 t := nat(nil).sub(x.abs, natOne)
860 return t.bit(uint(i)) ^ 1
861 }
862
863 return x.abs.bit(uint(i))
864 }
865
866
867
868
869
870 func (z *Int) SetBit(x *Int, i int, b uint) *Int {
871 if i < 0 {
872 panic("negative bit index")
873 }
874 if x.neg {
875 t := z.abs.sub(x.abs, natOne)
876 t = t.setBit(t, uint(i), b^1)
877 z.abs = t.add(t, natOne)
878 z.neg = len(z.abs) > 0
879 return z
880 }
881 z.abs = z.abs.setBit(x.abs, uint(i), b)
882 z.neg = false
883 return z
884 }
885
886
887 func (z *Int) And(x, y *Int) *Int {
888 if x.neg == y.neg {
889 if x.neg {
890
891 x1 := nat(nil).sub(x.abs, natOne)
892 y1 := nat(nil).sub(y.abs, natOne)
893 z.abs = z.abs.add(z.abs.or(x1, y1), natOne)
894 z.neg = true
895 return z
896 }
897
898
899 z.abs = z.abs.and(x.abs, y.abs)
900 z.neg = false
901 return z
902 }
903
904
905 if x.neg {
906 x, y = y, x
907 }
908
909
910 y1 := nat(nil).sub(y.abs, natOne)
911 z.abs = z.abs.andNot(x.abs, y1)
912 z.neg = false
913 return z
914 }
915
916
917 func (z *Int) AndNot(x, y *Int) *Int {
918 if x.neg == y.neg {
919 if x.neg {
920
921 x1 := nat(nil).sub(x.abs, natOne)
922 y1 := nat(nil).sub(y.abs, natOne)
923 z.abs = z.abs.andNot(y1, x1)
924 z.neg = false
925 return z
926 }
927
928
929 z.abs = z.abs.andNot(x.abs, y.abs)
930 z.neg = false
931 return z
932 }
933
934 if x.neg {
935
936 x1 := nat(nil).sub(x.abs, natOne)
937 z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne)
938 z.neg = true
939 return z
940 }
941
942
943 y1 := nat(nil).sub(y.abs, natOne)
944 z.abs = z.abs.and(x.abs, y1)
945 z.neg = false
946 return z
947 }
948
949
950 func (z *Int) Or(x, y *Int) *Int {
951 if x.neg == y.neg {
952 if x.neg {
953
954 x1 := nat(nil).sub(x.abs, natOne)
955 y1 := nat(nil).sub(y.abs, natOne)
956 z.abs = z.abs.add(z.abs.and(x1, y1), natOne)
957 z.neg = true
958 return z
959 }
960
961
962 z.abs = z.abs.or(x.abs, y.abs)
963 z.neg = false
964 return z
965 }
966
967
968 if x.neg {
969 x, y = y, x
970 }
971
972
973 y1 := nat(nil).sub(y.abs, natOne)
974 z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne)
975 z.neg = true
976 return z
977 }
978
979
980 func (z *Int) Xor(x, y *Int) *Int {
981 if x.neg == y.neg {
982 if x.neg {
983
984 x1 := nat(nil).sub(x.abs, natOne)
985 y1 := nat(nil).sub(y.abs, natOne)
986 z.abs = z.abs.xor(x1, y1)
987 z.neg = false
988 return z
989 }
990
991
992 z.abs = z.abs.xor(x.abs, y.abs)
993 z.neg = false
994 return z
995 }
996
997
998 if x.neg {
999 x, y = y, x
1000 }
1001
1002
1003 y1 := nat(nil).sub(y.abs, natOne)
1004 z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne)
1005 z.neg = true
1006 return z
1007 }
1008
1009
1010 func (z *Int) Not(x *Int) *Int {
1011 if x.neg {
1012
1013 z.abs = z.abs.sub(x.abs, natOne)
1014 z.neg = false
1015 return z
1016 }
1017
1018
1019 z.abs = z.abs.add(x.abs, natOne)
1020 z.neg = true
1021 return z
1022 }
1023
1024
1025
1026 func (z *Int) Sqrt(x *Int) *Int {
1027 if x.neg {
1028 panic("square root of negative number")
1029 }
1030 z.neg = false
1031 z.abs = z.abs.sqrt(x.abs)
1032 return z
1033 }
1034
View as plain text