#!/usr/libexec/platform-python
import os
import subprocess
import socket
import re
import math
import rpm
import sys

COLORS = {
    'bold': "\033[01m",
    'black': "\033[30m",
    'red': "\033[31m",
    'green': "\033[32m",
    'yellow': "\033[33m",
    'blue': "\033[34m",
    'purple': "\033[35m",
    'cyan': "\033[36m",
    'white': "\033[37m",
    'reset': "\033[0m",
    'system': "\033[38;5;120m"}


# colorize
def colored(col, s):
    return COLORS[col] + s + COLORS['reset']


def humanise(num):
    for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:
        if num < 1024.0:
            return "%3.1f%s" % (num, x)
        num /= 1024.0


def smartlen(line):
    line = line.replace("\t", ' '*4)
    for color in COLORS.keys():
        line = line.replace(COLORS[color], '')
    return len(line)


def column_display(rows, num_columns=2):
    columns = []
    column = []
    for row in rows:
        column.append(row)
        if len(column) >= math.floor(float(len(rows))/num_columns):
            columns.append(column)
            column = []
    if len(column) > 0: columns.append(column)

    col_texts = []
    for col in columns:
        coltext = []
        max_keylen = max([len(row[0]) for row in col])
        for row in col:
            padding = ' '*(max_keylen-len(row[0]))
            coltext.append("%s: %s%s" % (row[0], padding, row[1]))
        col_texts.append(coltext)

    result = ""
    for i in range(len(col_texts[0])):
        line = ""

        for j, coltext in enumerate(col_texts):
            if i < len(coltext):
                if j > 0:
                    max_vallen = max([len(row[1]) for row in columns[j-1]])
                    vallen = len(columns[j-1][i][1])
                    line += "\t"*max([1, int(math.floor(float(max_vallen-vallen)/4))])
                line += coltext[i]
        result += line + "\n"

    return result


def center_by(width, uncentered):
    result = ""
    lines = uncentered.split("\n")
    length = max([smartlen(line) for line in lines])
    for line in lines:
        result += ' '*int((width-length)/2) + line + "\n"
    return result


def run_cmd(cmd):
    return subprocess.check_output(cmd, shell=True, stderr=subprocess.PIPE).decode('utf-8').strip()


# return ip from external interace
def public_ip():
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    s.connect(("192.168.88.22", 80))
    return s.getsockname()[0]


def license_info():
    d = {}
    try:
       out = run_cmd("/usr/sbin/vzlicview --class VZSRV")
       for line in out.splitlines():
           if '=' not in line:
              continue
           k,v = line.strip().rstrip().split("=")
           d[k] = v.strip('"')
       license = {'cert_expiration': str(d['expiration']), 'cert_status': str(d['status']), 'key': str(d['key_number'])}
       return license
    except Exception:
       return None


def os_release():
    with open("/etc/os-release") as f:
        d = {}
        for line in f:
            k,v = line.rstrip().split("=")
            d[k] = v
    version = str(d['NAME']).strip('"') + ' ' +  str(d['VERSION']).strip('"')
    return version


def mount_info():
    try:
        d = {}
        for l in file('/proc/mounts'):
            if l[0] == '/':
                l = l.split()
                d[l[1]] = l[3]
        if 'lazytime' not in d['/vz']:
            return d['/vz']
    except Exception:
        return None



def vzkernel():
    vzlist = []
    ts = rpm.TransactionSet()
    mi = ts.dbMatch('name', 'kernel')
    for h in mi:
        vzlist.append("%s-%s-%s" % (h['name'], h['version'], h['release']))
    return vzlist

def vercmp(p1, p2):
    return rpm.labelCompare(p1, p2)


def vz_ver(booted_kernel, installed_kernel):
    (e1, v1, r1) = (None, booted_kernel.split('-')[0], booted_kernel.split('-')[1])
    (e2, v2, r2) = (None, installed_kernel.split('-')[0], installed_kernel.split('-')[1])
    rc = vercmp((e1, v1, r1), (e2, v2, r2))
    if rc > 0:
        # installed NEWER than checked one
        return 11
    elif rc == 0:
        return 0
    elif rc < 0:
        return 12

def issue_info():
    red = lambda text: '\033[0;31m' + text + '\033[0m'
    dear = red('Dear OpenVZ user!')
    ip = red('IP: {}'.format(run_cmd('hostname --all-ip-addresses')))
    try:
   	 with open('/etc/issue', 'w') as issue_file:
   	      issue_file.write("""{}\nkernel: \\r\n\nUse the following hostname and IP address to connect to this server:\n\n{}\n\n\\t \\d\n\n""".format(dear, ip))
    except Exception:
         # just ignore
         pass


