challenge 魔方分割数

1 .. N^2までの数をN個の数字の和が等しいN個のグループに分けたいと思います。

たとえば、N=3のときは、
(1) { 1, 5, 9 }, { 2, 6, 7 }, { 3, 4, 8 } 
(2) { 1, 6, 8 }, { 2, 4, 9 }, { 3, 5, 7 }
の2通りの方法があります。

ここで指定されたNに対して、何通りのグループ分けの方法があるかを数えるプログラムを作ってください。
(何通りかという値だけが出力されればよいのですが、予め計算してある結果を返すのはダメですよ。)
また、N=5を指定したときの実行時間もあわせて教えてください。

なお、数え上げるときの注意として、

・{ 1, 5, 9 } と { 1, 9, 5 }は同じもの

・{ 1, 5, 9 }, { 2, 6, 7 }, { 3, 4, 8 }と
 { 1, 5, 9 }, { 3, 4, 8 }, { 2, 6, 7 }は同じもの
とすることに注意してください。

Posted feedbacks - Flatten

Nested Hidden
実装がナイーブ過ぎてN=4が実行できませんでした。

平均値計算の式が↓のようになっていて、これに気づくのに三十分かかった。
ave = N*(N*2+1) / 2

枝狩りをもうちょっとがんばるかなぁ。
アルゴリズムはこんな感じ。
1.総当りでペアを出す
2.ペアをソートする
3.文字列にキャストしてsetにぶち込んでユニーク化
4.何個残ってるか?

コード書き終わってから、魔方陣の書き方って確かあったよなぁ。
と思って、Wikipedia先生に聞いてみたら二次元用であった。
これを応用すれば、うまいこと出てこないかなぁ。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import copy
import time
def mahobunkatsu(N):
    m = set()
    for i in xrange(1,N**2+1):
        m.add(i)
    ave = N*(N**2+1) / 2
    
    def createPair(restNumber,numberSum, pair, count, pairList):
        for i in restNumber:
            if count and (count+1)%N == 0:
                if numberSum + i == ave:
                    if count+1 == N**2:
                        pair.append(i)
                        pairList.append(copy.copy(pair))
                        pair.pop()
                        continue
                else:
                    continue
            if numberSum + i > ave:
                continue
            restNumber.remove(i)
            pair.append(i)
            createPair(restNumber, (numberSum + i) % ave, pair, count + 1, pairList)
            restNumber.add(i)
            pair.pop()
        return

    pairList = []
    createPair(m, 0, [], 0, pairList)

    def listToSortedPairs(pairList):
        temp = []
        for l in pairList:
            temp.append([])
            for i in range(N):
                temp[-1].append(l[i*N:(i+1)*N])
                temp[-1][-1].sort()
                temp[-1].sort()
        return temp

    sortedList = listToSortedPairs(pairList)

    uniqueList = set()
    for x in sortedList:
        uniqueList.add(str(x))

    print uniqueList
    return uniqueList


t = time.time()
for x in range(2,6):
    n = len(mahobunkatsu(x))
    print "Size=",x,"Mahozin_num=",n,"time=",time.time()-t
    t = time.time()

投稿してから、プログラムにあまり影響しないバグを見つけた。
39行目、タブが一個多いです。
要素を一個追加するごとにソートしているので、不毛。

グループ未定の数字の最初のものpは、まだ決まってないグループ(X)に。
Xの残りの(n-1)個は、残っている数字から、和が(s/n-p)になるものとする。
最後まで行ったらカウンタに1加算。
初めに戻る。

(ComplementやSubsetsは常にソートされている)

In[3]:= f[3]
Out[3]= 2

In[4]:= f[4]
Out[4]= 392(Core2 6700で0.06秒)

In[5]:= f[5]
Out[5]= 3245664(Core2 6700で690秒)
1
2
3
4
5
6
f[n_] := f[Range[n^2], n, Total@Range[n^2]/n]

