test_monqueue.py revision d3907f0d
1# Copyright (C) PyZMQ Developers
2# Distributed under the terms of the Modified BSD License.
3
4import time
5from unittest import TestCase
6
7import zmq
8from zmq import devices
9
10from zmq.tests import BaseZMQTestCase, SkipTest, PYPY
11from zmq.utils.strtypes import unicode
12
13
14if PYPY or zmq.zmq_version_info() >= (4,1):
15    # cleanup of shared Context doesn't work on PyPy
16    # there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
17    devices.Device.context_factory = zmq.Context
18
19
20class TestMonitoredQueue(BaseZMQTestCase):
21
22    sockets = []
23
24    def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
25        self.device = devices.ThreadMonitoredQueue(zmq.PAIR, zmq.PAIR, zmq.PUB,
26                                            in_prefix, out_prefix)
27        alice = self.context.socket(zmq.PAIR)
28        bob = self.context.socket(zmq.PAIR)
29        mon = self.context.socket(zmq.SUB)
30
31        aport = alice.bind_to_random_port('tcp://127.0.0.1')
32        bport = bob.bind_to_random_port('tcp://127.0.0.1')
33        mport = mon.bind_to_random_port('tcp://127.0.0.1')
34        mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
35
36        self.device.connect_in("tcp://127.0.0.1:%i"%aport)
37        self.device.connect_out("tcp://127.0.0.1:%i"%bport)
38        self.device.connect_mon("tcp://127.0.0.1:%i"%mport)
39        self.device.start()
40        time.sleep(.2)
41        try:
42            # this is currenlty necessary to ensure no dropped monitor messages
43            # see LIBZMQ-248 for more info
44            mon.recv_multipart(zmq.NOBLOCK)
45        except zmq.ZMQError:
46            pass
47        self.sockets.extend([alice, bob, mon])
48        return alice, bob, mon
49
50
51    def teardown_device(self):
52        for socket in self.sockets:
53            socket.close()
54            del socket
55        del self.device
56
57    def test_reply(self):
58        alice, bob, mon = self.build_device()
59        alices = b"hello bob".split()
60        alice.send_multipart(alices)
61        bobs = self.recv_multipart(bob)
62        self.assertEqual(alices, bobs)
63        bobs = b"hello alice".split()
64        bob.send_multipart(bobs)
65        alices = self.recv_multipart(alice)
66        self.assertEqual(alices, bobs)
67        self.teardown_device()
68
69    def test_queue(self):
70        alice, bob, mon = self.build_device()
71        alices = b"hello bob".split()
72        alice.send_multipart(alices)
73        alices2 = b"hello again".split()
74        alice.send_multipart(alices2)
75        alices3 = b"hello again and again".split()
76        alice.send_multipart(alices3)
77        bobs = self.recv_multipart(bob)
78        self.assertEqual(alices, bobs)
79        bobs = self.recv_multipart(bob)
80        self.assertEqual(alices2, bobs)
81        bobs = self.recv_multipart(bob)
82        self.assertEqual(alices3, bobs)
83        bobs = b"hello alice".split()
84        bob.send_multipart(bobs)
85        alices = self.recv_multipart(alice)
86        self.assertEqual(alices, bobs)
87        self.teardown_device()
88
89    def test_monitor(self):
90        alice, bob, mon = self.build_device()
91        alices = b"hello bob".split()
92        alice.send_multipart(alices)
93        alices2 = b"hello again".split()
94        alice.send_multipart(alices2)
95        alices3 = b"hello again and again".split()
96        alice.send_multipart(alices3)
97        bobs = self.recv_multipart(bob)
98        self.assertEqual(alices, bobs)
99        mons = self.recv_multipart(mon)
100        self.assertEqual([b'in']+bobs, mons)
101        bobs = self.recv_multipart(bob)
102        self.assertEqual(alices2, bobs)
103        bobs = self.recv_multipart(bob)
104        self.assertEqual(alices3, bobs)
105        mons = self.recv_multipart(mon)
106        self.assertEqual([b'in']+alices2, mons)
107        bobs = b"hello alice".split()
108        bob.send_multipart(bobs)
109        alices = self.recv_multipart(alice)
110        self.assertEqual(alices, bobs)
111        mons = self.recv_multipart(mon)
112        self.assertEqual([b'in']+alices3, mons)
113        mons = self.recv_multipart(mon)
114        self.assertEqual([b'out']+bobs, mons)
115        self.teardown_device()
116
117    def test_prefix(self):
118        alice, bob, mon = self.build_device(b"", b'foo', b'bar')
119        alices = b"hello bob".split()
120        alice.send_multipart(alices)
121        alices2 = b"hello again".split()
122        alice.send_multipart(alices2)
123        alices3 = b"hello again and again".split()
124        alice.send_multipart(alices3)
125        bobs = self.recv_multipart(bob)
126        self.assertEqual(alices, bobs)
127        mons = self.recv_multipart(mon)
128        self.assertEqual([b'foo']+bobs, mons)
129        bobs = self.recv_multipart(bob)
130        self.assertEqual(alices2, bobs)
131        bobs = self.recv_multipart(bob)
132        self.assertEqual(alices3, bobs)
133        mons = self.recv_multipart(mon)
134        self.assertEqual([b'foo']+alices2, mons)
135        bobs = b"hello alice".split()
136        bob.send_multipart(bobs)
137        alices = self.recv_multipart(alice)
138        self.assertEqual(alices, bobs)
139        mons = self.recv_multipart(mon)
140        self.assertEqual([b'foo']+alices3, mons)
141        mons = self.recv_multipart(mon)
142        self.assertEqual([b'bar']+bobs, mons)
143        self.teardown_device()
144
145    def test_monitor_subscribe(self):
146        alice, bob, mon = self.build_device(b"out")
147        alices = b"hello bob".split()
148        alice.send_multipart(alices)
149        alices2 = b"hello again".split()
150        alice.send_multipart(alices2)
151        alices3 = b"hello again and again".split()
152        alice.send_multipart(alices3)
153        bobs = self.recv_multipart(bob)
154        self.assertEqual(alices, bobs)
155        bobs = self.recv_multipart(bob)
156        self.assertEqual(alices2, bobs)
157        bobs = self.recv_multipart(bob)
158        self.assertEqual(alices3, bobs)
159        bobs = b"hello alice".split()
160        bob.send_multipart(bobs)
161        alices = self.recv_multipart(alice)
162        self.assertEqual(alices, bobs)
163        mons = self.recv_multipart(mon)
164        self.assertEqual([b'out']+bobs, mons)
165        self.teardown_device()
166
167    def test_router_router(self):
168        """test router-router MQ devices"""
169        dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
170        self.device = dev
171        dev.setsockopt_in(zmq.LINGER, 0)
172        dev.setsockopt_out(zmq.LINGER, 0)
173        dev.setsockopt_mon(zmq.LINGER, 0)
174
175        binder = self.context.socket(zmq.DEALER)
176        porta = binder.bind_to_random_port('tcp://127.0.0.1')
177        portb = binder.bind_to_random_port('tcp://127.0.0.1')
178        binder.close()
179        time.sleep(0.1)
180        a = self.context.socket(zmq.DEALER)
181        a.identity = b'a'
182        b = self.context.socket(zmq.DEALER)
183        b.identity = b'b'
184        self.sockets.extend([a, b])
185
186        a.connect('tcp://127.0.0.1:%i'%porta)
187        dev.bind_in('tcp://127.0.0.1:%i'%porta)
188        b.connect('tcp://127.0.0.1:%i'%portb)
189        dev.bind_out('tcp://127.0.0.1:%i'%portb)
190        dev.start()
191        time.sleep(0.2)
192        if zmq.zmq_version_info() >= (3,1,0):
193            # flush erroneous poll state, due to LIBZMQ-280
194            ping_msg = [ b'ping', b'pong' ]
195            for s in (a,b):
196                s.send_multipart(ping_msg)
197                try:
198                    s.recv(zmq.NOBLOCK)
199                except zmq.ZMQError:
200                    pass
201        msg = [ b'hello', b'there' ]
202        a.send_multipart([b'b']+msg)
203        bmsg = self.recv_multipart(b)
204        self.assertEqual(bmsg, [b'a']+msg)
205        b.send_multipart(bmsg)
206        amsg = self.recv_multipart(a)
207        self.assertEqual(amsg, [b'b']+msg)
208        self.teardown_device()
209
210    def test_default_mq_args(self):
211        self.device = dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.DEALER, zmq.PUB)
212        dev.setsockopt_in(zmq.LINGER, 0)
213        dev.setsockopt_out(zmq.LINGER, 0)
214        dev.setsockopt_mon(zmq.LINGER, 0)
215        # this will raise if default args are wrong
216        dev.start()
217        self.teardown_device()
218
219    def test_mq_check_prefix(self):
220        ins = self.context.socket(zmq.ROUTER)
221        outs = self.context.socket(zmq.DEALER)
222        mons = self.context.socket(zmq.PUB)
223        self.sockets.extend([ins, outs, mons])
224
225        ins = unicode('in')
226        outs = unicode('out')
227        self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)
228