trex_stl_streams.py revision f0ab9eba
1#!/router/bin/python
2
3from trex_stl_exceptions import *
4from trex_stl_types import verify_exclusive_arg, validate_type
5from trex_stl_packet_builder_interface import CTrexPktBuilderInterface
6from trex_stl_packet_builder_scapy import CScapyTRexPktBuilder, Ether, IP, UDP, TCP, RawPcapReader
7from collections import OrderedDict, namedtuple
8
9from dpkt import pcap
10import random
11import yaml
12import base64
13import string
14import traceback
15from types import NoneType
16import copy
17
18
19# base class for TX mode
20class STLTXMode(object):
21    def __init__ (self, pps = None, bps_L1 = None, bps_L2 = None, percentage = None):
22        args = [pps, bps_L1, bps_L2, percentage]
23
24        # default
25        if all([x is None for x in args]):
26            pps = 1.0
27        else:
28            verify_exclusive_arg(args)
29
30        self.fields = {'rate': {}}
31
32        if pps is not None:
33            validate_type('pps', pps, [float, int])
34
35            self.fields['rate']['type']  = 'pps'
36            self.fields['rate']['value'] = pps
37
38        elif bps_L1 is not None:
39            validate_type('bps_L1', bps_L1, [float, int])
40
41            self.fields['rate']['type']  = 'bps_L1'
42            self.fields['rate']['value'] = bps_L1
43
44        elif bps_L2 is not None:
45            validate_type('bps_L2', bps_L2, [float, int])
46
47            self.fields['rate']['type']  = 'bps_L2'
48            self.fields['rate']['value'] = bps_L2
49
50        elif percentage is not None:
51            validate_type('percentage', percentage, [float, int])
52            if not (percentage > 0 and percentage <= 100):
53                raise STLArgumentError('percentage', percentage)
54
55            self.fields['rate']['type']  = 'percentage'
56            self.fields['rate']['value'] = percentage
57
58
59
60    def to_json (self):
61        return self.fields
62
63
64# continuous mode
65class STLTXCont(STLTXMode):
66
67    def __init__ (self, **kwargs):
68
69        super(STLTXCont, self).__init__(**kwargs)
70
71        self.fields['type'] = 'continuous'
72
73    def __str__ (self):
74        return "Continuous"
75
76# single burst mode
77class STLTXSingleBurst(STLTXMode):
78
79    def __init__ (self, total_pkts = 1, **kwargs):
80
81        if not isinstance(total_pkts, int):
82            raise STLArgumentError('total_pkts', total_pkts)
83
84        super(STLTXSingleBurst, self).__init__(**kwargs)
85
86        self.fields['type'] = 'single_burst'
87        self.fields['total_pkts'] = total_pkts
88
89    def __str__ (self):
90        return "Single Burst"
91
92# multi burst mode
93class STLTXMultiBurst(STLTXMode):
94
95    def __init__ (self,
96                  pkts_per_burst = 1,
97                  ibg = 0.0,   # usec not SEC
98                  count = 1,
99                  **kwargs):
100
101        if not isinstance(pkts_per_burst, int):
102            raise STLArgumentError('pkts_per_burst', pkts_per_burst)
103
104        if not isinstance(ibg, (int, float)):
105            raise STLArgumentError('ibg', ibg)
106
107        if not isinstance(count, int):
108            raise STLArgumentError('count', count)
109
110        super(STLTXMultiBurst, self).__init__(**kwargs)
111
112        self.fields['type'] = 'multi_burst'
113        self.fields['pkts_per_burst'] = pkts_per_burst
114        self.fields['ibg'] = ibg
115        self.fields['count'] = count
116
117    def __str__ (self):
118        return "Multi Burst"
119
120STLStreamDstMAC_CFG_FILE=0
121STLStreamDstMAC_PKT     =1
122STLStreamDstMAC_ARP     =2
123
124# RX stats class
125class STLRxStats(object):
126    def __init__ (self, user_id):
127        self.fields = {}
128        self.fields['stream_id'] = user_id
129        self.fields['enabled'] = True
130        self.fields['seq_enabled'] = False
131        self.fields['latency_enabled'] = False
132
133    def to_json (self):
134        return dict(self.fields)
135
136    @staticmethod
137    def defaults ():
138        return {'enabled' : False}
139
140class STLStream(object):
141
142    def __init__ (self,
143                  name = None,
144                  packet = None,
145                  mode = STLTXCont(pps = 1),
146                  enabled = True,
147                  self_start = True,
148                  isg = 0.0,
149                  rx_stats = None,
150                  next = None,
151                  stream_id = None,
152                  action_count = 0,
153                  mac_src_override_by_pkt=None,
154                  mac_dst_override_mode=None    #see  STLStreamDstMAC_xx
155                  ):
156
157        # type checking
158        validate_type('mode', mode, STLTXMode)
159        validate_type('packet', packet, (NoneType, CTrexPktBuilderInterface))
160        validate_type('enabled', enabled, bool)
161        validate_type('self_start', self_start, bool)
162        validate_type('isg', isg, (int, float))
163        validate_type('stream_id', stream_id, (NoneType, int))
164
165        if (type(mode) == STLTXCont) and (next != None):
166            raise STLError("continuous stream cannot have a next stream ID")
167
168        # tag for the stream and next - can be anything
169        self.name = name
170        self.next = next
171
172        self.id = stream_id
173
174
175        self.fields = {}
176
177        int_mac_src_override_by_pkt = 0;
178        int_mac_dst_override_mode   = 0;
179
180
181        if mac_src_override_by_pkt == None:
182            int_mac_src_override_by_pkt=0
183            if packet :
184                if packet.is_def_src_mac ()==False:
185                    int_mac_src_override_by_pkt=1
186
187        else:
188            int_mac_src_override_by_pkt = int(mac_src_override_by_pkt);
189
190        if mac_dst_override_mode == None:
191            int_mac_dst_override_mode   = 0;
192            if packet :
193                if packet.is_def_dst_mac ()==False:
194                    int_mac_dst_override_mode=STLStreamDstMAC_PKT
195        else:
196            int_mac_dst_override_mode = int(mac_dst_override_mode);
197
198
199        self.fields['flags'] = (int_mac_src_override_by_pkt&1) +  ((int_mac_dst_override_mode&3)<<1)
200
201        self.fields['action_count'] = action_count
202
203        # basic fields
204        self.fields['enabled'] = enabled
205        self.fields['self_start'] = self_start
206        self.fields['isg'] = isg
207
208        # mode
209        self.fields['mode'] = mode.to_json()
210        self.mode_desc      = str(mode)
211
212
213        # packet
214        self.fields['packet'] = {}
215        self.fields['vm'] = {}
216
217        if not packet:
218            packet = CScapyTRexPktBuilder(pkt = Ether()/IP())
219
220        # packet builder
221        packet.compile()
222
223        # packet and VM
224        self.fields['packet'] = packet.dump_pkt()
225        self.fields['vm']     = packet.get_vm_data()
226
227        self.pkt = base64.b64decode(self.fields['packet']['binary'])
228
229        # this is heavy, calculate lazy
230        self.packet_desc = None
231
232        if not rx_stats:
233            self.fields['rx_stats'] = STLRxStats.defaults()
234        else:
235            self.fields['rx_stats'] = rx_stats.to_json()
236
237
238    def __str__ (self):
239        s =  "Stream Name: {0}\n".format(self.name)
240        s += "Stream Next: {0}\n".format(self.next)
241        s += "Stream JSON:\n{0}\n".format(json.dumps(self.fields, indent = 4, separators=(',', ': '), sort_keys = True))
242        return s
243
244    def to_json (self):
245        return dict(self.fields)
246
247    def get_id (self):
248        return self.id
249
250
251    def get_name (self):
252        return self.name
253
254    def get_next (self):
255        return self.next
256
257
258    def get_pkt (self):
259        return self.pkt
260
261    def get_pkt_len (self, count_crc = True):
262       pkt_len = len(self.get_pkt())
263       if count_crc:
264           pkt_len += 4
265
266       return pkt_len
267
268
269    def get_pkt_type (self):
270        if self.packet_desc == None:
271            self.packet_desc = CScapyTRexPktBuilder.pkt_layers_desc_from_buffer(self.get_pkt())
272
273        return self.packet_desc
274
275    def get_mode (self):
276        return self.mode_desc
277
278    @staticmethod
279    def get_rate_from_field (rate_json):
280        t = rate_json['type']
281        v = rate_json['value']
282
283        if t == "pps":
284            return format_num(v, suffix = "pps")
285        elif t == "bps_L1":
286            return format_num(v, suffix = "bps (L1)")
287        elif t == "bps_L2":
288            return format_num(v, suffix = "bps (L2)")
289        elif t == "percentage":
290            return format_num(v, suffix = "%")
291
292    def get_rate (self):
293        return self.get_rate_from_field(self.fields['mode']['rate'])
294
295
296    def to_yaml (self):
297        y = {}
298
299        if self.name:
300            y['name'] = self.name
301
302        if self.next:
303            y['next'] = self.next
304
305        y['stream'] = copy.deepcopy(self.fields)
306
307        # some shortcuts for YAML
308        rate_type  = self.fields['mode']['rate']['type']
309        rate_value = self.fields['mode']['rate']['value']
310
311        y['stream']['mode'][rate_type] = rate_value
312        del y['stream']['mode']['rate']
313
314        return y
315
316    def dump_to_yaml (self, yaml_file = None):
317        yaml_dump = yaml.dump([self.to_yaml()], default_flow_style = False)
318
319        # write to file if provided
320        if yaml_file:
321            with open(yaml_file, 'w') as f:
322                f.write(yaml_dump)
323
324        return yaml_dump
325
326class YAMLLoader(object):
327
328    def __init__ (self, yaml_file):
329        self.yaml_path = os.path.dirname(yaml_file)
330        self.yaml_file = yaml_file
331
332
333    def __parse_packet (self, packet_dict):
334
335        packet_type = set(packet_dict).intersection(['binary', 'pcap'])
336        if len(packet_type) != 1:
337            raise STLError("packet section must contain either 'binary' or 'pcap'")
338
339        if 'binary' in packet_type:
340            try:
341                pkt_str = base64.b64decode(packet_dict['binary'])
342            except TypeError:
343                raise STLError("'binary' field is not a valid packet format")
344
345            builder = CScapyTRexPktBuilder(pkt_buffer = pkt_str)
346
347        elif 'pcap' in packet_type:
348            pcap = os.path.join(self.yaml_path, packet_dict['pcap'])
349
350            if not os.path.exists(pcap):
351                raise STLError("'pcap' - cannot find '{0}'".format(pcap))
352
353            builder = CScapyTRexPktBuilder(pkt = pcap)
354
355        return builder
356
357
358    def __parse_mode (self, mode_obj):
359
360        rate_parser = set(mode_obj).intersection(['pps', 'bps_L1', 'bps_L2', 'percentage'])
361        if len(rate_parser) != 1:
362            raise STLError("'rate' must contain exactly one from 'pps', 'bps_L1', 'bps_L2', 'percentage'")
363
364        rate_type  = rate_parser.pop()
365        rate = {rate_type : mode_obj[rate_type]}
366
367        mode_type = mode_obj.get('type')
368
369        if mode_type == 'continuous':
370            mode = STLTXCont(**rate)
371
372        elif mode_type == 'single_burst':
373            defaults = STLTXSingleBurst()
374            mode = STLTXSingleBurst(total_pkts  = mode_obj.get('total_pkts', defaults.fields['total_pkts']),
375                                    **rate)
376
377        elif mode_type == 'multi_burst':
378            defaults = STLTXMultiBurst()
379            mode = STLTXMultiBurst(pkts_per_burst = mode_obj.get('pkts_per_burst', defaults.fields['pkts_per_burst']),
380                                   ibg            = mode_obj.get('ibg', defaults.fields['ibg']),
381                                   count          = mode_obj.get('count', defaults.fields['count']),
382                                   **rate)
383
384        else:
385            raise STLError("mode type can be 'continuous', 'single_burst' or 'multi_burst")
386
387
388        return mode
389
390
391
392
393    def __parse_stream (self, yaml_object):
394        s_obj = yaml_object['stream']
395
396        # parse packet
397        packet = s_obj.get('packet')
398        if not packet:
399            raise STLError("YAML file must contain 'packet' field")
400
401        builder = self.__parse_packet(packet)
402
403
404        # mode
405        mode_obj = s_obj.get('mode')
406        if not mode_obj:
407            raise STLError("YAML file must contain 'mode' field")
408
409        mode = self.__parse_mode(mode_obj)
410
411
412        defaults = STLStream()
413
414        # create the stream
415        stream = STLStream(name       = yaml_object.get('name'),
416                           packet     = builder,
417                           mode       = mode,
418                           enabled    = s_obj.get('enabled', defaults.fields['enabled']),
419                           self_start = s_obj.get('self_start', defaults.fields['self_start']),
420                           isg        = s_obj.get('isg', defaults.fields['isg']),
421                           rx_stats   = s_obj.get('rx_stats', defaults.fields['rx_stats']),
422                           next       = yaml_object.get('next'),
423                           action_count = s_obj.get('action_count', defaults.fields['action_count']),
424                           mac_src_override_by_pkt = s_obj.get('mac_src_override_by_pkt', 0),
425                           mac_dst_override_mode = s_obj.get('mac_src_override_by_pkt', 0)
426                           )
427
428        # hack the VM fields for now
429        if 'vm' in s_obj:
430            stream.fields['vm'].update(s_obj['vm'])
431
432        return stream
433
434
435    def parse (self):
436        with open(self.yaml_file, 'r') as f:
437            # read YAML and pass it down to stream object
438            yaml_str = f.read()
439
440            try:
441                objects = yaml.load(yaml_str)
442            except yaml.parser.ParserError as e:
443                raise STLError(str(e))
444
445            streams = [self.__parse_stream(object) for object in objects]
446
447            return streams
448
449
450# profile class
451class STLProfile(object):
452    def __init__ (self, streams = None):
453        if streams == None:
454            streams = []
455
456        if not type(streams) == list:
457            streams = [streams]
458
459        if not all([isinstance(stream, STLStream) for stream in streams]):
460            raise STLArgumentError('streams', streams, valid_values = STLStream)
461
462        self.streams = streams
463
464
465    def get_streams (self):
466        return self.streams
467
468    def __str__ (self):
469        return '\n'.join([str(stream) for stream in self.streams])
470
471
472    @staticmethod
473    def load_yaml (yaml_file):
474        # check filename
475        if not os.path.isfile(yaml_file):
476            raise STLError("file '{0}' does not exists".format(yaml_file))
477
478        yaml_loader = YAMLLoader(yaml_file)
479        streams = yaml_loader.parse()
480
481        return STLProfile(streams)
482
483
484    @staticmethod
485    def load_py (python_file):
486        # check filename
487        if not os.path.isfile(python_file):
488            raise STLError("file '{0}' does not exists".format(python_file))
489
490        basedir = os.path.dirname(python_file)
491        sys.path.append(basedir)
492
493        try:
494            file    = os.path.basename(python_file).split('.')[0]
495            module = __import__(file, globals(), locals(), [], -1)
496            reload(module) # reload the update
497
498            streams = module.register().get_streams()
499
500            return STLProfile(streams)
501
502        except Exception as e:
503            a, b, tb = sys.exc_info()
504            x =''.join(traceback.format_list(traceback.extract_tb(tb)[1:])) + a.__name__ + ": " + str(b) + "\n"
505
506            summary = "\nPython Traceback follows:\n\n" + x
507            raise STLError(summary)
508
509
510        finally:
511            sys.path.remove(basedir)
512
513
514    # loop_count = 0 means loop forever
515    @staticmethod
516    def load_pcap (pcap_file, ipg_usec = None, speedup = 1.0, loop_count = 1, vm = None):
517        # check filename
518        if not os.path.isfile(pcap_file):
519            raise STLError("file '{0}' does not exists".format(pcap_file))
520
521        streams = []
522        last_ts_usec = 0
523
524        pkts = RawPcapReader(pcap_file).read_all()
525
526        for i, (cap, meta) in enumerate(pkts, start = 1):
527            # IPG - if not provided, take from cap
528            if ipg_usec == None:
529                ts_usec = (meta[0] * 1e6 + meta[1]) / float(speedup)
530            else:
531                ts_usec = (ipg_usec * i) / float(speedup)
532
533            # handle last packet
534            if i == len(pkts):
535                next = 1
536                action_count = loop_count
537            else:
538                next = i + 1
539                action_count = 0
540
541
542            streams.append(STLStream(name = i,
543                                     packet = CScapyTRexPktBuilder(pkt_buffer = cap, vm = vm),
544                                     mode = STLTXSingleBurst(total_pkts = 1, percentage = 100),
545                                     self_start = True if (i == 1) else False,
546                                     isg = (ts_usec - last_ts_usec),  # seconds to usec
547                                     action_count = action_count,
548                                     next = next))
549
550            last_ts_usec = ts_usec
551
552
553        return STLProfile(streams)
554
555
556
557    @staticmethod
558    def load (filename):
559        x = os.path.basename(filename).split('.')
560        suffix = x[1] if (len(x) == 2) else None
561
562        if suffix == 'py':
563            profile = STLProfile.load_py(filename)
564
565        elif suffix == 'yaml':
566            profile = STLProfile.load_yaml(filename)
567
568        elif suffix in ['cap', 'pcap']:
569            profile = STLProfile.load_pcap(filename, speedup = 1, ipg_usec = 1e6)
570
571        else:
572            raise STLError("unknown profile file type: '{0}'".format(suffix))
573
574        return profile
575
576
577    def dump_to_yaml (self, yaml_file = None):
578        yaml_list = [stream.to_yaml() for stream in self.streams]
579        yaml_str = yaml.dump(yaml_list, default_flow_style = False)
580
581        # write to file if provided
582        if yaml_file:
583            with open(yaml_file, 'w') as f:
584                f.write(yaml_str)
585
586        return yaml_str
587
588
589    def __len__ (self):
590        return len(self.streams)
591
592