#!/usr/bin/env python

import sys
import struct
import socket

def num_to_dotted_quad(network):
    """Convert an unsigned integer into a 'dotted quad' string.

    NUM -> '192.168.1.1'
    The number must be given in host byte order.
    """
    return socket.inet_ntoa(struct.pack('>L', network))

# This will _not_ make sure the netmask is valid.
def num_to_bitcount(netmask):
    """Count the number of bits set in a netmask number.

    The number must be given in host byte order.
    No validation is done to verify that the given value is a real
    netmask, use id_valid_netmask() for that.
    """
    bits = 0
    for n in range(32):
        if ((netmask >> n) & 1) == 1:
            bits += 1
    return bits

def dotted_quad_to_num(network):
    """Convert a 'dotted quad' string to an unsigned integer.

    '192.168.1.1' -> NUM
    The number is returned in host byte order.
    """
    try:
        return long(struct.unpack('>L',socket.inet_aton(network))[0])
    except socket.error, e:
        raise errors.SiptrackError('%s' % e)

def bitcount_to_num(netmask):
    """Return an unsigned integer with 'netmask' bits set.

    ie. convert a '/24' netmask count to an integer.
    The returned value is in host byte order.
    """
    res = 0L
    for n in range(netmask):
        res |= 1<<31 - n
    return res

def dotted_quad_cidr_to_num(network):
    """Convert the network string a.b.c.d/nn to network, netmask integers.

    Returns a tuple of the form (address-NUM, netmask-NUM) or (None, None)
    on error.
    """
    try:
        network, netmask = network.split('/')
    except ValueError:
        return (None, None)

    try:
        network = dotted_quad_to_num(network)
    except errors.SiptrackError:
        return (None, None)

    try:
        netmask = int(netmask)
    except ValueError:
        return (None, None)
    if netmask < 0 or netmask > 32:
        return (None, None)

    netmask = bitcount_to_num(netmask)
    network = network & netmask

    return (network, netmask)

class Address(object):
    def __init__(self, address, netmask, mask = True, validate = True):
        self.address = address
        self.netmask = netmask
        self._calcAddrData()

        if mask:
            self.address = self.network

        if validate:
            if not self._isValidNetmask(netmask):
                raise ValueError('invalid netmask')

    def clone(self):
        return Address(self.address, self.netmask, mask = False,
                validate = False)

    def _calcAddrData(self):
        self.network = self.address & self.netmask
        self.start = self.network
        self.broadcast = self.network + (0xffffffff - self.netmask)
        self.end = self.broadcast

    def __repr__(self):
        return '<IPV4.Address(%s, %s)>' % (self.address, self.netmask)

    def __str__(self):
        return self.printableCIDR()

    def __lt__(self, other):
        """True if the current address is a subnet of 'other'."""
        if self.start >= other.start and self.end <= other.end:
            if self.start > other.start or self.end < other.end:
                return True
        return False

    def __le__(self, other):
        """True if the current address is a subnet of, or equal to, 'other'."""
        if self.start >= other.start and self.end <= other.end:
            return True
        return False

    def __eq__(self, other):
        """True if the addresses are identical."""
        if self.start == other.start and self.end == other.end:
            return True
        return False
    
    def __ne__(self, other):
        """True if the address are not identical."""
        if self.start != other.start or self.end != other.end:
            return True
        return False
    
    def __gt__(self, other):
        """True if the current address is a supernet of 'other'."""
        if other.start >= self.start and other.end <= self.end:
            if other.start > self.start or other.end < self.end:
                return True
        return False
    
    def __ge__(self, other):
        """True if the current address is a supernet of, or equal to, 'other'."""
        if other.start >= self.start and other.end <= self.end:
            return True
        return False

    def _isValidNetmask(self, netmask):
        foundzero = False
        for n in range(32):
            pos = 31 - n
            val = (netmask >> pos) & 1
            if val == 0:
                foundzero = True
            if val == 1 and foundzero is True:
                return False
        return True

    def inc(self, step = 1):
        addr = self.clone()
        addr.address += step
        addr._calcAddrData()
        return addr

    def dec(self, step = 1):
        addr = self.clone()
        addr.address -= step
        addr._calcAddrData()
        return addr

    def isHigher(self, other):
        if self.address > other.address:
            return True
        return False

    def printableCIDR(self):
        return '%s/%s' % (self.numToDottedQuad(self.address),
                          self.numToBitcount(self.netmask))
    printable = printableCIDR
    
    def printableNonCIDR(self):
        return '%s %s' % (self.numToDottedQuad(self.address),
                          self.numToDottedQuad(self.netmask))

    def numToDottedQuad(self, network):
        """Convert an unsigned integer into a 'dotted quad' string.

        NUM -> '192.168.1.1'
        The number must be given in host byte order.
        """
        return socket.inet_ntoa(struct.pack('>L', network))

    def numToBitcount(self, netmask):
        """Count the number of bits set in a netmask number.
    
        The number must be given in host byte order.
        No validation is done to verify that the given value is a real
        netmask.
        """
        bits = 0
        for n in range(32):
            if ((netmask >> n) & 1) == 1:
                bits += 1
        return bits

