/*
** Copyright (C) 2005-2012 by Carnegie Mellon University.
**
** @OPENSOURCE_HEADER_START@
**
** Use of the SILK system and related source code is subject to the terms
** of the following licenses:
**
** GNU Public License (GPL) Rights pursuant to Version 2, June 1991
** Government Purpose License Rights (GPLR) pursuant to DFARS 252.227.7013
**
** NO WARRANTY
**
** ANY INFORMATION, MATERIALS, SERVICES, INTELLECTUAL PROPERTY OR OTHER
** PROPERTY OR RIGHTS GRANTED OR PROVIDED BY CARNEGIE MELLON UNIVERSITY
** PURSUANT TO THIS LICENSE (HEREINAFTER THE "DELIVERABLES") ARE ON AN
** "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY
** KIND, EITHER EXPRESS OR IMPLIED AS TO ANY MATTER INCLUDING, BUT NOT
** LIMITED TO, WARRANTY OF FITNESS FOR A PARTICULAR PURPOSE,
** MERCHANTABILITY, INFORMATIONAL CONTENT, NONINFRINGEMENT, OR ERROR-FREE
** OPERATION. CARNEGIE MELLON UNIVERSITY SHALL NOT BE LIABLE FOR INDIRECT,
** SPECIAL OR CONSEQUENTIAL DAMAGES, SUCH AS LOSS OF PROFITS OR INABILITY
** TO USE SAID INTELLECTUAL PROPERTY, UNDER THIS LICENSE, REGARDLESS OF
** WHETHER SUCH PARTY WAS AWARE OF THE POSSIBILITY OF SUCH DAMAGES.
** LICENSEE AGREES THAT IT WILL NOT MAKE ANY WARRANTY ON BEHALF OF
** CARNEGIE MELLON UNIVERSITY, EXPRESS OR IMPLIED, TO ANY PERSON
** CONCERNING THE APPLICATION OF OR THE RESULTS TO BE OBTAINED WITH THE
** DELIVERABLES UNDER THIS LICENSE.
**
** Licensee hereby agrees to defend, indemnify, and hold harmless Carnegie
** Mellon University, its trustees, officers, employees, and agents from
** all claims or demands made against them (and any related losses,
** expenses, or attorney's fees) arising out of, or relating to Licensee's
** and/or its sub licensees' negligent use or willful misuse of or
** negligent conduct or willful misconduct regarding the Software,
** facilities, or other rights or assistance granted by Carnegie Mellon
** University under this License, including, but not limited to, any
** claims of product liability, personal injury, death, damage to
** property, or violation of any laws or regulations.
**
** Carnegie Mellon University Software Engineering Institute authored
** documents are sponsored by the U.S. Department of Defense under
** Contract FA8721-05-C-0003. Carnegie Mellon University retains
** copyrights in all material produced under this contract. The U.S.
** Government retains a non-exclusive, royalty-free license to publish or
** reproduce these documents, or allow others to do so, for U.S.
** Government purposes only pursuant to the copyright license under the
** contract clause at 252.227.7013.
**
** @OPENSOURCE_HEADER_END@
*/

/*
**  Functions for printing CIDR blocks with indentation noting the
**  level of the block.  Also does summation for smaller CIDR blocks.
**  Used by rwbagcat and rwsetcat.
**
*/


#include <silk/silk.h>

RCSIDENT("$SiLK: skprintnets.c 372a8bc31d8a 2012-02-10 21:55:28Z mthomas $");

#include <silk/skprintnets.h>
#include <silk/iptree.h>


#define  PRINTNET_DEFAULT_SUMMARY  "ABCXH"
#define  PRINTNET_DEFAULT_INPUT    "TS/" PRINTNET_DEFAULT_SUMMARY

