#!/bin/python3
import os
import sys
import logging
import shutil
import argparse
import subprocess
from tempfile import gettempdir
from subprocess import Popen
from pathlib import Path
from functools import lru_cache

logger = logging.getLogger(__name__)

@lru_cache()
def tmp_dir() -> Path:
    candidate = Path(gettempdir()) / "lief-binaryninja-tests"
    if candidate.is_dir():
        shutil.rmtree(candidate, ignore_errors=True)
    candidate.mkdir(exist_ok=True)
    return candidate

@lru_cache()
def _lief_samples_dir() -> Path:
    dir_path = Path(os.environ["LIEF_SAMPLES_DIR"])
    assert dir_path.is_dir()
    return dir_path

@lru_cache()
def find_program(name: str) -> Path:
    target_paths = os.getenv("PATH", "").split(":")
    target_paths.append(str(_lief_samples_dir() / "ELF"))

    target = shutil.which(name, path=':'.join(target_paths))
    assert target is not None
    return Path(target)

@lru_cache()
def find_bndb(name: str) -> Path:
    candidates = [
        _lief_samples_dir() / "private/binaryninja" / name,
        _lief_samples_dir() / "binaryninja" / name,
    ]

    for candidate in candidates:
        if candidate.is_file():
            logger.debug("%s -> %s", name, candidate.resolve().absolute())
            return candidate
    raise FileNotFoundError(f"Can't find path for '{name}''")

@lru_cache()
def find_check_file(name: str) -> Path:
    candidates = [
        _lief_samples_dir() / "private/binaryninja" / name,
        _lief_samples_dir() / "binaryninja" / name,
    ]

    for candidate in candidates:
        if candidate.is_file():
            logger.debug("%s -> %s", name, candidate.resolve().absolute())
            return candidate
    raise FileNotFoundError(f"Can't find path for '{name}''")

@lru_cache()
def ld_library_path() -> str:
    diff_tool = find_program("lief-diff-analysis-tool-linux-x86_64")

    target_paths = os.getenv("LD_LIBRARY_PATH", "").split(":")
    target_paths.extend([
        (diff_tool.parent / "../../../build/LIEF/").resolve().absolute().as_posix()
    ])
    return ':'.join(target_paths)

def run_test(name: str, bndb: Path, check_file: Path, verbose: bool):
    logger.info("== %s ==", name)

    filecheck = find_program("FileCheck")
    diff_tool = find_program("lief-diff-analysis-tool-linux-x86_64")

    bndb_path = find_bndb(bndb)
    check_file_path = find_check_file(check_file)

    tmp_test_dir = tmp_dir() / name
    tmp_test_dir.mkdir(exist_ok=True)

    original = tmp_test_dir / "original.txt"
    new = tmp_test_dir / "new.txt"

    env = dict(os.environ)
    env['LD_LIBRARY_PATH'] = ld_library_path()

    args = {
        'stdout': subprocess.PIPE,
        'stderr': subprocess.PIPE,
        'universal_newlines': True,
        'env': env,
    }

    with Popen([diff_tool, bndb_path, original, new], **args) as P:
        stdout = P.stdout.read()
        stderr = P.stderr.read()

        logger.debug("-- STDOUT --\n%s", stdout)
        logger.debug("-- STDERR --\n%s", stderr)

        P.communicate()

        (tmp_test_dir / "analysis.stdout.txt").write_text(stdout)
        (tmp_test_dir / "analysis.stderr.txt").write_text(stderr)

        assert P.returncode == 0

    file_check_args = [
        filecheck,
        '-vv',
        '--color',
        "--input-file", new,
        check_file_path
    ]
    logger.debug("FileCheck: %s", " ".join(map(str, file_check_args)))

    with Popen(file_check_args, **args) as P:
        stdout = P.stdout.read()
        stderr = P.stderr.read()

        logger.debug("-- STDOUT --\n%s", stdout)
        logger.debug("-- STDERR --\n%s", stderr)

        P.communicate()

        (tmp_test_dir / "file_check.stdout.txt").write_text(stdout)
        (tmp_test_dir / "file_check.stderr.txt").write_text(stderr)

        assert P.returncode == 0

def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--verbose", action='store_true', default=False)
    parser.add_argument("--name", type=str, default='')
    args = parser.parse_args()

    test_name = args.name

    log_level = logging.INFO if not args.verbose else logging.DEBUG

    formatter = logging.Formatter('%(asctime)s - %(funcName)s - %(levelname)s - %(message)s')

    stdout = logging.StreamHandler(stream=sys.stdout)
    stdout.setLevel(log_level)
    stdout.setFormatter(formatter)

    fh = logging.FileHandler(tmp_dir() / "log.txt")
    fh.setFormatter(formatter)
    fh.setLevel(log_level)

    logger.addHandler(stdout)
    logger.addHandler(fh)


    logger.setLevel(log_level)

    if test_name in ('', 'arm64ec_hello_world_2025'):
        run_test(
            "arm64ec_hello_world_2025",
            find_bndb("arm64ec_hello_world_2025.exe.bndb"),
            find_check_file("arm64ec_hello_world_2025.check"),
            args.verbose
        )

    if test_name in ('', 'hostfxr'):
        run_test(
            "hostfxr",
            find_bndb("hostfxr.dll.bndb"),
            find_check_file("hostfxr.check"),
            args.verbose
    )

    if test_name in ('', 'hello-android-packed.aarch64.elf'):
        run_test(
            "hello-android-packed.aarch64.elf",
            find_bndb("hello-android-packed.aarch64.elf.bndb"),
            find_check_file("hello-android-packed.aarch64.elf.check"),
            args.verbose
        )


    if test_name in ('', 'libnl.so'):
        run_test(
            "libnl.so",
            find_bndb("libnl.so.bndb"),
            find_check_file("libnl.so.check"),
            args.verbose
        )
    return 0

if __name__ == "__main__":
    sys.exit(main())
