__init__.py revision 781d71db
1# Copyright (c) PyZMQ Developers.
2# Distributed under the terms of the Modified BSD License.
3
4import functools
5import sys
6import time
7from threading import Thread
8
9from unittest import TestCase
10
11import zmq
12from zmq.utils import jsonapi
13
14try:
15    import gevent
16    from zmq import green as gzmq
17    have_gevent = True
18except ImportError:
19    have_gevent = False
20
21try:
22    from unittest import SkipTest
23except ImportError:
24    try:
25        from nose import SkipTest
26    except ImportError:
27        class SkipTest(Exception):
28            pass
29
30PYPY = 'PyPy' in sys.version
31
32#-----------------------------------------------------------------------------
33# skip decorators (directly from unittest)
34#-----------------------------------------------------------------------------
35
36_id = lambda x: x
37
38def skip(reason):
39    """
40    Unconditionally skip a test.
41    """
42    def decorator(test_item):
43        if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
44            @functools.wraps(test_item)
45            def skip_wrapper(*args, **kwargs):
46                raise SkipTest(reason)
47            test_item = skip_wrapper
48
49        test_item.__unittest_skip__ = True
50        test_item.__unittest_skip_why__ = reason
51        return test_item
52    return decorator
53
54def skip_if(condition, reason="Skipped"):
55    """
56    Skip a test if the condition is true.
57    """
58    if condition:
59        return skip(reason)
60    return _id
61
62skip_pypy = skip_if(PYPY, "Doesn't work on PyPy")
63
64#-----------------------------------------------------------------------------
65# Base test class
66#-----------------------------------------------------------------------------
67
68class BaseZMQTestCase(TestCase):
69    green = False
70
71    @property
72    def Context(self):
73        if self.green:
74            return gzmq.Context
75        else:
76            return zmq.Context
77
78    def socket(self, socket_type):
79        s = self.context.socket(socket_type)
80        self.sockets.append(s)
81        return s
82
83    def setUp(self):
84        if self.green and not have_gevent:
85                raise SkipTest("requires gevent")
86        self.context = self.Context.instance()
87        self.sockets = []
88
89    def tearDown(self):
90        contexts = set([self.context])
91        while self.sockets:
92            sock = self.sockets.pop()
93            contexts.add(sock.context) # in case additional contexts are created
94            sock.close(0)
95        for ctx in contexts:
96            t = Thread(target=ctx.term)
97            t.daemon = True
98            t.start()
99            t.join(timeout=2)
100            if t.is_alive():
101                # reset Context.instance, so the failure to term doesn't corrupt subsequent tests
102                zmq.sugar.context.Context._instance = None
103                raise RuntimeError("context could not terminate, open sockets likely remain in test")
104
105    def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
106        """Create a bound socket pair using a random port."""
107        s1 = self.context.socket(type1)
108        s1.setsockopt(zmq.LINGER, 0)
109        port = s1.bind_to_random_port(interface)
110        s2 = self.context.socket(type2)
111        s2.setsockopt(zmq.LINGER, 0)
112        s2.connect('%s:%s' % (interface, port))
113        self.sockets.extend([s1,s2])
114        return s1, s2
115
116    def ping_pong(self, s1, s2, msg):
117        s1.send(msg)
118        msg2 = s2.recv()
119        s2.send(msg2)
120        msg3 = s1.recv()
121        return msg3
122
123    def ping_pong_json(self, s1, s2, o):
124        if jsonapi.jsonmod is None:
125            raise SkipTest("No json library")
126        s1.send_json(o)
127        o2 = s2.recv_json()
128        s2.send_json(o2)
129        o3 = s1.recv_json()
130        return o3
131
132    def ping_pong_pyobj(self, s1, s2, o):
133        s1.send_pyobj(o)
134        o2 = s2.recv_pyobj()
135        s2.send_pyobj(o2)
136        o3 = s1.recv_pyobj()
137        return o3
138
139    def assertRaisesErrno(self, errno, func, *args, **kwargs):
140        try:
141            func(*args, **kwargs)
142        except zmq.ZMQError as e:
143            self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
144got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
145        else:
146            self.fail("Function did not raise any error")
147
148    def _select_recv(self, multipart, socket, **kwargs):
149        """call recv[_multipart] in a way that raises if there is nothing to receive"""
150        if zmq.zmq_version_info() >= (3,1,0):
151            # zmq 3.1 has a bug, where poll can return false positives,
152            # so we wait a little bit just in case
153            # See LIBZMQ-280 on JIRA
154            time.sleep(0.1)
155
156        r,w,x = zmq.select([socket], [], [], timeout=5)
157        assert len(r) > 0, "Should have received a message"
158        kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
159
160        recv = socket.recv_multipart if multipart else socket.recv
161        return recv(**kwargs)
162
163    def recv(self, socket, **kwargs):
164        """call recv in a way that raises if there is nothing to receive"""
165        return self._select_recv(False, socket, **kwargs)
166
167    def recv_multipart(self, socket, **kwargs):
168        """call recv_multipart in a way that raises if there is nothing to receive"""
169        return self._select_recv(True, socket, **kwargs)
170
171
172class PollZMQTestCase(BaseZMQTestCase):
173    pass
174
175class GreenTest:
176    """Mixin for making green versions of test classes"""
177    green = True
178
179    def assertRaisesErrno(self, errno, func, *args, **kwargs):
180        if errno == zmq.EAGAIN:
181            raise SkipTest("Skipping because we're green.")
182        try:
183            func(*args, **kwargs)
184        except zmq.ZMQError:
185            e = sys.exc_info()[1]
186            self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
187got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
188        else:
189            self.fail("Function did not raise any error")
190
191    def tearDown(self):
192        contexts = set([self.context])
193        while self.sockets:
194            sock = self.sockets.pop()
195            contexts.add(sock.context) # in case additional contexts are created
196            sock.close()
197        try:
198            gevent.joinall([gevent.spawn(ctx.term) for ctx in contexts], timeout=2, raise_error=True)
199        except gevent.Timeout:
200            raise RuntimeError("context could not terminate, open sockets likely remain in test")
201
202    def skip_green(self):
203        raise SkipTest("Skipping because we are green")
204
205def skip_green(f):
206    def skipping_test(self, *args, **kwargs):
207        if self.green:
208            raise SkipTest("Skipping because we are green")
209        else:
210            return f(self, *args, **kwargs)
211    return skipping_test
212