#define IP_TO_STRING(ns, ip, cidr)                                      \
    switch ((ns)->ip_format) {                                          \
      case SKIPADDR_CANONICAL:                                          \
        if ((cidr) == 32) {                                             \
            num2dot_r((ip), (ns)->ip_buf);                              \
        } else {                                                        \
            snprintf((ns)->ip_buf, sizeof((ns)->ip_buf), "%s/%d",       \
                     num2dot(ip), (cidr));                              \
        }                                                               \
        break;                                                          \
                                                                        \
      case SKIPADDR_ZEROPAD:                                            \
        if ((cidr) == 32) {                                             \
            num2dot0_r((ip), (ns)->ip_buf);                             \
        } else {                                                        \
            snprintf((ns)->ip_buf, sizeof((ns)->ip_buf), "%s/%d",       \
                     num2dot0(ip), (cidr));                             \
        }                                                               \
        break;                                                          \
                                                                        \
      case SKIPADDR_DECIMAL:                                            \
        if ((cidr) == 32) {                                             \
            snprintf((ns)->ip_buf, sizeof((ns)->ip_buf), "%" PRIu32,    \
                     (ip));                                             \
        } else {                                                        \
            snprintf((ns)->ip_buf, sizeof((ns)->ip_buf), "%" PRIu32"/%d", \
                     (ip), (cidr));                                     \
        }                                                               \
        break;                                                          \
                                                                        \
      default:                                                          \
        skAbortBadCase((ns)->ip_format);                                \
    }


#define PLURAL(x) (((x) == 1) ? "" : "s")

#define NET_TOTAL_TITLE "TOTAL"

typedef struct net_struct_cidr_st {
    /* The sum of the counters. */
    uint64_t    cb_sum;

    /* number of bits */
    int         cb_bits;

    /* mask that passes the most significant 'cb_bits' bits */
    uint32_t    cb_mask;

    /* whether to output the data for this CIDR block. */
    int         cb_print;

    /* the number of spaces by* which to indent this CIDR block. */
    int         cb_indent;

    /* the number of characters to allow for the CIDR block.
     * cb_indent+cb_width should be identical for all CIDR blocks. */
    int         cb_width;

    /* cblock[i].cb_ips[j], where i=NET_A..i-1, is the number of
     * smaller CIDR blocks 'j' seen in the CIDR block 'i'. */
    uint64_t   *cb_ips;
} net_struct_cidr_t;


struct netStruct_st {
    /* output stream */
    skstream_t         *outstrm;

    net_struct_cidr_t  *cblock;

    /* previous key */
    uint32_t            prev_ip;

    /* the entry in the 'cblock[]' array where the totals are is. this
     * is one less than the total number of entries in cblock[]. */
    int                 total_level;

    /* the width of the 'count' column */
    int                 count_width;

    /* how to print the IP address. Values from "enum skipaddr_flags_t" */
    int                 ip_format;

    /* a buffer for IPs as strings. Must allow extra space for CIDR. */
    char                ip_buf[2*SK_NUM2DOT_STRLEN];

    /* delimiter to print between columns */
    char                delimiter;

    /* delimiter or empty string to go between IP and Count */
    char                ip_count_delim[2];

    /* delimiter or empty string to go between Count and EOL */
    char                count_eol_delim[2];

    /* whether the blocks to print have been initialized */
    unsigned            parsed_input        :1;

    /* whether this entry is the first entry to be printed */
    unsigned            first_entry         :1;

    /* whether this entry is final entry to be printed */
    unsigned            final_entry         :1;

    /* whether to suppress fixed width columnar output */
    unsigned            no_columns          :1;

    /* whether to suppress the final delimiter */
    unsigned            no_final_delimiter  :1;

    /* whether to print the summary */
    unsigned            print_summary       :1;

    /* whether the caller will be passing a valid 'count' value. */
    unsigned            use_count           :1;

    /* whether to print the number of IPs */
    unsigned            print_ip_count      :1;
};


/* LOCAL FUNCTION PROTOTYPES */

static void netStructureInitialize(netStruct_t *ns, int has_count);
static void netStructurePreparePrint(netStruct_t *ns);


