#!/usr/bin/python
#
# Copyright (c) 2011 Citrix Systems, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation; version 2.1 only. with the special
# exception on linking described in file LICENSE.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#

import fcntl
import os
import os.path
import string
import subprocess
import sys
import syslog
import time

command_name = sys.argv[0]
short_command_name = os.path.basename(command_name)

ebtables = "/sbin/ebtables"
ebtables_lock_path = "/var/lock/ebtables.lock"

vsctl = "/usr/bin/ovs-vsctl"
ofctl = "/usr/bin/ovs-ofctl"

ip = "/sbin/ip"
xenstore_read_cmd = "/usr/bin/xenstore-read"

host_network_config = "/etc/xensource/network.conf"

def get_host_network_mode():
    f = open(host_network_config, "r")
    try:
        return f.readline().strip()
    finally:
        f.close()

def send_to_syslog(msg):
    pid = os.getpid()
    syslog.syslog("%s[%d] - %s" %(command_name, pid, msg))

def doexec(args):
    """Execute a subprocess, then return its return code, stdout and stderr"""
    send_to_syslog(args)
    proc = subprocess.Popen(args,stdin=None,stdout=subprocess.PIPE,stderr=subprocess.PIPE,close_fds=True)
    rc = proc.wait()
    stdout = proc.stdout
    stderr = proc.stderr
    return (rc, stdout, stderr)

def xenstore_read(path):
    (rc, stdout, stderr) = doexec([xenstore_read_cmd, path])
    return stdout.readline().strip()

def get_words(value, separator):
    if string.strip(value) == "":
        return []
    else:
        return string.split(value, separator)

def get_locking_config(domid, devid):
    mac = xenstore_read("/local/domain/0/backend/vif/%s/%s/mac" % (domid, devid))
    private_path = "/xapi/%s/private/vif/%s" % (domid, devid)
    locking_mode = xenstore_read("/".join([private_path, "locking-mode"]))
    ipv4_allowed = xenstore_read("/".join([private_path, "ipv4-allowed"]))
    ipv6_allowed = xenstore_read("/".join([private_path, "ipv6-allowed"]))
    send_to_syslog("Got locking config: MAC=%s; locking_mode=%s; ipv4_allowed=%s; ipv6_allowed=%s"
                    % (mac, locking_mode, ipv4_allowed, ipv6_allowed))
    return {
        "mac": mac,
        "locking_mode": locking_mode,
        "ipv4_allowed": get_words(ipv4_allowed, ","),
        "ipv6_allowed": get_words(ipv6_allowed, ",")
    }

def ip_link_set(device, direction):
    doexec([ip, "link", "set", device, direction])

###############################################################################
# Creation of ebtables rules in the case of bridge.
###############################################################################

def get_chain_name(vif_name):
    return ("FORWARD_%s" % vif_name)

def do_chain_action(executable, action, chain_name, args=[]):
    return doexec([executable, action, chain_name] + args)

def chain_exists(executable, chain_name):
    rc, _, _ = do_chain_action(executable, "-L", chain_name)
    return (rc == 0)

def clear_bridge_rules(vif_name):
    vif_chain = get_chain_name(vif_name)
    if chain_exists(ebtables, vif_chain):
        # Stop forwarding to this VIF's chain.
        do_chain_action(ebtables, "-D", "FORWARD", ["-i", vif_name, "-j", vif_chain])
        do_chain_action(ebtables, "-D", "FORWARD", ["-o", vif_name, "-j", vif_chain])
        # Flush and delete the VIF's chain.
        do_chain_action(ebtables, "-F", vif_chain)
        do_chain_action(ebtables, "-X", vif_chain)

