#!/usr/bin/env python

import os
import curses
import time
from collections import defaultdict

def fetch_btrfs_filesystems():
    """Fetch all Btrfs filesystems and their devices."""
    btrfs_fs = {}
    labels = {}
    btrfs_path = "/sys/fs/btrfs"
    if os.path.exists(btrfs_path):
        for uuid in os.listdir(btrfs_path):
            devices_path = os.path.join(btrfs_path, uuid, "devices")
            label_path = os.path.join(btrfs_path, uuid, "label")
            label = None
            if os.path.exists(label_path):
                with open(label_path, "r") as f:
                    label = f.read().strip()
            labels[uuid] = label if label else uuid  # Use UUID if no label is available
            if os.path.exists(devices_path):
                devices = [os.path.join(devices_path, d, "stat") for d in os.listdir(devices_path)]
                btrfs_fs[uuid] = devices
    return btrfs_fs, labels

def get_sector_size(device):
    """Get the hardware sector size for a device."""
    try:
        with open(f"/sys/block/{device}/queue/hw_sector_size", "r") as f:
            return int(f.read().strip())
    except FileNotFoundError:
        return 512  # Default to 512 bytes if not found

def calculate_btrfs_stats(btrfs_fs):
    """Calculate total stats for each Btrfs filesystem."""
    fs_stats = {}

    for uuid, devices in btrfs_fs.items():
        read_bytes = write_bytes = 0  # Initialize counters for read/write bytes

        for device_stat in devices:
            # Get the device name (e.g., sda, nvme0n1)
            device = os.path.basename(os.path.dirname(device_stat))

            # Get the sector size for the device
            #sector_size = get_sector_size(device)
            sector_size = 512

            try:
                # Read the 17 fields from the stat file
                with open(device_stat, "r") as f:
                    stats = f.read().split()

                # Parse the fields of interest
                read_sectors = int(stats[2])  # Field 2: read sectors
                write_sectors = int(stats[6])  # Field 6: write sectors

                # Convert sectors to bytes and add to the totals
                read_bytes += read_sectors * sector_size
                write_bytes += write_sectors * sector_size

            except (FileNotFoundError, ValueError) as e:
                # Handle missing or invalid files gracefully
                continue

        # Store the aggregated read/write stats for the filesystem
        fs_stats[uuid] = {
            "read_bytes": read_bytes,
            "write_bytes": write_bytes,
        }

    return fs_stats

def init_colors():
    """Initialize color pairs for curses."""
    curses.start_color()
    curses.init_pair(1, curses.COLOR_YELLOW, curses.COLOR_BLACK)  # Headings
    curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK)   # Read chart
    curses.init_pair(3, curses.COLOR_RED, curses.COLOR_BLACK)     # Write chart
    curses.init_pair(4, curses.COLOR_CYAN, curses.COLOR_BLACK)    # Selected item
    curses.init_pair(5, curses.COLOR_WHITE, curses.COLOR_BLACK)

def format_iec(value):
    """Convert a value in bytes to IEC units (KiB, MiB, GiB)."""
    units = ['B', 'KiB', 'MiB', 'GiB', 'TiB']
    for unit in units:
        if value < 1024:
            return f"{value:.1f} {unit}"
        value /= 1024
    return f"{value:.1f} PiB"

