#===============================================================================
# Copyright 2012 NetApp, Inc. All Rights Reserved,
# contribution by Jorge Mora <mora@netapp.com>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#===============================================================================
"""
Packet trace module

The Packet trace module is a python module that takes a trace file created
by tcpdump and unpacks the contents of each packet. You can decode one packet
at a time, or do a search for specific packets. The main difference between
these modules and other tools used to decode trace files is that you can use
this module to completely automate your tests.

How does it work? It opens the trace file and reads one record at a time
keeping track where each record starts. This way, very large trace files
can be opened without having to wait for the file to load and avoid loading
the whole file into memory.

Packet layers supported:
    - ETHERNET II (RFC 894)
    - IP layer (supports IPv4 and IPv6)
    - UDP layer
    - TCP layer
    - RPC layer
    - NFS v4.0
    - NFS v4.1 including pNFS file layouts
    - NFS v4.2
    - PORTMAP v2
    - MOUNT v3
    - NLM v4
"""
import os
import re
import gzip
import time
import token
import struct
import parser
import symbol
from formatstr import *
import nfstest_config as c
from baseobj import BaseObj
from packet.unpack import Unpack
from packet.record import Record
from packet.pkt import Pkt, PKT_layers
from packet.link.ethernet import ETHERNET

# Module constants
__author__    = "Jorge Mora (%s)" % c.NFSTEST_AUTHOR_EMAIL
__copyright__ = "Copyright (C) 2012 NetApp, Inc."
__license__   = "GPL v2"
__version__   = "2.0"

BaseObj.debug_map(0x100000000, 'pkt1', "PKT1: ")
BaseObj.debug_map(0x200000000, 'pkt2', "PKT2: ")
BaseObj.debug_map(0x400000000, 'pkt3', "PKT3: ")
BaseObj.debug_map(0x800000000, 'pkt4', "PKT4: ")
BaseObj.debug_map(0xF00000000, 'pktt', "PKTT: ")

# Map of tokens
_token_map = dict(token.tok_name.items() + symbol.sym_name.items())
# Map of items not in the array of the compound
_nfsopmap = {'status': 1, 'tag': 1}
# Match function map
_match_func_map = dict(zip(PKT_layers,["self._match_%s"%x for x in PKT_layers]))

class Header(BaseObj):
    # Class attributes
    _attrlist = ("major", "minor", "zone_offset", "accuracy",
                 "dump_length", "link_type")

    def __init__(self, pktt):
        ulist = struct.unpack(pktt.header_fmt, pktt._read(20))
        self.major       = ulist[0]
        self.minor       = ulist[1]
        self.zone_offset = ulist[2]
        self.accuracy    = ulist[3]
        self.dump_length = ulist[4]
        self.link_type   = ulist[5]

