#!/usr/libexec/platform-python

import os
import sys
import time
import subprocess
import json


class Output(object):
    def __init__(self, prefix):
        self.metrics = {}
        self.order = []
        self.prefix = prefix

    def add(self, name, val, **kwargs):
        if name not in self.metrics:
            self.order.append(name)
            self.metrics[name] = []
        self.metrics[name].append((val, kwargs))

    def flush(self):
        for name in self.order:
            print('# HELP {0}_{1} metric {1}'.format(self.prefix, name))
            print('# TYPE {0}_{1} gauge'.format(self.prefix, name))
            for val, lbls in self.metrics[name]:
                slbls = ','.join('{}="{}"'.format(key, value) for key, value
                                 in lbls.items())
                print('{0}_{1}{{{2}}} {3}'.format(self.prefix, name,
                                                  slbls, val))


def maybe_fraction(which):
    def process(val):
        parts = val.split('/')
        if which >= len(parts):
            return int(parts[0], 0)
        return int(parts[which], 0)
    return process


SMARTCTL = '/usr/sbin/smartctl'
SMARTATTR = {
    'airflow_temperature_cel': None,
    'available_reservd_space': None,
    'average_erase_count': None,
    'average_slc_erase_ct': None,
    'calibration_retry_count': None,
    'command_timeout': None,
    'crc_error_count': None,
    'current_pending_sector': None,
    'disk_shift': None,
    'ecc_uncorr_error_count': maybe_fraction(0),
    'end_to_end_error': None,
    'erase_fail_count': None,
    'erase_fail_count_total': None,
    'flash_writes_gib': None,
    'g_sense_error_rate': None,
    'hardware_ecc_recovered': None,
    'head_flying_hours': None,
    'high_fly_writes': None,
    'host_reads_mib': None,
    'host_reads_32mib': None,
    'host_writes_mib': None,
    'host_writes_32mib': None,
    'initial_bad_block_count': None,
    'life_curve_status': None,
    'lifetime_writes_gib': None,
    'lifetime_reads_gib': None,
    'load_cycle_count': None,
    'load_friction': None,
    'load_in_time': None,
    'load_retry_count': None,
    'loaded_hours': None,
    'max_erase_count': None,
    'max_slc_erase_ct': None,
    'min_erase_count': None,
    'min_slc_erase_ct': None,
    'maxavgerase_ct': None,
    'media_wearout_indicator': None,
    'multi_zone_error_rate': None,
    'nand_writes_1gib': None,
    'nand_writes_32mib': None,
    'offline_uncorrectable': None,
    'power_cycle_count': None,
    'power_loss_cap_test': None,
    'power_off_retract_count': None,
    'power_on_hours': None,
    'program_fail_count': None,
    'program_fail_cnt_total': None,
    'raid_recoverty_ct': None,
    'raw_read_error_rate': maybe_fraction(0),
    'read_soft_error_rate': None,
    'reallocated_event_count': None,
    'reallocated_sector_ct': None,
    'remaining_lifetime_perc': None,
    'reported_uncorrect': None,
    'retired_block_count': None,
    'runtime_bad_block': None,
    'sandforce_internal': None,
    'sata_downshift_count': None,
    'sata_phy_error_count': None,
    'seek_error_rate': None,
    'seek_time_performance': None,
    'slc_writes_32mib': None,
    'soft_ecc_correct_rate': maybe_fraction(0),
    'spin_retry_count': None,
    'spin_up_time': None,
    'ssd_life_left': None,
    'start_stop_count': None,
    'temperature_case': None,
    'temperature_celsius': None,
    'temperature_internal': None,
    'thermal_throttle': maybe_fraction(1),
    'throughput_performance': None,
    'tlc_writes_32mib': None,
    'total_erase_count': None,
    'total_lbas_read': None,
    'total_lbas_written': None,
    'total_slc_erase_ct': None,
    'udma_crc_error_count': None,
    'unc_soft_read_err_rate': maybe_fraction(0),
    'uncorrectable_error_cnt': None,
    'unexpect_power_loss_ct': None,
    'unsafe_shutdown_count': None,
    'unused_rsvd_blk_cnt_tot': None,
    'used_rsvd_blk_cnt_tot': None,
    'valid_spare_block_cnt': None,
    'wear_range_delta': None,
    'workld_host_reads_perc': None,
    'workld_media_wear_indic': None,
    'workload_minutes': None,
}
smartout = Output("smart")


def smart_list_devices():
    p = subprocess.Popen([SMARTCTL, '--scan-open'],
                         stdout=subprocess.PIPE)
    out, _ = p.communicate()

    devs = []
    for line in out.decode('utf-8').splitlines():
        line = line.strip()
        if len(line) == 0 or line.startswith('#'):
            continue
        ls = line.split(' ')
        devs.append({'device': ls[0], 'type': ls[2]})
    return devs