def display_ui(stdscr, btrfs_fs, fs_labels):
    """Display the UI with curses."""
    curses.curs_set(0)
    stdscr.nodelay(1)
    init_colors()

    selected_idx = 0
    use_labels = True  # Toggle between UUIDs and labels

    prev_stats = {}
    first_update_skipped = set()
    history = defaultdict(lambda: {"read": [0] * 60, "write": [0] * 60})

    while True:
        # Capture user input
        key = stdscr.getch()
        if key == curses.KEY_UP and selected_idx > 0:
            selected_idx -= 1
        elif key == curses.KEY_DOWN and selected_idx < len(btrfs_fs) - 1:
            selected_idx += 1
        elif key in [ord("q"), ord("Q")]:
            break  # Quit on 'q' or 'Q'
        elif key in [ord("l"), ord("L")]:
            use_labels = not use_labels  # Toggle between UUID and label display

        # Get screen dimensions
        height, width = stdscr.getmaxyx()

        # Check if terminal is too small
        if height < 20 or width < 70:
            stdscr.clear()
            stdscr.addstr(0, 0, "Terminal size too small. Please resize to at least 20x60.")
            stdscr.refresh()
            time.sleep(1)
            continue

        # Fetch stats
        current_stats = calculate_btrfs_stats(btrfs_fs)

        # Calculate deltas
        deltas = {}
        for uuid, stats in current_stats.items():
            prev = prev_stats.get(uuid, {"read_bytes": 0, "write_bytes": 0})

            # Skip the first update to avoid erroneous deltas
            if uuid not in first_update_skipped:
                first_update_skipped.add(uuid)
                prev_stats[uuid] = stats
                continue

            deltas[uuid] = {
                "read_delta": (stats["read_bytes"] - prev["read_bytes"]),
                "write_delta": (stats["write_bytes"] - prev["write_bytes"]),
            }
        prev_stats = current_stats

        # Clear screen
        stdscr.clear()

        # Display headings
        stdscr.attron(curses.color_pair(1))
        stdscr.addstr(0, 0, "Btrfs Filesystem I/O Monitor (Press 'q' to quit, 'L' to toggle labels)")
        stdscr.addstr(2, 0, f"{'Filesystem':<38}{'Read/s':>12}{'Write/s':>14}")
        stdscr.attroff(curses.color_pair(1))

        # Calculate max height for filesystem list
        chart_height = 8 * 2 + 3  # Two charts and labels
        max_list_height = height - chart_height - 5  # Account for gaps and headers

        # Display the filesystem list
        uuids = list(btrfs_fs.keys())
        for idx, uuid in enumerate(uuids[:max_list_height]):
            label = fs_labels[uuid] if use_labels else uuid
            label_display = label if len(label) < 37 else label[:34] + "..."

            if uuid in deltas:
                read_kb = deltas[uuid]["read_delta"]
                write_kb = deltas[uuid]["write_delta"]
            else:
                read_kb, write_kb = 0.0, 0.0

            line = f"{label_display:<37}{format_iec(read_kb):>12}{format_iec(write_kb):>14}"
            y_position = 3 + idx
            if idx == selected_idx:
                stdscr.attron(curses.color_pair(4))
                stdscr.addstr(y_position, 0, ">" + line)
                stdscr.attroff(curses.color_pair(4))
            else:
                stdscr.addstr(y_position, 1, line)

        # Update history for selected filesystem
        selected_fs = uuids[selected_idx]
        if selected_fs in deltas:
            history[selected_fs]["read"].append(deltas[selected_fs]["read_delta"])
            history[selected_fs]["write"].append(deltas[selected_fs]["write_delta"])
        else:
            history[selected_fs]["read"].append(0.0)
            history[selected_fs]["write"].append(0.0)

        # Keep only the last 60 measurements
        history[selected_fs]["read"] = history[selected_fs]["read"][-60:]
        history[selected_fs]["write"] = history[selected_fs]["write"][-60:]

        # Display charts
        chart_start = 3 + min(len(uuids), max_list_height) + 1  # Add gap
        read_chart = history[selected_fs]["read"]
        write_chart = history[selected_fs]["write"]

        # Normalize values for the bar chart
        max_read = max(max(read_chart), 1)  # Avoid division by zero
        max_write = max(max(write_chart), 1)
        
        x_axis = " ".join(f"{x:>4}" for x in range(60, -1, -5))  # Countdown from 50 to 0 with 5-character spacing

        stdscr.addstr(chart_start, 11,"Read bytes/second")
        for i in range(8):  # Chart height is 8
            read_row = "".join(
                "|" if value > max_read * (7 - i) / 7 else
                "_" if i == 7 else
                "."
                for value in read_chart[-60:]
            ).ljust(60)
            scale = max_read * (7 - i) / 7  # Y scale value as integer
            stdscr.addstr(chart_start + 1 + i, 0, f"{format_iec(scale):>10} {read_row}")
        stdscr.addstr(chart_start + 9, 7, x_axis)  # Add X-axis below Read Chart

        stdscr.addstr(chart_start + 10, 11, "Write bytes/second")
        for i in range(8):  # Write chart
            write_row = "".join(
                "|" if value > max_write * (7 - i) / 7 else "_" if i == 7 else "."
                for value in write_chart[-60:]
            ).ljust(60)
            scale = max_write * (7 - i) / 7  # Y scale value as integer
            stdscr.addstr(chart_start + 11 + i, 0, f"{format_iec(scale):>10} {write_row}")
        stdscr.addstr(chart_start + 19, 7, x_axis)  # Add X-axis below Read Chart

        # Refresh screen
        stdscr.refresh()
        time.sleep(1)

if __name__ == "__main__":
    # Fetch Btrfs filesystems and their labels
    btrfs_fs, fs_labels = fetch_btrfs_filesystems()

    # Use curses to run the display UI with additional arguments
    curses.wrapper(lambda stdscr: display_ui(stdscr, btrfs_fs, fs_labels))
 
 