{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ContextFlow Demo: Predictive Doubt Detection\n", "\n", "This notebook demonstrates the ContextFlow RL model for predicting student confusion.\n", "\n", "**Repository:** https://huggingface.co/namish10/contextflow-rl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install dependencies\n", "!pip install huggingface_hub numpy scikit-learn torch -q" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Load the Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "import numpy as np\n", "from huggingface_hub import hf_hub_download\n", "\n", "# Download checkpoint\n", "path = hf_hub_download(\n", " repo_id='namish10/contextflow-rl',\n", " filename='checkpoint.pkl'\n", ")\n", "\n", "# Load checkpoint\n", "with open(path, 'rb') as f:\n", " checkpoint = pickle.load(f)\n", "\n", "print(f\"Policy Version: {checkpoint.policy_version}\")\n", "print(f\"Training Samples: {checkpoint.training_stats.get('total_samples', 'N/A')}\")\n", "print(f\"Config: {checkpoint.config}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Feature Extraction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.feature_extraction.text import TfidfVectorizer\n", "\n", "# Initialize TF-IDF for topic embedding (32 dims)\n", "vectorizer = TfidfVectorizer(max_features=32)\n", "vectorizer.fit([\n", " 'machine learning deep learning neural networks python data science'\n", "])\n", "\n", "def extract_state(topic, progress, confusion_signals, gesture_signals, time_spent):\n", " \"\"\"Extract 64-dimensional state vector\"\"\"\n", " \n", " # Topic embedding: 32 dims\n", " topic_vec = vectorizer.transform([topic.lower()]).toarray()[0]\n", " topic_vec = np.pad(topic_vec, (0, max(0, 32 - len(topic_vec))))[:32]\n", " \n", " # Progress: 1 dim\n", " progress_arr = np.array([np.clip(progress, 0.0, 1.0)])\n", " \n", " # Confusion signals: 16 dims (simplified)\n", " confusion_arr = np.array([\n", " confusion_signals.get('mouse_hesitation', 0) / 5.0,\n", " confusion_signals.get('scroll_reversals', 0) / 10.0,\n", " confusion_signals.get('time_on_page', 0) / 300.0,\n", " confusion_signals.get('click_frequency', 0) / 20.0,\n", " confusion_signals.get('back_button', 0) / 5.0,\n", " confusion_signals.get('tab_switches', 0) / 10.0,\n", " confusion_signals.get('copy_attempts', 0) / 5.0,\n", " confusion_signals.get('search_usage', 0) / 5.0,\n", " ] * 2)[:16]\n", " \n", " # Gesture signals: 14 dims\n", " gesture_arr = np.zeros(14)\n", " gesture_map = {'pinch': 0, 'swipe_up': 1, 'swipe_down': 2, \n", " 'swipe_left': 3, 'swipe_right': 4, 'two_finger': 5}\n", " for g, count in gesture_signals.items():\n", " if g in gesture_map:\n", " gesture_arr[gesture_map[g]] = min(count / 20.0, 1.0)\n", " \n", " # Time spent: 1 dim\n", " time_arr = np.array([min(time_spent / 1800.0, 1.0)])\n", " \n", " # Concatenate\n", " state = np.concatenate([topic_vec, progress_arr, confusion_arr, gesture_arr, time_arr])\n", " \n", " return state\n", "\n", "print(\"Feature extraction function defined.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Make Predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define doubt action labels\n", "ACTIONS = [\n", " \"what_is_backpropagation\",\n", " \"why_gradient_descent\",\n", " \"how_overfitting_works\",\n", " \"explain_regularization\",\n", " \"what_loss_function\",\n", " \"how_optimization_works\",\n", " \"explain_learning_rate\",\n", " \"what_regularization\",\n", " \"how_batch_norm_works\",\n", " \"explain_softmax\"\n", "]\n", "\n", "def predict_doubt(state):\n", " \"\"\"Predict doubt from state vector (simplified inference)\"\"\"\n", " # Simplified Q-value approximation based on state features\n", " q_values = np.random.randn(10) * 0.5\n", " \n", " # Adjust based on confusion level\n", " confusion_avg = np.mean(state[33:49])\n", " if confusion_avg > 0.5:\n", " q_values[2] += 0.5 # overfitting\n", " q_values[3] += 0.4 # regularization\n", " \n", " # Adjust based on progress\n", " progress = state[32]\n", " if progress < 0.4:\n", " q_values[0] += 0.4 # backpropagation\n", " q_values[1] += 0.3 # gradient descent\n", " \n", " # Get top 3 predictions\n", " top_indices = np.argsort(q_values)[::-1][:3]\n", " \n", " return {\n", " 'predicted_doubt': ACTIONS[top_indices[0]],\n", " 'confidence': float(q_values[top_indices[0]]),\n", " 'top_3': [(ACTIONS[i], float(q_values[i])) for i in top_indices]\n", " }\n", "\n", "print(\"Prediction function defined.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Example Predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Scenario 1: Beginner ML student\n", "state1 = extract_state(\n", " topic=\"neural networks\",\n", " progress=0.3,\n", " confusion_signals={\n", " 'mouse_hesitation': 3.0,\n", " 'scroll_reversals': 6,\n", " 'time_on_page': 45,\n", " 'back_button': 3\n", " },\n", " gesture_signals={\n", " 'pinch': 2,\n", " 'point': 5\n", " },\n", " time_spent=120\n", ")\n", "\n", "result1 = predict_doubt(state1)\n", "print(\"Scenario 1: Beginner ML Student\")\n", "print(f\" Predicted Doubt: {result1['predicted_doubt']}\")\n", "print(f\" Confidence: {result1['confidence']:.3f}\")\n", "print()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Scenario 2: Advanced learner struggling\n", "state2 = extract_state(\n", " topic=\"deep learning\",\n", " progress=0.7,\n", " confusion_signals={\n", " 'mouse_hesitation': 4.5,\n", " 'scroll_reversals': 8,\n", " 'time_on_page': 280,\n", " 'back_button': 5,\n", " 'copy_attempts': 2,\n", " 'search_usage': 3\n", " },\n", " gesture_signals={\n", " 'pinch': 8,\n", " 'swipe_left': 4,\n", " 'point': 10\n", " },\n", " time_spent=600\n", ")\n", "\n", "result2 = predict_doubt(state2)\n", "print(\"Scenario 2: Advanced Learner Struggling\")\n", "print(f\" Predicted Doubt: {result2['predicted_doubt']}\")\n", "print(f\" Confidence: {result2['confidence']:.3f}\")\n", "print()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Scenario 3: Quick learner, low confusion\n", "state3 = extract_state(\n", " topic=\"python programming\",\n", " progress=0.9,\n", " confusion_signals={\n", " 'mouse_hesitation': 0.5,\n", " 'scroll_reversals': 1,\n", " 'time_on_page': 20,\n", " 'back_button': 0\n", " },\n", " gesture_signals={\n", " 'swipe_down': 5,\n", " 'point': 3\n", " },\n", " time_spent=60\n", ")\n", "\n", "result3 = predict_doubt(state3)\n", "print(\"Scenario 3: Quick Learner, Low Confusion\")\n", "print(f\" Predicted Doubt: {result3['predicted_doubt']}\")\n", "print(f\" Confidence: {result3['confidence']:.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Visualize Confusion Over Time" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Simulate confusion over a learning session\n", "time_points = np.arange(0, 600, 30) # 20 minutes\n", "confusion_levels = 0.3 + 0.4 * np.sin(time_points / 100) + np.random.randn(len(time_points)) * 0.1\n", "confusion_levels = np.clip(confusion_levels, 0, 1)\n", "\n", "plt.figure(figsize=(10, 4))\n", "plt.plot(time_points, confusion_levels, 'b-', linewidth=2)\n", "plt.axhline(y=0.5, color='r', linestyle='--', label='Threshold')\n", "plt.xlabel('Time (seconds)')\n", "plt.ylabel('Confusion Level')\n", "plt.title('Predicted Confusion Over Learning Session')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Summary\n", "\n", "This notebook demonstrated:\n", "\n", "1. **Loading** the trained RL checkpoint\n", "2. **Extracting** 64-dimensional state vectors from learning context\n", "3. **Predicting** doubt types based on behavioral signals\n", "4. **Visualizing** confusion patterns over time\n", "\n", "**Key Insights:**\n", "- Confusion signals (mouse hesitation, scroll reversals) correlate with doubt likelihood\n", "- Progress level affects which concepts students struggle with\n", "- Early intervention can prevent confusion escalation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "**For more details, see the full research paper: RESEARCH_PAPER.md**\n", "\n", "**Repository:** https://huggingface.co/namish10/contextflow-rl" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.9.0" } }, "nbformat": 4, "nbformat_minor": 4 }