def create_bridge_rules(vif_name, config):
    vif_chain = get_chain_name(vif_name)
    # Forward all traffic on this VIF to a new chain, with default policy DROP.
    do_chain_action(ebtables, "-N", vif_chain)
    do_chain_action(ebtables, "-A", "FORWARD", ["-i", vif_name, "-j", vif_chain])
    do_chain_action(ebtables, "-A", "FORWARD", ["-o", vif_name, "-j", vif_chain])
    do_chain_action(ebtables, "-P", vif_chain, ["DROP"])
    # We now need to create rules to allow valid traffic.
    mac = config["mac"]
    # We only fully support IPv4 multitenancy with bridge.
    # We allow all IPv6 traffic if any IPv6 addresses are associated with the VIF.
    ipv4_allowed = config["ipv4_allowed"]
    ipv6_allowed = config["ipv6_allowed"]
    # Accept all traffic going to the VM.
    do_chain_action(ebtables, "-A", vif_chain, ["-o", vif_name, "-j", "ACCEPT"])
    # Drop everything not coming from the correct MAC.
    do_chain_action(ebtables, "-A", vif_chain, ["-s", "!", mac, "-i", vif_name, "-j", "DROP"])
    # Accept DHCP.
    do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv4", "-i", vif_name, "--ip-protocol", "UDP", "--ip-dport", "67", "-j", "ACCEPT"])
    for ipv4 in ipv4_allowed:
        # Accept ARP travelling from known IP addresses, also filtering ARP replies by MAC.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "ARP", "-i", vif_name, "--arp-opcode", "Request", "--arp-ip-src", ipv4, "-j", "ACCEPT"])
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "ARP", "-i", vif_name, "--arp-opcode", "Reply", "--arp-ip-src", ipv4, "--arp-mac-src", mac, "-j", "ACCEPT"])
        # Accept IP travelling from known IP addresses.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv4", "-i", vif_name, "--ip-src", ipv4, "-j", "ACCEPT"])
    if ipv6_allowed != []:
        # Accept all IPv6 traffic.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv6", "-i", vif_name, "-j", "ACCEPT"])

def acquire_lock(path):
    lock_file = open(path, 'w')
    while True:
        send_to_syslog("attempting to acquire lock %s" % path)
        try:
            fcntl.lockf(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB)
            send_to_syslog("acquired lock %s" % path)
            return lock_file
        except IOError, e:
            send_to_syslog("failed to acquire lock %s because '%s' - waiting 1 second" % (path, e))
            time.sleep(1)

def handle_bridge(vif_type, domid, devid, action):
    if (action == "clear") or (action == "filter"):
        # ebtables fails if called concurrently, so acquire a lock before starting to call it.
        ebtables_lock_file = acquire_lock(ebtables_lock_path)
        vif_name = "%s%s.%s" % (vif_type, domid, devid)
        ip_link_set(vif_name, "down")
        clear_bridge_rules(vif_name)
        if action == "filter":
            config = get_locking_config(domid, devid)
            locking_mode = config["locking_mode"]
            if locking_mode == "locked":
                create_bridge_rules(vif_name, config)
            if locking_mode in ["locked", "unlocked"]:
                ip_link_set(vif_name, "up")
        if action == "clear":
            ip_link_set(vif_name, "up")

###############################################################################
# Creation of openflow rules in the case of openvswitch.
###############################################################################

def get_vswitch_port(vif_name, wait_for_valid):
    args = ["get", "interface", vif_name, "ofport"]
    if wait_for_valid:
        cmd = [vsctl, "wait-until", "interface", vif_name, "ofport!=-1", "--"] + args
    else:
        cmd = [vsctl] + args
    (rc, stdout, stderr) = doexec(cmd)
    return stdout.readline().strip()

def clear_vswitch_rules(bridge_name, port):
    doexec([ofctl, "del-flows", bridge_name, "in_port=%s" % port])

def add_flow(bridge_name, args):
    doexec([ofctl, "add-flow", bridge_name, args])

