1# coding: utf-8
2"""0MQ Socket pure Python methods."""
3
4# Copyright (C) PyZMQ Developers
5# Distributed under the terms of the Modified BSD License.
6
7
8import codecs
9import random
10import warnings
11
12import zmq
13from zmq.backend import Socket as SocketBase
14from .poll import Poller
15from . import constants
16from .attrsettr import AttributeSetter
17from zmq.error import ZMQError, ZMQBindError
18from zmq.utils import jsonapi
19from zmq.utils.strtypes import bytes,unicode,basestring
20from zmq.utils.interop import cast_int_addr
21
22from .constants import (
23    SNDMORE, ENOTSUP, POLLIN,
24    int64_sockopt_names,
25    int_sockopt_names,
26    bytes_sockopt_names,
27    fd_sockopt_names,
28)
29try:
30    import cPickle
31    pickle = cPickle
32except:
33    cPickle = None
34    import pickle
35
36try:
37    DEFAULT_PROTOCOL = pickle.DEFAULT_PROTOCOL
38except AttributeError:
39    DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
40
41
42class Socket(SocketBase, AttributeSetter):
43    """The ZMQ socket object
44
45    To create a Socket, first create a Context::
46
47        ctx = zmq.Context.instance()
48
49    then call ``ctx.socket(socket_type)``::
50
51        s = ctx.socket(zmq.ROUTER)
52
53    """
54    _shadow = False
55
56    def __del__(self):
57        if not self._shadow:
58            self.close()
59
60    # socket as context manager:
61    def __enter__(self):
62        """Sockets are context managers
63
64        .. versionadded:: 14.4
65        """
66        return self
67
68    def __exit__(self, *args, **kwargs):
69        self.close()
70
71    #-------------------------------------------------------------------------
72    # Socket creation
73    #-------------------------------------------------------------------------
74
75    @classmethod
76    def shadow(cls, address):
77        """Shadow an existing libzmq socket
78
79        address is the integer address of the libzmq socket
80        or an FFI pointer to it.
81
82        .. versionadded:: 14.1
83        """
84        address = cast_int_addr(address)
85        return cls(shadow=address)
86
87    #-------------------------------------------------------------------------
88    # Deprecated aliases
89    #-------------------------------------------------------------------------
90
91    @property
92    def socket_type(self):
93        warnings.warn("Socket.socket_type is deprecated, use Socket.type",
94            DeprecationWarning
95        )
96        return self.type
97
98    #-------------------------------------------------------------------------
99    # Hooks for sockopt completion
100    #-------------------------------------------------------------------------
101
102    def __dir__(self):
103        keys = dir(self.__class__)
104        for collection in (
105            bytes_sockopt_names,
106            int_sockopt_names,
107            int64_sockopt_names,
108            fd_sockopt_names,
109        ):
110            keys.extend(collection)
111        return keys
112
113    #-------------------------------------------------------------------------
114    # Getting/Setting options
115    #-------------------------------------------------------------------------
116    setsockopt = SocketBase.set
117    getsockopt = SocketBase.get
118
119    def set_string(self, option, optval, encoding='utf-8'):
120        """set socket options with a unicode object
121
122        This is simply a wrapper for setsockopt to protect from encoding ambiguity.
123
124        See the 0MQ documentation for details on specific options.
125
126        Parameters
127        ----------
128        option : int
129            The name of the option to set. Can be any of: SUBSCRIBE,
130            UNSUBSCRIBE, IDENTITY
131        optval : unicode string (unicode on py2, str on py3)
132            The value of the option to set.
133        encoding : str
134            The encoding to be used, default is utf8
135        """
136        if not isinstance(optval, unicode):
137            raise TypeError("unicode strings only")
138        return self.set(option, optval.encode(encoding))
139
140    setsockopt_unicode = setsockopt_string = set_string
141
142    def get_string(self, option, encoding='utf-8'):
143        """get the value of a socket option
144
145        See the 0MQ documentation for details on specific options.
146
147        Parameters
148        ----------
149        option : int
150            The option to retrieve.
151
152        Returns
153        -------
154        optval : unicode string (unicode on py2, str on py3)
155            The value of the option as a unicode string.
156        """
157
158        if option not in constants.bytes_sockopts:
159            raise TypeError("option %i will not return a string to be decoded"%option)
160        return self.getsockopt(option).decode(encoding)
161
162    getsockopt_unicode = getsockopt_string = get_string
163
164    def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=100):
165        """bind this socket to a random port in a range
166
167        Parameters
168        ----------
169        addr : str
170            The address string without the port to pass to ``Socket.bind()``.
171        min_port : int, optional
172            The minimum port in the range of ports to try (inclusive).
173        max_port : int, optional
174            The maximum port in the range of ports to try (exclusive).
175        max_tries : int, optional
176            The maximum number of bind attempts to make.
177
178        Returns
179        -------
180        port : int
181            The port the socket was bound to.
182
183        Raises
184        ------
185        ZMQBindError
186            if `max_tries` reached before successful bind
187        """
188        for i in range(max_tries):
189            try:
190                port = random.randrange(min_port, max_port)
191                self.bind('%s:%s' % (addr, port))
192            except ZMQError as exception:
193                if not exception.errno == zmq.EADDRINUSE:
194                    raise
195            else:
196                return port
197        raise ZMQBindError("Could not bind socket to random port.")
198
199    def get_hwm(self):
200        """get the High Water Mark
201
202        On libzmq ≥ 3, this gets SNDHWM if available, otherwise RCVHWM
203        """
204        major = zmq.zmq_version_info()[0]
205        if major >= 3:
206            # return sndhwm, fallback on rcvhwm
207            try:
208                return self.getsockopt(zmq.SNDHWM)
209            except zmq.ZMQError as e:
210                pass
211
212            return self.getsockopt(zmq.RCVHWM)
213        else:
214            return self.getsockopt(zmq.HWM)
215
216    def set_hwm(self, value):
217        """set the High Water Mark
218
219        On libzmq ≥ 3, this sets both SNDHWM and RCVHWM
220        """
221        major = zmq.zmq_version_info()[0]
222        if major >= 3:
223            raised = None
224            try:
225                self.sndhwm = value
226            except Exception as e:
227                raised = e
228            try:
229                self.rcvhwm = value
230            except Exception:
231                raised = e
232
233            if raised:
234                raise raised
235        else:
236            return self.setsockopt(zmq.HWM, value)
237
238    hwm = property(get_hwm, set_hwm,
239        """property for High Water Mark
240
241        Setting hwm sets both SNDHWM and RCVHWM as appropriate.
242        It gets SNDHWM if available, otherwise RCVHWM.
243        """
244    )
245
246    #-------------------------------------------------------------------------
247    # Sending and receiving messages
248    #-------------------------------------------------------------------------
249
250    def send_multipart(self, msg_parts, flags=0, copy=True, track=False):
251        """send a sequence of buffers as a multipart message
252
253        The zmq.SNDMORE flag is added to all msg parts before the last.
254
255        Parameters
256        ----------
257        msg_parts : iterable
258            A sequence of objects to send as a multipart message. Each element
259            can be any sendable object (Frame, bytes, buffer-providers)
260        flags : int, optional
261            SNDMORE is handled automatically for frames before the last.
262        copy : bool, optional
263            Should the frame(s) be sent in a copying or non-copying manner.
264        track : bool, optional
265            Should the frame(s) be tracked for notification that ZMQ has
266            finished with it (ignored if copy=True).
267
268        Returns
269        -------
270        None : if copy or not track
271        MessageTracker : if track and not copy
272            a MessageTracker object, whose `pending` property will
273            be True until the last send is completed.
274        """
275        for msg in msg_parts[:-1]:
276            self.send(msg, SNDMORE|flags, copy=copy, track=track)
277        # Send the last part without the extra SNDMORE flag.
278        return self.send(msg_parts[-1], flags, copy=copy, track=track)
279
280    def recv_multipart(self, flags=0, copy=True, track=False):
281        """receive a multipart message as a list of bytes or Frame objects
282
283        Parameters
284        ----------
285        flags : int, optional
286            Any supported flag: NOBLOCK. If NOBLOCK is set, this method
287            will raise a ZMQError with EAGAIN if a message is not ready.
288            If NOBLOCK is not set, then this method will block until a
289            message arrives.
290        copy : bool, optional
291            Should the message frame(s) be received in a copying or non-copying manner?
292            If False a Frame object is returned for each part, if True a copy of
293            the bytes is made for each frame.
294        track : bool, optional
295            Should the message frame(s) be tracked for notification that ZMQ has
296            finished with it? (ignored if copy=True)
297
298        Returns
299        -------
300        msg_parts : list
301            A list of frames in the multipart message; either Frames or bytes,
302            depending on `copy`.
303
304        """
305        parts = [self.recv(flags, copy=copy, track=track)]
306        # have first part already, only loop while more to receive
307        while self.getsockopt(zmq.RCVMORE):
308            part = self.recv(flags, copy=copy, track=track)
309            parts.append(part)
310
311        return parts
312
313    def send_string(self, u, flags=0, copy=True, encoding='utf-8'):
314        """send a Python unicode string as a message with an encoding
315
316        0MQ communicates with raw bytes, so you must encode/decode
317        text (unicode on py2, str on py3) around 0MQ.
318
319        Parameters
320        ----------
321        u : Python unicode string (unicode on py2, str on py3)
322            The unicode string to send.
323        flags : int, optional
324            Any valid send flag.
325        encoding : str [default: 'utf-8']
326            The encoding to be used
327        """
328        if not isinstance(u, basestring):
329            raise TypeError("unicode/str objects only")
330        return self.send(u.encode(encoding), flags=flags, copy=copy)
331
332    send_unicode = send_string
333
334    def recv_string(self, flags=0, encoding='utf-8'):
335        """receive a unicode string, as sent by send_string
336
337        Parameters
338        ----------
339        flags : int
340            Any valid recv flag.
341        encoding : str [default: 'utf-8']
342            The encoding to be used
343
344        Returns
345        -------
346        s : unicode string (unicode on py2, str on py3)
347            The Python unicode string that arrives as encoded bytes.
348        """
349        b = self.recv(flags=flags)
350        return b.decode(encoding)
351
352    recv_unicode = recv_string
353
354    def send_pyobj(self, obj, flags=0, protocol=DEFAULT_PROTOCOL):
355        """send a Python object as a message using pickle to serialize
356
357        Parameters
358        ----------
359        obj : Python object
360            The Python object to send.
361        flags : int
362            Any valid send flag.
363        protocol : int
364            The pickle protocol number to use. The default is pickle.DEFAULT_PROTOCOl
365            where defined, and pickle.HIGHEST_PROTOCOL elsewhere.
366        """
367        msg = pickle.dumps(obj, protocol)
368        return self.send(msg, flags)
369
370    def recv_pyobj(self, flags=0):
371        """receive a Python object as a message using pickle to serialize
372
373        Parameters
374        ----------
375        flags : int
376            Any valid recv flag.
377
378        Returns
379        -------
380        obj : Python object
381            The Python object that arrives as a message.
382        """
383        s = self.recv(flags)
384        return pickle.loads(s)
385
386    def send_json(self, obj, flags=0, **kwargs):
387        """send a Python object as a message using json to serialize
388
389        Keyword arguments are passed on to json.dumps
390
391        Parameters
392        ----------
393        obj : Python object
394            The Python object to send
395        flags : int
396            Any valid send flag
397        """
398        msg = jsonapi.dumps(obj, **kwargs)
399        return self.send(msg, flags)
400
401    def recv_json(self, flags=0, **kwargs):
402        """receive a Python object as a message using json to serialize
403
404        Keyword arguments are passed on to json.loads
405
406        Parameters
407        ----------
408        flags : int
409            Any valid recv flag.
410
411        Returns
412        -------
413        obj : Python object
414            The Python object that arrives as a message.
415        """
416        msg = self.recv(flags)
417        return jsonapi.loads(msg, **kwargs)
418
419    _poller_class = Poller
420
421    def poll(self, timeout=None, flags=POLLIN):
422        """poll the socket for events
423
424        The default is to poll forever for incoming
425        events.  Timeout is in milliseconds, if specified.
426
427        Parameters
428        ----------
429        timeout : int [default: None]
430            The timeout (in milliseconds) to wait for an event. If unspecified
431            (or specified None), will wait forever for an event.
432        flags : bitfield (int) [default: POLLIN]
433            The event flags to poll for (any combination of POLLIN|POLLOUT).
434            The default is to check for incoming events (POLLIN).
435
436        Returns
437        -------
438        events : bitfield (int)
439            The events that are ready and waiting.  Will be 0 if no events were ready
440            by the time timeout was reached.
441        """
442
443        if self.closed:
444            raise ZMQError(ENOTSUP)
445
446        p = self._poller_class()
447        p.register(self, flags)
448        evts = dict(p.poll(timeout))
449        # return 0 if no events, otherwise return event bitfield
450        return evts.get(self, 0)
451
452    def get_monitor_socket(self, events=None, addr=None):
453        """Return a connected PAIR socket ready to receive the event notifications.
454
455        .. versionadded:: libzmq-4.0
456        .. versionadded:: 14.0
457
458        Parameters
459        ----------
460        events : bitfield (int) [default: ZMQ_EVENTS_ALL]
461            The bitmask defining which events are wanted.
462        addr :  string [default: None]
463            The optional endpoint for the monitoring sockets.
464
465        Returns
466        -------
467        socket :  (PAIR)
468            The socket is already connected and ready to receive messages.
469        """
470        # safe-guard, method only available on libzmq >= 4
471        if zmq.zmq_version_info() < (4,):
472            raise NotImplementedError("get_monitor_socket requires libzmq >= 4, have %s" % zmq.zmq_version())
473        if addr is None:
474            # create endpoint name from internal fd
475            addr = "inproc://monitor.s-%d" % self.FD
476        if events is None:
477            # use all events
478            events = zmq.EVENT_ALL
479        # attach monitoring socket
480        self.monitor(addr, events)
481        # create new PAIR socket and connect it
482        ret = self.context.socket(zmq.PAIR)
483        ret.connect(addr)
484        return ret
485
486    def disable_monitor(self):
487        """Shutdown the PAIR socket (created using get_monitor_socket)
488        that is serving socket events.
489
490        .. versionadded:: 14.4
491        """
492        self.monitor(None, 0)
493
494
495__all__ = ['Socket']
496