challenge 魔方分割数

1 .. N^2までの数をN個の数字の和が等しいN個のグループに分けたいと思います。

たとえば、N=3のときは、
(1) { 1, 5, 9 }, { 2, 6, 7 }, { 3, 4, 8 } 
(2) { 1, 6, 8 }, { 2, 4, 9 }, { 3, 5, 7 }
の2通りの方法があります。

ここで指定されたNに対して、何通りのグループ分けの方法があるかを数えるプログラムを作ってください。
(何通りかという値だけが出力されればよいのですが、予め計算してある結果を返すのはダメですよ。)
また、N=5を指定したときの実行時間もあわせて教えてください。

なお、数え上げるときの注意として、

・{ 1, 5, 9 } と { 1, 9, 5 }は同じもの

・{ 1, 5, 9 }, { 2, 6, 7 }, { 3, 4, 8 }と
 { 1, 5, 9 }, { 3, 4, 8 }, { 2, 6, 7 }は同じもの
とすることに注意してください。

Posted feedbacks - Nested

Flatten Hidden
実装がナイーブ過ぎてN=4が実行できませんでした。

平均値計算の式が↓のようになっていて、これに気づくのに三十分かかった。
ave = N*(N*2+1) / 2

枝狩りをもうちょっとがんばるかなぁ。
アルゴリズムはこんな感じ。
1.総当りでペアを出す
2.ペアをソートする
3.文字列にキャストしてsetにぶち込んでユニーク化
4.何個残ってるか?

コード書き終わってから、魔方陣の書き方って確かあったよなぁ。
と思って、Wikipedia先生に聞いてみたら二次元用であった。
これを応用すれば、うまいこと出てこないかなぁ。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import copy
import time
def mahobunkatsu(N):
    m = set()
    for i in xrange(1,N**2+1):
        m.add(i)
    ave = N*(N**2+1) / 2
    
    def createPair(restNumber,numberSum, pair, count, pairList):
        for i in restNumber:
            if count and (count+1)%N == 0:
                if numberSum + i == ave:
                    if count+1 == N**2:
                        pair.append(i)
                        pairList.append(copy.copy(pair))
                        pair.pop()
                        continue
                else:
                    continue
            if numberSum + i > ave:
                continue
            restNumber.remove(i)
            pair.append(i)
            createPair(restNumber, (numberSum + i) % ave, pair, count + 1, pairList)
            restNumber.add(i)
            pair.pop()
        return

    pairList = []
    createPair(m, 0, [], 0, pairList)

    def listToSortedPairs(pairList):
        temp = []
        for l in pairList:
            temp.append([])
            for i in range(N):
                temp[-1].append(l[i*N:(i+1)*N])
                temp[-1][-1].sort()
                temp[-1].sort()
        return temp

    sortedList = listToSortedPairs(pairList)

    uniqueList = set()
    for x in sortedList:
        uniqueList.add(str(x))

    print uniqueList
    return uniqueList


t = time.time()
for x in range(2,6):
    n = len(mahobunkatsu(x))
    print "Size=",x,"Mahozin_num=",n,"time=",time.time()-t
    t = time.time()
投稿してから、プログラムにあまり影響しないバグを見つけた。
39行目、タブが一個多いです。
要素を一個追加するごとにソートしているので、不毛。

タブが多いとバグになるというのが、python使わない人からすると新鮮です。

グループ未定の数字の最初のものpは、まだ決まってないグループ(X)に。
Xの残りの(n-1)個は、残っている数字から、和が(s/n-p)になるものとする。
最後まで行ったらカウンタに1加算。
初めに戻る。

(ComplementやSubsetsは常にソートされている)

In[3]:= f[3]
Out[3]= 2

In[4]:= f[4]
Out[4]= 392(Core2 6700で0.06秒)

In[5]:= f[5]
Out[5]= 3245664(Core2 6700で690秒)
1
2
3
4
5
6
f[n_] := f[Range[n^2], n, Total@Range[n^2]/n]

