#!/usr/libexec/platform-python
#
# Test virtio-scsi and virtio-blk queue settings for all machine types
#
# Copyright (c) 2019 Virtuozzo International GmbH
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import sys
import os
import re
import iotests

sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts'))
from qemu.machine import QEMUMachine

iotests.script_initialize(supported_fmts=['qcow2'])

#list of machine types and virtqueue properties to test
virtio_scsi_props = {'vq_size': 'virtqueue_size'}
virtio_blk_props = {'vq_size': 'queue-size'}

dev_types = {'virtio-scsi-pci': virtio_scsi_props,
             'virtio-blk-pci': virtio_blk_props}

vm_dev_params = {'virtio-scsi-pci': ['-device', 'virtio-scsi-pci,id=scsi0'],
                 'virtio-blk-pci': ['-device',
                                    'virtio-blk-pci,id=scsi0,drive=drive0',
                                    '-drive',
                                    'driver=null-co,id=drive0,if=none']}

def make_pattern(props):
     pattern_items = ['{0} = \d+'.format(prop) for prop in props]
     return '|'.join(pattern_items)


def query_virtqueue(vm, dev_type_name):
    output = vm.qmp('human-monitor-command', command_line='info qtree')
    output = output['return']

    props_list = dev_types[dev_type_name].values();

    pattern = make_pattern(props_list)

    res = re.findall(pattern, output)

    if len(res) != len(props_list):
        not_found = props_list.difference(set(res))

        ret = (0, '({0}): The following properties not found: {1}'
                  .format(dev_type_name, ', '.join(not_found)))
    else:
        props = dict()
        for prop in res:
            p = prop.split(' = ')
            props[p[0]] = int(p[1])
        ret = (1, props)

    return ret


def check_mt(mt, dev_type_name):
    vm_params = ['-machine', mt['name']] + vm_dev_params[dev_type_name]

    vm = QEMUMachine(iotests.qemu_prog, vm_params)
    vm.launch()
    ret = query_virtqueue(vm, dev_type_name)
    vm.shutdown()

    if ret[0] == 0:
        print('Error ({0}): {1}'.format(mt['name'], ret[1]))
        return 1

    errors = 0
    props = ret[1]

    for prop_name, prop_val in props.items():
        if mt[prop_name] != prop_val:
            print('Error [{0}, {1}]: {2}={3} (expected {4})'.
                  format(mt['name'], dev_type_name, prop_name, prop_val,
                         mt[prop_name]))
            errors += 1

    return errors

# collect all machine types except 'none'
vm = iotests.VM()
vm.launch()
machines = [m['name'] for m in vm.qmp('query-machines')['return']]
vm.shutdown()
machines.remove('none')

failed = 0

for dev_type in dev_types:
    # create a list of machine types and their parameters
    # machine types vz8.X.X have virtqueue_length=256 and max_segment=254
    # others have virtqueue_length=128 and max_segment=126
    mtypes = list()
    for m in machines:
        # checking for balloon subvendor setting
        prefix, major, minor = m.split('.')

        vq_size = 128
        if prefix in ('pc-i440fx-vz7', 'pc-q35-vz7'):
            if int(major) >= 16:
                vq_size = 512

        if prefix == 'pc-q35-rhel8':
            if int(major) >= 3:
                vq_size = 256

        if prefix in 'pc-q35-vz8':
            vq_size = 512

        mtypes.append({'name': m, dev_types[dev_type]['vq_size']: vq_size})
    # test each mahine type
    for mt in mtypes:
        failed += check_mt(mt, dev_type)

if failed > 0:
    print('Failed')
else:
    print('Success')