/* FUNCTION DEFINITIONS */

/*
 * print an individual key inside the network-structure
 */
void netStructurePrintIP(
    uint32_t        ip,
    uint64_t       *count,
    netStruct_t    *ns)
{
    uint32_t xor_ips;
    int first_change = -1;
    int i, j;

    /* determine whether a given octet has changed */
    if (ns->first_entry) {
        /* first entry is considered a change for all octets, but
         * don't print close blocks since nothing is open yet. */
        first_change = ns->total_level;
        ns->first_entry = 0;
        netStructurePreparePrint(ns);
    } else {

        if (ns->final_entry) {
            /* for the last entry, we need to close all the blocks and
             * print all summaries. */
            first_change = ns->total_level;
        } else {
            /* determine into which block the most significant bit to
             * change from previous value falls. */
            xor_ips = ip ^ ns->prev_ip;

            /* if nothing matches, first_change will be set to the host */
            for (first_change = ns->total_level - 1;
                 first_change > 0;
                 --first_change)
            {
                if (xor_ips & ns->cblock[first_change].cb_mask) {
                    break;
                }
            }
        }

        /*
         * if counting and summarizing, calculate changes in blocks
         * and close out blocks.
         */
        for (i = 0; i <= first_change; ++i) {
            /* only print if requested */
            if ( !ns->cblock[i].cb_print) {
                continue;
            }

            /* Convert IP and CIDR block to string, or use NET_TOTAL_TITLE */
            if (ns->total_level == i) {
                strncpy(ns->ip_buf, NET_TOTAL_TITLE, sizeof(ns->ip_buf));
            } else {
                IP_TO_STRING(ns, ns->prev_ip & ns->cblock[i].cb_mask,
                             ns->cblock[i].cb_bits);
            }

            if (ns->use_count) {
                skStreamPrint(ns->outstrm, ("%*s%*s%s%*" PRIu64 "%s"),
                              ns->cblock[i].cb_indent, "",
                              ns->cblock[i].cb_width, ns->ip_buf,
                              ns->ip_count_delim,
                              ns->count_width, ns->cblock[i].cb_sum,
                              ns->count_eol_delim);
            } else {
                skStreamPrint(ns->outstrm, "%*s%*s%s",
                              ns->cblock[i].cb_indent, "",
                              ns->cblock[i].cb_width, ns->ip_buf,
                              ns->ip_count_delim);
            }

            if (0 == i) {
                /* this is the host, so do nothing else */
            } else if (ns->print_summary) {
                const char *join_strings[] = {" in", ",", " and", ", and"};
                const char *joiner = NULL;

                skStreamPrint(ns->outstrm, (" %" PRIu64 " host%s"),
                              ns->cblock[i].cb_ips[0],
                              PLURAL(ns->cblock[i].cb_ips[0]));

                for (j = i-1; j > 0; --j) {
                    /* determine what text to use between counts */
                    if (NULL == joiner) {
                        joiner = join_strings[0];
                    } else if (j > 1) {
                        joiner = join_strings[1];
                    } else if (join_strings[0] == joiner) {
                        joiner = join_strings[2];
                    } else {
                        joiner = join_strings[3];
                    }

                    skStreamPrint(ns->outstrm, ("%s %" PRIu64 " /%d%s"),
                                  joiner,
                                  ns->cblock[i].cb_ips[j],
                                  ns->cblock[j].cb_bits,
                                  PLURAL(ns->cblock[i].cb_ips[j]));
                }
            } else if (ns->print_ip_count) {
                skStreamPrint(ns->outstrm, (" %" PRIu64),
                              ns->cblock[i].cb_ips[0]);
            }

            skStreamPrint(ns->outstrm, "\n");
        }
    } /* if !first_entry */

    /*
     * Now that we've closed the footers, if we are at the end of
     * the data set, we can quit.
     */
    if (ns->final_entry) {
        return;
    }

    /* store this IP */
    ns->prev_ip = ip;

    /* Reset the IP count for all blocks that are smaller than the one
     * where the change was seen */
    for (i = 1; i <= first_change; ++i) {
        for (j = 0; j < i; ++j) {
            ns->cblock[i].cb_ips[j] = 1;
        }
    }
    /* Increment the counts in the larger blocks */
    for ( ; i <= ns->total_level; ++i) {
        for (j = 0; j <= first_change; ++j) {
            ++ns->cblock[i].cb_ips[j];
        }
    }

    /* If the caller is providing counts, sum/reset those */
    if (ns->use_count) {
        for (i = 0; i <= first_change; ++i) {
            ns->cblock[i].cb_sum = *count;
        }
        for ( ; i <= ns->total_level; ++i) {
            ns->cblock[i].cb_sum += *count;
        }
    }
}