def smart_print_device_info(dev):
    encoding = {'PYTHONIOENCODING': 'utf-8'}
    p = subprocess.Popen([SMARTCTL, '-i', '-H', '-d',
                          dev['type'], dev['device']],
                         stdout=subprocess.PIPE, env=encoding)
    out, _ = p.communicate()
    if p.returncode & 0x3 != 0:
        return False

    infos = {}
    for line in out.decode('utf-8').splitlines():
        ls = line.split(':')
        if len(ls) != 2:
            continue
        infos[ls[0].strip().lower()] = ls[1].strip()

    model = infos.get('device model', '')
    serial = infos.get('serial number', "UNKNOWN")
    if model == '':
        if 'vendor' in infos and 'product' in infos:
            model = infos['vendor'] + ' ' + infos['product']
        elif 'model number' in infos:
            model = infos['model number']
    enabled = False
    available = False
    if 'smart support is' in infos:
        enabled = 'Enabled' in infos['smart support is']
        available = enabled or 'Disabled' in infos['smart support is']
    elif 'smart overall-health self-assessment test result' in infos:
        enabled = True
    healthy = \
        infos.get('smart overall-health self-assessment test result', '') == \
        'PASSED' or infos.get('smart health status', '') == 'OK'
    smartout.add("device_info", 1, disk=dev['device'], type=dev['type'],
                 device_model=model,
                 serial_number=serial)
    smartout.add('device_smart_available', int(available), disk=dev['device'],
                 type=dev['type'])
    smartout.add('device_smart_enabled', int(enabled), disk=dev['device'],
                 device_model=model,
                 serial_number=serial,
                 type=dev['type'])
    smartout.add('device_smart_healthy', int(healthy), disk=dev['device'],
                 device_model=model,
                 serial_number=serial,
                 type=dev['type'])
    if 'user_capacity' in infos:
        smartout.add('device_capacity_bytes', int(
            infos['user capacity'].split(' ')[0].replace(',', '')),
                     disk=dev['device'], type=dev['type'])
    elif 'number of namespaces' in infos:
        total_capacity = 0
        for idx in range(1, int(infos['number of namespaces']) + 1):
            capacity_idx = 'namespace %d size/capacity' % idx
            if capacity_idx in infos:
                total_capacity += int(
                    infos[capacity_idx].split(' ')[0].replace(',', ''))
        smartout.add('device_capacity_bytes', total_capacity,
                     disk=dev['device'], type=dev['type'])

    return enabled


def scsi_parse_temperature(name):
    def parse(line):
        parts = [x.strip() for x in line.split(':')]
        val = 0
        if parts[1] != '<not available>':
            val = int(parts[1][:-2])
        return [(name, val)]
    return parse


def scsi_parse_percent(name):
    return lambda line: [(name, 100 - int(line.split(':')[1].strip()[:-1]))]


def scsi_parse_int(name):
    return lambda line: [(name, int(line.split(':')[1].strip()))]


def scsi_parse_equal_sign(name, term, typ):
    return lambda line: [(name, typ(line.split('=')[term].strip()))]


def scsi_parse_errors_log(name):
    def parse(line):
        metrics = []
        p = [x.strip() for x in line.split()]
        metrics.append((name + '_errors_corrected_ecc_fast', int(p[1])))
        metrics.append((name + '_errors_corrected_ecc_delayed', int(p[2])))
        metrics.append((name + '_errors_corrected_reread_rewrite', int(p[3])))
        metrics.append((name + '_errors_corrections', int(p[5])))
        metrics.append((name + '_errors_data_processed', float(p[6]) *
                        (10**9)))
        metrics.append((name + '_errors_uncorrected', int(p[7])))

        return metrics
    return parse


SMART_SCSI_ATTR = {
    'Current Drive Temperature': scsi_parse_temperature('temperature_celsius'),
    'Drive Trip Temperature':
    scsi_parse_temperature('trip_temperature_celsius'),
    'Percentage used endurance indicator': scsi_parse_percent('life_left'),
    'Specified cycle count': scsi_parse_int('start_stop_spec'),
    'Accumulated start-stop cycles': scsi_parse_int('start_stop_cycles'),
    'Specified load-unload count': scsi_parse_int('load_cycle_spec'),
    'Accumulated load-unload cycles': scsi_parse_int('load_unload_cycles'),
    'Elements in grown defect list': scsi_parse_int('reallocated_sector_ct'),
    'Blocks sent to initiator':
    scsi_parse_equal_sign('total_blocks_read', 1, int),
    'Blocks received from initiator':
    scsi_parse_equal_sign('total_blocks_written', 1, int),
    'Blocks read from cache and sent to initiator':
    scsi_parse_equal_sign('total_blocks_cache_read', 1, int),
    'Number of read and write commands whose size <= segment size':
    scsi_parse_equal_sign('smaller_than_segment_reqs', 2, int),
    'Number of read and write commands whose size > segment size':
    scsi_parse_equal_sign('larger_that_segment_reqs', 1, int),
    'number of hours powered up':
    scsi_parse_equal_sign('power_on_hours', 1, float),
    'read:': scsi_parse_errors_log('read'),
    'write:': scsi_parse_errors_log('write'),
    'verify:': scsi_parse_errors_log('verify'),
    'Non-medium error count': scsi_parse_int('non_medium_errors'),
}


