#!/usr/bin/env python3
# group: rw vvfat
#
# Test vvfat driver implementation
# Here, we use a simple FAT16 implementation and check the behavior of
# the vvfat driver.
#
# Copyright (C) 2024 Amjad Alsharafi <amjadsharafi10@gmail.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import shutil
import iotests
from iotests import imgfmt, QMPTestCase
from fat16 import MBR, Fat16, DIRENTRY_SIZE

filesystem = os.path.join(iotests.test_dir, "filesystem")

nbd_sock = iotests.file_path("nbd.sock", base_dir=iotests.sock_dir)
nbd_uri = "nbd+unix:///disk?socket=" + nbd_sock

SECTOR_SIZE = 512


class TestVVFatDriver(QMPTestCase):
    # pylint: disable=broad-exception-raised
    def setUp(self) -> None:
        if os.path.exists(filesystem):
            if os.path.isdir(filesystem):
                shutil.rmtree(filesystem)
            else:
                raise Exception(f"{filesystem} exists and is not a directory")

        os.mkdir(filesystem)

        # Add some text files to the filesystem
        for i in range(10):
            with open(os.path.join(filesystem, f"file{i}.txt"),
                      "w", encoding="ascii") as f:
                f.write(f"Hello, world! {i}\n")

        # Add 2 large files, above the cluster size (8KB)
        with open(os.path.join(filesystem, "large1.txt"), "wb") as f:
            # write 'A' * 1KB, 'B' * 1KB, 'C' * 1KB, ...
            for i in range(8 * 2):  # two clusters
                f.write(bytes([0x41 + i] * 1024))

        with open(os.path.join(filesystem, "large2.txt"), "wb") as f:
            # write 'A' * 1KB, 'B' * 1KB, 'C' * 1KB, ...
            for i in range(8 * 3):  # 3 clusters
                f.write(bytes([0x41 + i] * 1024))

        self.vm = iotests.VM()

        self.vm.add_blockdev(
            self.vm.qmp_to_opts(
                {
                    "driver": imgfmt,
                    "node-name": "disk",
                    "rw": "true",
                    "fat-type": "16",
                    "dir": filesystem,
                }
            )
        )

        self.vm.launch()

        self.vm.qmp_log("block-dirty-bitmap-add", **{
            "node": "disk",
            "name": "bitmap0",
        })

        # attach nbd server
        self.vm.qmp_log(
            "nbd-server-start",
            **{"addr": {"type": "unix", "data": {"path": nbd_sock}}},
            filters=[],
        )

        self.vm.qmp_log(
            "nbd-server-add",
            **{"device": "disk", "writable": True, "bitmap": "bitmap0"},
        )

        self.qio = iotests.QemuIoInteractive("-f", "raw", nbd_uri)

    def tearDown(self) -> None:
        self.qio.close()
        self.vm.shutdown()
        # print(self.vm.get_log())
        shutil.rmtree(filesystem)

    def read_sectors(self, sector: int, num: int = 1) -> bytes:
        """
        Read `num` sectors starting from `sector` from the `disk`.
        This uses `QemuIoInteractive` to read the sectors into `stdout` and
        then parse the output.
        """
        self.assertGreater(num, 0)

        # The output contains the content of the sector in hex dump format
        # We need to extract the content from it
        output = self.qio.cmd(
            f"read -v {sector * SECTOR_SIZE} {num * SECTOR_SIZE}")

        # Each row is 16 bytes long, and we are writing `num` sectors
        rows = num * SECTOR_SIZE // 16
        output_rows = output.split("\n")[:rows]

        hex_content = "".join(
            [(row.split(": ")[1]).split("  ")[0] for row in output_rows]
        )
        bytes_content = bytes.fromhex(hex_content)

        self.assertEqual(len(bytes_content), num * SECTOR_SIZE)

        return bytes_content

    def write_sectors(self, sector: int, data: bytes) -> None:
        """
        Write `data` to the `disk` starting from `sector`.
        This uses `QemuIoInteractive` to write the data into the disk.
        """

        self.assertGreater(len(data), 0)
        self.assertEqual(len(data) % SECTOR_SIZE, 0)

        temp_file = os.path.join(iotests.test_dir, "temp.bin")
        with open(temp_file, "wb") as f:
            f.write(data)

        self.qio.cmd(
            f"write -s {temp_file} {sector * SECTOR_SIZE} {len(data)}"
        )

        os.remove(temp_file)

    def init_fat16(self):
        mbr = MBR(self.read_sectors(0))
        return Fat16(
            mbr.partition_table[0]["start_lba"],
            mbr.partition_table[0]["size"],
            self.read_sectors,
            self.write_sectors,
        )

    # Tests

    def test_fat_filesystem(self):
        """
        Test that vvfat produce a valid FAT16 and MBR sectors
        """
        mbr = MBR(self.read_sectors(0))

        self.assertEqual(mbr.partition_table[0]["status"], 0x80)
        self.assertEqual(mbr.partition_table[0]["type"], 6)

        fat16 = Fat16(
            mbr.partition_table[0]["start_lba"],
            mbr.partition_table[0]["size"],
            self.read_sectors,
            self.write_sectors,
        )
        self.assertEqual(fat16.boot_sector.bytes_per_sector, 512)
        self.assertEqual(fat16.boot_sector.volume_label, "QEMU VVFAT")

    def test_read_root_directory(self):
        """
        Test the content of the root directory
        """
        fat16 = self.init_fat16()

        root_dir = fat16.read_root_directory()

        self.assertEqual(len(root_dir), 13)  # 12 + 1 special file

        files = {
            "QEMU VVF.AT": 0,  # special empty file
            "FILE0.TXT": 16,
            "FILE1.TXT": 16,
            "FILE2.TXT": 16,
            "FILE3.TXT": 16,
            "FILE4.TXT": 16,
            "FILE5.TXT": 16,
            "FILE6.TXT": 16,
            "FILE7.TXT": 16,
            "FILE8.TXT": 16,
            "FILE9.TXT": 16,
            "LARGE1.TXT": 0x2000 * 2,
            "LARGE2.TXT": 0x2000 * 3,
        }

        for entry in root_dir:
            self.assertIn(entry.whole_name(), files)
            self.assertEqual(entry.size_bytes, files[entry.whole_name()])

    def test_direntry_as_bytes(self):
        """
        Test if we can convert Direntry back to bytes, so that we can write it
        back to the disk safely.
        """
        fat16 = self.init_fat16()

        root_dir = fat16.read_root_directory()
        first_entry_bytes = fat16.read_sectors(
            fat16.boot_sector.root_dir_start(), 1)

        # The first entry won't be deleted, so we can compare it with the first
        # entry in the root directory
        self.assertEqual(root_dir[0].as_bytes(),
                         first_entry_bytes[:DIRENTRY_SIZE])

    def test_read_files(self):
        """
        Test reading the content of the files
        """
        fat16 = self.init_fat16()

        for i in range(10):
            file = fat16.find_direntry(f"/FILE{i}.TXT")
            self.assertIsNotNone(file)
            self.assertEqual(
                fat16.read_file(file), f"Hello, world! {i}\n".encode("ascii")
            )

        # test large files
        large1 = fat16.find_direntry("/LARGE1.TXT")
        with open(os.path.join(filesystem, "large1.txt"), "rb") as f:
            self.assertEqual(fat16.read_file(large1), f.read())

        large2 = fat16.find_direntry("/LARGE2.TXT")
        self.assertIsNotNone(large2)
        with open(os.path.join(filesystem, "large2.txt"), "rb") as f:
            self.assertEqual(fat16.read_file(large2), f.read())

    def test_write_file_same_content_direct(self):
        """
        Similar to `test_write_file_in_same_content`, but we write the file
        directly clusters and thus we don't go through the modification of
        direntry.
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/FILE0.TXT")
        self.assertIsNotNone(file)

        data = fat16.read_cluster(file.cluster)
        fat16.write_cluster(file.cluster, data)

        with open(os.path.join(filesystem, "file0.txt"), "rb") as f:
            self.assertEqual(fat16.read_file(file), f.read())

    def test_write_file_in_same_content(self):
        """
        Test writing the same content to the file back to it
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/FILE0.TXT")
        self.assertIsNotNone(file)

        self.assertEqual(fat16.read_file(file), b"Hello, world! 0\n")

        fat16.write_file(file, b"Hello, world! 0\n")
        self.assertEqual(fat16.read_file(file), b"Hello, world! 0\n")

        with open(os.path.join(filesystem, "file0.txt"), "rb") as f:
            self.assertEqual(f.read(), b"Hello, world! 0\n")

    def test_modify_content_same_clusters(self):
        """
        Test modifying the content of the file without changing the number of
        clusters
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/FILE0.TXT")
        self.assertIsNotNone(file)

        new_content = b"Hello, world! Modified\n"
        self.assertEqual(fat16.read_file(file), b"Hello, world! 0\n")

        fat16.write_file(file, new_content)
        self.assertEqual(fat16.read_file(file), new_content)

        with open(os.path.join(filesystem, "file0.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    def test_truncate_file_same_clusters_less(self):
        """
        Test truncating the file without changing number of clusters
        Test decreasing the file size
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/FILE0.TXT")
        self.assertIsNotNone(file)

        self.assertEqual(fat16.read_file(file), b"Hello, world! 0\n")

        fat16.truncate_file(file, 5)
        new_content = fat16.read_file(file)
        self.assertEqual(new_content, b"Hello")

        with open(os.path.join(filesystem, "file0.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    def test_truncate_file_same_clusters_more(self):
        """
        Test truncating the file without changing number of clusters
        Test increase the file size
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/FILE0.TXT")
        self.assertIsNotNone(file)

        self.assertEqual(fat16.read_file(file), b"Hello, world! 0\n")

        fat16.truncate_file(file, 20)
        new_content = fat16.read_file(file)
        self.assertIsNotNone(new_content)

        # random pattern will be appended to the file, and its not always the
        # same
        self.assertEqual(new_content[:16], b"Hello, world! 0\n")
        self.assertEqual(len(new_content), 20)

        with open(os.path.join(filesystem, "file0.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    def test_write_large_file(self):
        """
        Test writing a large file
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/LARGE1.TXT")
        self.assertIsNotNone(file)

        # The content of LARGE1 is A * 1KB, B * 1KB, C * 1KB, ..., P * 1KB
        # Lets change it to be Z * 1KB, Y * 1KB, X * 1KB, ..., K * 1KB
        # without changing the number of clusters or filesize
        new_content = b"".join([bytes([0x5A - i] * 1024) for i in range(16)])
        fat16.write_file(file, new_content)
        self.assertEqual(fat16.read_file(file), new_content)

        with open(os.path.join(filesystem, "large1.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    def test_truncate_file_change_clusters_less(self):
        """
        Test truncating a file by reducing the number of clusters
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/LARGE1.TXT")
        self.assertIsNotNone(file)

        fat16.truncate_file(file, 1)
        self.assertEqual(fat16.read_file(file), b"A")

        with open(os.path.join(filesystem, "large1.txt"), "rb") as f:
            self.assertEqual(f.read(), b"A")

    def test_write_file_change_clusters_less(self):
        """
        Test truncating a file by reducing the number of clusters
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/LARGE2.TXT")
        self.assertIsNotNone(file)

        new_content = b"X" * 8 * 1024 + b"Y" * 8 * 1024
        fat16.write_file(file, new_content)
        self.assertEqual(fat16.read_file(file), new_content)

        with open(os.path.join(filesystem, "large2.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    def test_write_file_change_clusters_more(self):
        """
        Test truncating a file by increasing the number of clusters
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/LARGE2.TXT")
        self.assertIsNotNone(file)

        # from 3 clusters to 4 clusters
        new_content = (
            b"W" * 8 * 1024 +
            b"X" * 8 * 1024 +
            b"Y" * 8 * 1024 +
            b"Z" * 8 * 1024
        )
        fat16.write_file(file, new_content)
        self.assertEqual(fat16.read_file(file), new_content)

        with open(os.path.join(filesystem, "large2.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    def test_write_file_change_clusters_more_non_contiguous_2_mappings(self):
        """
        Test truncating a file by increasing the number of clusters Here we
        allocate the new clusters in a way that makes them non-contiguous so
        that we will get 2 cluster mappings for the file
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/LARGE1.TXT")
        self.assertIsNotNone(file)

        # from 2 clusters to 3 clusters with non-contiguous allocation
        fat16.truncate_file(file, 3 * 0x2000, allocate_non_continuous=True)
        new_content = b"X" * 8 * 1024 + b"Y" * 8 * 1024 + b"Z" * 8 * 1024
        fat16.write_file(file, new_content)
        self.assertEqual(fat16.read_file(file), new_content)

        with open(os.path.join(filesystem, "large1.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    def test_write_file_change_clusters_more_non_contiguous_3_mappings(self):
        """
        Test truncating a file by increasing the number of clusters Here we
        allocate the new clusters in a way that makes them non-contiguous so
        that we will get 3 cluster mappings for the file
        """
        fat16 = self.init_fat16()

        file = fat16.find_direntry("/LARGE1.TXT")
        self.assertIsNotNone(file)

        # from 2 clusters to 4 clusters with non-contiguous allocation
        fat16.truncate_file(file, 4 * 0x2000, allocate_non_continuous=True)
        new_content = (
            b"W" * 8 * 1024 +
            b"X" * 8 * 1024 +
            b"Y" * 8 * 1024 +
            b"Z" * 8 * 1024
        )
        fat16.write_file(file, new_content)
        self.assertEqual(fat16.read_file(file), new_content)

        with open(os.path.join(filesystem, "large1.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    def test_create_file(self):
        """
        Test creating a new file
        """
        fat16 = self.init_fat16()

        new_file = fat16.create_file("/NEWFILE.TXT")

        self.assertIsNotNone(new_file)
        self.assertEqual(new_file.size_bytes, 0)

        new_content = b"Hello, world! New file\n"
        fat16.write_file(new_file, new_content)
        self.assertEqual(fat16.read_file(new_file), new_content)

        with open(os.path.join(filesystem, "newfile.txt"), "rb") as f:
            self.assertEqual(f.read(), new_content)

    # TODO: support deleting files


if __name__ == "__main__":
    # This is a specific test for vvfat driver
    iotests.main(supported_fmts=["vvfat"], supported_protocols=["file"])
