1# -*- coding: utf8 -*-
2
3# Copyright (C) PyZMQ Developers
4# Distributed under the terms of the Modified BSD License.
5
6import logging
7import os
8import shutil
9import sys
10import tempfile
11
12import zmq.auth
13from zmq.auth.ioloop import IOLoopAuthenticator
14from zmq.auth.thread import ThreadAuthenticator
15
16from zmq.eventloop import ioloop, zmqstream
17from zmq.tests import (BaseZMQTestCase, SkipTest)
18
19class BaseAuthTestCase(BaseZMQTestCase):
20    def setUp(self):
21        if zmq.zmq_version_info() < (4,0):
22            raise SkipTest("security is new in libzmq 4.0")
23        try:
24            zmq.curve_keypair()
25        except zmq.ZMQError:
26            raise SkipTest("security requires libzmq to be linked against libsodium")
27        super(BaseAuthTestCase, self).setUp()
28        # enable debug logging while we run tests
29        logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
30        self.auth = self.make_auth()
31        self.auth.start()
32        self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
33
34    def make_auth(self):
35        raise NotImplementedError()
36
37    def tearDown(self):
38        if self.auth:
39            self.auth.stop()
40            self.auth = None
41        self.remove_certs(self.base_dir)
42        super(BaseAuthTestCase, self).tearDown()
43
44    def create_certs(self):
45        """Create CURVE certificates for a test"""
46
47        # Create temporary CURVE keypairs for this test run. We create all keys in a
48        # temp directory and then move them into the appropriate private or public
49        # directory.
50
51        base_dir = tempfile.mkdtemp()
52        keys_dir = os.path.join(base_dir, 'certificates')
53        public_keys_dir = os.path.join(base_dir, 'public_keys')
54        secret_keys_dir = os.path.join(base_dir, 'private_keys')
55
56        os.mkdir(keys_dir)
57        os.mkdir(public_keys_dir)
58        os.mkdir(secret_keys_dir)
59
60        server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server")
61        client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client")
62
63        for key_file in os.listdir(keys_dir):
64            if key_file.endswith(".key"):
65                shutil.move(os.path.join(keys_dir, key_file),
66                            os.path.join(public_keys_dir, '.'))
67
68        for key_file in os.listdir(keys_dir):
69            if key_file.endswith(".key_secret"):
70                shutil.move(os.path.join(keys_dir, key_file),
71                            os.path.join(secret_keys_dir, '.'))
72
73        return (base_dir, public_keys_dir, secret_keys_dir)
74
75    def remove_certs(self, base_dir):
76        """Remove certificates for a test"""
77        shutil.rmtree(base_dir)
78
79    def load_certs(self, secret_keys_dir):
80        """Return server and client certificate keys"""
81        server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
82        client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
83
84        server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
85        client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
86
87        return server_public, server_secret, client_public, client_secret
88
89
90class TestThreadAuthentication(BaseAuthTestCase):
91    """Test authentication running in a thread"""
92
93    def make_auth(self):
94        return ThreadAuthenticator(self.context)
95
96    def can_connect(self, server, client):
97        """Check if client can connect to server using tcp transport"""
98        result = False
99        iface = 'tcp://127.0.0.1'
100        port = server.bind_to_random_port(iface)
101        client.connect("%s:%i" % (iface, port))
102        msg = [b"Hello World"]
103        server.send_multipart(msg)
104        if client.poll(1000):
105            rcvd_msg = client.recv_multipart()
106            self.assertEqual(rcvd_msg, msg)
107            result = True
108        return result
109
110    def test_null(self):
111        """threaded auth - NULL"""
112        # A default NULL connection should always succeed, and not
113        # go through our authentication infrastructure at all.
114        self.auth.stop()
115        self.auth = None
116
117        server = self.socket(zmq.PUSH)
118        client = self.socket(zmq.PULL)
119        self.assertTrue(self.can_connect(server, client))
120
121        # By setting a domain we switch on authentication for NULL sockets,
122        # though no policies are configured yet. The client connection
123        # should still be allowed.
124        server = self.socket(zmq.PUSH)
125        server.zap_domain = b'global'
126        client = self.socket(zmq.PULL)
127        self.assertTrue(self.can_connect(server, client))
128
129    def test_blacklist(self):
130        """threaded auth - Blacklist"""
131        # Blacklist 127.0.0.1, connection should fail
132        self.auth.deny('127.0.0.1')
133        server = self.socket(zmq.PUSH)
134        # By setting a domain we switch on authentication for NULL sockets,
135        # though no policies are configured yet.
136        server.zap_domain = b'global'
137        client = self.socket(zmq.PULL)
138        self.assertFalse(self.can_connect(server, client))
139
140    def test_whitelist(self):
141        """threaded auth - Whitelist"""
142        # Whitelist 127.0.0.1, connection should pass"
143        self.auth.allow('127.0.0.1')
144        server = self.socket(zmq.PUSH)
145        # By setting a domain we switch on authentication for NULL sockets,
146        # though no policies are configured yet.
147        server.zap_domain = b'global'
148        client = self.socket(zmq.PULL)
149        self.assertTrue(self.can_connect(server, client))
150
151    def test_plain(self):
152        """threaded auth - PLAIN"""
153
154        # Try PLAIN authentication - without configuring server, connection should fail
155        server = self.socket(zmq.PUSH)
156        server.plain_server = True
157        client = self.socket(zmq.PULL)
158        client.plain_username = b'admin'
159        client.plain_password = b'Password'
160        self.assertFalse(self.can_connect(server, client))
161
162        # Try PLAIN authentication - with server configured, connection should pass
163        server = self.socket(zmq.PUSH)
164        server.plain_server = True
165        client = self.socket(zmq.PULL)
166        client.plain_username = b'admin'
167        client.plain_password = b'Password'
168        self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
169        self.assertTrue(self.can_connect(server, client))
170
171        # Try PLAIN authentication - with bogus credentials, connection should fail
172        server = self.socket(zmq.PUSH)
173        server.plain_server = True
174        client = self.socket(zmq.PULL)
175        client.plain_username = b'admin'
176        client.plain_password = b'Bogus'
177        self.assertFalse(self.can_connect(server, client))
178
179        # Remove authenticator and check that a normal connection works
180        self.auth.stop()
181        self.auth = None
182
183        server = self.socket(zmq.PUSH)
184        client = self.socket(zmq.PULL)
185        self.assertTrue(self.can_connect(server, client))
186        client.close()
187        server.close()
188
189    def test_curve(self):
190        """threaded auth - CURVE"""
191        self.auth.allow('127.0.0.1')
192        certs = self.load_certs(self.secret_keys_dir)
193        server_public, server_secret, client_public, client_secret = certs
194
195        #Try CURVE authentication - without configuring server, connection should fail
196        server = self.socket(zmq.PUSH)
197        server.curve_publickey = server_public
198        server.curve_secretkey = server_secret
199        server.curve_server = True
200        client = self.socket(zmq.PULL)
201        client.curve_publickey = client_public
202        client.curve_secretkey = client_secret
203        client.curve_serverkey = server_public
204        self.assertFalse(self.can_connect(server, client))
205
206        #Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
207        self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
208        server = self.socket(zmq.PUSH)
209        server.curve_publickey = server_public
210        server.curve_secretkey = server_secret
211        server.curve_server = True
212        client = self.socket(zmq.PULL)
213        client.curve_publickey = client_public
214        client.curve_secretkey = client_secret
215        client.curve_serverkey = server_public
216        self.assertTrue(self.can_connect(server, client))
217
218        # Try CURVE authentication - with server configured, connection should pass
219        self.auth.configure_curve(domain='*', location=self.public_keys_dir)
220        server = self.socket(zmq.PUSH)
221        server.curve_publickey = server_public
222        server.curve_secretkey = server_secret
223        server.curve_server = True
224        client = self.socket(zmq.PULL)
225        client.curve_publickey = client_public
226        client.curve_secretkey = client_secret
227        client.curve_serverkey = server_public
228        self.assertTrue(self.can_connect(server, client))
229
230        # Remove authenticator and check that a normal connection works
231        self.auth.stop()
232        self.auth = None
233
234        # Try connecting using NULL and no authentication enabled, connection should pass
235        server = self.socket(zmq.PUSH)
236        client = self.socket(zmq.PULL)
237        self.assertTrue(self.can_connect(server, client))
238
239
240def with_ioloop(method, expect_success=True):
241    """decorator for running tests with an IOLoop"""
242    def test_method(self):
243        r = method(self)
244
245        loop = self.io_loop
246        if expect_success:
247            self.pullstream.on_recv(self.on_message_succeed)
248        else:
249            self.pullstream.on_recv(self.on_message_fail)
250
251        t = loop.time()
252        loop.add_callback(self.attempt_connection)
253        loop.add_callback(self.send_msg)
254        if expect_success:
255            loop.add_timeout(t + 1, self.on_test_timeout_fail)
256        else:
257            loop.add_timeout(t + 1, self.on_test_timeout_succeed)
258
259        loop.start()
260        if self.fail_msg:
261            self.fail(self.fail_msg)
262
263        return r
264    return test_method
265
266def should_auth(method):
267    return with_ioloop(method, True)
268
269def should_not_auth(method):
270    return with_ioloop(method, False)
271
272class TestIOLoopAuthentication(BaseAuthTestCase):
273    """Test authentication running in ioloop"""
274
275    def setUp(self):
276        self.fail_msg = None
277        self.io_loop = ioloop.IOLoop()
278        super(TestIOLoopAuthentication, self).setUp()
279        self.server = self.socket(zmq.PUSH)
280        self.client = self.socket(zmq.PULL)
281        self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
282        self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
283
284    def make_auth(self):
285        return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
286
287    def tearDown(self):
288        if self.auth:
289            self.auth.stop()
290            self.auth = None
291        self.io_loop.close(all_fds=True)
292        super(TestIOLoopAuthentication, self).tearDown()
293
294    def attempt_connection(self):
295        """Check if client can connect to server using tcp transport"""
296        iface = 'tcp://127.0.0.1'
297        port = self.server.bind_to_random_port(iface)
298        self.client.connect("%s:%i" % (iface, port))
299
300    def send_msg(self):
301        """Send a message from server to a client"""
302        msg = [b"Hello World"]
303        self.pushstream.send_multipart(msg)
304
305    def on_message_succeed(self, frames):
306        """A message was received, as expected."""
307        if frames != [b"Hello World"]:
308            self.fail_msg = "Unexpected message received"
309        self.io_loop.stop()
310
311    def on_message_fail(self, frames):
312        """A message was received, unexpectedly."""
313        self.fail_msg = 'Received messaged unexpectedly, security failed'
314        self.io_loop.stop()
315
316    def on_test_timeout_succeed(self):
317        """Test timer expired, indicates test success"""
318        self.io_loop.stop()
319
320    def on_test_timeout_fail(self):
321        """Test timer expired, indicates test failure"""
322        self.fail_msg = 'Test timed out'
323        self.io_loop.stop()
324
325    @should_auth
326    def test_none(self):
327        """ioloop auth - NONE"""
328        # A default NULL connection should always succeed, and not
329        # go through our authentication infrastructure at all.
330        # no auth should be running
331        self.auth.stop()
332        self.auth = None
333
334    @should_auth
335    def test_null(self):
336        """ioloop auth - NULL"""
337        # By setting a domain we switch on authentication for NULL sockets,
338        # though no policies are configured yet. The client connection
339        # should still be allowed.
340        self.server.zap_domain = b'global'
341
342    @should_not_auth
343    def test_blacklist(self):
344        """ioloop auth - Blacklist"""
345        # Blacklist 127.0.0.1, connection should fail
346        self.auth.deny('127.0.0.1')
347        self.server.zap_domain = b'global'
348
349    @should_auth
350    def test_whitelist(self):
351        """ioloop auth - Whitelist"""
352        # Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
353        self.auth.allow('127.0.0.1')
354
355        self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
356
357    @should_not_auth
358    def test_plain_unconfigured_server(self):
359        """ioloop auth - PLAIN, unconfigured server"""
360        self.client.plain_username = b'admin'
361        self.client.plain_password = b'Password'
362        # Try PLAIN authentication - without configuring server, connection should fail
363        self.server.plain_server = True
364
365    @should_auth
366    def test_plain_configured_server(self):
367        """ioloop auth - PLAIN, configured server"""
368        self.client.plain_username = b'admin'
369        self.client.plain_password = b'Password'
370        # Try PLAIN authentication - with server configured, connection should pass
371        self.server.plain_server = True
372        self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
373
374    @should_not_auth
375    def test_plain_bogus_credentials(self):
376        """ioloop auth - PLAIN, bogus credentials"""
377        self.client.plain_username = b'admin'
378        self.client.plain_password = b'Bogus'
379        self.server.plain_server = True
380
381        self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
382
383    @should_not_auth
384    def test_curve_unconfigured_server(self):
385        """ioloop auth - CURVE, unconfigured server"""
386        certs = self.load_certs(self.secret_keys_dir)
387        server_public, server_secret, client_public, client_secret = certs
388
389        self.auth.allow('127.0.0.1')
390
391        self.server.curve_publickey = server_public
392        self.server.curve_secretkey = server_secret
393        self.server.curve_server = True
394
395        self.client.curve_publickey = client_public
396        self.client.curve_secretkey = client_secret
397        self.client.curve_serverkey = server_public
398
399    @should_auth
400    def test_curve_allow_any(self):
401        """ioloop auth - CURVE, CURVE_ALLOW_ANY"""
402        certs = self.load_certs(self.secret_keys_dir)
403        server_public, server_secret, client_public, client_secret = certs
404
405        self.auth.allow('127.0.0.1')
406        self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
407
408        self.server.curve_publickey = server_public
409        self.server.curve_secretkey = server_secret
410        self.server.curve_server = True
411
412        self.client.curve_publickey = client_public
413        self.client.curve_secretkey = client_secret
414        self.client.curve_serverkey = server_public
415
416    @should_auth
417    def test_curve_configured_server(self):
418        """ioloop auth - CURVE, configured server"""
419        self.auth.allow('127.0.0.1')
420        certs = self.load_certs(self.secret_keys_dir)
421        server_public, server_secret, client_public, client_secret = certs
422
423        self.auth.configure_curve(domain='*', location=self.public_keys_dir)
424
425        self.server.curve_publickey = server_public
426        self.server.curve_secretkey = server_secret
427        self.server.curve_server = True
428
429        self.client.curve_publickey = client_public
430        self.client.curve_secretkey = client_secret
431        self.client.curve_serverkey = server_public
432