def smart_scsi_print_counters(dev, out):
    for line in out.splitlines():
        for k, v in SMART_SCSI_ATTR.items():
            if k in line:
                metrics = v(line)
                for m in metrics:
                    smartout.add('scsi_' + m[0], m[1], disk=dev['device'],
                                 type=dev['type'])


def smart_print_counters(dev):
    p = subprocess.Popen([SMARTCTL, '-a', '-d', dev['type'],
                          '-v', '9,raw24(raw8)',  # Power_On_Hours
                          '-v', '240,raw56:3210r54',  # Head_Flying_Hours
                          dev['device']],
                         stdout=subprocess.PIPE)
    out, _ = p.communicate()

    if dev['type'].startswith('scsi') or dev['type'].startswith('megaraid'):
        smart_scsi_print_counters(dev, out)
        return

    for line in out.decode('utf-8').splitlines():
        line = line.strip()
        if len(line) == 0:
            continue

        ls = line.split()
        if len(ls) < 10:
            continue

        aname = ls[1].lower().replace("-", "_").replace("/", "_")
        if aname in SMARTATTR:
            smartout.add(aname, int(ls[3]), smart_id=ls[0],
                         value='normalized', disk=dev['device'],
                         type=dev['type'])
            val = int(ls[9], 0) if SMARTATTR[aname] is None else \
                SMARTATTR[aname](ls[9])
            smartout.add(aname, val, smart_id=ls[0], value='raw',
                         disk=dev['device'], type=dev['type'])
            smartout.add(aname, int(ls[5]), smart_id=ls[0], value='thresh',
                         disk=dev['device'], type=dev['type'])


def smart_collect():
    devices = smart_list_devices()
    for dev in devices:
        smartout.add("smartctl_run", int(time.time()),
                     disk=dev['device'], type=dev['type'])
        if not smart_print_device_info(dev):
            continue
        smart_print_counters(dev)
    smartout.flush()


def nvme_intel_print_counters(dev):
    p = subprocess.Popen([NVME, 'intel', 'smart-log-add', '-j',
                          dev['device']], stdout=subprocess.PIPE)
    out, _ = p.communicate()
    if p.returncode != 0:
        return
    nvme_smart_log_add = json.loads(out)

    for stat, statv in nvme_smart_log_add['Device stats'].items():
        for k, v in statv.items():
            if isinstance(v, dict):
                for k2, v2 in v.items():
                    nvmeout.add('intel_'+stat, v2, type=k, subtype=k2,
                                disk=dev['device'])
            else:
                nvmeout.add('intel_'+stat, v, type=k, disk=dev['device'])


NVME = '/usr/sbin/nvme'
NVME_VENDOR_SPECIFIC = {
    0x8086: nvme_intel_print_counters,
}
nvmeout = Output("smart_nvme")


def nvme_list_devices():
    devs = []

    p = subprocess.Popen([NVME, 'list', '-o' 'json'],
                         stdout=subprocess.PIPE)
    out, _ = p.communicate()
    if not out:
        return devs
    nvme_list = json.loads(out)

    if 'Devices' not in nvme_list:
        return devs

    for device in nvme_list['Devices']:
        pi = subprocess.Popen([NVME, 'id-ctrl', '-o', 'json',
                               device['DevicePath']],
                              stdout=subprocess.PIPE)
        out, _ = pi.communicate()
        nvme_idctrl = json.loads(out)

        nvmeout.add("device_info", 1, disk=device['DevicePath'],
                    device_model=device['ModelNumber'],
                    serial_number=device['SerialNumber'])
        nvmeout.add("capacity_bytes", device['PhysicalSize'],
                    disk=device['DevicePath'])

        devs.append({'device': device['DevicePath'],
                     'vendor': nvme_idctrl['vid']})

    return devs


def nvme_print_counters(dev):
    p = subprocess.Popen([NVME, 'smart-log', '-o', 'json',
                          dev['device']], stdout=subprocess.PIPE)
    out, _ = p.communicate()
    nvme_smart_log = json.loads(out)

    for k, v in nvme_smart_log.items():
        nvmeout.add(k, v, disk=dev['device'])

    if dev['vendor'] in NVME_VENDOR_SPECIFIC:
        NVME_VENDOR_SPECIFIC[dev['vendor']](dev)


def nvme_collect():
    devices = nvme_list_devices()
    for dev in devices:
        nvmeout.add("nvme_run", int(time.time()),
                    disk=dev['device'])
        nvme_print_counters(dev)
    nvmeout.flush()


if __name__ == '__main__':
    if os.geteuid() != 0:
        print("You need to have root privileges to run "
              "this script.", file=sys.stderr)
        sys.exit(1)

    smart_collect()
    if os.path.isfile(NVME):
        nvme_collect()
    sys.exit(0)