def create_vswitch_rules(bridge_name, port, config):
    mac = config["mac"]
    ipv4_allowed = config["ipv4_allowed"]
    ipv6_allowed = config["ipv6_allowed"]
    # Allow DHCP traffic (outgoing UDP on port 67).
    add_flow(bridge_name, "in_port=%s,priority=8000,dl_type=0x0800,nw_proto=0x11,"
                          "tp_dst=67,dl_src=%s,idle_timeout=0,action=normal" % (port, mac))
    # Filter ARP requests.
    add_flow(bridge_name, "in_port=%s,priority=7000,dl_type=0x0806,dl_src=%s,arp_sha=%s,"
                          "nw_src=0.0.0.0,idle_timeout=0,action=normal" % (port, mac, mac))
    for ipv4 in ipv4_allowed:
        # Filter ARP responses.
        add_flow(bridge_name, "in_port=%s,priority=7000,dl_type=0x0806,dl_src=%s,arp_sha=%s,"
                              "nw_src=%s,idle_timeout=0,action=normal" % (port, mac, mac, ipv4))
        # Allow traffic from specified ipv4 addresses.
        add_flow(bridge_name, "in_port=%s,priority=6000,dl_type=0x0800,nw_src=%s,"
                              "dl_src=%s,idle_timeout=0,action=normal" % (port, ipv4, mac))
    for ipv6 in ipv6_allowed:
        # Neighbour solicitation.
        add_flow(bridge_name, "in_port=%s,priority=8000,dl_src=%s,icmp6,ipv6_src=%s,"
                              "icmp_type=135,nd_sll=%s,idle_timeout=0,action=normal" % (port, mac, ipv6, mac))
        # Neighbour advertisement.
        add_flow(bridge_name, "in_port=%s,priority=8000,dl_src=%s,icmp6,ipv6_src=%s,"
                              "icmp_type=136,nd_target=%s,idle_timeout=0,action=normal" % (port, mac, ipv6, ipv6))
        # Allow traffic from specified ipv6 addresses.
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,icmp6,action=normal" % (port, mac, ipv6))
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,tcp6,action=normal" % (port, mac, ipv6))
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,udp6,action=normal" % (port, mac, ipv6))
    # Drop all other neighbour discovery.
    add_flow(bridge_name, "in_port=%s,priority=7000,icmp6,icmp_type=135,action=drop" % port)
    add_flow(bridge_name, "in_port=%s,priority=7000,icmp6,icmp_type=136,action=drop" % port)
    # Drop other specific ICMPv6 types.
    # Router advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=134,action=drop" % port)
    # Redirect gateway.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=137,action=drop" % port)
    # Mobile prefix solicitation.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=146,action=drop" % port)
    # Mobile prefix advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=147,action=drop" % port)
    # Multicast router advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=151,action=drop" % port)
    # Multicast router solicitation.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=152,action=drop" % port)
    # Multicast router termination.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=153,action=drop" % port)
    # Drop everything else.
    add_flow(bridge_name, "in_port=%s,priority=4000,idle_timeout=0,action=drop" % port)

def get_bridge_name_vswitch(vif_name):
    '''return bridge vif belong to'''
    (rc, stdout, stderr) = doexec([vsctl, "iface-to-br",  vif_name ])
    temp_bridge_name = stdout.readline().strip() 
    '''get bridge parent, in case we were given a fake bridge device'''
    '''will return same name if it is already a real bridge'''
    (rc, stdout, stderr) = doexec([vsctl, "br-to-parent", temp_bridge_name ])
    return stdout.readline().strip()

def handle_vswitch(vif_type, domid, devid, action):
    if (action == "clear") or (action == "filter"):
        vif_name = "%s%s.%s" % (vif_type, domid, devid)
        bridge_name = get_bridge_name_vswitch(vif_name) 
        ip_link_set(vif_name, "down")
        port = get_vswitch_port(vif_name, action == "filter")
        if port != -1:
            clear_vswitch_rules(bridge_name, port)
        if action == "filter":
            config = get_locking_config(domid, devid)
            locking_mode = config["locking_mode"]
            if locking_mode == "locked":
                create_vswitch_rules(bridge_name, port, config)
            if locking_mode in ["locked", "unlocked"]:
                ip_link_set(vif_name, "up")
        if action == "clear":
            ip_link_set(vif_name, "up")

###############################################################################
# Executable entry point.
###############################################################################

def main(vif_type, domid, devid, network_mode, action):
    if network_mode == "bridge":
        handle_bridge(vif_type, domid, devid, action)
    elif network_mode == "openvswitch":
        handle_vswitch(vif_type, domid, devid, action)

def usage():
    print "Usage:"
    print "%s vif_type domid devid action" % short_command_name
    print ""
    print "vif_type: [vif|tap]"
    print "    The type of the vif for which rules will be created."
    print "action: [clear|filter]"
    print "    Specifies whether to create filtering rules for the vif, or to clear all rules associated with the vif."

if __name__ == "__main__":
    if len(sys.argv) != 5:
        usage()
        sys.exit(1)
    else:
        network_mode = get_host_network_mode ()
        vif_type, domid, devid, action = sys.argv[1:5]
        send_to_syslog("Called with vif_type=%s, domid=%s, devid=%s, network_mode=%s, action=%s" %
            (vif_type, domid, devid, network_mode, action))
        main(vif_type, domid, devid, network_mode, action)
