test_security.py revision d3907f0d
1"""Test libzmq security (libzmq >= 3.3.0)"""
2# -*- coding: utf8 -*-
3
4# Copyright (C) PyZMQ Developers
5# Distributed under the terms of the Modified BSD License.
6
7import os
8from threading import Thread
9
10import zmq
11from zmq.tests import (
12    BaseZMQTestCase, SkipTest, PYPY
13)
14from zmq.utils import z85
15
16
17USER = b"admin"
18PASS = b"password"
19
20class TestSecurity(BaseZMQTestCase):
21
22    def setUp(self):
23        if zmq.zmq_version_info() < (4,0):
24            raise SkipTest("security is new in libzmq 4.0")
25        try:
26            zmq.curve_keypair()
27        except zmq.ZMQError:
28            raise SkipTest("security requires libzmq to be linked against libsodium")
29        super(TestSecurity, self).setUp()
30
31
32    def zap_handler(self):
33        socket = self.context.socket(zmq.REP)
34        socket.bind("inproc://zeromq.zap.01")
35        try:
36            msg = self.recv_multipart(socket)
37
38            version, sequence, domain, address, identity, mechanism = msg[:6]
39            if mechanism == b'PLAIN':
40                username, password = msg[6:]
41            elif mechanism == b'CURVE':
42                key = msg[6]
43
44            self.assertEqual(version, b"1.0")
45            self.assertEqual(identity, b"IDENT")
46            reply = [version, sequence]
47            if mechanism == b'CURVE' or \
48                (mechanism == b'PLAIN' and username == USER and password == PASS) or \
49                (mechanism == b'NULL'):
50                reply.extend([
51                    b"200",
52                    b"OK",
53                    b"anonymous",
54                    b"\5Hello\0\0\0\5World",
55                ])
56            else:
57                reply.extend([
58                    b"400",
59                    b"Invalid username or password",
60                    b"",
61                    b"",
62                ])
63            socket.send_multipart(reply)
64        finally:
65            socket.close()
66
67    def start_zap(self):
68        self.zap_thread = Thread(target=self.zap_handler)
69        self.zap_thread.start()
70
71    def stop_zap(self):
72        self.zap_thread.join()
73
74    def bounce(self, server, client, test_metadata=True):
75        msg = [os.urandom(64), os.urandom(64)]
76        client.send_multipart(msg)
77        frames = self.recv_multipart(server, copy=False)
78        recvd = list(map(lambda x: x.bytes, frames))
79
80        try:
81            if test_metadata and not PYPY:
82                for frame in frames:
83                    self.assertEqual(frame.get('User-Id'), 'anonymous')
84                    self.assertEqual(frame.get('Hello'), 'World')
85                    self.assertEqual(frame['Socket-Type'], 'DEALER')
86        except zmq.ZMQVersionError:
87            pass
88
89        self.assertEqual(recvd, msg)
90        server.send_multipart(recvd)
91        msg2 = self.recv_multipart(client)
92        self.assertEqual(msg2, msg)
93
94    def test_null(self):
95        """test NULL (default) security"""
96        server = self.socket(zmq.DEALER)
97        client = self.socket(zmq.DEALER)
98        self.assertEqual(client.MECHANISM, zmq.NULL)
99        self.assertEqual(server.mechanism, zmq.NULL)
100        self.assertEqual(client.plain_server, 0)
101        self.assertEqual(server.plain_server, 0)
102        iface = 'tcp://127.0.0.1'
103        port = server.bind_to_random_port(iface)
104        client.connect("%s:%i" % (iface, port))
105        self.bounce(server, client, False)
106
107    def test_plain(self):
108        """test PLAIN authentication"""
109        server = self.socket(zmq.DEALER)
110        server.identity = b'IDENT'
111        client = self.socket(zmq.DEALER)
112        self.assertEqual(client.plain_username, b'')
113        self.assertEqual(client.plain_password, b'')
114        client.plain_username = USER
115        client.plain_password = PASS
116        self.assertEqual(client.getsockopt(zmq.PLAIN_USERNAME), USER)
117        self.assertEqual(client.getsockopt(zmq.PLAIN_PASSWORD), PASS)
118        self.assertEqual(client.plain_server, 0)
119        self.assertEqual(server.plain_server, 0)
120        server.plain_server = True
121        self.assertEqual(server.mechanism, zmq.PLAIN)
122        self.assertEqual(client.mechanism, zmq.PLAIN)
123
124        assert not client.plain_server
125        assert server.plain_server
126
127        self.start_zap()
128
129        iface = 'tcp://127.0.0.1'
130        port = server.bind_to_random_port(iface)
131        client.connect("%s:%i" % (iface, port))
132        self.bounce(server, client)
133        self.stop_zap()
134
135    def skip_plain_inauth(self):
136        """test PLAIN failed authentication"""
137        server = self.socket(zmq.DEALER)
138        server.identity = b'IDENT'
139        client = self.socket(zmq.DEALER)
140        self.sockets.extend([server, client])
141        client.plain_username = USER
142        client.plain_password = b'incorrect'
143        server.plain_server = True
144        self.assertEqual(server.mechanism, zmq.PLAIN)
145        self.assertEqual(client.mechanism, zmq.PLAIN)
146
147        self.start_zap()
148
149        iface = 'tcp://127.0.0.1'
150        port = server.bind_to_random_port(iface)
151        client.connect("%s:%i" % (iface, port))
152        client.send(b'ping')
153        server.rcvtimeo = 250
154        self.assertRaisesErrno(zmq.EAGAIN, server.recv)
155        self.stop_zap()
156
157    def test_keypair(self):
158        """test curve_keypair"""
159        try:
160            public, secret = zmq.curve_keypair()
161        except zmq.ZMQError:
162            raise SkipTest("CURVE unsupported")
163
164        self.assertEqual(type(secret), bytes)
165        self.assertEqual(type(public), bytes)
166        self.assertEqual(len(secret), 40)
167        self.assertEqual(len(public), 40)
168
169        # verify that it is indeed Z85
170        bsecret, bpublic = [ z85.decode(key) for key in (public, secret) ]
171        self.assertEqual(type(bsecret), bytes)
172        self.assertEqual(type(bpublic), bytes)
173        self.assertEqual(len(bsecret), 32)
174        self.assertEqual(len(bpublic), 32)
175
176
177    def test_curve(self):
178        """test CURVE encryption"""
179        server = self.socket(zmq.DEALER)
180        server.identity = b'IDENT'
181        client = self.socket(zmq.DEALER)
182        self.sockets.extend([server, client])
183        try:
184            server.curve_server = True
185        except zmq.ZMQError as e:
186            # will raise EINVAL if not linked against libsodium
187            if e.errno == zmq.EINVAL:
188                raise SkipTest("CURVE unsupported")
189
190        server_public, server_secret = zmq.curve_keypair()
191        client_public, client_secret = zmq.curve_keypair()
192
193        server.curve_secretkey = server_secret
194        server.curve_publickey = server_public
195        client.curve_serverkey = server_public
196        client.curve_publickey = client_public
197        client.curve_secretkey = client_secret
198
199        self.assertEqual(server.mechanism, zmq.CURVE)
200        self.assertEqual(client.mechanism, zmq.CURVE)
201
202        self.assertEqual(server.get(zmq.CURVE_SERVER), True)
203        self.assertEqual(client.get(zmq.CURVE_SERVER), False)
204
205        self.start_zap()
206
207        iface = 'tcp://127.0.0.1'
208        port = server.bind_to_random_port(iface)
209        client.connect("%s:%i" % (iface, port))
210        self.bounce(server, client)
211        self.stop_zap()
212
213