1#
2# Copyright 2009 Facebook
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may
5# not use this file except in compliance with the License. You may obtain
6# a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations
14# under the License.
15
16"""A utility class to send to and recv from a non-blocking socket."""
17
18from __future__ import with_statement
19
20import sys
21
22import zmq
23from zmq.utils import jsonapi
24
25try:
26    import cPickle as pickle
27except ImportError:
28    import pickle
29
30from .ioloop import IOLoop
31
32try:
33    # gen_log will only import from >= 3.0
34    from tornado.log import gen_log
35    from tornado import stack_context
36except ImportError:
37    from .minitornado.log import gen_log
38    from .minitornado import stack_context
39
40try:
41    from queue import Queue
42except ImportError:
43    from Queue import Queue
44
45from zmq.utils.strtypes import bytes, unicode, basestring
46
47try:
48    callable
49except NameError:
50    callable = lambda obj: hasattr(obj, '__call__')
51
52
53class ZMQStream(object):
54    """A utility class to register callbacks when a zmq socket sends and receives
55
56    For use with zmq.eventloop.ioloop
57
58    There are three main methods
59
60    Methods:
61
62    * **on_recv(callback, copy=True):**
63        register a callback to be run every time the socket has something to receive
64    * **on_send(callback):**
65        register a callback to be run every time you call send
66    * **send(self, msg, flags=0, copy=False, callback=None):**
67        perform a send that will trigger the callback
68        if callback is passed, on_send is also called.
69
70        There are also send_multipart(), send_json(), send_pyobj()
71
72    Three other methods for deactivating the callbacks:
73
74    * **stop_on_recv():**
75        turn off the recv callback
76    * **stop_on_send():**
77        turn off the send callback
78
79    which simply call ``on_<evt>(None)``.
80
81    The entire socket interface, excluding direct recv methods, is also
82    provided, primarily through direct-linking the methods.
83    e.g.
84
85    >>> stream.bind is stream.socket.bind
86    True
87
88    """
89
90    socket = None
91    io_loop = None
92    poller = None
93
94    def __init__(self, socket, io_loop=None):
95        self.socket = socket
96        self.io_loop = io_loop or IOLoop.instance()
97        self.poller = zmq.Poller()
98
99        self._send_queue = Queue()
100        self._recv_callback = None
101        self._send_callback = None
102        self._close_callback = None
103        self._recv_copy = False
104        self._flushed = False
105
106        self._state = self.io_loop.ERROR
107        self._init_io_state()
108
109        # shortcircuit some socket methods
110        self.bind = self.socket.bind
111        self.bind_to_random_port = self.socket.bind_to_random_port
112        self.connect = self.socket.connect
113        self.setsockopt = self.socket.setsockopt
114        self.getsockopt = self.socket.getsockopt
115        self.setsockopt_string = self.socket.setsockopt_string
116        self.getsockopt_string = self.socket.getsockopt_string
117        self.setsockopt_unicode = self.socket.setsockopt_unicode
118        self.getsockopt_unicode = self.socket.getsockopt_unicode
119
120
121    def stop_on_recv(self):
122        """Disable callback and automatic receiving."""
123        return self.on_recv(None)
124
125    def stop_on_send(self):
126        """Disable callback on sending."""
127        return self.on_send(None)
128
129    def stop_on_err(self):
130        """DEPRECATED, does nothing"""
131        gen_log.warn("on_err does nothing, and will be removed")
132
133    def on_err(self, callback):
134        """DEPRECATED, does nothing"""
135        gen_log.warn("on_err does nothing, and will be removed")
136
137    def on_recv(self, callback, copy=True):
138        """Register a callback for when a message is ready to recv.
139
140        There can be only one callback registered at a time, so each
141        call to `on_recv` replaces previously registered callbacks.
142
143        on_recv(None) disables recv event polling.
144
145        Use on_recv_stream(callback) instead, to register a callback that will receive
146        both this ZMQStream and the message, instead of just the message.
147
148        Parameters
149        ----------
150
151        callback : callable
152            callback must take exactly one argument, which will be a
153            list, as returned by socket.recv_multipart()
154            if callback is None, recv callbacks are disabled.
155        copy : bool
156            copy is passed directly to recv, so if copy is False,
157            callback will receive Message objects. If copy is True,
158            then callback will receive bytes/str objects.
159
160        Returns : None
161        """
162
163        self._check_closed()
164        assert callback is None or callable(callback)
165        self._recv_callback = stack_context.wrap(callback)
166        self._recv_copy = copy
167        if callback is None:
168            self._drop_io_state(self.io_loop.READ)
169        else:
170            self._add_io_state(self.io_loop.READ)
171
172    def on_recv_stream(self, callback, copy=True):
173        """Same as on_recv, but callback will get this stream as first argument
174
175        callback must take exactly two arguments, as it will be called as::
176
177            callback(stream, msg)
178
179        Useful when a single callback should be used with multiple streams.
180        """
181        if callback is None:
182            self.stop_on_recv()
183        else:
184            self.on_recv(lambda msg: callback(self, msg), copy=copy)
185
186    def on_send(self, callback):
187        """Register a callback to be called on each send
188
189        There will be two arguments::
190
191            callback(msg, status)
192
193        * `msg` will be the list of sendable objects that was just sent
194        * `status` will be the return result of socket.send_multipart(msg) -
195          MessageTracker or None.
196
197        Non-copying sends return a MessageTracker object whose
198        `done` attribute will be True when the send is complete.
199        This allows users to track when an object is safe to write to
200        again.
201
202        The second argument will always be None if copy=True
203        on the send.
204
205        Use on_send_stream(callback) to register a callback that will be passed
206        this ZMQStream as the first argument, in addition to the other two.
207
208        on_send(None) disables recv event polling.
209
210        Parameters
211        ----------
212
213        callback : callable
214            callback must take exactly two arguments, which will be
215            the message being sent (always a list),
216            and the return result of socket.send_multipart(msg) -
217            MessageTracker or None.
218
219            if callback is None, send callbacks are disabled.
220        """
221
222        self._check_closed()
223        assert callback is None or callable(callback)
224        self._send_callback = stack_context.wrap(callback)
225
226
227    def on_send_stream(self, callback):
228        """Same as on_send, but callback will get this stream as first argument
229
230        Callback will be passed three arguments::
231
232            callback(stream, msg, status)
233
234        Useful when a single callback should be used with multiple streams.
235        """
236        if callback is None:
237            self.stop_on_send()
238        else:
239            self.on_send(lambda msg, status: callback(self, msg, status))
240
241
242    def send(self, msg, flags=0, copy=True, track=False, callback=None):
243        """Send a message, optionally also register a new callback for sends.
244        See zmq.socket.send for details.
245        """
246        return self.send_multipart([msg], flags=flags, copy=copy, track=track, callback=callback)
247
248    def send_multipart(self, msg, flags=0, copy=True, track=False, callback=None):
249        """Send a multipart message, optionally also register a new callback for sends.
250        See zmq.socket.send_multipart for details.
251        """
252        kwargs = dict(flags=flags, copy=copy, track=track)
253        self._send_queue.put((msg, kwargs))
254        callback = callback or self._send_callback
255        if callback is not None:
256            self.on_send(callback)
257        else:
258            # noop callback
259            self.on_send(lambda *args: None)
260        self._add_io_state(self.io_loop.WRITE)
261
262    def send_string(self, u, flags=0, encoding='utf-8', callback=None):
263        """Send a unicode message with an encoding.
264        See zmq.socket.send_unicode for details.
265        """
266        if not isinstance(u, basestring):
267            raise TypeError("unicode/str objects only")
268        return self.send(u.encode(encoding), flags=flags, callback=callback)
269
270    send_unicode = send_string
271
272    def send_json(self, obj, flags=0, callback=None):
273        """Send json-serialized version of an object.
274        See zmq.socket.send_json for details.
275        """
276        if jsonapi is None:
277            raise ImportError('jsonlib{1,2}, json or simplejson library is required.')
278        else:
279            msg = jsonapi.dumps(obj)
280            return self.send(msg, flags=flags, callback=callback)
281
282    def send_pyobj(self, obj, flags=0, protocol=-1, callback=None):
283        """Send a Python object as a message using pickle to serialize.
284
285        See zmq.socket.send_json for details.
286        """
287        msg = pickle.dumps(obj, protocol)
288        return self.send(msg, flags, callback=callback)
289
290    def _finish_flush(self):
291        """callback for unsetting _flushed flag."""
292        self._flushed = False
293
294    def flush(self, flag=zmq.POLLIN|zmq.POLLOUT, limit=None):
295        """Flush pending messages.
296
297        This method safely handles all pending incoming and/or outgoing messages,
298        bypassing the inner loop, passing them to the registered callbacks.
299
300        A limit can be specified, to prevent blocking under high load.
301
302        flush will return the first time ANY of these conditions are met:
303            * No more events matching the flag are pending.
304            * the total number of events handled reaches the limit.
305
306        Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback
307        is registered, unlike normal IOLoop operation. This allows flush to be
308        used to remove *and ignore* incoming messages.
309
310        Parameters
311        ----------
312        flag : int, default=POLLIN|POLLOUT
313                0MQ poll flags.
314                If flag|POLLIN,  recv events will be flushed.
315                If flag|POLLOUT, send events will be flushed.
316                Both flags can be set at once, which is the default.
317        limit : None or int, optional
318                The maximum number of messages to send or receive.
319                Both send and recv count against this limit.
320
321        Returns
322        -------
323        int : count of events handled (both send and recv)
324        """
325        self._check_closed()
326        # unset self._flushed, so callbacks will execute, in case flush has
327        # already been called this iteration
328        already_flushed = self._flushed
329        self._flushed = False
330        # initialize counters
331        count = 0
332        def update_flag():
333            """Update the poll flag, to prevent registering POLLOUT events
334            if we don't have pending sends."""
335            return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT)
336        flag = update_flag()
337        if not flag:
338            # nothing to do
339            return 0
340        self.poller.register(self.socket, flag)
341        events = self.poller.poll(0)
342        while events and (not limit or count < limit):
343            s,event = events[0]
344            if event & zmq.POLLIN: # receiving
345                self._handle_recv()
346                count += 1
347                if self.socket is None:
348                    # break if socket was closed during callback
349                    break
350            if event & zmq.POLLOUT and self.sending():
351                self._handle_send()
352                count += 1
353                if self.socket is None:
354                    # break if socket was closed during callback
355                    break
356
357            flag = update_flag()
358            if flag:
359                self.poller.register(self.socket, flag)
360                events = self.poller.poll(0)
361            else:
362                events = []
363        if count: # only bypass loop if we actually flushed something
364            # skip send/recv callbacks this iteration
365            self._flushed = True
366            # reregister them at the end of the loop
367            if not already_flushed: # don't need to do it again
368                self.io_loop.add_callback(self._finish_flush)
369        elif already_flushed:
370            self._flushed = True
371
372        # update ioloop poll state, which may have changed
373        self._rebuild_io_state()
374        return count
375
376    def set_close_callback(self, callback):
377        """Call the given callback when the stream is closed."""
378        self._close_callback = stack_context.wrap(callback)
379
380    def close(self, linger=None):
381        """Close this stream."""
382        if self.socket is not None:
383            self.io_loop.remove_handler(self.socket)
384            self.socket.close(linger)
385            self.socket = None
386            if self._close_callback:
387                self._run_callback(self._close_callback)
388
389    def receiving(self):
390        """Returns True if we are currently receiving from the stream."""
391        return self._recv_callback is not None
392
393    def sending(self):
394        """Returns True if we are currently sending to the stream."""
395        return not self._send_queue.empty()
396
397    def closed(self):
398        return self.socket is None
399
400    def _run_callback(self, callback, *args, **kwargs):
401        """Wrap running callbacks in try/except to allow us to
402        close our socket."""
403        try:
404            # Use a NullContext to ensure that all StackContexts are run
405            # inside our blanket exception handler rather than outside.
406            with stack_context.NullContext():
407                callback(*args, **kwargs)
408        except:
409            gen_log.error("Uncaught exception, closing connection.",
410                          exc_info=True)
411            # Close the socket on an uncaught exception from a user callback
412            # (It would eventually get closed when the socket object is
413            # gc'd, but we don't want to rely on gc happening before we
414            # run out of file descriptors)
415            self.close()
416            # Re-raise the exception so that IOLoop.handle_callback_exception
417            # can see it and log the error
418            raise
419
420    def _handle_events(self, fd, events):
421        """This method is the actual handler for IOLoop, that gets called whenever
422        an event on my socket is posted. It dispatches to _handle_recv, etc."""
423        # print "handling events"
424        if not self.socket:
425            gen_log.warning("Got events for closed stream %s", fd)
426            return
427        try:
428            # dispatch events:
429            if events & IOLoop.ERROR:
430                gen_log.error("got POLLERR event on ZMQStream, which doesn't make sense")
431                return
432            if events & IOLoop.READ:
433                self._handle_recv()
434                if not self.socket:
435                    return
436            if events & IOLoop.WRITE:
437                self._handle_send()
438                if not self.socket:
439                    return
440
441            # rebuild the poll state
442            self._rebuild_io_state()
443        except:
444            gen_log.error("Uncaught exception, closing connection.",
445                          exc_info=True)
446            self.close()
447            raise
448
449    def _handle_recv(self):
450        """Handle a recv event."""
451        if self._flushed:
452            return
453        try:
454            msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy)
455        except zmq.ZMQError as e:
456            if e.errno == zmq.EAGAIN:
457                # state changed since poll event
458                pass
459            else:
460                gen_log.error("RECV Error: %s"%zmq.strerror(e.errno))
461        else:
462            if self._recv_callback:
463                callback = self._recv_callback
464                # self._recv_callback = None
465                self._run_callback(callback, msg)
466
467        # self.update_state()
468
469
470    def _handle_send(self):
471        """Handle a send event."""
472        if self._flushed:
473            return
474        if not self.sending():
475            gen_log.error("Shouldn't have handled a send event")
476            return
477
478        msg, kwargs = self._send_queue.get()
479        try:
480            status = self.socket.send_multipart(msg, **kwargs)
481        except zmq.ZMQError as e:
482            gen_log.error("SEND Error: %s", e)
483            status = e
484        if self._send_callback:
485            callback = self._send_callback
486            self._run_callback(callback, msg, status)
487
488        # self.update_state()
489
490    def _check_closed(self):
491        if not self.socket:
492            raise IOError("Stream is closed")
493
494    def _rebuild_io_state(self):
495        """rebuild io state based on self.sending() and receiving()"""
496        if self.socket is None:
497            return
498        state = self.io_loop.ERROR
499        if self.receiving():
500            state |= self.io_loop.READ
501        if self.sending():
502            state |= self.io_loop.WRITE
503        if state != self._state:
504            self._state = state
505            self._update_handler(state)
506
507    def _add_io_state(self, state):
508        """Add io_state to poller."""
509        if not self._state & state:
510            self._state = self._state | state
511            self._update_handler(self._state)
512
513    def _drop_io_state(self, state):
514        """Stop poller from watching an io_state."""
515        if self._state & state:
516            self._state = self._state & (~state)
517            self._update_handler(self._state)
518
519    def _update_handler(self, state):
520        """Update IOLoop handler with state."""
521        if self.socket is None:
522            return
523        self.io_loop.update_handler(self.socket, state)
524
525    def _init_io_state(self):
526        """initialize the ioloop event handler"""
527        with stack_context.NullContext():
528            self.io_loop.add_handler(self.socket, self._handle_events, self._state)
529
530