# -*- coding: utf-8 -*-
# SPDX-License-Identifier: GPL-3.0-or-later

"""
btrfs_io: Netdata python.d collector for Btrfs I/O
--------------------------------------------------

Monitors I/O activity of Btrfs filesystems,
displaying IOPS and bandwidth for each filesystem
and its member block devices.

Groups:
- btrfs: global  → global overview charts (IOPS, BW)
- btrfs: details → per-filesystem charts (IOPS, BW, each member device)
"""

__description__ = "Netdata python.d collector for Btrfs I/O"
__author__      = "Forza <forza@tnonline.net>"
__license__     = "GPL-3.0-or-later"
__version__     = "0.0.1"

import os
import time
from collections import defaultdict
from bases.FrameworkServices.SimpleService import SimpleService


# Netdata constants
priority         = 90000 # Chart ordering priority in the Netdata dashboard.
retries          = 60    # How many times to retry check() during startup
                         # before disabling the collector.

# Base sysfs paths used to locate Btrfs and block device statistics.
SYSFS_PATH_BTRFS = "/sys/fs/btrfs" 
SYSFS_PATH_BLOCK = "/sys/block"

SECTOR_SIZE      = 512   # Linux kernel reports all I/O in 512-byte sectors,
                         # regardless of the device’s physical sector size
                         # (4 K drives, etc.). Used for bandwidth calculation.

###
# Helper functions
###

def _read_first_line(path):
    """Read and return the first line of a file, or None if missing."""
    try:
        with open(path, 'r') as f:
            return f.readline().strip()
    except Exception:
        return None


def _list_btrfs_fsids():
    """Return a list of Btrfs FSIDs from SYSFS_PATH_BTRFS."""
    base = SYSFS_PATH_BTRFS
    try:
        return [d for d in os.listdir(base) if os.path.isdir(os.path.join(base, d))]
    except Exception:
        return []


def _fs_label(fsid):
    """Read Btrfs volume label; fallback to short UUID if empty."""
    s = _read_first_line(os.path.join(SYSFS_PATH_BTRFS, fsid, 'label'))
    return s if s else fsid.split('-')[0]


def _resolve_device_symlink(device_symlink_path):
    """Resolve the block device name from a SYSFS_PATH_BTRFS/.../device symlink."""
    try:
        real = os.path.realpath(device_symlink_path)
        parts = real.split(os.sep)
        for i in range(len(parts) - 1):
            if parts[i] == 'block' and i + 1 < len(parts):
                return parts[i + 1]
        # fallback: handle dev-mapper or NVMe paths
        for name in os.listdir(SYSFS_PATH_BLOCK):
            if real.endswith(os.sep + name):
                return name
    except Exception:
        pass
    return None


def _collect_btrfs_topology():
    """Discover all Btrfs filesystems and their member block devices."""
    fs_map = {}
    for fsid in _list_btrfs_fsids():
        devdir = os.path.join(SYSFS_PATH_BTRFS, fsid, 'devices')
        devices = set()
        if os.path.isdir(devdir):
            try:
                for devid in os.listdir(devdir):
                    dpath = os.path.join(devdir, devid)
                    if not os.path.isdir(dpath):
                        continue
                    block_name = _resolve_device_symlink(os.path.join(dpath, 'device'))
                    if block_name:
                        devices.add(block_name)
            except Exception:
                # If the devices dir disappears mid-scan, just skip this FSID
                pass

        fs_map[fsid] = {'label': _fs_label(fsid), 'devices': sorted(devices)}

    # drop entries with no devices
    return {k: v for k, v in fs_map.items() if v['devices']}


def _read_block_stat(block_name):
    """Read SYSFS_PATH_BLOCK/<dev>/stat and extract read/write counters."""
    path = os.path.join(SYSFS_PATH_BLOCK, block_name, 'stat')
    try:
        s = _read_first_line(path)
        if not s:
            return None
        f = [int(x) for x in s.split()]
        return {
            'reads': f[0],
            'sectors_read': f[2],
            'writes': f[4],
            'sectors_written': f[6]
        }
    except Exception:
        return None

###
# main Netdata collector class
###

