challenge メソッド名一覧の表示

リフレクション系のお題の続編です。

「ある与えられたオブジェクトtargetのメソッドのうち、 "test_"で始まるものをすべて呼びだす」というコードを書いてください。 引数に関しては都合のいいように仮定して構いません(全部0個、など)。

メソッドという概念がない言語の場合は、 「複数の関数への参照を持っているようなオブジェクト(たとえばパッケージとかモジュールとか)から"test_"で始まる関数をすべて呼び出す」と読み替えても構いません。

Posted feedbacks - Python

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import types

class A:
    def __init__(self):
        self.bar = 0
        self.test_bar = 1
        self.baz = []
        self.test_baz = {}
    def foo(self):
        print "foo"
    def test_foo(self):
        print "test_foo"
    def boo(self):
        print "boo"
    def test_boo(self):
        print "test_boo"

def call_tests(obj):
    for name in dir(obj):
        if name.startswith("test_"):
            attr = getattr(obj, name)
            if isinstance(attr, types.MethodType):
                attr()

def main():
    call_tests(A())

if __name__ == '__main__':
    main()

inspect モジュールの出番。 classmethod も呼んでしまうのはどうにかならんかな。
1
2
3
4
5
import inspect

def call_tests(target):
    methods = [method for name, method in inspect.getmembers(target, inspect.ismethod) if name.startswith('test_')]
    for method in methods: method()

以下の特徴を持つ。
・call_tests()の第一引数にはクラスとインスタンスのどちらでも渡せる。
・ベースクラスのメソッドも呼ぶか、指定したクラスのみのメソッドしか呼ばないかを指定可能。
・staticmethodを呼ぶかどうかを指定可能。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import types

class A:
	a1 = 1
	test_a1 = 1
	def __init__(self):
		self.a2 = 2
		self.test_a2 = 2
	def a3(self):
		print "a3 called!"
	def test_a3(self):
		print "test_a3 called!"

class B(A):
	def b1(self):
		print 'b1 called!'
	def test_b1(self):
		print 'test_b1 called!'
	@staticmethod
	def test_b2():
		print 'test_b2 called!'

def call_tests(obj, single_level=False, call_static=True):
	if isinstance(obj, types.InstanceType):
		cls = obj.__class__
	elif isinstance(obj, types.ClassType):
		cls = obj
		obj = obj()  # obj is bound to instance object.
	else:
		return
	# make dict of attributes
	names = cls.__dict__
	for base in cls.__bases__:
		names.update(base.__dict__)
	if single_level:
		for base in cls.__bases__:
			for name in base.__dict__:
				if hasattr(base, name):
					del names[name]
	for name in sorted(names):
		if name[:5] == 'test_':
			if call_static:
				try:
					getattr(obj, name)()  # EAFP
				except:
					continue
			else:
				f = names[name]
				if callable(f):
					f(obj)

if __name__ == '__main__':
	call_tests(B, 0, 0)
	print '==='
	call_tests(B, 1, 0)
	print '==='
	call_tests(B, 0, 1)
	print '==='
	call_tests(B, 1, 1)
# output:
# test_a3 called!
# test_b1 called!
# ===
# test_b1 called!
# ===
# test_a3 called!
# test_b1 called!
# test_b2 called!
# ===
# test_b1 called!
# test_b2 called!

以下を修正。
・cls.__bases__で、1つ上の階層のベースクラスしかとれていなかったので、再帰的にベースクラスを探すようにした。 
・継承階層で同じ名称のメソッドをオーバーライドしていた場合、一番子供側のメソッドを呼び出すようにした。 
・新型クラスに対して、 isinstance(obj, types.InstanceType) と isinstance(obj, types.ClassType) の
  両方ともFalseになってしまって対応していなかったので対応した。 
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def call_tests(obj, single_level=False, call_static=True):
    if hasattr(obj, '__bases__'):
        # class object
        cls = obj
        obj = obj()  # obj is bound to instance object.
    else:
        if not hasattr(obj, '__class__'): return  # non class type
        # instance object
        cls = obj.__class__
    
    # make dict of attributes
    bases = []
    def base_classes(cls):
        if cls.__bases__:
            bases.extend(cls.__bases__)
            for base in cls.__bases__:
                base_classes(base)
    base_classes(cls)
    
    names = {}
    names.update(cls.__dict__)
    for base in bases:
        for attr, val in base.__dict__.items():
            if not names.has_key(attr):
                names[attr] = val
    
    if single_level:
        for base in bases:
            for name in base.__dict__:
                if names.has_key(name):
                    del names[name]
    
    # call function of attributes
    for name in sorted(names):
        if name[:5] == 'test_':
            if call_static:
                f = getattr(obj, name)
                if callable(f): f()
            else:
                f = names[name]
                if callable(f): f(obj)

Pythonのクラスメソッドはあくまで
「呼んだときに第一引数にクラスオブジェクトが渡されるメソッド」
というだけなので、
単純に「xのメソッド」と言った場合には
classmethodやstaticmethodでラップしてあるメソッドも含まれるわけです。

なのでもしそれらを省きたければim_selfがインスタンス自身かどうかをチェックすればいいと思います。
下のコードならばtest_methodだけが呼ばれます。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import inspect

class Foo():
    @classmethod
    def test_classmethod(cls):
        print cls

    @staticmethod
    def test_staticmethod(x):
        print x

    def test_method(self):
        print self

target = Foo()
for name, method in inspect.getmembers(target):
    if inspect.ismethod(method):
        if method.im_self == target:
            print "call", name
            method()
    

Index

Feed

Other

Link

Pathtraq

loading...