魔方分割数
Posted feedbacks - Nested
Flatten Hidden実装がナイーブ過ぎてN=4が実行できませんでした。 平均値計算の式が↓のようになっていて、これに気づくのに三十分かかった。 ave = N*(N*2+1) / 2 枝狩りをもうちょっとがんばるかなぁ。 アルゴリズムはこんな感じ。 1.総当りでペアを出す 2.ペアをソートする 3.文字列にキャストしてsetにぶち込んでユニーク化 4.何個残ってるか? コード書き終わってから、魔方陣の書き方って確かあったよなぁ。 と思って、Wikipedia先生に聞いてみたら二次元用であった。 これを応用すれば、うまいこと出てこないかなぁ。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import copy
import time
def mahobunkatsu(N):
m = set()
for i in xrange(1,N**2+1):
m.add(i)
ave = N*(N**2+1) / 2
def createPair(restNumber,numberSum, pair, count, pairList):
for i in restNumber:
if count and (count+1)%N == 0:
if numberSum + i == ave:
if count+1 == N**2:
pair.append(i)
pairList.append(copy.copy(pair))
pair.pop()
continue
else:
continue
if numberSum + i > ave:
continue
restNumber.remove(i)
pair.append(i)
createPair(restNumber, (numberSum + i) % ave, pair, count + 1, pairList)
restNumber.add(i)
pair.pop()
return
pairList = []
createPair(m, 0, [], 0, pairList)
def listToSortedPairs(pairList):
temp = []
for l in pairList:
temp.append([])
for i in range(N):
temp[-1].append(l[i*N:(i+1)*N])
temp[-1][-1].sort()
temp[-1].sort()
return temp
sortedList = listToSortedPairs(pairList)
uniqueList = set()
for x in sortedList:
uniqueList.add(str(x))
print uniqueList
return uniqueList
t = time.time()
for x in range(2,6):
n = len(mahobunkatsu(x))
print "Size=",x,"Mahozin_num=",n,"time=",time.time()-t
t = time.time()
|
投稿してから、プログラムにあまり影響しないバグを見つけた。 39行目、タブが一個多いです。 要素を一個追加するごとにソートしているので、不毛。
タブが多いとバグになるというのが、python使わない人からすると新鮮です。
グループ未定の数字の最初のものpは、まだ決まってないグループ(X)に。 Xの残りの(n-1)個は、残っている数字から、和が(s/n-p)になるものとする。 最後まで行ったらカウンタに1加算。 初めに戻る。 (ComplementやSubsetsは常にソートされている) In[3]:= f[3] Out[3]= 2 In[4]:= f[4] Out[4]= 392(Core2 6700で0.06秒) In[5]:= f[5] Out[5]= 3245664(Core2 6700で690秒)
1 2 3 4 5 6 | f[n_] := f[Range[n^2], n, Total@Range[n^2]/n]
f[in_, n_, s_] :=
If[Length@in == n, 1,
Total[f[Complement[Rest@in, #], n, s] & /@ (
Select[Subsets[Rest@in, {n - 1}], First@in + Total@# == s &])]]
|
>先頭が 1 ~ n のものが並ぶはず
確かに、そうなるかな?と思いたくなるんですが、違うみたいです:
{1, 2, 15, 16}, {3, 4, 13, 14}, {5, 6, 11, 12}, {7, 8, 9, 10}
がーん……ダメすぎる...orz
とりあえず単純に
(1)合計が平均値になる組み合わせを生成
(2)(1)から要素が重ならない組み合わせを生成
N=4で62秒(PenM 1.7GHz)、解は392
N=5はこのやり方では無理ですね。
(1)合計が平均値になる組み合わせを生成
(2)(1)から要素が重ならない組み合わせを生成
N=4で62秒(PenM 1.7GHz)、解は392
N=5はこのやり方では無理ですね。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | def gen_comb(list,n)
if n==0
yield([])
else
(list.size-n+1).times do |i|
gen_comb(list[i+1..-1],n-1) do|ls|
yield([list[i]]+ls)
end
end
end
end
def comb_num(list,n,sum)
ret=[]
gen_comb(list,n) do |a|
ret<<=a if (a.inject(0){|r,e| r+=e})==sum
end
ret
end
def comb_array(list,n)
ret=[]
gen_comb(list,n) do |a|
ret<<=a if a.inject{|r,e| break unless (r&e).empty?;r+=e}
end
ret
end
def maho(n)
m=n**2
sum=m*(m+1)/2/n
comb_array(comb_num((1..m).to_a,n,sum),n)
end
start_time = Time.now
puts maho(4).size
puts Time.now-start_time
|
和が N*(N^2+1)/2 になる組み合わせを昇順に列挙して交わらないものを探す。nido さんの #4819 と同じ方針かな?
5 のときは 20 分かかって答えが出ました。 3245664 だそうです。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | (* list solutions of
x1+...+xn=m
1<=x1<...<xn<=l *)
let list_solutions m n k =
let rec loop m n lbd tail =
if m < n || m < lbd then []
else if n = 1 then
if lbd <= m && m <= k then [m::tail] else []
else
let acc = ref [] in
for i = lbd to (min (m-n+1) k) do
acc :=
List.map (fun sols -> i::sols)
(loop (m-i) (n-1) (i+1) tail) :: !acc
done;
List.concat !acc
in loop m n 1
let rec disjoint xs ys =
match xs, ys with
| [], _ | _, [] -> true
| x::xs', y::ys' ->
if x > y then disjoint xs ys'
else if x < y then disjoint xs' ys
else false
let (@<) = List.merge compare
let rec count_choices n lists ex =
if n = 1 then
List.fold_left (fun i x -> if disjoint x ex then i+1 else i) 0 lists
else
match lists with
| [] -> 0
| x::rest ->
let b = count_choices n rest ex in
if disjoint x ex then
b + count_choices (n-1) rest (x@<ex)
else
b
let count_partitions n =
count_choices n (list_solutions (n * (n*n + 1) / 2) n (n*n) []) []
let _ = print_int (count_partitions 5); exit 0
|
枝刈りしようとしたらわけがわからなくなったので Common Lisp で書き直し。
最初に組み合わせを求めた後で「1を含む組」「1を含まなくて2を含む組」「1, 2を含まなくて3を含む組」……と分類します。こうすると、グループごとに探索の対象とするかしないかを決めることができてかなり範囲が狭まるようです。
disjoint がボトルネックになるようなのでここだけ最適化をかけています。実行時間は SBCL で 28 秒でした。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | (defun list-solutions (m n lbd ubd)
(cond ((not (<= (* lbd n) m (* ubd n))) ())
((= n 1) `((,m)))
(t
(loop for i from lbd to ubd nconc
(mapl (lambda (l) (push i (car l)))
(list-solutions (- m i) (1- n) (1+ i) ubd))))))
(defun group-solutions (m n lbd ubd)
(loop with sols = (list-solutions m n lbd ubd)
for i from 1
for s1 = sols then s2
for s2 = (member i s1 :key #'car :test #'/=)
while s1 collect (cons i (ldiff s1 s2))))
(defun disjoint (l1 l2) ; l1, l2 must be sorted
(declare (optimize (speed 3) (safety 0)))
(or (null l1)
(null l2)
(let ((a (car l1)) (b (car l2)))
(declare (fixnum a b))
(cond ((> a b) (disjoint l1 (cdr l2)))
((< a b) (disjoint (cdr l1) l2))
(t nil)))))
(defun merge-list (l1 l2) ; l1, l2 must be sorted
(do* ((head (cons () ()))
(tail head (cdr tail)))
(())
(cond ((endp l1) (setf (cdr tail) l2) (return (cdr head)))
((endp l2) (setf (cdr tail) l1) (return (cdr head)))
(t
(let ((a (car l1)) (b (car l2)))
(cond ((> a b)
(setf (cdr tail) (list b)
l2 (cdr l2)))
((< a b)
(setf (cdr tail) (list a)
l1 (cdr l1)))
(t
(setf (cdr tail) (list a)
l1 (cdr l1)
l2 (cdr l2)))))))))
(defun count-choices (n lists ex)
(if (= n 1)
(loop for x in (cdar lists) count (disjoint x ex))
(loop for x in (cdar lists) if (disjoint x ex) sum
(count-choices (1- n)
(remove-if (lambda (a) (member (car a) x)) (cdr lists))
(merge-list x ex)))))
(defun count-partitions (n)
(count-choices n (group-solutions (/ (* n (1+ (* n n))) 2) n 1 (* n n)) ()))
|
無駄に読みにくい書き方をしてました。こっちのほうが普通でしょう。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | (defun merge-list (l1 l2)
(let (acc)
(loop
(if (endp l1) (return (nreconc acc l2)))
(if (endp l2) (return (nreconc acc l1)))
(let ((a (car l1)) (b (car l2)))
(cond ((> a b)
(push b acc)
(setf l2 (cdr l2)))
((< a b)
(push a acc)
(setf l1 (cdr l1)))
(t
(push a acc)
(setf l1 (cdr l1) l2 (cdr l2))))))))
|
filterとcombinationsを覚えた
;Pen4 3GHzで
;(time (maho 4)) => 392
;real 0.109/user 0.109/sys 0.000
;(time (maho 5)) => 3245664
;real 1528.250/user 1423.391/sys 104.187
;(time (maho-by-enm 4)) => 392
;real 254.094/user 249.000/sys 1.531
maho-by-enmは#4819や#4821と同じ方法です
;Pen4 3GHzで
;(time (maho 4)) => 392
;real 0.109/user 0.109/sys 0.000
;(time (maho 5)) => 3245664
;real 1528.250/user 1423.391/sys 104.187
;(time (maho-by-enm 4)) => 392
;real 254.094/user 249.000/sys 1.531
maho-by-enmは#4819や#4821と同じ方法です
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | (use srfi-1)
(use util.combinations)
(define (maho n)
(define (maho-in m l)
(if (= m 0) 1
(letrec
((cn (/ (* n (+ (* n n) 1)) 2))
(mylist
(map (lambda (li) (cons (car l) li))
(filter
(lambda (li) (= (apply + li) (- cn (car l))))
(combinations (cdr l) (- n 1))))))
(apply + (map (lambda (a) (maho-in (- m n) (lset-difference equal? l a))) mylist)))))
(maho-in (* n n) (iota (* n n) 1)))
(define (maho-by-enm n)
(define (center n) (/ (* n (+ (* n n) 1)) 2))
(define (flatten2 l c)
(define (flatten1 l c)
(if (null? l) c
(cons (car l) (flatten1 (cdr l) c))))
(if (null? l) c
(flatten2 (cdr l) (flatten1 (car l) c))))
(define (my-equal? l1 l2)
(null? (lset-xor eq? (flatten2 l1 '()) l2)))
(define (enm-n n)
(filter (lambda (l) (= (apply + l) (center n)))
(combinations (iota (* n n) 1) n)))
(filter (lambda (l) (my-equal? l (iota (* n n) 1)))
(combinations (enm-n n) n)))
|
タグ付け忘れた. Gaucheです.
同じ計算を何回もしているので#1671を参考にメモ化しました.
n=5で2分と少々です. メモリ使用量は11MB (多分).
Pen4 3GHzです.
gosh> ;(time (maho-memo 4))
; real 0.078
; user 0.078
; sys 0.000
392
gosh> ;(time (maho-memo 5))
; real 130.656
; user 124.562
; sys 4.828
3245664
n=5で2分と少々です. メモリ使用量は11MB (多分).
Pen4 3GHzです.
gosh> ;(time (maho-memo 4))
; real 0.078
; user 0.078
; sys 0.000
392
gosh> ;(time (maho-memo 5))
; real 130.656
; user 124.562
; sys 4.828
3245664
see: どう書く?org #1671 shiro: メモ化をちょっと試してみた。 単純にn...(「組合せ型の最小完全ハッシュ関数」の逆関数)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | (define (maho-memo n)
(define cn (/ (* n (+ (* n n) 1)) 2))
(define maho-in-memo
(let1
tab (make-hash-table 'equal?)
(define (memo m l v) (hash-table-put! tab (cons m l) v) v)
(lambda (m l)
(cond ((= m 0) 1)
((hash-table-get tab (cons m l) #f))
(else
(memo m l
(apply
+
(map
(lambda (a) (maho-in-memo (- m n) (lset-difference equal? l a)))
(map
(lambda (li) (cons (car l) li))
(filter
(lambda (li) (= (apply + li) (- cn (car l))))
(combinations (cdr l) (- n 1))))))))))))
(maho-in-memo (* n n) (iota (* n n) 1)))
|
二重投稿orz 申し訳ないんですが, #4828の方を消してもらえませんでしょうか?
無駄な計算を省いて1分切ったので再投稿
gosh>
;(time (maho-memo 5))
; real 43.391
; user 42.625
; sys 0.641
3245664
gosh>
;(time (maho-memo 5))
; real 43.391
; user 42.625
; sys 0.641
3245664
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | (use srfi-1)
(use util.combinations)
(define (maho-memo n)
;tableに足してcnになるリストを入れておく
(define table
(let ((cn (/ (* n (+ (* n n) 1)) 2)))
(filter
(lambda (li) (= (apply + li) cn))
(combinations (iota (* n n) 1) n))))
(define maho-in-memo
(let1
tab (make-hash-table 'equal?)
(define (memo m l v) (hash-table-put! tab (cons m l) v) v)
(lambda (m l)
(cond ((= m 0) 1)
((hash-table-get tab (cons m l) #f))
(else
(memo m l
(apply
+
(map ; 今のlからaを除いたリストでmaho-in-memoを呼ぶ
(lambda (a) (maho-in-memo (- m n) (lset-difference equal? l a)))
(filter ;(car l)で始まるtableの中のリストを抜き出す
(lambda (li) (and (equal? (car li) (car l)) (lset<= equal? li l)))
table)))))))))
(maho-in-memo (* n n) (iota (* n n) 1)))
(time (maho-memo 5))
|
Pen4 2.4Gで実行して、
- 4の場合
- 392 : 31(ms)
- 5の場合
- 3245664 : 224157(ms)
と、なりました。 簡単に枝狩りをしただけの、総当りです。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class Answer108 {
private final int size_;
private final int maxNumber_;
private final int average_;
private int count_ = 0;
public Answer108(int size) {
size_ = size;
maxNumber_ = size * size;
average_ = size * (maxNumber_ + 1) / 2;
countPair();
}
private void countPair() {
if (size_ <= 1) return;
List<List<Integer>> array = new ArrayList<List<Integer>>();
for (int index = 0; index < size_; index++) {
array.add(new ArrayList<Integer>());
}
array.get(0).add(1);
countPair(array, 2);
}
private void countPair(List<List<Integer>> array, int nextNumber) {
if (nextNumber <= maxNumber_) {
for (int index = 0; index < size_; index++) {
List<Integer> list = array.get(index);
int size = list.size();
if (size == size_) continue;
if (size == size_ - 1) {
if (sum(list) + nextNumber != average_) continue;
} else {
int rest = 0;
for (int lastIndex = 0; lastIndex < size_ - size - 1; lastIndex++) {
rest += maxNumber_ - lastIndex;
}
if (sum(list) + nextNumber + rest < average_) continue;
}
if (index >= nextNumber) continue;
list.add(nextNumber);
countPair(array, nextNumber + 1);
list.remove(Integer.valueOf(nextNumber));
if (list.size() == 0) break;
}
} else {
//System.out.println(toString(array));
count_++;
}
}
private int sum(List<Integer> array) {
int sum = 0;
for (int num: array) {
sum += num;
}
return sum;
}
public int getCount() {
return count_;
}
public static String toString(List<List<Integer>> array) {
String[] strs = new String[array.size()];
for (int index = 0; index < strs.length; index++) {
strs[index] = array.get(index).toString();
}
return Arrays.toString(strs);
}
public static void main(String[] args) {
long start = System.currentTimeMillis();
Answer108 ans = new Answer108(5);
System.out.println(ans.getCount());
long end = System.currentTimeMillis();
System.out.println("elapse: " + (end - start) + "(ms)");
}
}
|
$ g++ -O3 maho.cpp && time ./a.out 3245664 real 0m26.930s user 0m22.310s sys 0m4.560s 最初に和が (総和/n) となるn個の値の組を ビット列として全パターン生成してしまいます。 そうしておいて、ビットパターンの中から排他的なものを選んでいくアルゴリズム。 それなりに速いかと。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | #include <iostream>
#include <vector>
typedef unsigned int bit_t;
std::vector<bit_t> bits;
int n, n2, cnt = 0;
void comb(int a, int k, int s, bit_t b) {
if (k < n-1)
for (int i = a; i < n2; ++i)
comb(i+1, k+1, s-i, b | 1 << (i-1));
else
if (a <= s && s <= n2)
bits.push_back(b | 1 << (s-1));
}
void calc(int s, int k, bit_t b) {
if (k == n) { ++cnt; return; }
for (int i = s; i < (int)bits.size(); ++i)
if (!(b & bits[i]))
calc(i+1, k+1, b | bits[i]);
}
int main() {
n = 5;
n2 = n * n;
int m = n * (n2+1) / 2;
comb(1, 0, m, 0);
calc(0, 0, 0);
std::cout << cnt << std::endl;
}
|




xsd
#4702()
Rating8/8=1.00
たとえば、N=3のときは、
(1) { 1, 5, 9 }, { 2, 6, 7 }, { 3, 4, 8 }
(2) { 1, 6, 8 }, { 2, 4, 9 }, { 3, 5, 7 }
の2通りの方法があります。
ここで指定されたNに対して、何通りのグループ分けの方法があるかを数えるプログラムを作ってください。
(何通りかという値だけが出力されればよいのですが、予め計算してある結果を返すのはダメですよ。)
また、N=5を指定したときの実行時間もあわせて教えてください。
なお、数え上げるときの注意として、
・{ 1, 5, 9 } と { 1, 9, 5 }は同じもの
・{ 1, 5, 9 }, { 2, 6, 7 }, { 3, 4, 8 }と
{ 1, 5, 9 }, { 3, 4, 8 }, { 2, 6, 7 }は同じもの
とすることに注意してください。
[ reply ]