#!/usr/bin/env bash # scripts/install_flash_attn.sh # # Build/install flash-attention into the active conda venv, detached # under nohup so SSH disconnects don't kill the build (which can take # 30-60 minutes on a fresh checkout). # # Usage: # conda activate causalgrok # do this in your shell first # bash scripts/install_flash_attn.sh # auto-pick FLASH_ATTN_VERSION # FLASH_ATTN_VERSION=2.7.4 bash scripts/install_flash_attn.sh # # The build/install logs land in: # logs/install/_flash_attn/{logs/train.log,logs/train.err,run.pid,env.txt} # Watch progress with: # tail -f logs/install/_flash_attn/logs/train.log # Kill if needed: # kill "$(cat logs/install/_flash_attn/run.pid)" # # Requirements: # - active conda env with PyTorch already installed (matching CUDA) # - nvcc / CUDA toolkit visible (the build needs it) # - Ampere+ GPU (A100/H100/RTX30xx/40xx). On A100 (sm_80) flash-attn # ships prebuilt wheels for most PyTorch×CUDA combos, so the # install is usually fast. set -euo pipefail ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "${ROOT}" source "${ROOT}/scripts/lib/nohup_runner.sh" FLASH_ATTN_VERSION="${FLASH_ATTN_VERSION:-2.7.4}" # ── Sanity checks ───────────────────────────────────────────────────── if ! command -v python >/dev/null 2>&1; then echo " python not found in PATH. Activate the conda env first:" >&2 echo " conda activate causalgrok" >&2 exit 1 fi if [[ -z "${CONDA_PREFIX:-}" ]]; then echo " CONDA_PREFIX is empty — no conda env appears to be active." >&2 echo " conda activate causalgrok" >&2 exit 1 fi # ── Where the install logs live ─────────────────────────────────────── # logs/install/ is reserved for environment / dependency setup output. # It is deliberately separate from experiments/runs/ so an install # never gets confused with a training run. STAMP="$(date -u +%Y%m%d-%H%M%S)" INSTALL_DIR="logs/install/${STAMP}_flash_attn" mkdir -p "${INSTALL_DIR}" # Snapshot the env we're installing into (so we can debug later) { echo "# captured: $(date -u +%FT%TZ)" echo "# host: $(hostname)" echo "# CONDA_PREFIX:${CONDA_PREFIX}" echo "# CONDA_DEFAULT_ENV: ${CONDA_DEFAULT_ENV:-}" echo "# python: $(python --version 2>&1)" echo "# which python: $(command -v python)" if command -v nvcc >/dev/null 2>&1; then echo "# nvcc: $(nvcc --version | tail -1)" else echo "# nvcc: NOT FOUND (CUDA toolkit may be missing)" fi if command -v nvidia-smi >/dev/null 2>&1; then echo "# nvidia-smi:" nvidia-smi --query-gpu=name,driver_version,compute_cap --format=csv,noheader \ | sed 's/^/# /' fi python -c "import torch, sys; print(f'# torch: {torch.__version__}'); print(f'# torch.cuda: {torch.version.cuda}'); print(f'# CUDA avail: {torch.cuda.is_available()}')" 2>/dev/null \ || echo "# torch: NOT INSTALLED — install torch first" } > "${INSTALL_DIR}/env.txt" cat "${INSTALL_DIR}/env.txt" echo # ── Pre-flight: torch must be present ───────────────────────────────── if ! python -c "import torch" >/dev/null 2>&1; then echo " torch is not installed in this env. Install it first:" >&2 echo " pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118" >&2 exit 1 fi # ── Launch the build under nohup ────────────────────────────────────── # packaging/ninja/wheel are required for the source build path; the # resolver will pick the prebuilt wheel where one exists. echo "Starting flash-attn ${FLASH_ATTN_VERSION} install (detached)..." launch_detached "${INSTALL_DIR}" \ bash -c " set -euo pipefail echo '== installing build deps ==' pip install --upgrade pip wheel setuptools packaging ninja echo echo '== installing flash-attn ${FLASH_ATTN_VERSION} ==' pip install --no-build-isolation 'flash-attn==${FLASH_ATTN_VERSION}' echo echo '== sanity check ==' python -c 'import flash_attn; print(\"flash-attn version:\", flash_attn.__version__)' echo 'DONE' " echo echo "Outputs:" echo " env snapshot : ${INSTALL_DIR}/env.txt" echo " build log : ${INSTALL_DIR}/logs/train.log" echo " build err : ${INSTALL_DIR}/logs/train.err" echo " PID : ${INSTALL_DIR}/run.pid" echo echo "Verify completion with:" echo " grep -E 'DONE|ERROR' ${INSTALL_DIR}/logs/train.log"