# Use a PyTorch image with CUDA support for faster training and better compatibility # PyTorch 2.1.0 with CUDA 12.1 is fully compatible with NVIDIA A10G (Ampere) FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime # Set architecture list for A10G (Ampere, Compute Capability 8.6) ENV TORCH_CUDA_ARCH_LIST="8.6" # Install system dependencies (ffmpeg for imageio/visualization, git for pip) RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ ffmpeg \ git \ && rm -rf /var/lib/apt/lists/* # Create a non-root user to match HF Spaces default (user 1000) RUN useradd -m -u 1000 user # Set working directory WORKDIR /app # Install dependencies COPY requirements.txt . # 1. Remove torch and torchvision from requirements.txt to prevent pip from upgrading them # 2. Install the rest of the requirements. # 3. Explicitly ensure compatible torchvision is installed (0.16.0 matches torch 2.1.0). RUN sed -i '/torch/d' requirements.txt && \ pip install --no-cache-dir -r requirements.txt && \ pip install --no-cache-dir torchvision==0.16.0 # Copy all project files into the container COPY . . # Copy entrypoint COPY entrypoint.sh /app/entrypoint.sh RUN chmod +x /app/entrypoint.sh # Set up environment variables for the user ENV HOME=/home/user \ PATH=/home/user/.local/bin:$PATH \ MPLCONFIGDIR=/tmp/matplotlib \ NUMBA_CACHE_DIR=/tmp/numba_cache # Create cache directories with correct permissions RUN mkdir -p /tmp/matplotlib /tmp/numba_cache && \ chmod 777 /tmp/matplotlib /tmp/numba_cache && \ chown -R user:user /app # Switch to the non-root user USER user # Accelerate configuration is now handled in entrypoint.sh at runtime ENTRYPOINT ["/app/entrypoint.sh"] CMD ["--energy_head_enabled", "--loss_type", "energy_contrastive", "--push_to_hub", "--hub_model_id", "Uday/ctm-energy-based-halting"]