1## This file is part of Scapy
2## See http://www.secdev.org/projects/scapy for more informations
3## Copyright (C) Philippe Biondi <phil@secdev.org>
4## This program is published under a GPLv2 license
5
6"""
7Automata with states, transitions and actions.
8"""
9
10from __future__ import with_statement
11import types,itertools,time,os,sys,socket
12from select import select
13from collections import deque
14import thread
15from config import conf
16from utils import do_graph
17from error import log_interactive
18from plist import PacketList
19from data import MTU
20from supersocket import SuperSocket
21
22class ObjectPipe:
23    def __init__(self):
24        self.rd,self.wr = os.pipe()
25        self.queue = deque()
26    def fileno(self):
27        return self.rd
28    def send(self, obj):
29        self.queue.append(obj)
30        os.write(self.wr,"X")
31    def recv(self, n=0):
32        os.read(self.rd,1)
33        return self.queue.popleft()
34
35
36class Message:
37    def __init__(self, **args):
38        self.__dict__.update(args)
39    def __repr__(self):
40        return "<Message %s>" % " ".join("%s=%r"%(k,v)
41                                         for (k,v) in self.__dict__.iteritems()
42                                         if not k.startswith("_"))
43
44class _instance_state:
45    def __init__(self, instance):
46        self.im_self = instance.im_self
47        self.im_func = instance.im_func
48        self.im_class = instance.im_class
49    def __getattr__(self, attr):
50        return getattr(self.im_func, attr)
51
52    def __call__(self, *args, **kargs):
53        return self.im_func(self.im_self, *args, **kargs)
54    def breaks(self):
55        return self.im_self.add_breakpoints(self.im_func)
56    def intercepts(self):
57        return self.im_self.add_interception_points(self.im_func)
58    def unbreaks(self):
59        return self.im_self.remove_breakpoints(self.im_func)
60    def unintercepts(self):
61        return self.im_self.remove_interception_points(self.im_func)
62
63
64##############
65## Automata ##
66##############
67
68class ATMT:
69    STATE = "State"
70    ACTION = "Action"
71    CONDITION = "Condition"
72    RECV = "Receive condition"
73    TIMEOUT = "Timeout condition"
74    IOEVENT = "I/O event"
75
76    class NewStateRequested(Exception):
77        def __init__(self, state_func, automaton, *args, **kargs):
78            self.func = state_func
79            self.state = state_func.atmt_state
80            self.initial = state_func.atmt_initial
81            self.error = state_func.atmt_error
82            self.final = state_func.atmt_final
83            Exception.__init__(self, "Request state [%s]" % self.state)
84            self.automaton = automaton
85            self.args = args
86            self.kargs = kargs
87            self.action_parameters() # init action parameters
88        def action_parameters(self, *args, **kargs):
89            self.action_args = args
90            self.action_kargs = kargs
91            return self
92        def run(self):
93            return self.func(self.automaton, *self.args, **self.kargs)
94        def __repr__(self):
95            return "NewStateRequested(%s)" % self.state
96
97    @staticmethod
98    def state(initial=0,final=0,error=0):
99        def deco(f,initial=initial, final=final):
100            f.atmt_type = ATMT.STATE
101            f.atmt_state = f.func_name
102            f.atmt_initial = initial
103            f.atmt_final = final
104            f.atmt_error = error
105            def state_wrapper(self, *args, **kargs):
106                return ATMT.NewStateRequested(f, self, *args, **kargs)
107
108            state_wrapper.func_name = "%s_wrapper" % f.func_name
109            state_wrapper.atmt_type = ATMT.STATE
110            state_wrapper.atmt_state = f.func_name
111            state_wrapper.atmt_initial = initial
112            state_wrapper.atmt_final = final
113            state_wrapper.atmt_error = error
114            state_wrapper.atmt_origfunc = f
115            return state_wrapper
116        return deco
117    @staticmethod
118    def action(cond, prio=0):
119        def deco(f,cond=cond):
120            if not hasattr(f,"atmt_type"):
121                f.atmt_cond = {}
122            f.atmt_type = ATMT.ACTION
123            f.atmt_cond[cond.atmt_condname] = prio
124            return f
125        return deco
126    @staticmethod
127    def condition(state, prio=0):
128        def deco(f, state=state):
129            f.atmt_type = ATMT.CONDITION
130            f.atmt_state = state.atmt_state
131            f.atmt_condname = f.func_name
132            f.atmt_prio = prio
133            return f
134        return deco
135    @staticmethod
136    def receive_condition(state, prio=0):
137        def deco(f, state=state):
138            f.atmt_type = ATMT.RECV
139            f.atmt_state = state.atmt_state
140            f.atmt_condname = f.func_name
141            f.atmt_prio = prio
142            return f
143        return deco
144    @staticmethod
145    def ioevent(state, name, prio=0, as_supersocket=None):
146        def deco(f, state=state):
147            f.atmt_type = ATMT.IOEVENT
148            f.atmt_state = state.atmt_state
149            f.atmt_condname = f.func_name
150            f.atmt_ioname = name
151            f.atmt_prio = prio
152            f.atmt_as_supersocket = as_supersocket
153            return f
154        return deco
155    @staticmethod
156    def timeout(state, timeout):
157        def deco(f, state=state, timeout=timeout):
158            f.atmt_type = ATMT.TIMEOUT
159            f.atmt_state = state.atmt_state
160            f.atmt_timeout = timeout
161            f.atmt_condname = f.func_name
162            return f
163        return deco
164
165class _ATMT_Command:
166    RUN = "RUN"
167    NEXT = "NEXT"
168    FREEZE = "FREEZE"
169    STOP = "STOP"
170    END = "END"
171    EXCEPTION = "EXCEPTION"
172    SINGLESTEP = "SINGLESTEP"
173    BREAKPOINT = "BREAKPOINT"
174    INTERCEPT = "INTERCEPT"
175    ACCEPT = "ACCEPT"
176    REPLACE = "REPLACE"
177    REJECT = "REJECT"
178
179class _ATMT_supersocket(SuperSocket):
180    def __init__(self, name, ioevent, automaton, proto, args, kargs):
181        self.name = name
182        self.ioevent = ioevent
183        self.proto = proto
184        self.spa,self.spb = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
185        kargs["external_fd"] = {ioevent:self.spb}
186        self.atmt = automaton(*args, **kargs)
187        self.atmt.runbg()
188    def fileno(self):
189        return self.spa.fileno()
190    def send(self, s):
191        if type(s) is not str:
192            s = str(s)
193        return self.spa.send(s)
194    def recv(self, n=MTU):
195        r = self.spa.recv(n)
196        if self.proto is not None:
197            r = self.proto(r)
198        return r
199    def close(self):
200        pass
201
202class _ATMT_to_supersocket:
203    def __init__(self, name, ioevent, automaton):
204        self.name = name
205        self.ioevent = ioevent
206        self.automaton = automaton
207    def __call__(self, proto, *args, **kargs):
208        return _ATMT_supersocket(self.name, self.ioevent, self.automaton, proto, args, kargs)
209
210class Automaton_metaclass(type):
211    def __new__(cls, name, bases, dct):
212        cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct)
213        cls.states={}
214        cls.state = None
215        cls.recv_conditions={}
216        cls.conditions={}
217        cls.ioevents={}
218        cls.timeout={}
219        cls.actions={}
220        cls.initial_states=[]
221        cls.ionames = []
222        cls.iosupersockets = []
223
224        members = {}
225        classes = [cls]
226        while classes:
227            c = classes.pop(0) # order is important to avoid breaking method overloading
228            classes += list(c.__bases__)
229            for k,v in c.__dict__.iteritems():
230                if k not in members:
231                    members[k] = v
232
233        decorated = [v for v in members.itervalues()
234                     if type(v) is types.FunctionType and hasattr(v, "atmt_type")]
235
236        for m in decorated:
237            if m.atmt_type == ATMT.STATE:
238                s = m.atmt_state
239                cls.states[s] = m
240                cls.recv_conditions[s]=[]
241                cls.ioevents[s]=[]
242                cls.conditions[s]=[]
243                cls.timeout[s]=[]
244                if m.atmt_initial:
245                    cls.initial_states.append(m)
246            elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]:
247                cls.actions[m.atmt_condname] = []
248
249        for m in decorated:
250            if m.atmt_type == ATMT.CONDITION:
251                cls.conditions[m.atmt_state].append(m)
252            elif m.atmt_type == ATMT.RECV:
253                cls.recv_conditions[m.atmt_state].append(m)
254            elif m.atmt_type == ATMT.IOEVENT:
255                cls.ioevents[m.atmt_state].append(m)
256                cls.ionames.append(m.atmt_ioname)
257                if m.atmt_as_supersocket is not None:
258                    cls.iosupersockets.append(m)
259            elif m.atmt_type == ATMT.TIMEOUT:
260                cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
261            elif m.atmt_type == ATMT.ACTION:
262                for c in m.atmt_cond:
263                    cls.actions[c].append(m)
264
265
266        for v in cls.timeout.itervalues():
267            v.sort(lambda (t1,f1),(t2,f2): cmp(t1,t2))
268            v.append((None, None))
269        for v in itertools.chain(cls.conditions.itervalues(),
270                                 cls.recv_conditions.itervalues(),
271                                 cls.ioevents.itervalues()):
272            v.sort(lambda c1,c2: cmp(c1.atmt_prio,c2.atmt_prio))
273        for condname,actlst in cls.actions.iteritems():
274            actlst.sort(lambda c1,c2: cmp(c1.atmt_cond[condname], c2.atmt_cond[condname]))
275
276        for ioev in cls.iosupersockets:
277            setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls))
278
279        return cls
280
281    def graph(self, **kargs):
282        s = 'digraph "%s" {\n'  % self.__class__.__name__
283
284        se = "" # Keep initial nodes at the begining for better rendering
285        for st in self.states.itervalues():
286            if st.atmt_initial:
287                se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se
288            elif st.atmt_final:
289                se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state
290            elif st.atmt_error:
291                se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state
292        s += se
293
294        for st in self.states.values():
295            for n in st.atmt_origfunc.func_code.co_names+st.atmt_origfunc.func_code.co_consts:
296                if n in self.states:
297                    s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n)
298
299
300        for c,k,v in ([("purple",k,v) for k,v in self.conditions.items()]+
301                      [("red",k,v) for k,v in self.recv_conditions.items()]+
302                      [("orange",k,v) for k,v in self.ioevents.items()]):
303            for f in v:
304                for n in f.func_code.co_names+f.func_code.co_consts:
305                    if n in self.states:
306                        l = f.atmt_condname
307                        for x in self.actions[f.atmt_condname]:
308                            l += "\\l>[%s]" % x.func_name
309                        s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c)
310        for k,v in self.timeout.iteritems():
311            for t,f in v:
312                if f is None:
313                    continue
314                for n in f.func_code.co_names+f.func_code.co_consts:
315                    if n in self.states:
316                        l = "%s/%.1fs" % (f.atmt_condname,t)
317                        for x in self.actions[f.atmt_condname]:
318                            l += "\\l>[%s]" % x.func_name
319                        s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l)
320        s += "}\n"
321        return do_graph(s, **kargs)
322
323
324
325class Automaton:
326    __metaclass__ = Automaton_metaclass
327
328    ## Methods to overload
329    def parse_args(self, debug=0, store=1, **kargs):
330        self.debug_level=debug
331        self.socket_kargs = kargs
332        self.store_packets = store
333
334    def master_filter(self, pkt):
335        return True
336
337    def my_send(self, pkt):
338        self.send_sock.send(pkt)
339
340
341    ## Utility classes and exceptions
342    class _IO_fdwrapper:
343        def __init__(self,rd,wr):
344            if rd is not None and type(rd) is not int:
345                rd = rd.fileno()
346            if wr is not None and type(wr) is not int:
347                wr = wr.fileno()
348            self.rd = rd
349            self.wr = wr
350        def fileno(self):
351            return self.rd
352        def read(self, n=65535):
353            return os.read(self.rd, n)
354        def write(self, msg):
355            return os.write(self.wr,msg)
356        def recv(self, n=65535):
357            return self.read(n)
358        def send(self, msg):
359            return self.write(msg)
360
361    class _IO_mixer:
362        def __init__(self,rd,wr):
363            self.rd = rd
364            self.wr = wr
365        def fileno(self):
366            if type(self.rd) is int:
367                return self.rd
368            return self.rd.fileno()
369        def recv(self, n=None):
370            return self.rd.recv(n)
371        def read(self, n=None):
372            return self.rd.recv(n)
373        def send(self, msg):
374            return self.wr.send(msg)
375        def write(self, msg):
376            return self.wr.send(msg)
377
378
379    class AutomatonException(Exception):
380        def __init__(self, msg, state=None, result=None):
381            Exception.__init__(self, msg)
382            self.state = state
383            self.result = result
384
385    class AutomatonError(AutomatonException):
386        pass
387    class ErrorState(AutomatonException):
388        pass
389    class Stuck(AutomatonException):
390        pass
391    class AutomatonStopped(AutomatonException):
392        pass
393
394    class Breakpoint(AutomatonStopped):
395        pass
396    class Singlestep(AutomatonStopped):
397        pass
398    class InterceptionPoint(AutomatonStopped):
399        def __init__(self, msg, state=None, result=None, packet=None):
400            Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result)
401            self.packet = packet
402
403    class CommandMessage(AutomatonException):
404        pass
405
406
407    ## Services
408    def debug(self, lvl, msg):
409        if self.debug_level >= lvl:
410            log_interactive.debug(msg)
411
412    def send(self, pkt):
413        if self.state.state in self.interception_points:
414            self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary())
415            self.intercepted_packet = pkt
416            cmd = Message(type = _ATMT_Command.INTERCEPT, state=self.state, pkt=pkt)
417            self.cmdout.send(cmd)
418            cmd = self.cmdin.recv()
419            self.intercepted_packet = None
420            if cmd.type == _ATMT_Command.REJECT:
421                self.debug(3,"INTERCEPT: packet rejected")
422                return
423            elif cmd.type == _ATMT_Command.REPLACE:
424                pkt = cmd.pkt
425                self.debug(3,"INTERCEPT: packet replaced by: %s" % pkt.summary())
426            elif cmd.type == _ATMT_Command.ACCEPT:
427                self.debug(3,"INTERCEPT: packet accepted")
428            else:
429                raise self.AutomatonError("INTERCEPT: unkown verdict: %r" % cmd.type)
430        self.my_send(pkt)
431        self.debug(3,"SENT : %s" % pkt.summary())
432        self.packets.append(pkt.copy())
433
434
435    ## Internals
436    def __init__(self, *args, **kargs):
437        external_fd = kargs.pop("external_fd",{})
438        self.send_sock_class = kargs.pop("ll", conf.L3socket)
439        self.started = thread.allocate_lock()
440        self.threadid = None
441        self.breakpointed = None
442        self.breakpoints = set()
443        self.interception_points = set()
444        self.intercepted_packet = None
445        self.debug_level=0
446        self.init_args=args
447        self.init_kargs=kargs
448        self.io = type.__new__(type, "IOnamespace",(),{})
449        self.oi = type.__new__(type, "IOnamespace",(),{})
450        self.cmdin = ObjectPipe()
451        self.cmdout = ObjectPipe()
452        self.ioin = {}
453        self.ioout = {}
454        for n in self.ionames:
455            extfd = external_fd.get(n)
456            if type(extfd) is not tuple:
457                extfd = (extfd,extfd)
458            ioin,ioout = extfd
459            if ioin is None:
460                ioin = ObjectPipe()
461            elif type(ioin) is not types.InstanceType:
462                ioin = self._IO_fdwrapper(ioin,None)
463            if ioout is None:
464                ioout = ObjectPipe()
465            elif type(ioout) is not types.InstanceType:
466                ioout = self._IO_fdwrapper(None,ioout)
467
468            self.ioin[n] = ioin
469            self.ioout[n] = ioout
470            ioin.ioname = n
471            ioout.ioname = n
472            setattr(self.io, n, self._IO_mixer(ioout,ioin))
473            setattr(self.oi, n, self._IO_mixer(ioin,ioout))
474
475        for stname in self.states:
476            setattr(self, stname,
477                    _instance_state(getattr(self, stname)))
478
479        self.parse_args(*args, **kargs)
480
481        self.start()
482
483    def __iter__(self):
484        return self
485
486    def __del__(self):
487        self.stop()
488
489    def _run_condition(self, cond, *args, **kargs):
490        try:
491            self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname))
492            cond(self,*args, **kargs)
493        except ATMT.NewStateRequested, state_req:
494            self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state))
495            if cond.atmt_type == ATMT.RECV:
496                self.packets.append(args[0])
497            for action in self.actions[cond.atmt_condname]:
498                self.debug(2, "   + Running action [%s]" % action.func_name)
499                action(self, *state_req.action_args, **state_req.action_kargs)
500            raise
501        except Exception,e:
502            self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e))
503            raise
504        else:
505            self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))
506
507    def _do_start(self, *args, **kargs):
508
509        thread.start_new_thread(self._do_control, args, kargs)
510
511
512    def _do_control(self, *args, **kargs):
513        with self.started:
514            self.threadid = thread.get_ident()
515
516            # Update default parameters
517            a = args+self.init_args[len(args):]
518            k = self.init_kargs.copy()
519            k.update(kargs)
520            self.parse_args(*a,**k)
521
522            # Start the automaton
523            self.state=self.initial_states[0](self)
524            self.send_sock = self.send_sock_class()
525            self.listen_sock = conf.L2listen(**self.socket_kargs)
526            self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
527
528            singlestep = True
529            iterator = self._do_iter()
530            self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
531            try:
532                while True:
533                    c = self.cmdin.recv()
534                    self.debug(5, "Received command %s" % c.type)
535                    if c.type == _ATMT_Command.RUN:
536                        singlestep = False
537                    elif c.type == _ATMT_Command.NEXT:
538                        singlestep = True
539                    elif c.type == _ATMT_Command.FREEZE:
540                        continue
541                    elif c.type == _ATMT_Command.STOP:
542                        break
543                    while True:
544                        state = iterator.next()
545                        if isinstance(state, self.CommandMessage):
546                            break
547                        elif isinstance(state, self.Breakpoint):
548                            c = Message(type=_ATMT_Command.BREAKPOINT,state=state)
549                            self.cmdout.send(c)
550                            break
551                        if singlestep:
552                            c = Message(type=_ATMT_Command.SINGLESTEP,state=state)
553                            self.cmdout.send(c)
554                            break
555            except StopIteration,e:
556                c = Message(type=_ATMT_Command.END, result=e.args[0])
557                self.cmdout.send(c)
558            except Exception,e:
559                self.debug(3, "Transfering exception [%s] from tid=%i"% (e,self.threadid))
560                m = Message(type = _ATMT_Command.EXCEPTION, exception=e, exc_info=sys.exc_info())
561                self.cmdout.send(m)
562            self.debug(3, "Stopping control thread (tid=%i)"%self.threadid)
563            self.threadid = None
564
565    def _do_iter(self):
566        while True:
567            try:
568                self.debug(1, "## state=[%s]" % self.state.state)
569
570                # Entering a new state. First, call new state function
571                if self.state.state in self.breakpoints and self.state.state != self.breakpointed:
572                    self.breakpointed = self.state.state
573                    yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
574                                          state = self.state.state)
575                self.breakpointed = None
576                state_output = self.state.run()
577                if self.state.error:
578                    raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output),
579                                          result=state_output, state=self.state.state)
580                if self.state.final:
581                    raise StopIteration(state_output)
582
583                if state_output is None:
584                    state_output = ()
585                elif type(state_output) is not list:
586                    state_output = state_output,
587
588                # Then check immediate conditions
589                for cond in self.conditions[self.state.state]:
590                    self._run_condition(cond, *state_output)
591
592                # If still there and no conditions left, we are stuck!
593                if ( len(self.recv_conditions[self.state.state]) == 0 and
594                     len(self.ioevents[self.state.state]) == 0 and
595                     len(self.timeout[self.state.state]) == 1 ):
596                    raise self.Stuck("stuck in [%s]" % self.state.state,
597                                     state=self.state.state, result=state_output)
598
599                # Finally listen and pay attention to timeouts
600                expirations = iter(self.timeout[self.state.state])
601                next_timeout,timeout_func = expirations.next()
602                t0 = time.time()
603
604                fds = [self.cmdin]
605                if len(self.recv_conditions[self.state.state]) > 0:
606                    fds.append(self.listen_sock)
607                for ioev in self.ioevents[self.state.state]:
608                    fds.append(self.ioin[ioev.atmt_ioname])
609                while 1:
610                    t = time.time()-t0
611                    if next_timeout is not None:
612                        if next_timeout <= t:
613                            self._run_condition(timeout_func, *state_output)
614                            next_timeout,timeout_func = expirations.next()
615                    if next_timeout is None:
616                        remain = None
617                    else:
618                        remain = next_timeout-t
619
620                    self.debug(5, "Select on %r" % fds)
621                    r,_,_ = select(fds,[],[],remain)
622                    self.debug(5, "Selected %r" % r)
623                    for fd in r:
624                        self.debug(5, "Looking at %r" % fd)
625                        if fd == self.cmdin:
626                            yield self.CommandMessage("Received command message")
627                        elif fd == self.listen_sock:
628                            pkt = self.listen_sock.recv(MTU)
629                            if pkt is not None:
630                                if self.master_filter(pkt):
631                                    self.debug(3, "RECVD: %s" % pkt.summary())
632                                    for rcvcond in self.recv_conditions[self.state.state]:
633                                        self._run_condition(rcvcond, pkt, *state_output)
634                                else:
635                                    self.debug(4, "FILTR: %s" % pkt.summary())
636                        else:
637                            self.debug(3, "IOEVENT on %s" % fd.ioname)
638                            for ioevt in self.ioevents[self.state.state]:
639                                if ioevt.atmt_ioname == fd.ioname:
640                                    self._run_condition(ioevt, fd, *state_output)
641
642            except ATMT.NewStateRequested,state_req:
643                self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
644                self.state = state_req
645                yield state_req
646
647    ## Public API
648    def add_interception_points(self, *ipts):
649        for ipt in ipts:
650            if hasattr(ipt,"atmt_state"):
651                ipt = ipt.atmt_state
652            self.interception_points.add(ipt)
653
654    def remove_interception_points(self, *ipts):
655        for ipt in ipts:
656            if hasattr(ipt,"atmt_state"):
657                ipt = ipt.atmt_state
658            self.interception_points.discard(ipt)
659
660    def add_breakpoints(self, *bps):
661        for bp in bps:
662            if hasattr(bp,"atmt_state"):
663                bp = bp.atmt_state
664            self.breakpoints.add(bp)
665
666    def remove_breakpoints(self, *bps):
667        for bp in bps:
668            if hasattr(bp,"atmt_state"):
669                bp = bp.atmt_state
670            self.breakpoints.discard(bp)
671
672    def start(self, *args, **kargs):
673        if not self.started.locked():
674            self._do_start(*args, **kargs)
675
676    def run(self, resume=None, wait=True):
677        if resume is None:
678            resume = Message(type = _ATMT_Command.RUN)
679        self.cmdin.send(resume)
680        if wait:
681            try:
682                c = self.cmdout.recv()
683            except KeyboardInterrupt:
684                self.cmdin.send(Message(type = _ATMT_Command.FREEZE))
685                return
686            if c.type == _ATMT_Command.END:
687                return c.result
688            elif c.type == _ATMT_Command.INTERCEPT:
689                raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt)
690            elif c.type == _ATMT_Command.SINGLESTEP:
691                raise self.Singlestep("singlestep state=[%s]"%c.state.state, state=c.state.state)
692            elif c.type == _ATMT_Command.BREAKPOINT:
693                raise self.Breakpoint("breakpoint triggered on state [%s]"%c.state.state, state=c.state.state)
694            elif c.type == _ATMT_Command.EXCEPTION:
695                raise c.exc_info[0],c.exc_info[1],c.exc_info[2]
696
697    def runbg(self, resume=None, wait=False):
698        self.run(resume, wait)
699
700    def next(self):
701        return self.run(resume = Message(type=_ATMT_Command.NEXT))
702
703    def stop(self):
704        self.cmdin.send(Message(type=_ATMT_Command.STOP))
705        with self.started:
706            # Flush command pipes
707            while True:
708                r,_,_ = select([self.cmdin, self.cmdout],[],[],0)
709                if not r:
710                    break
711                for fd in r:
712                    fd.recv()
713
714    def restart(self, *args, **kargs):
715        self.stop()
716        self.start(*args, **kargs)
717
718    def accept_packet(self, pkt=None, wait=False):
719        rsm = Message()
720        if pkt is None:
721            rsm.type = _ATMT_Command.ACCEPT
722        else:
723            rsm.type = _ATMT_Command.REPLACE
724            rsm.pkt = pkt
725        return self.run(resume=rsm, wait=wait)
726
727    def reject_packet(self, wait=False):
728        rsm = Message(type = _ATMT_Command.REJECT)
729        return self.run(resume=rsm, wait=wait)
730
731
732
733