int netStructureCreate(
    netStruct_t   **ns,
    int             has_count)
{
    /* Current code is single threaded non-reentrant; so use static
     * variable */
    static netStruct_t static_ns;

    *ns = &static_ns;

    netStructureInitialize(*ns, has_count);
    return 0;
}


void netStructureDestroy(netStruct_t **ns_ptr)
{
    netStruct_t *ns;
    int i;

    if (!ns_ptr || !*ns_ptr) {
        return;
    }

    ns = *ns_ptr;
    *ns_ptr = NULL;

    if (ns->cblock) {
        for (i = 0; i <= ns->total_level; ++i) {
            if (ns->cblock[i].cb_ips) {
                free(ns->cblock[i].cb_ips);
                ns->cblock[i].cb_ips = NULL;
            }
        }
        free(ns->cblock);
        ns->cblock = NULL;
    }
}


static void netStructureInitialize(
    netStruct_t    *ns,
    int             has_count)
{
    assert(ns);
    memset(ns, 0, sizeof(netStruct_t));
    ns->first_entry = 1;
    ns->use_count = (has_count ? 1 : 0);
    ns->ip_format = SKIPADDR_CANONICAL;
    ns->delimiter = '|';
    ns->count_width = 15;
}


void netStructurePrintFinalize(netStruct_t *ns)
{
    ns->final_entry = 1;
    if (ns->first_entry) {
        /* assume no data was printed */
        return;
    }
    netStructurePrintIP(0, 0, ns);
}


void netStructurePrintStatistics(netStruct_t *ns)
{
    if ( !ns->final_entry) {
        skAppPrintErr("Must call netStructurePrintFinalize"
                      " before netStructurePrintStatistics");
        return;
    }
}


void netStructureSetCountWidth(
    netStruct_t             *ns,
    int                      width)
{
    assert(ns);
    ns->count_width = width;
}


void netStructureSetNoColumns(netStruct_t *ns)
{
    assert(ns);
    ns->no_columns = 1;
}


void netStructureSetNoFinalDelimiter(netStruct_t *ns)
{
    assert(ns);
    ns->no_final_delimiter = 1;
}


void netStructureSetOutputStream(
    netStruct_t    *ns,
    skstream_t     *stream)
{
    assert(ns);
    assert(stream);
    ns->outstrm = stream;
}


void netStructureSetIpFormat(
    netStruct_t            *ns,
    skipaddr_flags_t        format)
{
    assert(ns);
    ns->ip_format = format;
}


void netStructureSetDelimiter(
    netStruct_t    *ns,
    char            delimiter)
{
    assert(ns);
    ns->delimiter = delimiter;
}


