socket.py revision 781d71db
1# coding: utf-8
2"""zmq Socket class"""
3
4# Copyright (C) PyZMQ Developers
5# Distributed under the terms of the Modified BSD License.
6
7import random
8import codecs
9
10import errno as errno_mod
11
12from ._cffi import (C, ffi, new_uint64_pointer, new_int64_pointer,
13                    new_int_pointer, new_binary_data, value_uint64_pointer,
14                    value_int64_pointer, value_int_pointer, value_binary_data,
15                    IPC_PATH_MAX_LEN)
16
17from .message import Frame
18from .constants import *
19
20import zmq
21from zmq.error import ZMQError, _check_rc, _check_version
22from zmq.utils.strtypes import unicode
23
24
25def new_pointer_from_opt(option, length=0):
26    from zmq.sugar.constants import (
27        int64_sockopts, bytes_sockopts,
28    )
29    if option in int64_sockopts:
30        return new_int64_pointer()
31    elif option in bytes_sockopts:
32        return new_binary_data(length)
33    else:
34        # default
35        return new_int_pointer()
36
37def value_from_opt_pointer(option, opt_pointer, length=0):
38    from zmq.sugar.constants import (
39        int64_sockopts, bytes_sockopts,
40    )
41    if option in int64_sockopts:
42        return int(opt_pointer[0])
43    elif option in bytes_sockopts:
44        return ffi.buffer(opt_pointer, length)[:]
45    else:
46        return int(opt_pointer[0])
47
48def initialize_opt_pointer(option, value, length=0):
49    from zmq.sugar.constants import (
50        int64_sockopts, bytes_sockopts,
51    )
52    if option in int64_sockopts:
53        return value_int64_pointer(value)
54    elif option in bytes_sockopts:
55        return value_binary_data(value, length)
56    else:
57        return value_int_pointer(value)
58
59
60class Socket(object):
61    context = None
62    socket_type = None
63    _zmq_socket = None
64    _closed = None
65    _ref = None
66    _shadow = False
67
68    def __init__(self, context=None, socket_type=None, shadow=None):
69        self.context = context
70        if shadow is not None:
71            self._zmq_socket = ffi.cast("void *", shadow)
72            self._shadow = True
73        else:
74            self._shadow = False
75            self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type)
76        if self._zmq_socket == ffi.NULL:
77            raise ZMQError()
78        self._closed = False
79        if context:
80            self._ref = context._add_socket(self)
81
82    @property
83    def underlying(self):
84        """The address of the underlying libzmq socket"""
85        return int(ffi.cast('size_t', self._zmq_socket))
86
87    @property
88    def closed(self):
89        return self._closed
90
91    def close(self, linger=None):
92        rc = 0
93        if not self._closed and hasattr(self, '_zmq_socket'):
94            if self._zmq_socket is not None:
95                rc = C.zmq_close(self._zmq_socket)
96            self._closed = True
97            if self.context:
98                self.context._rm_socket(self._ref)
99        return rc
100
101    def bind(self, address):
102        if isinstance(address, unicode):
103            address = address.encode('utf8')
104        rc = C.zmq_bind(self._zmq_socket, address)
105        if rc < 0:
106            if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG:
107                # py3compat: address is bytes, but msg wants str
108                if str is unicode:
109                    address = address.decode('utf-8', 'replace')
110                path = address.split('://', 1)[-1]
111                msg = ('ipc path "{0}" is longer than {1} '
112                                'characters (sizeof(sockaddr_un.sun_path)).'
113                                .format(path, IPC_PATH_MAX_LEN))
114                raise ZMQError(C.zmq_errno(), msg=msg)
115            else:
116                _check_rc(rc)
117
118    def unbind(self, address):
119        _check_version((3,2), "unbind")
120        if isinstance(address, unicode):
121            address = address.encode('utf8')
122        rc = C.zmq_unbind(self._zmq_socket, address)
123        _check_rc(rc)
124
125    def connect(self, address):
126        if isinstance(address, unicode):
127            address = address.encode('utf8')
128        rc = C.zmq_connect(self._zmq_socket, address)
129        _check_rc(rc)
130
131    def disconnect(self, address):
132        _check_version((3,2), "disconnect")
133        if isinstance(address, unicode):
134            address = address.encode('utf8')
135        rc = C.zmq_disconnect(self._zmq_socket, address)
136        _check_rc(rc)
137
138    def set(self, option, value):
139        length = None
140        if isinstance(value, unicode):
141            raise TypeError("unicode not allowed, use bytes")
142
143        if isinstance(value, bytes):
144            if option not in zmq.constants.bytes_sockopts:
145                raise TypeError("not a bytes sockopt: %s" % option)
146            length = len(value)
147
148        c_data = initialize_opt_pointer(option, value, length)
149
150        c_value_pointer = c_data[0]
151        c_sizet = c_data[1]
152
153        rc = C.zmq_setsockopt(self._zmq_socket,
154                               option,
155                               ffi.cast('void*', c_value_pointer),
156                               c_sizet)
157        _check_rc(rc)
158
159    def get(self, option):
160        c_data = new_pointer_from_opt(option, length=255)
161
162        c_value_pointer = c_data[0]
163        c_sizet_pointer = c_data[1]
164
165        rc = C.zmq_getsockopt(self._zmq_socket,
166                               option,
167                               c_value_pointer,
168                               c_sizet_pointer)
169        _check_rc(rc)
170
171        sz = c_sizet_pointer[0]
172        v = value_from_opt_pointer(option, c_value_pointer, sz)
173        if option != zmq.IDENTITY and option in zmq.constants.bytes_sockopts and v.endswith(b'\0'):
174            v = v[:-1]
175        return v
176
177    def send(self, message, flags=0, copy=False, track=False):
178        if isinstance(message, unicode):
179            raise TypeError("Message must be in bytes, not an unicode Object")
180
181        if isinstance(message, Frame):
182            message = message.bytes
183
184        zmq_msg = ffi.new('zmq_msg_t*')
185        c_message = ffi.new('char[]', message)
186        rc = C.zmq_msg_init_size(zmq_msg, len(message))
187        C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(message))
188
189        rc = C.zmq_msg_send(zmq_msg, self._zmq_socket, flags)
190        C.zmq_msg_close(zmq_msg)
191        _check_rc(rc)
192
193        if track:
194            return zmq.MessageTracker()
195
196    def recv(self, flags=0, copy=True, track=False):
197        zmq_msg = ffi.new('zmq_msg_t*')
198        C.zmq_msg_init(zmq_msg)
199
200        rc = C.zmq_msg_recv(zmq_msg, self._zmq_socket, flags)
201
202        if rc < 0:
203            C.zmq_msg_close(zmq_msg)
204            _check_rc(rc)
205
206        _buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg))
207        value = _buffer[:]
208        C.zmq_msg_close(zmq_msg)
209
210        frame = Frame(value, track=track)
211        frame.more = self.getsockopt(RCVMORE)
212
213        if copy:
214            return frame.bytes
215        else:
216            return frame
217
218    def monitor(self, addr, events=-1):
219        """s.monitor(addr, flags)
220
221        Start publishing socket events on inproc.
222        See libzmq docs for zmq_monitor for details.
223
224        Note: requires libzmq >= 3.2
225
226        Parameters
227        ----------
228        addr : str
229            The inproc url used for monitoring. Passing None as
230            the addr will cause an existing socket monitor to be
231            deregistered.
232        events : int [default: zmq.EVENT_ALL]
233            The zmq event bitmask for which events will be sent to the monitor.
234        """
235
236        _check_version((3,2), "monitor")
237        if events < 0:
238            events = zmq.EVENT_ALL
239        if addr is None:
240            addr = ffi.NULL
241        rc = C.zmq_socket_monitor(self._zmq_socket, addr, events)
242
243
244__all__ = ['Socket', 'IPC_PATH_MAX_LEN']
245