import os
import subprocess
import argparse
import shutil
import sys
import tempfile


SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SCHP_DIR = os.path.join(SCRIPT_DIR, "SCHP")
DATASET = "lip"
GPU_ID = "0"
CHECKPOINT = os.path.join(SCHP_DIR, "checkpoints/exp-schp-201908261155-lip.pth")
SIMPLE_EXTRACTOR = os.path.join(SCHP_DIR, "simple_extractor.py")

parser = argparse.ArgumentParser()
parser.add_argument("image_path", help="입력 이미지 경로")
parser.add_argument("output_dir", help="출력 폴더")

args = parser.parse_args()

cache_home = os.path.join(tempfile.gettempdir(), f"schp_cache_{os.getuid()}")
torch_extensions_dir = os.path.join(tempfile.gettempdir(), f"schp_torch_extensions_{os.getuid()}")
os.makedirs(cache_home, mode=0o755, exist_ok=True)
os.makedirs(torch_extensions_dir, mode=0o755, exist_ok=True)

env = os.environ.copy()
env["HOME"] = cache_home
env["XDG_CACHE_HOME"] = cache_home
env["TORCH_EXTENSIONS_DIR"] = torch_extensions_dir
env["TORCH_CUDA_ARCH_LIST"] = "8.9"
env["CUDA_VISIBLE_DEVICES"] = GPU_ID
env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "")

with tempfile.TemporaryDirectory(prefix="schp_inputs_") as input_dir:
    image_name = os.path.basename(args.image_path)
    target_path = os.path.join(input_dir, image_name)

    # 입력 이미지를 실행별 임시 폴더로 복사한다.
    shutil.copy(args.image_path, target_path)

    cmd = [
        sys.executable,
        SIMPLE_EXTRACTOR,
        "--dataset", DATASET,
        "--gpu", GPU_ID,
        "--model-restore", CHECKPOINT,
        "--input-dir", input_dir,
        "--output-dir", args.output_dir
    ]

    print("cmd: ", cmd)
    print("") 
    subprocess.run(cmd, check=True, env=env)

print("SCHP 실행 완료")
print("결과 폴더:", args.output_dir)