int netStructureParse(
    netStruct_t    *ns,
    const char     *input)
{
    int block[129];
    const char *cp;
    uint32_t val;
    int num_levels = 0;
    int print_levels = 0;
    int rv;
    int i, j;

    ns->parsed_input = 1;

    /* Clear printing */
    memset(block, 0, sizeof(block));

    /* If input is NULL, use the default. */
    if (NULL == input) {
        cp = PRINTNET_DEFAULT_INPUT;
    } else {
        cp = input;
    }

    /* must have a host and total level */
    block[0] = 2;
    block[32] = 2;

    /* loop twice, once to parse the values before the '/' and the
     * second time to parse those after */
    for (i = 1; i <= 2; ++i) {
        while (*cp && *cp != '/') {
            switch (*cp) {
              case ',':
                break;
              case 'S':
                ns->print_summary = 1;
                break;
              case 'T':
                block[0] |= i;
                break;
              case 'A':
                block[8] |= i;
                break;
              case 'B':
                block[16] |= i;
                break;
              case 'C':
                block[24] |= i;
                break;
              case 'X':
                block[27] |= i;
                break;
              case 'H':
                block[32] |= i;
                break;
              default:
                if (isspace((int)*cp)) {
                    break;
                }
                if (!isdigit((int)*cp)) {
                    skAppPrintErr("Invalid network-structure character '%c'",
                                  *cp);
                    return 1;
                }
                rv = skStringParseUint32(&val, cp, 1, 32);
                if (rv == 0) {
                    /* parsed to end of string; move to final char */
                    cp += strlen(cp) - 1;
                } else if (rv > 0) {
                    /* parsed a value, move to last char of the value */
                    cp += rv - 1;
                } else {
                    skAppPrintErr("Invalid network-structure '%s': %s",
                                  input, skStringParseStrerror(rv));
                    return 1;
                }
                block[val] |= i;
                break;
            }

            ++cp;
        }

        if ('/' == *cp) {
            ns->print_summary = 1;
            ++cp;
            if (2 == i) {
                /* The '/' character appears twice */
                skAppPrintErr(("Invalid network-structure '%s':"
                               " Only one '/' is allowed"),
                              input);
                return 1;
            }
        } else {
            if (1 == i) {
                /* No summary definition provided. use default */
                cp = PRINTNET_DEFAULT_SUMMARY;
            }
        }
    }


    for (i = 0; i <= 32; ++i) {
        if (block[i]) {
            ++num_levels;
            if (block[i] & 1) {
                ++print_levels;
            }
        }
    }

    /* Make certain we have something other than just 'S' */
    if (print_levels == 0) {
        skAppPrintErr("Network structure must include one of TABCXH");
        return 1;
    }

    ns->cblock = calloc(num_levels, sizeof(net_struct_cidr_t));
    if (NULL == ns->cblock) {
        return 1;
    }
    for (i = 1; i < num_levels; ++i) {
        ns->cblock[i].cb_ips = calloc(i, sizeof(uint64_t));
        if (NULL == ns->cblock[i].cb_ips) {
            return 1;
        }
    }

    ns->total_level = num_levels - 1;

    j = 0;
    for (i = 32; i >= 0; --i) {
        if (block[i]) {
            if (block[i] & 1) {
                ns->cblock[j].cb_print = 1;
            }
            ns->cblock[j].cb_bits = i;
            ns->cblock[j].cb_mask = ((32 == i)
                                     ? UINT32_MAX
                                     : ~(UINT32_MAX >> i));
            ++j;
        }
    }

    if (!ns->print_summary && !ns->use_count) {
        /* Without summary nor counts, print the number of IPs seen in
         * the block (otherwise, net structure serves little
         * purpose. */
        ns->print_ip_count = 1;
    }

    return 0;
}