def get_ranges(total, exclude):
    include = [(total.start, total.end)]
    start = total.start

    for excl in exclude:
        new_include = []
        for start, end in include:
            if excl.end < start:
                new_include.append((start, end))
                continue
            if excl.start > end:
                new_include.append((start, end))
                continue
            if excl.start <= start and excl.end >= end:
                continue
            if excl.start > start:
                new_include.append((start, excl.start - 1))
            if excl.end < end:
                new_include.append((excl.end + 1, end))
        include = new_include
    return include

def fits_in_range(start, end, size, mask):
    pos = start
    while pos <= end:
        t_start = pos & mask
        t_end = t_start + (0xffffffff - mask)
        if t_start >= start and t_end <= end:
            return True
        pos += size

    return False

def snarf(start, end, size, mask):
    pos = start
    while pos <= end:
        t_start = pos & mask
        t_end = t_start + (0xffffffff - mask)
        if t_start >= start and t_end <= end:
            break
        pos += size

    ret = {}
    ret['snarfed_start'] = t_start
    ret['snarfed_mask'] = mask

    if start < t_start:
        ret['pre_start'] = start
        ret['pre_end'] = t_start - 1

    if t_end < end:
        ret['post_start'] = t_end + 1
        ret['post_end'] = end

    return ret

def get_cidr_ranges(input_ranges):
    cidrsizes = {}
    for n in range(33):
        cidrsizes[n] = ((1 << 32 - n), bitcount_to_num(n))

    for start, end in input_ranges:
        results = []
        ranges = [(start, end)]

        while len(ranges) > 0:
            start, end = ranges[0]
            size = end - start + 1
            for csize in cidrsizes:
                if size >= cidrsizes[csize][0]:
                    if not fits_in_range(start, end, cidrsizes[csize][0],
                                         cidrsizes[csize][1]):
                        continue
                    snarf_res = snarf(start, end, cidrsizes[csize][0],
                                      cidrsizes[csize][1])
                    results.append((snarf_res['snarfed_start'],
                                    snarf_res['snarfed_mask']))

                    ranges.pop(0)
                    if 'pre_start' in snarf_res:
                        ranges.append((snarf_res['pre_start'], snarf_res['pre_end']))
                    if 'post_start' in snarf_res:
                        ranges.append((snarf_res['post_start'], snarf_res['post_end']))

                    break

                if len(ranges) == 0:
                    break

        results.sort()
        for start, mask in results:
            print '%s/%d' % (num_to_dotted_quad(start), num_to_bitcount(mask))

def main():
    if len(sys.argv) < 3:
        print 'Usage: total-network exclude-range [exclude-range ...]'
        sys.exit(1)

    network, netmask = dotted_quad_cidr_to_num(sys.argv[1])
    total = Address(network, netmask)
    exclude = []
    for net_str in sys.argv[2:]:
        network, netmask = dotted_quad_cidr_to_num(net_str)
        exclude.append(Address(network, netmask))
    ranges = get_ranges(total, exclude)
    get_cidr_ranges(ranges)

if __name__ == '__main__':
    main()

