import argparse
import os
import subprocess
import sys
import tempfile


SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SCHP_DIR = os.path.join(SCRIPT_DIR, "SCHP")
DATASET = "lip"
GPU_ID = "0"
CHECKPOINT = os.path.join(SCHP_DIR, "checkpoints/exp-schp-201908261155-lip.pth")
SIMPLE_EXTRACTOR = os.path.join(SCHP_DIR, "simple_extractor.py")
DEFAULT_INPUT_DIR = os.path.join(SCRIPT_DIR, "inputs")
DEFAULT_OUTPUT_DIR = os.path.join(SCRIPT_DIR, "outputs")
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".webp")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-dir", default=DEFAULT_INPUT_DIR, help="입력 이미지 폴더")
    parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR, help="의상 분리 결과 저장 폴더")
    return parser.parse_args()


def build_env():
    cache_home = os.path.join(tempfile.gettempdir(), f"schp_cache_{os.getuid()}")
    torch_extensions_dir = os.path.join(tempfile.gettempdir(), f"schp_torch_extensions_{os.getuid()}")
    os.makedirs(cache_home, mode=0o755, exist_ok=True)
    os.makedirs(torch_extensions_dir, mode=0o755, exist_ok=True)

    env = os.environ.copy()
    env["HOME"] = cache_home
    env["XDG_CACHE_HOME"] = cache_home
    env["TORCH_EXTENSIONS_DIR"] = torch_extensions_dir
    env["TORCH_CUDA_ARCH_LIST"] = "8.9"
    env["CUDA_VISIBLE_DEVICES"] = GPU_ID
    env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "")
    return env


def get_image_files(input_dir):
    return sorted(
        file_name
        for file_name in os.listdir(input_dir)
        if file_name.lower().endswith(IMAGE_EXTENSIONS)
    )


def main():
    args = parse_args()
    input_dir = os.path.abspath(args.input_dir)
    output_dir = os.path.abspath(args.output_dir)

    if not os.path.isdir(input_dir):
        raise FileNotFoundError(f"입력 폴더가 없습니다: {input_dir}")

    image_files = get_image_files(input_dir)
    if not image_files:
        raise RuntimeError(f"입력 폴더에 이미지가 없습니다: {input_dir}")

    os.makedirs(output_dir, exist_ok=True)

    cmd = [
        sys.executable,
        SIMPLE_EXTRACTOR,
        "--dataset", DATASET,
        "--gpu", GPU_ID,
        "--model-restore", CHECKPOINT,
        "--input-dir", input_dir,
        "--output-dir", output_dir,
    ]

    print("입력 폴더:", input_dir)
    print("이미지 수:", len(image_files))
    for image_file in image_files:
        print(" -", image_file)
    print("결과 폴더:", output_dir)
    print("cmd:", cmd)
    print("")

    subprocess.run(cmd, check=True, env=build_env(), cwd=SCHP_DIR)

    print("SCHP 실행 완료")
    print("결과 폴더:", output_dir)


if __name__ == "__main__":
    main()