static void netStructurePreparePrint(netStruct_t *ns)
{
#define INDENT_LEVEL 2
    int first_level = -1;
    int last_level = 256;
    int indent = 0;
    int width;
    int justify = -1; /* -1 for left-justified IPs; 1 for right */
    int i;

    assert(ns);

    /* initialize the blocks */
    if (!ns->parsed_input) {
        netStructureParse(ns, NULL);
    }

    /* open output stream */
    if (ns->outstrm == NULL) {
        int rv;
        if ((rv = skStreamCreate(&ns->outstrm, SK_IO_WRITE, SK_CONTENT_TEXT))
            || (rv = skStreamBind(ns->outstrm, "stdout"))
            || (rv = skStreamOpen(ns->outstrm)))
        {
            skStreamPrintLastErr(ns->outstrm, rv, &skAppPrintErr);
            skStreamDestroy(&ns->outstrm);
            return;
        }
    }

    /* the delimiter between the IP and count, or between IP and
     * summary when handling sets */
    ns->ip_count_delim[0] = ns->delimiter;
    ns->ip_count_delim[1] = '\0';

    /* the delimiter between the count and summary or count and
     * end-of-line when no summary */
    ns->count_eol_delim[0] = ns->delimiter;
    ns->count_eol_delim[1] = '\0';

    /* Compute indentation for each level. */
    for (i = ns->total_level; i >= 0; --i) {
        ns->cblock[i].cb_indent = indent;
        if (ns->cblock[i].cb_print) {
            last_level = i;
            if (first_level == -1) {
                first_level = i;
                indent += INDENT_LEVEL;
                continue;
            }
        }
        if (last_level < ns->total_level) {
            /* Once we have one thing indented, indent the remaining
             * levels by the offset, even if they are not printed. */
            indent += INDENT_LEVEL;
        }
    }

    if ((first_level == last_level) && !ns->use_count
        && !ns->print_ip_count && !ns->print_summary)
    {
        /* If there is no 'count' column and we are not printing
         * the summary---i.e., print IPs only---do no
         * formatting and disable the ip_count_delim. */
        ns->cblock[0].cb_width = 0;
        ns->ip_count_delim[0] = '\0';
        return;
    }

    /* if no summary is requested, modify the delimiter between the
     * count and the end-of-line if no_final_delimiter is set */
    if (ns->no_final_delimiter && !ns->print_summary) {
        ns->count_eol_delim[0] = '\0';
    }

    if (ns->no_columns) {
        /* If fixed-width output is not requested, set all widths and
         * indents to 0 and return. */
        for (i = 0; i <= ns->total_level; ++i) {
            ns->cblock[i].cb_indent = 0;
            ns->cblock[i].cb_width = 0;
        }
        ns->count_width = 0;
        return;
    }

    if (ns->total_level == last_level) {
        /* We are printing the total only.  Set that width and
         * return. */
        ns->cblock[ns->total_level].cb_width = strlen(NET_TOTAL_TITLE);
        return;
    }

    /* Width will be at least the size of the indenation, but don't
     * include trailing levels that aren't printed. */
    width = indent - (INDENT_LEVEL * (1 + last_level));

    /* Allow space for the IP address. */
    if (SKIPADDR_DECIMAL == ns->ip_format) {
        width += 10;
    } else {
        width += 15;
    }

    /* Allow space for the CIDR block */
    if (last_level == 0) {
        /* Since the host IP does not include the CIDR block, it may
         * be more narrow that the next larger block.  Account for
         * that. */
        if (ns->cblock[1].cb_print && (INDENT_LEVEL < 3)) {
            width += (3 - INDENT_LEVEL);
        }
    } else {
        if (ns->cblock[last_level].cb_bits < 10) {
            /* Allow for something like "/8" */
            width += 2;
        } else if (ns->cblock[last_level].cb_bits < 100) {
            /* Allow for something like "/24" */
            width += 3;
        } else {
            /* Allow for something like "/120" */
            width += 4;
        }
    }

    /* When doing only one level, right justify the keys */
    if (first_level == last_level) {
        justify = 1;
    }

    /* Set the widths for every level */
    for (i = 0; i <= ns->total_level; ++i) {
        ns->cblock[i].cb_width = justify * (width - ns->cblock[i].cb_indent);
    }
}


/*
** Local Variables:
** mode:c
** indent-tabs-mode:nil
** c-basic-offset:4
** End:
*/