f[in_, n_, s_] :=
  If[Length@in == n, 1,
    Total[f[Complement[Rest@in, #], n, s] & /@ (
    Select[Subsets[Rest@in, {n - 1}], First@in + Total@# == s &])]]

それなりの工夫はしたつもりだけど、それでも超遅い...orz
答が合ってるのかすら心配。

Mac OS X 10.5 / PPC G5/1.6GHz mem 1GB な環境で、

% time ./numbers.native 4
n = 4 => 151 patterns
./numbers.native 4  0.03s user 0.01s system 27% cpu 0.141 total

n = 5 は 5 分くらい待っても終わらなかったのであきらめた。
方針としては、

* 各セットの合計は 1 ~ n までの合計を n で割ったものになるので、そうなる組み合わせを生成
* 先頭が 1 ~ n のものが並ぶはずなので、そこまでしか計算しない
* OCaml の Set モジュールは整列済みなので、それを利用して多少枝刈りしているつもり

といった感じ。
n = 5 が終わらないのは、メモリ使用量の問題もありそうなので、完全なセットを作ってしまわずに逐次表示するようなアプローチにすれば、もう少し何とかなるかも。

でももう頭がパンクしそうなので、とりあえずこれで。
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
module NS = Set.Make(
struct
   type t = int
   let compare = Pervasives.compare
end)

module SS = Set.Make(
struct
   type t = NS.t
   let compare = NS.compare
end)

module LS = Set.Make(
struct
   type t = SS.t
   let compare = SS.compare
end)

let make_set_with_sum n =
   let rec loop set sum = function
      | 0  -> (set, sum)
      | n' -> loop (NS.add n' set) (sum + n') (n' - 1)
   in
   loop NS.empty 0 n

let ns_map f ns = NS.fold (fun s acc -> NS.add (f s) acc) ns NS.empty
let ss_map f ss = SS.fold (fun s acc -> SS.add (f s) acc) ss SS.empty
let ls_map f ls = LS.fold (fun s acc -> LS.add (f s) acc) ls LS.empty

let rec take_subsets set num limit =
   if NS.is_empty set || NS.cardinal set < num || NS.choose set > limit
   then SS.empty
   else begin
      match num with
      | 1 ->
           if NS.mem limit set
           then SS.singleton (NS.singleton limit)
           else SS.empty
      | n when n > 1 ->
           let result = ref SS.empty in
           NS.iter begin fun i ->
              let set' = NS.remove i set in
              let num' = num - 1 in
              let limit' = limit - i in
              begin match take_subsets set' num' limit' with
              | s when SS.is_empty s -> ()
              | ss -> result := SS.union !result (ss_map (NS.add i) ss)
              end
           end set;
           !result
      | _ -> invalid_arg "num is required positive number."
   end

let make_sets n =
   if n < 2 then invalid_arg "required greater than or equal to 2.";
   let full, max = make_set_with_sum (n * n) in
   let heads, _  = make_set_with_sum n in
   let diff  = NS.diff full heads in
   let limit = max / n in
   let subsets =
      NS.fold begin fun i acc ->
         LS.add
            (ss_map (NS.add i) (take_subsets diff (n - 1) (limit - i)))
            acc
      end heads LS.empty
   in
   let inters =
      let hd = LS.choose subsets in
      let tl = LS.remove hd subsets in
      LS.fold begin fun ss acc ->
         let result = ref LS.empty in
         LS.iter begin fun ss' ->
            SS.iter begin fun ns ->
               if SS.for_all (fun s -> NS.is_empty (NS.inter s ns)) ss'
               then result := LS.add (SS.add ns ss') !result
            end ss
         end acc;
         !result
      end tl (SS.fold (fun e acc -> LS.add (SS.singleton e) acc) hd LS.empty)
   in
   inters

let ns_print ns =
   print_string "{ ";
   NS.iter (Printf.printf "%d, ") ns;
   print_string "}"
let ss_print ss =
   SS.iter begin fun s ->
      ns_print s;
      print_string ", "
   end ss;
   print_newline ()
let ls_print ls = LS.iter ss_print ls

let exam n = LS.cardinal (make_sets n)

let main () =
   let num =
      match Sys.argv with
      | [|_; n |] -> int_of_string n
      | _ -> 3
   in
   Printf.printf "n = %d => %d patterns\n" num (exam num)
let () = if not !Sys.interactive then main ()

とりあえず単純に
(1)合計が平均値になる組み合わせを生成
(2)(1)から要素が重ならない組み合わせを生成

N=4で62秒(PenM 1.7GHz)、解は392
N=5はこのやり方では無理ですね。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def gen_comb(list,n)
  if n==0
    yield([])
  else
    (list.size-n+1).times do |i|
      gen_comb(list[i+1..-1],n-1) do|ls|
        yield([list[i]]+ls)
      end
    end
  end
end

def comb_num(list,n,sum)
  ret=[]
  gen_comb(list,n) do |a|
    ret<<=a if (a.inject(0){|r,e| r+=e})==sum
  end
  ret
end

def comb_array(list,n)
  ret=[]
  gen_comb(list,n) do |a|
    ret<<=a if a.inject{|r,e| break unless (r&e).empty?;r+=e}
  end
  ret
end

def maho(n)
  m=n**2
  sum=m*(m+1)/2/n
  comb_array(comb_num((1..m).to_a,n,sum),n)
end

start_time = Time.now
puts maho(4).size
puts Time.now-start_time

>先頭が 1 ~ n のものが並ぶはず
確かに、そうなるかな?と思いたくなるんですが、違うみたいです:
{1, 2, 15, 16}, {3, 4, 13, 14}, {5, 6, 11, 12}, {7, 8, 9, 10}

和が N*(N^2+1)/2 になる組み合わせを昇順に列挙して交わらないものを探す。nido さんの #4819 と同じ方針かな?

5 のときは 20 分かかって答えが出ました。 3245664 だそうです。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
(* list solutions of
     x1+...+xn=m
     1<=x1<...<xn<=l *)
let list_solutions m n k =
  let rec loop m n lbd tail =
    if m < n || m < lbd then []
    else if n = 1 then
      if lbd <= m && m <= k then [m::tail] else []
    else
      let acc = ref [] in
        for i = lbd to (min (m-n+1) k) do
          acc :=
            List.map (fun sols -> i::sols)
              (loop (m-i) (n-1) (i+1) tail) :: !acc
        done;
        List.concat !acc
  in loop m n 1

let rec disjoint xs ys =
  match xs, ys with
    | [], _ | _, [] -> true
    | x::xs', y::ys' ->
        if x > y then disjoint xs ys'
        else if x < y then disjoint xs' ys
        else false

let (@<) = List.merge compare

let rec count_choices n lists ex =
  if n = 1 then
    List.fold_left (fun i x -> if disjoint x ex then i+1 else i) 0 lists
  else
    match lists with
      | [] -> 0
      | x::rest ->
          let b = count_choices n rest ex in
            if disjoint x ex then
              b + count_choices (n-1) rest (x@<ex)
            else
              b

let count_partitions n =
  count_choices n (list_solutions (n * (n*n + 1) / 2) n (n*n) []) []

let _ = print_int (count_partitions 5); exit 0

filterとcombinationsを覚えた

;Pen4 3GHzで
;(time (maho 4)) => 392
;real 0.109/user 0.109/sys 0.000
;(time (maho 5)) => 3245664
;real 1528.250/user 1423.391/sys 104.187
;(time (maho-by-enm 4)) => 392
;real 254.094/user 249.000/sys 1.531

maho-by-enmは#4819や#4821と同じ方法です
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
(use srfi-1)
(use util.combinations)
(define (maho n)
  (define (maho-in m l)
    (if (= m 0) 1
      (letrec
       ((cn (/ (* n (+ (* n n) 1)) 2))
        (mylist
         (map (lambda (li) (cons (car l) li))
              (filter
               (lambda (li) (= (apply + li) (- cn (car l))))
               (combinations (cdr l) (- n 1))))))
       (apply + (map (lambda (a) (maho-in (- m n) (lset-difference equal? l a))) mylist)))))
  (maho-in (* n n) (iota (* n n) 1)))

(define (maho-by-enm n)
  (define (center n) (/ (* n (+ (* n n) 1)) 2))
  (define (flatten2 l c)
    (define (flatten1 l c)
      (if (null? l) c
        (cons (car l) (flatten1 (cdr l) c))))
    (if (null? l) c
      (flatten2 (cdr l) (flatten1 (car l) c))))
  (define (my-equal? l1 l2)
    (null? (lset-xor eq? (flatten2 l1 '()) l2)))
  (define (enm-n n)
    (filter (lambda (l) (= (apply + l) (center n)))
            (combinations (iota (* n n) 1) n)))
  (filter (lambda (l) (my-equal? l (iota (* n n) 1)))
          (combinations (enm-n n) n)))

がーん……ダメすぎる...orz


タグ付け忘れた. Gaucheです.

同じ計算を何回もしているので#1671を参考にメモ化しました.
n=5で2分と少々です. メモリ使用量は11MB (多分).
Pen4 3GHzです.

gosh> ;(time (maho-memo 4))
; real   0.078
; user   0.078
; sys    0.000
392
gosh> ;(time (maho-memo 5))
; real 130.656
; user 124.562
; sys    4.828
3245664
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
(define (maho-memo n)
  (define cn (/ (* n (+ (* n n) 1)) 2))
  (define maho-in-memo
    (let1
     tab (make-hash-table 'equal?)
     (define (memo m l v) (hash-table-put! tab (cons m l) v) v)
     (lambda (m l)
       (cond ((= m 0) 1)
             ((hash-table-get tab (cons m l) #f))
             (else
              (memo m l
                    (apply
                     +
                     (map
                      (lambda (a) (maho-in-memo (- m n) (lset-difference equal? l a)))
                      (map
                       (lambda (li) (cons (car l) li))
                       (filter
                        (lambda (li) (= (apply + li) (- cn (car l))))
                        (combinations (cdr l) (- n 1))))))))))))
  (maho-in-memo (* n n) (iota (* n n) 1)))

二重投稿orz 申し訳ないんですが, #4828の方を消してもらえませんでしょうか?


タブが多いとバグになるというのが、python使わない人からすると新鮮です。


Pen4 2.4Gで実行して、

4の場合
392 : 31(ms)
5の場合
3245664 : 224157(ms)

と、なりました。 簡単に枝狩りをしただけの、総当りです。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class Answer108 {
    private final int size_;
    private final int maxNumber_;
    private final int average_;

    private int count_ = 0;

    public Answer108(int size) {
        size_ = size;
        maxNumber_ = size * size;
        average_ = size * (maxNumber_ + 1) / 2;

        countPair();
    }

    private void countPair() {
        if (size_ <= 1) return;
        List<List<Integer>> array = new ArrayList<List<Integer>>();
        for (int index = 0; index < size_; index++) {
            array.add(new ArrayList<Integer>());
        }
        array.get(0).add(1);
        countPair(array, 2);
    }
    private void countPair(List<List<Integer>> array, int nextNumber) {
        if (nextNumber <= maxNumber_) {
            for (int index = 0; index < size_; index++) {
                List<Integer> list = array.get(index);
                int size = list.size();
                if (size == size_) continue;
                if (size == size_ - 1) {
                    if (sum(list) + nextNumber != average_) continue;
                } else {
                    int rest = 0;
                    for (int lastIndex = 0; lastIndex < size_ - size - 1; lastIndex++) {
                        rest += maxNumber_ - lastIndex;
                    }
                    if (sum(list) + nextNumber + rest < average_) continue;
                }

                if (index >= nextNumber) continue;
                list.add(nextNumber);
                countPair(array, nextNumber + 1);
                list.remove(Integer.valueOf(nextNumber));
                if (list.size() == 0) break;
            }
        } else {
            //System.out.println(toString(array));
            count_++;
        }
    }
    private int sum(List<Integer> array) {
        int sum = 0;
        for (int num: array) {
            sum += num;
        }
        return sum;
    }

    public int getCount() {
        return count_;
    }


    public static String toString(List<List<Integer>> array) {
        String[] strs = new String[array.size()];
        for (int index = 0; index < strs.length; index++) {
            strs[index] = array.get(index).toString();
        }
        return Arrays.toString(strs);
    }

    public static void main(String[] args) {
        long start = System.currentTimeMillis();
        Answer108 ans = new Answer108(5);
        System.out.println(ans.getCount());
        long end = System.currentTimeMillis();
        System.out.println("elapse: " + (end - start) + "(ms)");
    }
}

枝刈りしようとしたらわけがわからなくなったので Common Lisp で書き直し。

最初に組み合わせを求めた後で「1を含む組」「1を含まなくて2を含む組」「1, 2を含まなくて3を含む組」……と分類します。こうすると、グループごとに探索の対象とするかしないかを決めることができてかなり範囲が狭まるようです。

disjoint がボトルネックになるようなのでここだけ最適化をかけています。実行時間は SBCL で 28 秒でした。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
(defun list-solutions (m n lbd ubd)
  (cond ((not (<= (* lbd n) m (* ubd n))) ())
        ((= n 1) `((,m)))
        (t
         (loop for i from lbd to ubd nconc
           (mapl (lambda (l) (push i (car l)))
                 (list-solutions (- m i) (1- n) (1+ i) ubd))))))

(defun group-solutions (m n lbd ubd)
  (loop with sols = (list-solutions m n lbd ubd)
    for i from 1
    for s1 = sols then s2
    for s2 = (member i s1 :key #'car :test #'/=)
    while s1 collect (cons i (ldiff s1 s2))))

(defun disjoint (l1 l2) ; l1, l2 must be sorted
  (declare (optimize (speed 3) (safety 0)))
  (or (null l1)
      (null l2)
      (let ((a (car l1)) (b (car l2)))
        (declare (fixnum a b))
        (cond ((> a b) (disjoint l1 (cdr l2)))
              ((< a b) (disjoint (cdr l1) l2))
              (t nil)))))

(defun merge-list (l1 l2) ; l1, l2 must be sorted
  (do* ((head (cons () ()))
        (tail head (cdr tail)))
      (())
    (cond ((endp l1) (setf (cdr tail) l2) (return (cdr head)))
          ((endp l2) (setf (cdr tail) l1) (return (cdr head)))
          (t
           (let ((a (car l1)) (b (car l2)))
             (cond ((> a b)
                    (setf (cdr tail) (list b)
                          l2 (cdr l2)))
                   ((< a b)
                    (setf (cdr tail) (list a)
                          l1 (cdr l1)))
                   (t
                    (setf (cdr tail) (list a)
                          l1 (cdr l1)
                          l2 (cdr l2)))))))))

(defun count-choices (n lists ex)
  (if (= n 1)
      (loop for x in (cdar lists) count (disjoint x ex))
    (loop for x in (cdar lists) if (disjoint x ex) sum
      (count-choices (1- n)
                     (remove-if (lambda (a) (member (car a) x)) (cdr lists))
                     (merge-list x ex)))))

(defun count-partitions (n)
  (count-choices n (group-solutions (/ (* n (1+ (* n n))) 2) n 1 (* n n)) ()))

無駄に読みにくい書き方をしてました。こっちのほうが普通でしょう。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
(defun merge-list (l1 l2)
  (let (acc)
    (loop
      (if (endp l1) (return (nreconc acc l2)))
      (if (endp l2) (return (nreconc acc l1)))
      (let ((a (car l1)) (b (car l2)))
        (cond ((> a b)
               (push b acc)
               (setf l2 (cdr l2)))
              ((< a b)
               (push a acc)
               (setf l1 (cdr l1)))
              (t
               (push a acc)
               (setf l1 (cdr l1) l2 (cdr l2))))))))

$ g++ -O3 maho.cpp && time ./a.out
3245664

real    0m26.930s
user    0m22.310s
sys     0m4.560s

最初に和が (総和/n) となるn個の値の組を
ビット列として全パターン生成してしまいます。
そうしておいて、ビットパターンの中から排他的なものを選んでいくアルゴリズム。

それなりに速いかと。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <iostream>
#include <vector>
typedef unsigned int bit_t;

std::vector<bit_t> bits;
int n, n2, cnt = 0;

void comb(int a, int k, int s, bit_t b) {
  if (k < n-1)
    for (int i = a; i < n2; ++i)
      comb(i+1, k+1, s-i, b | 1 << (i-1));
  else
    if (a <= s && s <= n2)
      bits.push_back(b | 1 << (s-1));
}

void calc(int s, int k, bit_t b) {
  if (k == n) { ++cnt; return; }

  for (int i = s; i < (int)bits.size(); ++i)
    if (!(b & bits[i]))
      calc(i+1, k+1, b | bits[i]);
}

int main() {
  n = 5;
  n2 = n * n;
  int m = n * (n2+1) / 2;

  comb(1, 0, m, 0);
  calc(0, 0, 0);

  std::cout << cnt << std::endl;
}

ビットの集合を生成後、ソートするようにしたら17秒ぐらいになりました:p


n-1個の組を見つけた後、最後の一つを二分探索するようにしたら n=5 で5秒ぐらいになりました。
real    0m4.920s
user    0m2.920s
sys     0m1.980s

n=6 もやろうとしてみたのですが、32ビットでは全然足りなくて、
64ビットに収まるかどうかも怪しい感じですので、
解を1つずつカウントする方針では根本的に駄目そうです。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
typedef unsigned int bit_t;

vector<bit_t> bits;
bit_t mask;
int n, n2;
long long cnt = 0LL;

void comb(int a, int k, bit_t b, int rest) {
  if (k == n-1) {
    if (rest <= n2)
      bits.push_back(b | 1 << (rest-1));
    return;
  }
  for (int i = a; i < n2; ++i)
    if (rest-i > i)
      comb(i+1, k+1, (b | 1 << (i-1)), rest-i);
}

void calc(int a, int k, bit_t b) {
  if (k == n-1) {
    if (binary_search(bits.begin()+a, bits.end(), mask & ~b))
      ++cnt;
    return;
  }
  for (int i = a; i < (int)bits.size(); ++i)
    if (!(b & bits[i]))
      calc(i+1, k+1, b | bits[i]);
}

int main(int argc, char **argv) {
  n = argc > 1 ? atoi(argv[1]) : 5;
  n2 = n * n;
  mask = (1 << n2) - 1;
  int m = n * (n2+1) / 2;

  comb(1, 0, 0, m);
  sort(bits.begin(), bits.end());
  calc(0, 0, 0);

  cout << cnt << endl;
}