#!/usr/bin/python3 -u
import os
import glob
import time
import signal

# blacklist = [ "1002:731f" ]
blacklist = []


# Disable PWM when terminated
def exit_handler(signum, frame):
    try:
        for device, info in devices.items():
            open(info["path"] + "/pwm1_enable", 'w').write("2")
            print("Disabling for " + device)
    except:
        pass
    exit(1)


# Run handler if program receives SIGTERM
signal.signal(signal.SIGTERM, exit_handler)

last_temp = {}

# Main loop
while True:
    try:
        # Store device info
        devices = {}

        # Itereate through hwmon devices
        for path in list(glob.glob("/sys/class/hwmon/hwmon*")):
            try:
                # Get vendor and product ID for devices
                vendor = open(
                        path + "/device/vendor", 'r'
                        ).read().lstrip("0x").rstrip()
                product = open(
                        path + "/device/device", 'r'
                        ).read().lstrip("0x").rstrip()
                name = open(
                        path + "/name", 'r'
                        ).read().rstrip()
                # Convert to VID:PID
                device = "{}:{}".format(vendor, product)
                # If device is using amdgpu driver
                if name == "amdgpu" and device not in blacklist:
                    # Add the path to the main dict
                    devices[device] = {"path": path}
            # Skip if device is missing VID or PID
            except FileNotFoundError:
                pass

        # Iterate through devices
        for device, info in devices.items():
            # Check if PWM is enabled
            if open(info["path"] + "/pwm1_enable", 'r').read().rstrip() != "1":
                print("Enabling for device " + device)
                # Enable manual PWM
                open(info["path"] + "/pwm1_enable", 'w').write("1")
                # Increase power cap
                power_max = open(
                    info["path"] + "/power1_cap_max", 'r').read().rstrip()
                open(info["path"] + "/power1_cap", 'w').write(power_max)

            # Get temp and fan speed
            pwm = int(
                    open(info["path"] + "/pwm1_enable", 'r'
                         ).read().rstrip())
            temp = int(
                    open(info["path"] + "/temp1_input", 'r'
                         ).read().rstrip()) / 1000

            try:
                # If temperature goes down
                if last_temp[device] > temp:
                    # Fade down slowly
                    temp = last_temp[device] - 1
            # Catch exception since last_temp isn't set on the first loop
            except KeyError:
                pass
            last_temp[device] = temp

            # Set fan speed based on temp
            min_temp = 60
            max_temp = 90

            # Keep temp within range
            if temp > max_temp:
                temp = max_temp
            if temp < min_temp:
                temp = min_temp

            # Calculate multiplier for fan speed based on temperature
            fan_multiplier = (temp - min_temp) / (max_temp - min_temp)
            # Apply multiplier to fan RPM maximum
            fan_speed = int(255 * fan_multiplier)

            # print(fan_speed)
            open(info["path"] + "/pwm1", 'w').write(str(fan_speed))

        # Run every second
        time.sleep(1)
    except KeyboardInterrupt:
        exit_handler(signal.SIGTERM, None)
