1## This file is part of Scapy
2## See http://www.secdev.org/projects/scapy for more informations
3## Copyright (C) Philippe Biondi <phil@secdev.org>
4## This program is published under a GPLv2 license
5
6"""
7Automata with states, transitions and actions.
8"""
9
10from __future__ import with_statement
11import types,itertools,time,os,sys,socket,functools
12from select import select
13from collections import deque
14import _thread
15from .config import conf
16from .utils import do_graph
17from .error import log_interactive
18from .plist import PacketList
19from .data import MTU
20from .supersocket import SuperSocket
21
22class ObjectPipe:
23    def __init__(self):
24        self.rd,self.wr = os.pipe()
25        self.queue = deque()
26    def fileno(self):
27        return self.rd
28    def send(self, obj):
29        self.queue.append(obj)
30        os.write(self.wr,b"X")
31    def recv(self, n=0):
32        os.read(self.rd,1)
33        return self.queue.popleft()
34
35
36class Message:
37    def __init__(self, **args):
38        self.__dict__.update(args)
39    def __repr__(self):
40        return "<Message %s>" % " ".join("%s=%r"%(k,v)
41                                         for (k,v) in self.__dict__.items()
42                                         if not k.startswith("_"))
43
44# Currently does not seem necessary
45# class _meta_instance_state(type):
46#     def __init__(cls, name, bases, dct):
47#         def special_gen(special_method):
48#             def special_wrapper(self):
49#                 print("Calling %s" % special_method)
50#                 return getattr(getattr(self, "im_func"), special_method)
51#             return special_wrapper
52
53#         type.__init__(cls, name, bases, dct)
54#         for i in ["__int__", "__repr__", "__str__", "__index__", "__add__", "__radd__", "__bytes__"]:
55#             setattr(cls, i, property(special_gen(i)))
56
57class _instance_state():
58    def __init__(self, instance):
59        self.im_self = instance.__self__
60        self.im_func = instance.__func__
61        self.im_class = instance.__self__.__class__
62    def __getattr__(self, attr):
63        return getattr(self.im_func, attr)
64
65    def __call__(self, *args, **kargs):
66        return self.im_func(self.im_self, *args, **kargs)
67    def breaks(self):
68        return self.im_self.add_breakpoints(self.im_func)
69    def intercepts(self):
70        return self.im_self.add_interception_points(self.im_func)
71    def unbreaks(self):
72        return self.im_self.remove_breakpoints(self.im_func)
73    def unintercepts(self):
74        return self.im_self.remove_interception_points(self.im_func)
75
76
77##############
78## Automata ##
79##############
80
81class ATMT:
82    STATE = "State"
83    ACTION = "Action"
84    CONDITION = "Condition"
85    RECV = "Receive condition"
86    TIMEOUT = "Timeout condition"
87    IOEVENT = "I/O event"
88
89    class NewStateRequested(Exception):
90        def __init__(self, state_func, automaton, *args, **kargs):
91            self.func = state_func
92            self.state = state_func.atmt_state
93            self.initial = state_func.atmt_initial
94            self.error = state_func.atmt_error
95            self.final = state_func.atmt_final
96            Exception.__init__(self, "Request state [%s]" % self.state)
97            self.automaton = automaton
98            self.args = args
99            self.kargs = kargs
100            self.action_parameters() # init action parameters
101        def action_parameters(self, *args, **kargs):
102            self.action_args = args
103            self.action_kargs = kargs
104            return self
105        def run(self):
106            return self.func(self.automaton, *self.args, **self.kargs)
107        def __repr__(self):
108            return "NewStateRequested(%s)" % self.state
109
110    @staticmethod
111    def state(initial=0,final=0,error=0):
112        def deco(f,initial=initial, final=final):
113            f.atmt_type = ATMT.STATE
114            f.atmt_state = f.__name__
115            f.atmt_initial = initial
116            f.atmt_final = final
117            f.atmt_error = error
118#            @functools.wraps(f)  This is possible alternative to assigning __qualname__; it would save __doc__, too
119            def state_wrapper(self, *args, **kargs):
120                return ATMT.NewStateRequested(f, self, *args, **kargs)
121
122            state_wrapper.__qualname__ = "%s_wrapper" % f.__name__
123            state_wrapper.atmt_type = ATMT.STATE
124            state_wrapper.atmt_state = f.__name__
125            state_wrapper.atmt_initial = initial
126            state_wrapper.atmt_final = final
127            state_wrapper.atmt_error = error
128            state_wrapper.atmt_origfunc = f
129            return state_wrapper
130        return deco
131    @staticmethod
132    def action(cond, prio=0):
133        def deco(f,cond=cond):
134            if not hasattr(f,"atmt_type"):
135                f.atmt_cond = {}
136            f.atmt_type = ATMT.ACTION
137            f.atmt_cond[cond.atmt_condname] = prio
138            return f
139        return deco
140    @staticmethod
141    def condition(state, prio=0):
142        def deco(f, state=state):
143            f.atmt_type = ATMT.CONDITION
144            f.atmt_state = state.atmt_state
145            f.atmt_condname = f.__name__
146            f.atmt_prio = prio
147            return f
148        return deco
149    @staticmethod
150    def receive_condition(state, prio=0):
151        def deco(f, state=state):
152            f.atmt_type = ATMT.RECV
153            f.atmt_state = state.atmt_state
154            f.atmt_condname = f.__name__
155            f.atmt_prio = prio
156            return f
157        return deco
158    @staticmethod
159    def ioevent(state, name, prio=0, as_supersocket=None):
160        def deco(f, state=state):
161            f.atmt_type = ATMT.IOEVENT
162            f.atmt_state = state.atmt_state
163            f.atmt_condname = f.__name__
164            f.atmt_ioname = name
165            f.atmt_prio = prio
166            f.atmt_as_supersocket = as_supersocket
167            return f
168        return deco
169    @staticmethod
170    def timeout(state, timeout):
171        def deco(f, state=state, timeout=timeout):
172            f.atmt_type = ATMT.TIMEOUT
173            f.atmt_state = state.atmt_state
174            f.atmt_timeout = timeout
175            f.atmt_condname = f.__name__
176            return f
177        return deco
178
179class _ATMT_Command:
180    RUN = "RUN"
181    NEXT = "NEXT"
182    FREEZE = "FREEZE"
183    STOP = "STOP"
184    END = "END"
185    EXCEPTION = "EXCEPTION"
186    SINGLESTEP = "SINGLESTEP"
187    BREAKPOINT = "BREAKPOINT"
188    INTERCEPT = "INTERCEPT"
189    ACCEPT = "ACCEPT"
190    REPLACE = "REPLACE"
191    REJECT = "REJECT"
192
193class _ATMT_supersocket(SuperSocket):
194    def __init__(self, name, ioevent, automaton, proto, args, kargs):
195        self.name = name
196        self.ioevent = ioevent
197        self.proto = proto
198        self.spa,self.spb = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
199        kargs["external_fd"] = {ioevent:self.spb}
200        self.atmt = automaton(*args, **kargs)
201        self.atmt.runbg()
202    def fileno(self):
203        return self.spa.fileno()
204    def send(self, s):
205        if type(s) is str:
206            s = s.encode()
207        elif type(s) is not bytes:
208            s = bytes(s)
209        return self.spa.send(s)
210    def recv(self, n=MTU):
211        r = self.spa.recv(n)
212        if self.proto is not None:
213            r = self.proto(r)
214        return r
215    def close(self):
216        pass
217
218class _ATMT_to_supersocket:
219    def __init__(self, name, ioevent, automaton):
220        self.name = name
221        self.ioevent = ioevent
222        self.automaton = automaton
223    def __call__(self, proto, *args, **kargs):
224        return _ATMT_supersocket(self.name, self.ioevent, self.automaton, proto, args, kargs)
225
226class Automaton_metaclass(type):
227    def __new__(cls, name, bases, dct):
228        cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct)
229        cls.states={}
230        cls.state = None
231        cls.recv_conditions={}
232        cls.conditions={}
233        cls.ioevents={}
234        cls.timeout={}
235        cls.actions={}
236        cls.initial_states=[]
237        cls.ionames = []
238        cls.iosupersockets = []
239
240        members = {}
241        classes = [cls]
242        while classes:
243            c = classes.pop(0) # order is important to avoid breaking method overloading
244            classes += list(c.__bases__)
245            for k,v in c.__dict__.items():
246                if k not in members:
247                    members[k] = v
248
249        decorated = [v for v in members.values()
250                     if type(v) is types.FunctionType and hasattr(v, "atmt_type")]
251
252        for m in decorated:
253            if m.atmt_type == ATMT.STATE:
254                s = m.atmt_state
255                cls.states[s] = m
256                cls.recv_conditions[s]=[]
257                cls.ioevents[s]=[]
258                cls.conditions[s]=[]
259                cls.timeout[s]=[]
260                if m.atmt_initial:
261                    cls.initial_states.append(m)
262            elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]:
263                cls.actions[m.atmt_condname] = []
264
265        for m in decorated:
266            if m.atmt_type == ATMT.CONDITION:
267                cls.conditions[m.atmt_state].append(m)
268            elif m.atmt_type == ATMT.RECV:
269                cls.recv_conditions[m.atmt_state].append(m)
270            elif m.atmt_type == ATMT.IOEVENT:
271                cls.ioevents[m.atmt_state].append(m)
272                cls.ionames.append(m.atmt_ioname)
273                if m.atmt_as_supersocket is not None:
274                    cls.iosupersockets.append(m)
275            elif m.atmt_type == ATMT.TIMEOUT:
276                cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
277            elif m.atmt_type == ATMT.ACTION:
278                for c in m.atmt_cond:
279                    cls.actions[c].append(m)
280
281
282        for v in cls.timeout.values():
283            #v.sort(lambda (t1,f1),(t2,f2): cmp(t1,t2))
284            v.sort(key = lambda x: x[0])
285            v.append((None, None))
286        for v in itertools.chain(cls.conditions.values(),
287                                 cls.recv_conditions.values(),
288                                 cls.ioevents.values()):
289            v.sort(key = lambda x: x.atmt_prio)
290        for condname,actlst in cls.actions.items():
291            actlst.sort(key = lambda x: x.atmt_cond[condname])
292
293        for ioev in cls.iosupersockets:
294            setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls))
295
296        return cls
297
298    def graph(self, **kargs):
299        s = 'digraph "%s" {\n'  % self.__class__.__name__
300
301        se = "" # Keep initial nodes at the begining for better rendering
302        for st in self.states.values():
303            if st.atmt_initial:
304                se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se
305            elif st.atmt_final:
306                se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state
307            elif st.atmt_error:
308                se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state
309        s += se
310
311        for st in self.states.values():
312            for n in st.atmt_origfunc.__code__.co_names+st.atmt_origfunc.__code__.co_consts:
313                if n in self.states:
314                    s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n)
315
316
317        for c,k,v in ([("purple",k,v) for k,v in self.conditions.items()]+
318                      [("red",k,v) for k,v in self.recv_conditions.items()]+
319                      [("orange",k,v) for k,v in self.ioevents.items()]):
320            for f in v:
321                for n in f.__code__.co_names+f.__code__.co_consts:
322                    if n in self.states:
323                        l = f.atmt_condname
324                        for x in self.actions[f.atmt_condname]:
325                            l += "\\l>[%s]" % x.__name__
326                        s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c)
327        for k,v in self.timeout.items():
328            for t,f in v:
329                if f is None:
330                    continue
331                for n in f.__code__.co_names+f.__code__.co_consts:
332                    if n in self.states:
333                        l = "%s/%.1fs" % (f.atmt_condname,t)
334                        for x in self.actions[f.atmt_condname]:
335                            l += "\\l>[%s]" % x.__name__
336                        s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l)
337        s += "}\n"
338        return do_graph(s, **kargs)
339
340
341
342class Automaton(metaclass = Automaton_metaclass):
343
344    ## Methods to overload
345    def parse_args(self, debug=0, store=1, **kargs):
346        self.debug_level=debug
347        self.socket_kargs = kargs
348        self.store_packets = store
349
350    def master_filter(self, pkt):
351        return True
352
353    def my_send(self, pkt):
354        self.send_sock.send(pkt)
355
356
357    ## Utility classes and exceptions
358    class _IO_fdwrapper:
359        def __init__(self,rd,wr):
360            if rd is not None and type(rd) is not int:
361                rd = rd.fileno()
362            if wr is not None and type(wr) is not int:
363                wr = wr.fileno()
364            self.rd = rd
365            self.wr = wr
366        def fileno(self):
367            return self.rd
368        def read(self, n=65535):
369            return os.read(self.rd, n)
370        def write(self, msg):
371            return os.write(self.wr,msg)
372        def recv(self, n=65535):
373            return self.read(n)
374        def send(self, msg):
375            return self.write(msg)
376
377    class _IO_mixer:
378        def __init__(self,rd,wr):
379            self.rd = rd
380            self.wr = wr
381        def fileno(self):
382            if type(self.rd) is int:
383                return self.rd
384            return self.rd.fileno()
385        def recv(self, n=None):
386            return self.rd.recv(n)
387        def read(self, n=None):
388            return self.rd.recv(n)
389        def send(self, msg):
390            return self.wr.send(msg)
391        def write(self, msg):
392            return self.wr.send(msg)
393
394
395    class AutomatonException(Exception):
396        def __init__(self, msg, state=None, result=None):
397            Exception.__init__(self, msg)
398            self.state = state
399            self.result = result
400
401    class AutomatonError(AutomatonException):
402        pass
403    class ErrorState(AutomatonException):
404        pass
405    class Stuck(AutomatonException):
406        pass
407    class AutomatonStopped(AutomatonException):
408        pass
409
410    class Breakpoint(AutomatonStopped):
411        pass
412    class Singlestep(AutomatonStopped):
413        pass
414    class InterceptionPoint(AutomatonStopped):
415        def __init__(self, msg, state=None, result=None, packet=None):
416            Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result)
417            self.packet = packet
418
419    class CommandMessage(AutomatonException):
420        pass
421
422
423    ## Services
424    def debug(self, lvl, msg):
425        if self.debug_level >= lvl:
426            log_interactive.debug(msg)
427
428    def send(self, pkt):
429        if self.state.state in self.interception_points:
430            self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary())
431            self.intercepted_packet = pkt
432            cmd = Message(type = _ATMT_Command.INTERCEPT, state=self.state, pkt=pkt)
433            self.cmdout.send(cmd)
434            cmd = self.cmdin.recv()
435            self.intercepted_packet = None
436            if cmd.type == _ATMT_Command.REJECT:
437                self.debug(3,"INTERCEPT: packet rejected")
438                return
439            elif cmd.type == _ATMT_Command.REPLACE:
440                pkt = cmd.pkt
441                self.debug(3,"INTERCEPT: packet replaced by: %s" % pkt.summary())
442            elif cmd.type == _ATMT_Command.ACCEPT:
443                self.debug(3,"INTERCEPT: packet accepted")
444            else:
445                raise self.AutomatonError("INTERCEPT: unkown verdict: %r" % cmd.type)
446        self.my_send(pkt)
447        self.debug(3,"SENT : %s" % pkt.summary())
448        self.packets.append(pkt.copy())
449
450
451    ## Internals
452    def __init__(self, *args, **kargs):
453        external_fd = kargs.pop("external_fd",{})
454        self.send_sock_class = kargs.pop("ll", conf.L3socket)
455        self.started = _thread.allocate_lock()
456        self.threadid = None
457        self.breakpointed = None
458        self.breakpoints = set()
459        self.interception_points = set()
460        self.intercepted_packet = None
461        self.debug_level=0
462        self.init_args=args
463        self.init_kargs=kargs
464        self.io = type.__new__(type, "IOnamespace",(),{})
465        self.oi = type.__new__(type, "IOnamespace",(),{})
466        self.cmdin = ObjectPipe()
467        self.cmdout = ObjectPipe()
468        self.ioin = {}
469        self.ioout = {}
470        for n in self.ionames:
471            extfd = external_fd.get(n)
472            if type(extfd) is not tuple:
473                extfd = (extfd,extfd)
474            ioin,ioout = extfd
475            if ioin is None:
476                ioin = ObjectPipe()
477            #elif type(ioin) is not types.InstanceType:
478            else:
479                #print(type(ioin))
480                ioin = self._IO_fdwrapper(ioin,None)
481            if ioout is None:
482                ioout = ObjectPipe()
483            #elif type(ioout) is not types.InstanceType:
484            else:
485                #print(type(ioout))
486                ioout = self._IO_fdwrapper(None,ioout)
487
488            self.ioin[n] = ioin
489            self.ioout[n] = ioout
490            ioin.ioname = n
491            ioout.ioname = n
492            setattr(self.io, n, self._IO_mixer(ioout,ioin))
493            setattr(self.oi, n, self._IO_mixer(ioin,ioout))
494
495        for stname in self.states:
496            setattr(self, stname,
497                    _instance_state(getattr(self, stname)))
498
499        self.parse_args(*args, **kargs)
500
501        self.start()
502
503    def __iter__(self):
504        return self
505
506    def __del__(self):
507        self.stop()
508
509    def _run_condition(self, cond, *args, **kargs):
510        try:
511            self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname))
512            cond(self,*args, **kargs)
513        except ATMT.NewStateRequested as state_req:
514            self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state))
515            if cond.atmt_type == ATMT.RECV:
516                self.packets.append(args[0])
517            for action in self.actions[cond.atmt_condname]:
518                self.debug(2, "   + Running action [%s]" % action.__name__)
519                action(self, *state_req.action_args, **state_req.action_kargs)
520            raise
521        except Exception as e:
522            self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e))
523            raise
524        else:
525            self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))
526
527    def _do_start(self, *args, **kargs):
528
529        _thread.start_new_thread(self._do_control, args, kargs)
530
531
532    def _do_control(self, *args, **kargs):
533        with self.started:
534            self.threadid = _thread.get_ident()
535
536            # Update default parameters
537            a = args+self.init_args[len(args):]
538            k = self.init_kargs.copy()
539            k.update(kargs)
540            self.parse_args(*a,**k)
541
542            # Start the automaton
543            self.state=self.initial_states[0](self)
544            self.send_sock = self.send_sock_class()
545            self.listen_sock = conf.L2listen(**self.socket_kargs)
546            self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
547
548            singlestep = True
549            iterator = self._do_iter()
550            self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
551            try:
552                while True:
553                    c = self.cmdin.recv()
554                    self.debug(5, "Received command %s" % c.type)
555                    if c.type == _ATMT_Command.RUN:
556                        singlestep = False
557                    elif c.type == _ATMT_Command.NEXT:
558                        singlestep = True
559                    elif c.type == _ATMT_Command.FREEZE:
560                        continue
561                    elif c.type == _ATMT_Command.STOP:
562                        break
563                    while True:
564                        state = next(iterator)
565                        if isinstance(state, self.CommandMessage):
566                            break
567                        elif isinstance(state, self.Breakpoint):
568                            c = Message(type=_ATMT_Command.BREAKPOINT,state=state)
569                            self.cmdout.send(c)
570                            break
571                        if singlestep:
572                            c = Message(type=_ATMT_Command.SINGLESTEP,state=state)
573                            self.cmdout.send(c)
574                            break
575            except StopIteration as e:
576                c = Message(type=_ATMT_Command.END, result=e.args[0])
577                self.cmdout.send(c)
578            except Exception as e:
579                self.debug(3, "Transfering exception [%s] from tid=%i"% (e,self.threadid))
580                m = Message(type = _ATMT_Command.EXCEPTION, exception=e, exc_info=sys.exc_info())
581                self.cmdout.send(m)
582            self.debug(3, "Stopping control thread (tid=%i)"%self.threadid)
583            self.threadid = None
584
585    def _do_iter(self):
586        while True:
587            try:
588                self.debug(1, "## state=[%s]" % self.state.state)
589
590                # Entering a new state. First, call new state function
591                if self.state.state in self.breakpoints and self.state.state != self.breakpointed:
592                    self.breakpointed = self.state.state
593                    yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
594                                          state = self.state.state)
595                self.breakpointed = None
596                state_output = self.state.run()
597                if self.state.error:
598                    raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output),
599                                          result=state_output, state=self.state.state)
600                if self.state.final:
601                    raise StopIteration(state_output)
602
603                if state_output is None:
604                    state_output = ()
605                elif type(state_output) is not list:
606                    state_output = state_output,
607
608                # Then check immediate conditions
609                for cond in self.conditions[self.state.state]:
610                    self._run_condition(cond, *state_output)
611
612                # If still there and no conditions left, we are stuck!
613                if ( len(self.recv_conditions[self.state.state]) == 0 and
614                     len(self.ioevents[self.state.state]) == 0 and
615                     len(self.timeout[self.state.state]) == 1 ):
616                    raise self.Stuck("stuck in [%s]" % self.state.state,
617                                     state=self.state.state, result=state_output)
618
619                # Finally listen and pay attention to timeouts
620                expirations = iter(self.timeout[self.state.state])
621                next_timeout,timeout_func = next(expirations)
622                t0 = time.time()
623
624                fds = [self.cmdin]
625                if len(self.recv_conditions[self.state.state]) > 0:
626                    fds.append(self.listen_sock)
627                for ioev in self.ioevents[self.state.state]:
628                    fds.append(self.ioin[ioev.atmt_ioname])
629                while 1:
630                    t = time.time()-t0
631                    if next_timeout is not None:
632                        if next_timeout <= t:
633                            self._run_condition(timeout_func, *state_output)
634                            next_timeout,timeout_func = next(expirations)
635                    if next_timeout is None:
636                        remain = None
637                    else:
638                        remain = next_timeout-t
639
640                    self.debug(5, "Select on %r" % fds)
641                    r,_,_ = select(fds,[],[],remain)
642                    self.debug(5, "Selected %r" % r)
643                    for fd in r:
644                        self.debug(5, "Looking at %r" % fd)
645                        if fd == self.cmdin:
646                            yield self.CommandMessage("Received command message")
647                        elif fd == self.listen_sock:
648                            pkt = self.listen_sock.recv(MTU)
649                            if pkt is not None:
650                                if self.master_filter(pkt):
651                                    self.debug(3, "RECVD: %s" % pkt.summary())
652                                    for rcvcond in self.recv_conditions[self.state.state]:
653                                        self._run_condition(rcvcond, pkt, *state_output)
654                                else:
655                                    self.debug(4, "FILTR: %s" % pkt.summary())
656                        else:
657                            self.debug(3, "IOEVENT on %s" % fd.ioname)
658                            for ioevt in self.ioevents[self.state.state]:
659                                if ioevt.atmt_ioname == fd.ioname:
660                                    self._run_condition(ioevt, fd, *state_output)
661
662            except ATMT.NewStateRequested as state_req:
663                self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
664                self.state = state_req
665                yield state_req
666
667    ## Public API
668    def add_interception_points(self, *ipts):
669        for ipt in ipts:
670            if hasattr(ipt,"atmt_state"):
671                ipt = ipt.atmt_state
672            self.interception_points.add(ipt)
673
674    def remove_interception_points(self, *ipts):
675        for ipt in ipts:
676            if hasattr(ipt,"atmt_state"):
677                ipt = ipt.atmt_state
678            self.interception_points.discard(ipt)
679
680    def add_breakpoints(self, *bps):
681        for bp in bps:
682            if hasattr(bp,"atmt_state"):
683                bp = bp.atmt_state
684            self.breakpoints.add(bp)
685
686    def remove_breakpoints(self, *bps):
687        for bp in bps:
688            if hasattr(bp,"atmt_state"):
689                bp = bp.atmt_state
690            self.breakpoints.discard(bp)
691
692    def start(self, *args, **kargs):
693        if not self.started.locked():
694            self._do_start(*args, **kargs)
695
696    def run(self, resume=None, wait=True):
697        if resume is None:
698            resume = Message(type = _ATMT_Command.RUN)
699        self.cmdin.send(resume)
700        if wait:
701            try:
702                c = self.cmdout.recv()
703            except KeyboardInterrupt:
704                self.cmdin.send(Message(type = _ATMT_Command.FREEZE))
705                return
706            if c.type == _ATMT_Command.END:
707                return c.result
708            elif c.type == _ATMT_Command.INTERCEPT:
709                raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt)
710            elif c.type == _ATMT_Command.SINGLESTEP:
711                raise self.Singlestep("singlestep state=[%s]"%c.state.state, state=c.state.state)
712            elif c.type == _ATMT_Command.BREAKPOINT:
713                raise self.Breakpoint("breakpoint triggered on state [%s]"%c.state.state, state=c.state.state)
714            elif c.type == _ATMT_Command.EXCEPTION:
715                raise c.exc_info[0](c.exc_info[1])
716                #raise c.exc_info[0],c.exc_info[1],c.exc_info[2]
717
718    def runbg(self, resume=None, wait=False):
719        self.run(resume, wait)
720
721    def next(self):
722        return self.run(resume = Message(type=_ATMT_Command.NEXT))
723
724    def stop(self):
725        self.cmdin.send(Message(type=_ATMT_Command.STOP))
726        with self.started:
727            # Flush command pipes
728            while True:
729                r,_,_ = select([self.cmdin, self.cmdout],[],[],0)
730                if not r:
731                    break
732                for fd in r:
733                    fd.recv()
734
735    def restart(self, *args, **kargs):
736        self.stop()
737        self.start(*args, **kargs)
738
739    def accept_packet(self, pkt=None, wait=False):
740        rsm = Message()
741        if pkt is None:
742            rsm.type = _ATMT_Command.ACCEPT
743        else:
744            rsm.type = _ATMT_Command.REPLACE
745            rsm.pkt = pkt
746        return self.run(resume=rsm, wait=wait)
747
748    def reject_packet(self, wait=False):
749        rsm = Message(type = _ATMT_Command.REJECT)
750        return self.run(resume=rsm, wait=wait)
751
752
753
754