f[in_, n_, s_] :=
  If[Length@in == n, 1,
    Total[f[Complement[Rest@in, #], n, s] & /@ (
    Select[Subsets[Rest@in, {n - 1}], First@in + Total@# == s &])]]
>先頭が 1 ~ n のものが並ぶはず
確かに、そうなるかな?と思いたくなるんですが、違うみたいです:
{1, 2, 15, 16}, {3, 4, 13, 14}, {5, 6, 11, 12}, {7, 8, 9, 10}

がーん……ダメすぎる...orz

とりあえず単純に
(1)合計が平均値になる組み合わせを生成
(2)(1)から要素が重ならない組み合わせを生成

N=4で62秒(PenM 1.7GHz)、解は392
N=5はこのやり方では無理ですね。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def gen_comb(list,n)
  if n==0
    yield([])
  else
    (list.size-n+1).times do |i|
      gen_comb(list[i+1..-1],n-1) do|ls|
        yield([list[i]]+ls)
      end
    end
  end
end

def comb_num(list,n,sum)
  ret=[]
  gen_comb(list,n) do |a|
    ret<<=a if (a.inject(0){|r,e| r+=e})==sum
  end
  ret
end

def comb_array(list,n)
  ret=[]
  gen_comb(list,n) do |a|
    ret<<=a if a.inject{|r,e| break unless (r&e).empty?;r+=e}
  end
  ret
end

def maho(n)
  m=n**2
  sum=m*(m+1)/2/n
  comb_array(comb_num((1..m).to_a,n,sum),n)
end

start_time = Time.now
puts maho(4).size
puts Time.now-start_time

和が N*(N^2+1)/2 になる組み合わせを昇順に列挙して交わらないものを探す。nido さんの #4819 と同じ方針かな?

5 のときは 20 分かかって答えが出ました。 3245664 だそうです。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
(* list solutions of
     x1+...+xn=m
     1<=x1<...<xn<=l *)
let list_solutions m n k =
  let rec loop m n lbd tail =
    if m < n || m < lbd then []
    else if n = 1 then
      if lbd <= m && m <= k then [m::tail] else []
    else
      let acc = ref [] in
        for i = lbd to (min (m-n+1) k) do
          acc :=
            List.map (fun sols -> i::sols)
              (loop (m-i) (n-1) (i+1) tail) :: !acc
        done;
        List.concat !acc
  in loop m n 1

let rec disjoint xs ys =
  match xs, ys with
    | [], _ | _, [] -> true
    | x::xs', y::ys' ->
        if x > y then disjoint xs ys'
        else if x < y then disjoint xs' ys
        else false

let (@<) = List.merge compare

let rec count_choices n lists ex =
  if n = 1 then
    List.fold_left (fun i x -> if disjoint x ex then i+1 else i) 0 lists
  else
    match lists with
      | [] -> 0
      | x::rest ->
          let b = count_choices n rest ex in
            if disjoint x ex then
              b + count_choices (n-1) rest (x@<ex)
            else
              b

let count_partitions n =
  count_choices n (list_solutions (n * (n*n + 1) / 2) n (n*n) []) []

let _ = print_int (count_partitions 5); exit 0

枝刈りしようとしたらわけがわからなくなったので Common Lisp で書き直し。

最初に組み合わせを求めた後で「1を含む組」「1を含まなくて2を含む組」「1, 2を含まなくて3を含む組」……と分類します。こうすると、グループごとに探索の対象とするかしないかを決めることができてかなり範囲が狭まるようです。

disjoint がボトルネックになるようなのでここだけ最適化をかけています。実行時間は SBCL で 28 秒でした。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
(defun list-solutions (m n lbd ubd)
  (cond ((not (<= (* lbd n) m (* ubd n))) ())
        ((= n 1) `((,m)))
        (t
         (loop for i from lbd to ubd nconc
           (mapl (lambda (l) (push i (car l)))
                 (list-solutions (- m i) (1- n) (1+ i) ubd))))))

(defun group-solutions (m n lbd ubd)
  (loop with sols = (list-solutions m n lbd ubd)
    for i from 1
    for s1 = sols then s2
    for s2 = (member i s1 :key #'car :test #'/=)
    while s1 collect (cons i (ldiff s1 s2))))

(defun disjoint (l1 l2) ; l1, l2 must be sorted
  (declare (optimize (speed 3) (safety 0)))
  (or (null l1)
      (null l2)
      (let ((a (car l1)) (b (car l2)))
        (declare (fixnum a b))
        (cond ((> a b) (disjoint l1 (cdr l2)))
              ((< a b) (disjoint (cdr l1) l2))
              (t nil)))))

(defun merge-list (l1 l2) ; l1, l2 must be sorted
  (do* ((head (cons () ()))
        (tail head (cdr tail)))
      (())
    (cond ((endp l1) (setf (cdr tail) l2) (return (cdr head)))
          ((endp l2) (setf (cdr tail) l1) (return (cdr head)))
          (t
           (let ((a (car l1)) (b (car l2)))
             (cond ((> a b)
                    (setf (cdr tail) (list b)
                          l2 (cdr l2)))
                   ((< a b)
                    (setf (cdr tail) (list a)
                          l1 (cdr l1)))
                   (t
                    (setf (cdr tail) (list a)
                          l1 (cdr l1)
                          l2 (cdr l2)))))))))

(defun count-choices (n lists ex)
  (if (= n 1)
      (loop for x in (cdar lists) count (disjoint x ex))
    (loop for x in (cdar lists) if (disjoint x ex) sum
      (count-choices (1- n)
                     (remove-if (lambda (a) (member (car a) x)) (cdr lists))
                     (merge-list x ex)))))

(defun count-partitions (n)
  (count-choices n (group-solutions (/ (* n (1+ (* n n))) 2) n 1 (* n n)) ()))

無駄に読みにくい書き方をしてました。こっちのほうが普通でしょう。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
(defun merge-list (l1 l2)
  (let (acc)
    (loop
      (if (endp l1) (return (nreconc acc l2)))
      (if (endp l2) (return (nreconc acc l1)))
      (let ((a (car l1)) (b (car l2)))
        (cond ((> a b)
               (push b acc)
               (setf l2 (cdr l2)))
              ((< a b)
               (push a acc)
               (setf l1 (cdr l1)))
              (t
               (push a acc)
               (setf l1 (cdr l1) l2 (cdr l2))))))))
filterとcombinationsを覚えた

;Pen4 3GHzで
;(time (maho 4)) => 392
;real 0.109/user 0.109/sys 0.000
;(time (maho 5)) => 3245664
;real 1528.250/user 1423.391/sys 104.187
;(time (maho-by-enm 4)) => 392
;real 254.094/user 249.000/sys 1.531

maho-by-enmは#4819や#4821と同じ方法です
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
(use srfi-1)
(use util.combinations)
(define (maho n)
  (define (maho-in m l)
    (if (= m 0) 1
      (letrec
       ((cn (/ (* n (+ (* n n) 1)) 2))
        (mylist
         (map (lambda (li) (cons (car l) li))
              (filter
               (lambda (li) (= (apply + li) (- cn (car l))))
               (combinations (cdr l) (- n 1))))))
       (apply + (map (lambda (a) (maho-in (- m n) (lset-difference equal? l a))) mylist)))))
  (maho-in (* n n) (iota (* n n) 1)))

(define (maho-by-enm n)
  (define (center n) (/ (* n (+ (* n n) 1)) 2))
  (define (flatten2 l c)
    (define (flatten1 l c)
      (if (null? l) c
        (cons (car l) (flatten1 (cdr l) c))))
    (if (null? l) c
      (flatten2 (cdr l) (flatten1 (car l) c))))
  (define (my-equal? l1 l2)
    (null? (lset-xor eq? (flatten2 l1 '()) l2)))
  (define (enm-n n)
    (filter (lambda (l) (= (apply + l) (center n)))
            (combinations (iota (* n n) 1) n)))
  (filter (lambda (l) (my-equal? l (iota (* n n) 1)))
          (combinations (enm-n n) n)))
タグ付け忘れた. Gaucheです.
同じ計算を何回もしているので#1671を参考にメモ化しました.
n=5で2分と少々です. メモリ使用量は11MB (多分).
Pen4 3GHzです.

gosh> ;(time (maho-memo 4))
; real   0.078
; user   0.078
; sys    0.000
392
gosh> ;(time (maho-memo 5))
; real 130.656
; user 124.562
; sys    4.828
3245664
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
(define (maho-memo n)
  (define cn (/ (* n (+ (* n n) 1)) 2))
  (define maho-in-memo
    (let1
     tab (make-hash-table 'equal?)
     (define (memo m l v) (hash-table-put! tab (cons m l) v) v)
     (lambda (m l)
       (cond ((= m 0) 1)
             ((hash-table-get tab (cons m l) #f))
             (else
              (memo m l
                    (apply
                     +
                     (map
                      (lambda (a) (maho-in-memo (- m n) (lset-difference equal? l a)))
                      (map
                       (lambda (li) (cons (car l) li))
                       (filter
                        (lambda (li) (= (apply + li) (- cn (car l))))
                        (combinations (cdr l) (- n 1))))))))))))
  (maho-in-memo (* n n) (iota (* n n) 1)))

二重投稿orz 申し訳ないんですが, #4828の方を消してもらえませんでしょうか?

無駄な計算を省いて1分切ったので再投稿

gosh>
;(time (maho-memo 5))
; real  43.391
; user  42.625
; sys    0.641
3245664
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
(use srfi-1)
(use util.combinations)

(define (maho-memo n)
  ;tableに足してcnになるリストを入れておく
  (define table 
    (let ((cn (/ (* n (+ (* n n) 1)) 2)))
      (filter
       (lambda (li) (= (apply + li) cn))
       (combinations (iota (* n n) 1) n))))
  (define maho-in-memo
    (let1
     tab (make-hash-table 'equal?)
     (define (memo m l v) (hash-table-put! tab (cons m l) v) v)
     (lambda (m l)
       (cond ((= m 0) 1)
             ((hash-table-get tab (cons m l) #f))
             (else
              (memo m l
                    (apply
                     +
                     (map ; 今のlからaを除いたリストでmaho-in-memoを呼ぶ
                      (lambda (a) (maho-in-memo (- m n) (lset-difference equal? l a)))
                      (filter ;(car l)で始まるtableの中のリストを抜き出す
                       (lambda (li) (and (equal? (car li) (car l)) (lset<= equal? li l)))
                       table)))))))))
  (maho-in-memo (* n n) (iota (* n n) 1)))
(time (maho-memo 5))

Pen4 2.4Gで実行して、

4の場合
392 : 31(ms)
5の場合
3245664 : 224157(ms)

と、なりました。 簡単に枝狩りをしただけの、総当りです。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class Answer108 {
    private final int size_;
    private final int maxNumber_;
    private final int average_;

    private int count_ = 0;

    public Answer108(int size) {
        size_ = size;
        maxNumber_ = size * size;
        average_ = size * (maxNumber_ + 1) / 2;

        countPair();
    }

    private void countPair() {
        if (size_ <= 1) return;
        List<List<Integer>> array = new ArrayList<List<Integer>>();
        for (int index = 0; index < size_; index++) {
            array.add(new ArrayList<Integer>());
        }
        array.get(0).add(1);
        countPair(array, 2);
    }
    private void countPair(List<List<Integer>> array, int nextNumber) {
        if (nextNumber <= maxNumber_) {
            for (int index = 0; index < size_; index++) {
                List<Integer> list = array.get(index);
                int size = list.size();
                if (size == size_) continue;
                if (size == size_ - 1) {
                    if (sum(list) + nextNumber != average_) continue;
                } else {
                    int rest = 0;
                    for (int lastIndex = 0; lastIndex < size_ - size - 1; lastIndex++) {
                        rest += maxNumber_ - lastIndex;
                    }
                    if (sum(list) + nextNumber + rest < average_) continue;
                }

                if (index >= nextNumber) continue;
                list.add(nextNumber);
                countPair(array, nextNumber + 1);
                list.remove(Integer.valueOf(nextNumber));
                if (list.size() == 0) break;
            }
        } else {
            //System.out.println(toString(array));
            count_++;
        }
    }
    private int sum(List<Integer> array) {
        int sum = 0;
        for (int num: array) {
            sum += num;
        }
        return sum;
    }

    public int getCount() {
        return count_;
    }


    public static String toString(List<List<Integer>> array) {
        String[] strs = new String[array.size()];
        for (int index = 0; index < strs.length; index++) {
            strs[index] = array.get(index).toString();
        }
        return Arrays.toString(strs);
    }

    public static void main(String[] args) {
        long start = System.currentTimeMillis();
        Answer108 ans = new Answer108(5);
        System.out.println(ans.getCount());
        long end = System.currentTimeMillis();
        System.out.println("elapse: " + (end - start) + "(ms)");
    }
}
$ g++ -O3 maho.cpp && time ./a.out
3245664

real    0m26.930s
user    0m22.310s
sys     0m4.560s

最初に和が (総和/n) となるn個の値の組を
ビット列として全パターン生成してしまいます。
そうしておいて、ビットパターンの中から排他的なものを選んでいくアルゴリズム。

それなりに速いかと。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <iostream>
#include <vector>
typedef unsigned int bit_t;

std::vector<bit_t> bits;
int n, n2, cnt = 0;

void comb(int a, int k, int s, bit_t b) {
  if (k < n-1)
    for (int i = a; i < n2; ++i)
      comb(i+1, k+1, s-i, b | 1 << (i-1));
  else
    if (a <= s && s <= n2)
      bits.push_back(b | 1 << (s-1));
}

void calc(int s, int k, bit_t b) {
  if (k == n) { ++cnt; return; }

  for (int i = s; i < (int)bits.size(); ++i)
    if (!(b & bits[i]))
      calc(i+1, k+1, b | bits[i]);
}

int main() {
  n = 5;
  n2 = n * n;
  int m = n * (n2+1) / 2;

  comb(1, 0, m, 0);
  calc(0, 0, 0);

  std::cout << cnt << std::endl;
}