thread.py revision 781d71db
1"""ZAP Authenticator in a Python Thread.
2
3.. versionadded:: 14.1
4"""
5
6# Copyright (C) PyZMQ Developers
7# Distributed under the terms of the Modified BSD License.
8
9import logging
10from threading import Thread
11
12import zmq
13from zmq.utils import jsonapi
14from zmq.utils.strtypes import bytes, unicode, b, u
15
16from .base import Authenticator
17
18class AuthenticationThread(Thread):
19    """A Thread for running a zmq Authenticator
20
21    This is run in the background by ThreadedAuthenticator
22    """
23
24    def __init__(self, context, endpoint, encoding='utf-8', log=None):
25        super(AuthenticationThread, self).__init__()
26        self.context = context or zmq.Context.instance()
27        self.encoding = encoding
28        self.log = log = log or logging.getLogger('zmq.auth')
29        self.authenticator = Authenticator(context, encoding=encoding, log=log)
30
31        # create a socket to communicate back to main thread.
32        self.pipe = context.socket(zmq.PAIR)
33        self.pipe.linger = 1
34        self.pipe.connect(endpoint)
35
36    def run(self):
37        """ Start the Authentication Agent thread task """
38        self.authenticator.start()
39        zap = self.authenticator.zap_socket
40        poller = zmq.Poller()
41        poller.register(self.pipe, zmq.POLLIN)
42        poller.register(zap, zmq.POLLIN)
43        while True:
44            try:
45                socks = dict(poller.poll())
46            except zmq.ZMQError:
47                break  # interrupted
48
49            if self.pipe in socks and socks[self.pipe] == zmq.POLLIN:
50                terminate = self._handle_pipe()
51                if terminate:
52                    break
53
54            if zap in socks and socks[zap] == zmq.POLLIN:
55                self._handle_zap()
56
57        self.pipe.close()
58        self.authenticator.stop()
59
60    def _handle_zap(self):
61        """
62        Handle a message from the ZAP socket.
63        """
64        msg = self.authenticator.zap_socket.recv_multipart()
65        if not msg: return
66        self.authenticator.handle_zap_message(msg)
67
68    def _handle_pipe(self):
69        """
70        Handle a message from front-end API.
71        """
72        terminate = False
73
74        # Get the whole message off the pipe in one go
75        msg = self.pipe.recv_multipart()
76
77        if msg is None:
78            terminate = True
79            return terminate
80
81        command = msg[0]
82        self.log.debug("auth received API command %r", command)
83
84        if command == b'ALLOW':
85            addresses = [u(m, self.encoding) for m in msg[1:]]
86            try:
87                self.authenticator.allow(*addresses)
88            except Exception as e:
89                self.log.exception("Failed to allow %s", addresses)
90
91        elif command == b'DENY':
92            addresses = [u(m, self.encoding) for m in msg[1:]]
93            try:
94                self.authenticator.deny(*addresses)
95            except Exception as e:
96                self.log.exception("Failed to deny %s", addresses)
97
98        elif command == b'PLAIN':
99            domain = u(msg[1], self.encoding)
100            json_passwords = msg[2]
101            self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords))
102
103        elif command == b'CURVE':
104            # For now we don't do anything with domains
105            domain = u(msg[1], self.encoding)
106
107            # If location is CURVE_ALLOW_ANY, allow all clients. Otherwise
108            # treat location as a directory that holds the certificates.
109            location = u(msg[2], self.encoding)
110            self.authenticator.configure_curve(domain, location)
111
112        elif command == b'TERMINATE':
113            terminate = True
114
115        else:
116            self.log.error("Invalid auth command from API: %r", command)
117
118        return terminate
119
120def _inherit_docstrings(cls):
121    """inherit docstrings from Authenticator, so we don't duplicate them"""
122    for name, method in cls.__dict__.items():
123        if name.startswith('_'):
124            continue
125        upstream_method = getattr(Authenticator, name, None)
126        if not method.__doc__:
127            method.__doc__ = upstream_method.__doc__
128    return cls
129
130@_inherit_docstrings
131class ThreadAuthenticator(object):
132    """Run ZAP authentication in a background thread"""
133
134    def __init__(self, context=None, encoding='utf-8', log=None):
135        self.context = context or zmq.Context.instance()
136        self.log = log
137        self.encoding = encoding
138        self.pipe = None
139        self.pipe_endpoint = "inproc://{0}.inproc".format(id(self))
140        self.thread = None
141
142    def allow(self, *addresses):
143        self.pipe.send_multipart([b'ALLOW'] + [b(a, self.encoding) for a in addresses])
144
145    def deny(self, *addresses):
146        self.pipe.send_multipart([b'DENY'] + [b(a, self.encoding) for a in addresses])
147
148    def configure_plain(self, domain='*', passwords=None):
149        self.pipe.send_multipart([b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})])
150
151    def configure_curve(self, domain='*', location=''):
152        domain = b(domain, self.encoding)
153        location = b(location, self.encoding)
154        self.pipe.send_multipart([b'CURVE', domain, location])
155
156    def start(self):
157        """Start the authentication thread"""
158        # create a socket to communicate with auth thread.
159        self.pipe = self.context.socket(zmq.PAIR)
160        self.pipe.linger = 1
161        self.pipe.bind(self.pipe_endpoint)
162        self.thread = AuthenticationThread(self.context, self.pipe_endpoint, encoding=self.encoding, log=self.log)
163        self.thread.start()
164
165    def stop(self):
166        """Stop the authentication thread"""
167        if self.pipe:
168            self.pipe.send(b'TERMINATE')
169            if self.is_alive():
170                self.thread.join()
171            self.thread = None
172            self.pipe.close()
173            self.pipe = None
174
175    def is_alive(self):
176        """Is the ZAP thread currently running?"""
177        if self.thread and self.thread.is_alive():
178            return True
179        return False
180
181    def __del__(self):
182        self.stop()
183
184__all__ = ['ThreadAuthenticator']
185