class Pktt(BaseObj, Unpack):
    """Packet trace object

       Usage:
           from packet.pktt import Pktt

           x = Pktt("/traces/tracefile.cap")

           # Iterate over all packets found in the trace file
           for pkt in x:
               print pkt
    """
    def __init__(self, tfile, live=False, state=True):
        """Constructor

           Initialize object's private data, note that this will not check the
           file for existence nor will open the file to verify if it is a valid
           tcpdump file. The tcpdump trace file will be opened the first time a
           packet is retrieved.

           tracefile:
               Name of tcpdump trace file or a list of trace file names
               (little or big endian format)
           live:
               If set to True, methods will not return if encountered <EOF>,
               they will keep on trying until more data is available in the
               file. This is useful when running tcpdump in parallel,
               especially when tcpdump is run with the '-C' option, in which
               case when <EOF> is encountered the next trace file created by
               tcpdump will be opened and the object will be re-initialized,
               all private data referencing the previous file is lost.
        """
        self.tfile   = tfile  # Current trace file name
        self.bfile   = tfile  # Base trace file name
        self.live    = live   # Set to True if dealing with a live tcpdump file
        self.offset  = 0      # Current file offset
        self.boffset = 0      # File offset of current packet
        self.ioffset = 0      # File offset of first packet
        self.index   = 0      # Current packet index
        self.mindex  = 0      # Maximum packet index for current trace file
        self.findex  = 0      # Current tcpdump file index (used with self.live)
        self.fh      = None   # Current file handle
        self.eof     = False  # End of file marker for current packet trace
        self.serial  = False  # Processing trace files serially
        self.pkt     = None   # Current packet
        self.pkt_call  = None # The current packet call if self.pkt is a reply
        self.pktt_list = []   # List of Pktt objects created
        self.tfiles    = []   # List of packet trace files

        # TCP stream map: to keep track of the different TCP streams within
        # the trace file -- used to deal with RPC packets spanning multiple
        # TCP packets or to handle a TCP packet having multiple RPC packets
        self._tcp_stream_map = {}

        # RPC xid map: to keep track of packet calls
        self._rpc_xid_map = {}
        # List of outstanding xids to match
        self._match_xid_list = []

        # Process tfile argument
        if isinstance(tfile, list):
            # The argument tfile is given as a list of packet trace files
            self.tfiles = tfile
            if len(self.tfiles) == 1:
                # Only one file is given
                self.tfile = self.tfiles[0]
            else:
                # Create all packet trace objects
                for tfile in self.tfiles:
                    self.pktt_list.append(Pktt(tfile))

    def __del__(self):
        """Destructor

           Gracefully close the tcpdump trace file if it is opened.
        """
        if self.fh:
            self.fh.close()

    def __iter__(self):
        """Make this object iterable."""
        return self

    def __contains__(self, expr):
        """Implement membership test operator.
           Return true if expr matches a packet in the trace file,
           false otherwise.

           The packet is also stored in the object attribute pkt.

           Examples:
               # Find the next READ request
               if ("NFS.argop == 25" in x):
                   print x.pkt.nfs

           See match() method for more information
        """
        pkt = self.match(expr)
        return (pkt is not None)

    def __getitem__(self, index):
        """Get the packet from the trace file given by the index
           or raise IndexError.

           The packet is also stored in the object attribute pkt.

           Examples:
               pkt = x[index]
        """
        self.dprint('PKT4', ">>> __getitem__(%d)" % index)
        if index < 0:
            # No negative index is allowed
            raise IndexError

        try:
            if index == self.pkt.record.index:
                # The requested packet is in memory, just return it
                return self.pkt
        except:
            pass

        if index < self.index:
            # Reset the current packet index and offset
            # The index is less than the current packet offset so position
            # the file pointer to the offset of the packet given by index
            self.rewind(index)

        # Move to the packet specified by the index
        pkt = None
        while self.index <= index:
            try:
                pkt = self.next()
            except:
                break

        if pkt is None:
            raise IndexError
        return pkt

    def next(self):
        """Get the next packet from the trace file or raise StopIteration.

           The packet is also stored in the object attribute pkt.

           Examples:
               # Iterate over all packets found in the trace file using
               # the iterable properties of the object
               for pkt in x:
                   print pkt

               # Iterate over all packets found in the trace file using it
               # as a method and using the object variable as the packet
               # Must use the try statement to catch StopIteration exception
               try:
                   while (x.next()):
                       print x.pkt
               except StopIteration:
                   pass

               # Iterate over all packets found in the trace file using it
               # as a method and using the return value as the packet
               # Must use the try statement to catch StopIteration exception
               while True:
                   try:
                       print x.next()
                   except StopIteration:
                       break

           NOTE:
               Supports only single active iteration
        """
        self.dprint('PKT4', ">>> %d: next()" % self.index)
        # Initialize next packet
        self.pkt = Pkt()

        if len(self.pktt_list) > 1:
            # Dealing with multiple trace files
            minsecs  = None
            pktt_obj = None
            for obj in self.pktt_list:
                if obj.pkt is None:
                    # Get first packet for this packet trace object
                    try:
                        obj.next()
                    except StopIteration:
                        obj.mindex = self.index
                if obj.eof:
                    continue
                if minsecs is None or obj.pkt.record.secs < minsecs:
                    minsecs = obj.pkt.record.secs
                    pktt_obj = obj
            if pktt_obj is None:
                # All packet trace files have been processed
                raise StopIteration
            elif len(self._tcp_stream_map):
                # This packet trace file should be processed serially
                # Have all state transferred to next packet object
                pktt_obj.rewind()
                pktt_obj._tcp_stream_map = self._tcp_stream_map
                pktt_obj._rpc_xid_map    = self._rpc_xid_map
                self._tcp_stream_map = {}
                self._rpc_xid_map    = {}
                pktt_obj.next()

            # Overwrite attributes seen by the caller with the attributes
            # from the current packet trace object
            self.pkt = pktt_obj.pkt
            self.pkt_call = pktt_obj.pkt_call
            self.tfile = pktt_obj.tfile
            self.pkt.record.index = self.index  # Use a cumulative index

            try:
                # Get next packet for this packet trace object
                pktt_obj.next()
            except StopIteration:
                # Set maximum packet index for this packet trace object to
                # be used by rewind to select the proper packet trace object
                pktt_obj.mindex = self.index
                # Check if objects should be serially processed
                pktt_obj.serial = False
                for obj in self.pktt_list:
                    if not obj.eof:
                        if obj.index > 1:
                            pktt_obj.serial = False
                            break
                        elif obj.index == 1:
                            pktt_obj.serial = True
                if pktt_obj.serial:
                    # Save current state
                    self._tcp_stream_map = pktt_obj._tcp_stream_map
                    self._rpc_xid_map    = pktt_obj._rpc_xid_map

            # Increment cumulative packet index
            self.index += 1
            return self.pkt

        # Save file offset for this packet
        self.boffset = self.offset

        # Get record header
        data = self._read(16)
        if len(data) < 16:
            self.eof = True
            raise StopIteration
        # Decode record header
        record = Record(self, data)

        # Get record data and create Unpack object
        self.unpack = Unpack(self._read(record.length_inc))
        if self.unpack.size() < record.length_inc:
            # Record has been truncated, stop iteration
            raise StopIteration

        if self.header.link_type == 1:
            # Decode ethernet layer
            ETHERNET(self)
        else:
            # Unknown link layer
            record.data = self.unpack.getbytes()

        # Increment packet index
        self.index += 1

        return self.pkt

    def rewind(self, index=0):
        """Rewind the trace file by setting the file pointer to the start of
           the given packet index. Returns False if unable to rewind the file,
           e.g., when the given index is greater than the maximum number
           of packets processed so far.
        """
        self.dprint('PKT1', ">>> rewind(%d)" % index)
        if index >= 0 and index < self.index:
            if len(self.pktt_list) > 1:
                # Dealing with multiple trace files
                self.index = 0
                for obj in self.pktt_list:
                    if not obj.eof or index <= obj.mindex:
                        obj.rewind()
                        try:
                            obj.next()
                        except StopIteration:
                            pass
                    elif obj.serial and index > obj.mindex:
                        self.index = obj.mindex + 1
            else:
                # Reset the current packet index and offset to the first packet
                self.offset = self.ioffset
                self.index  = 0
                self.eof    = False

                # Position the file pointer to the offset of the first packet
                self._getfh().seek(self.offset)

                # Clear state
                self._tcp_stream_map = {}
                self._rpc_xid_map    = {}

            # Move to the packet before the specified by the index so the
            # next packet fetched will be the one given by index
            while self.index < index:
                try:
                    pkt = self.next()
                except:
                    break

            # Rewind succeeded
            return True
        return False

    def _getfh(self):
        """Get the filehandle of the trace file, open file if necessary."""
        if self.fh == None:
            # Check size of file
            fstat = os.stat(self.tfile)
            if fstat.st_size == 0:
                raise Exception("Packet trace file is empty")

            # Open trace file
            self.fh = open(self.tfile, 'rb')

            iszip = False
            self.header_fmt = None
            while self.header_fmt is None:
                # Initialize offset
                self.offset = 0

                # Get file identifier
                try:
                    self.ident = self._read(4)
                except:
                    self.ident = ""

                if self.ident == '\324\303\262\241':
                    # Little endian
                    self.header_fmt = '<HHIIII'
                    self.header_rec = '<IIII'
                elif self.ident == '\241\262\303\324':
                    # Big endian
                    self.header_fmt = '>HHIIII'
                    self.header_rec = '>IIII'
                else:
                    if iszip:
                        raise Exception('Not a tcpdump file')
                    iszip = True
                    self.fh.seek(0)
                    # Try if this is a gzip compress file
                    self.fh = gzip.GzipFile(fileobj=self.fh)

            # Get header information
            self.header = Header(self)

            # Initialize packet number
            self.index   = 0
            self.tstart  = None
            self.ioffset = self.offset

        return self.fh

    def _read(self, count):
        """Wrapper for read in order to increment the object's offset. It also
           takes care of <EOF> when 'live' option is set which keeps on trying
           to read and switching files when needed.
        """
        while True:
            # Read number of bytes specified
            data = self._getfh().read(count)
            ldata = len(data)
            if self.live and ldata != count:
                # Not all data was read (<EOF>)
                tracefile = "%s%d" % (self.bfile, self.findex+1)
                # Check if next trace file exists
                if os.path.isfile(tracefile):
                    # Save information that keeps track of the next trace file
                    basefile = self.bfile
                    findex = self.findex + 1
                    # Re-initialize the object
                    self.__del__()
                    self.__init__(tracefile, live=self.live)
                    # Overwrite next trace file info
                    self.bfile = basefile
                    self.findex = findex
                # Re-position file pointer to last known offset
                self._getfh().seek(self.offset)
                time.sleep(1)
            else:
                break

        # Increment object's offset by the amount of data read
        self.offset += ldata
        return data

    def _split_match(self, uargs):
        """Split match arguments and return a tuple (lhs, opr, rhs)
           where lhs is the left hand side of the given argument expression,
           opr is the operation and rhs is the right hand side of the given
           argument expression:

               <lhs> <opr> <rhs>

           Valid opr values are: ==, !=, <, >, <=, >=, in
        """
        m = re.search(r"([^!=<>]+)\s*([!=<>]+|in)\s*(.*)", uargs)
        lhs = m.group(1).rstrip()
        opr = m.group(2)
        rhs = m.group(3)
        return (lhs, opr, rhs)

    def _process_match(self, obj, lhs, opr, rhs):
        """Process "regex" and 'in' operator on match expression.
           Regular expression is given as re('regex') and converted to a
           proper regex re.search('regex', data), where data is the object
           compose of obj and lhs|rhs depending on opr. The argument obj
           is an object prefix.

           If opr is a comparison operation (==, !=, etc.), both obj and lhs
           will be the actual LHS and rhs will be the actual RHS.
           If opr is 'in', lhs will be the actual LHS and both obj and rhs
           will be the actual RHS.

           Return the processed match expression.

           Examples:
               # Regular expression processing
               expr = x._process_match('self.pkt.ip.', 'src', '==', "re(r'192\.*')")

               Returns the following expression ready to be evaluated:
               expr = "re.search(r'192\.*', str(self.pkt,ip.src))"

               # Object prefix processing
               expr = x._process_match('item.', 'argop', '==', '25')

               Returns the following expression ready to be evaluated:
               expr = "item.argop==25"

               # Membership (in) processing
               expr = x._process_match('item.', '62', 'in', 'attributes')

               Returns the following expression ready to be evaluated:
               expr = "62 in item.attributes"
        """
        func = None
        if opr != 'in':
            regex = re.search(r"(\w+)\((.*)\)", lhs)
            if regex:
                lhs = regex.group(2)
                func = regex.group(1)

        if rhs[:3] == 're(':
            # Regular expression, it must be in rhs
            rhs = "re.search" + rhs[2:]
            if opr == "!=":
                rhs = "not " + rhs
            LHS = rhs[:-1] + ", str(" + obj + lhs +  "))"
            RHS = ""
            opr = ""
        elif opr == 'in':
            opr = " in "
            if self.inlhs:
                LHS = obj + lhs
                RHS = rhs
            else:
                LHS = lhs
                RHS = obj + rhs
        else:
            LHS = obj + lhs
            RHS = rhs

        if func is not None:
            LHS = "%s(%s)" % (func, LHS)

        return LHS + opr + RHS

    def _match(self, layer, uargs):
        """Default match function."""
        if not hasattr(self.pkt, layer):
            return False

        if layer == "nfs":
            # Use special matching function for NFS
            texpr = self.match_nfs(uargs)
        else:
            # Use general match
            obj = "self.pkt.%s." % layer.lower()
            lhs, opr, rhs = self._split_match(uargs)
            expr = self._process_match(obj, lhs, opr, rhs)
            texpr = eval(expr)
        self.dprint('PKT2', "    %d: match_%s(%s) -> %r" % (self.pkt.record.index, layer, uargs, texpr))
        return texpr

    def clear_xid_list(self):
        """Clear list of outstanding xids"""
        self._match_xid_list = []

    def _match_nfs(self, uargs):
        """Match NFS values on current packet."""
        array = None
        isarg = True
        lhs, opr, rhs = self._split_match(uargs)

        if self.pkt.rpc.version == 3 or _nfsopmap.get(lhs):
            try:
                # Top level NFSv4 packet info or NFSv3 packet
                expr = self._process_match("self.pkt.nfs.", lhs, opr, rhs)
                if eval(expr):
                    # Set NFSop and NFSidx
                    self.pkt.NFSop = self.pkt.nfs
                    self.pkt.NFSidx = 0
                    return True
                return False
            except Exception:
                return False

        idx = 0
        obj_prefix = "item."
        for item in self.pkt.nfs.array:
            try:
                # Get expression to eval
                expr = self._process_match(obj_prefix, lhs, opr, rhs)
                if eval(expr):
                    self.pkt.NFSop = item
                    self.pkt.NFSidx = idx
                    return True
            except Exception:
                # Continue searching
                pass
            idx += 1
        return False

    def match_nfs(self, uargs):
        """Match NFS values on current packet.

           In NFSv4, there is a single compound procedure with multiple
           operations, matching becomes a little bit tricky in order to make
           the matching expression easy to use. The NFS object's name space
           gets converted into a flat name space for the sole purpose of
           matching. In other words, all operation objects in array are
           treated as being part of the NFS object's top level attributes.

           Consider the following NFS object:
               nfsobj = COMPOUND4res(
                   status=NFS4_OK,
                   tag='NFSv4_tag',
                   array = [
                       nfs_resop4(
                           resop=OP_SEQUENCE,
                           opsequence=SEQUENCE4res(
                               status=NFS4_OK,
                               resok=SEQUENCE4resok(
                                   sessionid='sessionid',
                                   sequenceid=29,
                                   slotid=0,
                                   highest_slotid=179,
                                   target_highest_slotid=179,
                                   status_flags=0,
                               ),
                           ),
                       ),
                       nfs_resop4(
                           resop=OP_PUTFH,
                           opputfh = PUTFH4res(
                               status=NFS4_OK,
                           ),
                       ),
                       ...
                   ]
               ),

           The result for operation PUTFH is the second in the list:
               putfh = nfsobj.array[1]

           From this putfh object the status operation can be accessed as:
               status = putfh.opputfh.status

           or simply as (this is how the NFS object works):
               status = putfh.status

           In this example, the following match expression 'NFS.status == 0'
           could match the top level status of the compound (nfsobj.status)
           or the putfh status (nfsobj.array[1].status)

           The following match expression 'NFS.sequenceid == 25' will also
           match this packet as well, even though the actual expression should
           be 'nfsobj.array[0].opsequence.resok.sequenceid == 25' or
           simply 'nfsobj.array[0].sequenceid == 25'.

           This approach makes the match expressions simpler at the expense of
           having some ambiguities on where the actual matched occurred. If a
           match is desired on a specific operation, a more qualified name can
           be given. In the above example, in order to match the status of the
           PUTFH operation the match expression 'NFS.opputfh.status == 0' can
           be used. On the other hand, consider a compound having multiple
           PUTFH results the above match expression will always match the first
           occurrence of PUTFH where the status is 0. There is no way to tell
           the match engine to match the second or Nth occurrence of an
           operation.
        """
        texpr = self._match_nfs(uargs)
        self.dprint('PKT2', "    %d: match_nfs(%s) -> %r" % (self.pkt.record.index, uargs, texpr))
        return texpr

    def _convert_match(self, ast):
        """Convert a parser list match expression into their corresponding
           function calls.

           Example:
               expr = "TCP.flags.ACK == 1 and NFS.argop == 50"
               st = parser.expr(expr)
               ast = parser.st2list(st)
               data =  self._convert_match(ast)

               Returns:
               data = "(self._match('tcp','flags.ACK==1'))and(self._match('nfs','argop==50'))"
        """
        ret = ''
        isin = False
        if not isinstance(ast, list):
            if ast.lower() in _match_func_map:
                # Replace name by its corresponding function name
                return _match_func_map[ast.lower()]
            return ast
        if len(ast) == 2:
            return self._convert_match(ast[1])

        for a in ast[1:]:
            data = self._convert_match(a)
            if data == 'in':
                data = ' in '
                isin = True
                if ret[:5] == "self.":
                    # LHS in the 'in' operator is a packet object
                    self.inlhs = True
                else:
                    # LHS in the 'in' operator is a constant value
                    self.inlhs = False
            ret += data

        if _token_map[ast[0]] == "comparison":
            # Comparison
            if isin:
                regex = re.search(r'(.*)(self\._match)_(\w+)\.(.*)', ret)
                data  = regex.groups()
                func  = data[1]
                layer = data[2]
                uargs = data[0] + data[3]
            else:
                regex = re.search(r"^((\w+)\()?(self\._match)_(\w+)\.(.*)", ret)
                data  = regex.groups()
                func  = data[2]
                layer = data[3]
                if data[0] is None:
                    uargs = data[4]
                else:
                    uargs = data[0] + data[4]
            # Escape all single quotes since the whole string will be quoted
            uargs = re.sub(r"'", "\\'", uargs)
            ret = "(%s('%s','%s'))" % (func, layer, uargs)

        return ret

    def match(self, expr, maxindex=None, rewind=True, reply=False):
        """Return the packet that matches the given expression, also the packet
           index points to the next packet after the matched packet.
           Returns None if packet is not found and the packet index points
           to the packet at the beginning of the search.

           expr:
               String of expressions to be evaluated
           maxindex:
               The match fails if packet index hits this limit
           rewind:
               Rewind to index where matching started if match fails
           reply:
               Match RPC replies of previously matched calls as well

           Examples:
               # Find the packet with both the ACK and SYN TCP flags set to 1
               pkt = x.match("TCP.flags.ACK == 1 and TCP.flags.SYN == 1")

               # Find the next NFS EXCHANGE_ID request
               pkt = x.match("NFS.argop == 42")

               # Find the next NFS EXCHANGE_ID or CREATE_SESSION request
               pkt = x.match("NFS.argop in [42,43]")

               # Find the next NFS OPEN request or reply
               pkt = x.match("NFS.op == 18")

               # Find all packets coming from subnet 192.168.1.0/24 using
               # a regular expression
               while x.match(r"IP.src == re('192\.168\.1\.\d*')"):
                   print x.pkt.tcp

               # Find packet having a GETATTR asking for FATTR4_FS_LAYOUT_TYPES(bit 62)
               pkt_call = x.match("NFS.attr_request & 0x4000000000000000L != 0")
               if pkt_call:
                   # Find GETATTR reply
                   xid = pkt_call.rpc.xid
                   # Find reply where the number 62 is in the array NFS.attributes
                   pkt_reply = x.match("RPC.xid == %d and 62 in NFS.attributes" % xid)

               # Find the next WRITE request
               pkt = x.match("NFS.argop == 38")
               if pkt:
                   print pkt.nfs

               # Same as above, but using membership test operator instead
               if ("NFS.argop == 38" in x):
                   print x.pkt.nfs

           See also:
               match_ethernet(), match_ip(), match_tcp(), match_rpc(), match_nfs()
        """
        # Save current position
        save_index = self.index

        # Parse match expression
        st = parser.expr(expr)
        smap = parser.st2list(st)
        pdata = self._convert_match(smap)
        self.dprint('PKT1', ">>> %d: match(%s)" % (self.index, expr))
        self.reply_matched = False

        # Search one packet at a time
        for pkt in self:
            if maxindex and self.index > maxindex:
                # Hit maxindex limit
                break
            try:
                if reply and pkt == "rpc" and pkt.rpc.type == 1 and pkt.rpc.xid in self._match_xid_list:
                    self.dprint('PKT1', ">>> %d: match() -> True: reply" % pkt.record.index)
                    self._match_xid_list.remove(pkt.rpc.xid)
                    self.reply_matched = True
                    return pkt
                if eval(pdata):
                    # Return matched packet
                    self.dprint('PKT1', ">>> %d: match() -> True" % pkt.record.index)
                    if reply and pkt == "rpc" and pkt.rpc.type == 0:
                        # Save xid of matched call
                        self._match_xid_list.append(pkt.rpc.xid)
                    return pkt
            except Exception:
                pass

        if rewind:
            # No packet matched, re-position the file pointer back to where
            # the search started
            self.rewind(save_index)
        self.pkt = None
        self.dprint('PKT1', ">>> match() -> False")
        return None

    @staticmethod
    def escape(data):
        """Escape special characters.

           Examples:
               # Call as an instance
               escaped_data = x.escape(data)

               # Call as a class
               escaped_data = Pktt.escape(data)
        """
        # repr() can escape or not a single quote depending if a double
        # quote is present, just make sure both quotes are escaped correctly
        rdata = repr(data)
        if rdata[0] == '"':
            # Double quotes are escaped
            dquote = r'x22'
            squote = r'\x27'
        else:
            # Single quotes are escaped
            dquote = r'\x22'
            squote = r'x27'
        # Replace all double quotes to its corresponding hex value
        data = re.sub(r'"', dquote, rdata[1:-1])
        # Replace all single quotes to its corresponding hex value
        data = re.sub(r"'", squote, data)
        # Escape all backslashes
        data = re.sub(r'\\', r'\\\\', data)
        return data

    @staticmethod
    def ip_tcp_src_expr(ipaddr, port):
        """Return a match expression to find a packet coming from ipaddr:port.

           Examples:
               # Call as an instance
               expr = x.ip_tcp_src_expr('192.168.1.50', 2049)

               # Call as a class
               expr = Pktt.ip_tcp_src_expr('192.168.1.50', 2049)

               # Returns "IP.src == '192.168.1.50' and TCP.src_port == 2049"
               # Expression ready for x.match()
               pkt = x.match(expr)
        """
        return "IP.src == '%s' and TCP.src_port == %d" % (ipaddr, port)

    @staticmethod
    def ip_tcp_dst_expr(ipaddr, port):
        """Return a match expression to find a packet going to ipaddr:port.

           Examples:
               # Call as an instance
               expr = x.ip_tcp_dst_expr('192.168.1.50', 2049)

               # Call as a class
               expr = Pktt.ip_tcp_dst_expr('192.168.1.50', 2049)

               # Returns "IP.dst == '192.168.1.50' and TCP.dst_port == 2049"
               # Expression ready for x.match()
               pkt = x.match(expr)
        """
        return "IP.dst == '%s' and TCP.dst_port == %d" % (ipaddr, port)

if __name__ == '__main__':
    # Self test of module
    l_escape = [
        "hello",
        "\x00\\test",
        "single'quote",
        'double"quote',
        'back`quote',
        'single\'double"quote',
        'double"single\'quote',
        'single\'double"back`quote',
        'double"single\'back`quote',
    ]
    ntests = 2*len(l_escape)

    tcount = 0
    for quote in ["'", '"']:
        for data in l_escape:
            expr = "data == %s%s%s" % (quote, Pktt.escape(data), quote)
            expr = re.sub(r'\\\\', r'\\', expr)
            if eval(expr):
                tcount += 1

    if tcount == ntests:
        print "All tests passed!"
        exit(0)
    else:
        print "%d tests failed" % (ntests-tcount)
        exit(1)