class Service(SimpleService):
    """Implements Netdata's SimpleService interface."""

    def __init__(self, configuration=None, name=None):
        super(Service, self).__init__(configuration=configuration, name=name)
        self.prev = {}          # previous sample of block device stats
        self.last_ts = None     # last timestamp for delta calculation
        self.fs_map = _collect_btrfs_topology()
        self.order = []         # chart creation order
        self.definitions = {}   # chart definitions
        self._define_charts()

    @staticmethod
    def check():
        """Return True if at least one Btrfs filesystem is found."""
        if not os.path.isdir(SYSFS_PATH_BTRFS):
            return False
        return bool(_collect_btrfs_topology())

    def _define_charts(self):
        """Define all charts and dimensions."""

        # GLOBAL CHARTS
        dims_global_iops = []
        dims_global_bw = []
        for fsid, meta in sorted(self.fs_map.items()):
            label = meta['label']
            dims_global_iops.append([f'fs_{fsid}_riops', f'{label} read', 'absolute', 1, 1])
            dims_global_iops.append([f'fs_{fsid}_wiops', f'{label} write', 'absolute', -1, 1])
            dims_global_bw.append([f'fs_{fsid}_rbw', f'{label} read', 'absolute', 1, 1])
            dims_global_bw.append([f'fs_{fsid}_wbw', f'{label} write', 'absolute', -1, 1])

        # global IOPS chart
        self.order.append('fs_iops')
        self.definitions['fs_iops'] = {
            'options': ['btrfs.fs_iops',
                        'Btrfs filesystems IOPS',
                        'IOPS',
                        'btrfs: global',
                        'btrfs.fs_iops',
                        'area'],
            'lines': dims_global_iops
        }

        # global Bandwidth chart
        self.order.append('fs_bw')
        self.definitions['fs_bw'] = {
            'options': ['btrfs.fs_bw',
                        'Btrfs filesystems Bandwidth',
                        'bytes/s',
                        'btrfs: global',
                        'btrfs.fs_bw',
                        'area'],
            'lines': dims_global_bw
        }

        # PER-FILESYSTEM CHARTS
        for fsid, meta in sorted(self.fs_map.items()):
            label = meta['label']
            devs = meta['devices']

            # per-FS IOPS chart (all devices)
            cid = f'fs_{fsid}_devices_iops'
            self.order.append(cid)
            self.definitions[cid] = {
                'options': [f'btrfs.{fsid}.iops',
                            f'Btrfs {label} IOPS',
                            'IOPS',
                            'btrfs: details',
                            f'btrfs.{fsid}.iops',
                            'area'],
                'lines': []
            }
            for d in devs:
                self.definitions[cid]['lines'].append([f'dev_{fsid}_{d}_riops', f'{d} read', 'absolute', 1, 1])
                self.definitions[cid]['lines'].append([f'dev_{fsid}_{d}_wiops', f'{d} write', 'absolute', -1, 1])

            # per-FS Bandwidth chart (all devices)
            cid = f'fs_{fsid}_devices_bw'
            self.order.append(cid)
            self.definitions[cid] = {
                'options': [f'btrfs.{fsid}.bw',
                            f'Btrfs {label} Bandwidth',
                            'bytes/s',
                            'btrfs: details',
                            f'btrfs.{fsid}.bw',
                            'area'],
                'lines': []
            }
            for d in devs:
                self.definitions[cid]['lines'].append([f'dev_{fsid}_{d}_rbw', f'{d} read', 'absolute', 1, 1])
                self.definitions[cid]['lines'].append([f'dev_{fsid}_{d}_wbw', f'{d} write', 'absolute', -1, 1])

    # Sampling
    def _delta(self, cur, prev):
        """Compute counter deltas between two samples."""
        out = {}
        for k in ('reads', 'writes', 'sectors_read', 'sectors_written'):
            if cur is None or prev is None:
                out[k] = 0
            else:
                d = cur.get(k, 0) - prev.get(k, 0)
                out[k] = d if d >= 0 else 0
        return out

    def _sample_rawstat(self):
        """Read current SYSFS_PATH_BLOCK/<dev>/stat counters for all devices."""
        rawstat = {}
        for fsid, meta in self.fs_map.items():
            for dev in meta['devices']:
                rawstat[dev] = _read_block_stat(dev)
        return rawstat

    def get_data(self):
        """Collect data, compute deltas, and return dict of dimension values."""
        now = time.time()
        if self.last_ts is None:
            self.last_ts = now

        # Calculate delta time between samples
        dt = max(0.001, now - self.last_ts)
        rawstat = self._sample_rawstat()

        # aggregate per-FS read/write totals
        fs_r_iops = defaultdict(float)
        fs_w_iops = defaultdict(float)
        fs_r_bw = defaultdict(float)
        fs_w_bw = defaultdict(float)
        per_fs_dev = defaultdict(lambda: {'riops': {}, 'wiops': {}, 'rbw': {}, 'wbw': {}})

        # compute per-device deltas
        for fsid, meta in self.fs_map.items():
            for dev in meta['devices']:
                cur = rawstat.get(dev)
                prv = self.prev.get(dev)
                d = self._delta(cur, prv)

                # convert to rates
                riops = d['reads'] / dt
                wiops = d['writes'] / dt
                rbw = (d['sectors_read'] * SECTOR_SIZE) / dt
                wbw = (d['sectors_written'] * SECTOR_SIZE) / dt

                # accumulate per-FS totals
                fs_r_iops[fsid] += riops
                fs_w_iops[fsid] += wiops
                fs_r_bw[fsid] += rbw
                fs_w_bw[fsid] += wbw

                # store per-device values
                per_fs_dev[fsid]['riops'][dev] = int(riops)
                per_fs_dev[fsid]['wiops'][dev] = int(wiops)
                per_fs_dev[fsid]['rbw'][dev] = int(rbw)
                per_fs_dev[fsid]['wbw'][dev] = int(wbw)

        # update previous snapshot
        self.prev = rawstat
        self.last_ts = now

        # build Netdata data dict
        data = {}
        for fsid in self.fs_map.keys():
            # Global totals
            data[f'fs_{fsid}_riops'] = int(fs_r_iops[fsid])
            data[f'fs_{fsid}_wiops'] = int(fs_w_iops[fsid])
            data[f'fs_{fsid}_rbw'] = int(fs_r_bw[fsid])
            data[f'fs_{fsid}_wbw'] = int(fs_w_bw[fsid])

            # Per-device
            for dev, v in per_fs_dev[fsid]['riops'].items():
                data[f'dev_{fsid}_{dev}_riops'] = v
            for dev, v in per_fs_dev[fsid]['wiops'].items():
                data[f'dev_{fsid}_{dev}_wiops'] = v
            for dev, v in per_fs_dev[fsid]['rbw'].items():
                data[f'dev_{fsid}_{dev}_rbw'] = v
            for dev, v in per_fs_dev[fsid]['wbw'].items():
                data[f'dev_{fsid}_{dev}_wbw'] = v

        return data