def check_swap_pages():

    swap_pages = subprocess.check_output(['vzlist','-a', '-o', 'swappages.l']).split().decode('utf-8')
    results = [int(i) for i in swap_pages[1:]]
    if any(swapvalue > 0 for swapvalue in results):
       # values in vzlist more than 0
       return False
    else:
       # swappages value is zero
       return True


def warnings_info():
    red = lambda text: '\033[0;31m' + text + '\033[0m'
    boot_ver = 'kernel-' + subprocess.check_output(['uname','-r']).strip().decode('utf-8')
    if mount_info() is not None:
       banner = red('WARNING:\n/vz is mounted with non-optimal mount options\nand may lead to performance degradation.\nPlease remount with: defaults,noatime,lazytime for /vz\n\n')
       with open('/etc/motd', 'a') as motd_file:
            motd_file.write(banner)
    # check kernel
    for kernel in vzkernel():
        if vz_ver(boot_ver, kernel) is 12:
           banner = red('WARNING:\nrunning kernel {} older than installed {}\n'.format(boot_ver, kernel))
           with open('/etc/motd', 'a') as motd_file:
                motd_file.write(banner)
           break
    # check that swap enabled
    swap_list = ['partition', 'file']
    with open('/proc/swaps') as f:
        contents = f.read()
        if not any(x in contents for x in swap_list) and check_swap_pages() is False:
            banner = red('WARNING:\nswap partition must be enabled\n')
            with open('/etc/motd', 'a') as motd_file:
                motd_file.write(banner)


def motd_info():
    raw_loadavg = run_cmd("cat /proc/loadavg").split()
    # load
    load = {'1min': float(raw_loadavg[0]),
            '5min': float(raw_loadavg[1]),
            '15min': float(raw_loadavg[2])}
    # /vz
    raw_vz = run_cmd("/bin/df -P /vz | tail -1").split()
    raw_vz_human = run_cmd("/bin/df -Ph /vz | tail -1").split()

    vz = {'used': int(raw_vz[2]),
            'total': int(raw_vz[1]),
            'used_human': raw_vz_human[2],
            'total_human': raw_vz_human[1]}
    vz['ratio'] = float(vz['used'])/vz['total']
    # memory
    raw_free = run_cmd("/bin/free -b").split("\n")
    raw_mem = raw_free[1].split()
    raw_swap = raw_free[2].split()

    memory = {'used': int(raw_mem[1])-int(raw_mem[2]),
              'total': int(raw_mem[1])}

    memory['ratio'] = float(memory['used'])/memory['total']
    swap = {'used': int(raw_swap[2]),
            'total': int(raw_swap[1])}

    swap['ratio'] = 0.0 if swap['total'] == 0 else float(swap['used'])/swap['total']
    if swap['ratio'] > 0.8: swap['color'] = 'red'
    else: swap['color'] = 'bold'

    rows = []
    raw_uptime = run_cmd('uptime')
    raw_date = run_cmd('date +"%T"')
    uptime = raw_uptime.split(',')[0].split('up')[1].strip()
    rows.append(['MOTD generated at', raw_date])
    rows.append(['Uptime', uptime])
    rows.append(['OS', os_release()])
    # rows.append(['Public IP', public_ip()])
    rows.append(['IP', run_cmd('hostname --all-ip-addresses')])
    rows.append(['Hostname', socket.gethostname()])
    rows.append(['Kernel', run_cmd('uname -or')])
    rows.append(['System Load', str(load['1min'])])
    rows.append(["/vz Usage", "%d%% of %s" % (vz['ratio']*100, vz['total_human'])])
    rows.append(["Swap Usage", colored(swap['color'], "%d%%" % (swap['ratio']*100))])
    rows.append(['RAM Free', "%d%% of %s" % (memory['ratio']*100, humanise(memory['total']))])
    if license_info() is not None:
       rows.append(['License', str(license_info()['cert_expiration']) + ' status:' + str(license_info()['cert_status'])])
       rows.append(['Key', str(license_info()['key'])])

    return(rows)

if __name__ == "__main__":
    banner_file = '/etc/motd'
    # create if not exist
    mode = 'w' if os.path.exists(banner_file) else 'a'
    # create if not exist, needed for column formatting
    with open(banner_file, mode):
         pass 
    banner = open(banner_file).read()
    banner_length = max([smartlen(line) for line in banner.split("\n")])
    info = center_by(banner_length, column_display(motd_info(), num_columns=1))
    output = banner + "\n" + info
    f = open(banner_file, mode)
    f.write(info)
    f.close()
    issue_info()
    warnings_info()
