#!/usr/bin/python3 -u
"""
Description: Reset PCI device by vendor and product ID
Author: thnikk
"""
from subprocess import run, CalledProcessError
from glob import glob
import os
import sys
import argparse
import time


def parse_args() -> argparse.ArgumentParser:
    """ Parse arguments """
    parser = argparse.ArgumentParser()
    parser.add_argument('vid', type=str, help='Vendor ID')
    parser.add_argument('pid', type=str, help='Product ID')
    parser.add_argument('driver', type=str, help='Driver to load')
    return parser.parse_args()


def get_card(dev):
    """ d """
    for card in glob('/sys/class/drm/card[0-9]'):
        if dev in os.readlink(card):
            return card.split('/')[-1]
    return None


def get_cards() -> dict:
    """ Get dict of cards structured as {vid: {pid: {}}} """
    devices = {}
    for d in glob('/sys/bus/pci/devices/*'):
        with open(f'{d}/vendor', 'r', encoding='utf-8') as file:
            vid = file.read().lstrip('0x').rstrip()
        with open(f'{d}/device', 'r', encoding='utf-8') as file:
            pid = file.read().lstrip('0x').rstrip()
        dev = d.split('/')[-1]
        try:
            devices[vid]
        except KeyError:
            devices[vid] = {}
        devices[vid][pid] = {"dev": dev, "card": get_card(dev)}
    return devices


def unbind(dev):
    """ Unbind driver """
    try:
        with open(
            f'/sys/bus/pci/devices/{dev}/driver/unbind',
            'w', encoding='utf-8'
        ) as file:
            file.write(dev)
    except FileNotFoundError:
        print(f'Driver not bound to {dev}')


def bind(dev, driver):
    """ Bind driver """
    with open(
        f'/sys/bus/pci/drivers/{driver}/bind', 'w', encoding='utf-8'
    ) as file:
        file.write(dev)


def rescan():
    """ Rescan PCI devices """
    with open(
        '/sys/bus/pci/rescan', 'w', encoding='utf-8'
    ) as file:
        file.write('1')


def fuser(card) -> list:
    """ Get PIDs using fuser """
    try:
        return run(
            ['fuser', f'/dev/dri/{card}'],
            check=True, capture_output=True
        ).stdout.decode('utf-8').strip().split()
    except CalledProcessError:
        return []


def get_pids(card):
    """ Get running processes """
    if card:
        number = int(''.join([s for s in card if s.isdigit()]))
        render_dev = f'renderD{128+number}'
        return set(fuser(card) + fuser(render_dev))
    return []


def fix_sway():
    """ d """
    try:
        output = run([
            'udevadm', 'trigger', '--verbose', '--type=devices',
            '--action=remove', '--subsystem-match=drm',
            '--property-match="MINOR=1"'], check=True, capture_output=True)
        print(output.stderr.decode('utf-8'))
    except CalledProcessError as e:
        print(e.stderr.decode('utf-8'))


def main():
    """ Main function """
    args = parse_args()

    if not os.path.exists(f'/sys/bus/pci/drivers/{args.driver}'):
        print('Driver does not exist, exiting', file=sys.stderr)
        sys.exit(1)

    cards = get_cards()
    card = cards[args.vid][args.pid]

    try:
        pids = get_pids(card['card'])
        if pids:
            print(f'Processes running:\n{pids}')
            sys.exit(1)
        unbind(card['dev'])
        bind(card['dev'], args.driver)
        time.sleep(3)
        fix_sway()
    except PermissionError:
        print('Please run as root.', file=sys.stderr)


if __name__ == "__main__":
    main()
