RavinduSen commited on
Commit
5fc44bd
·
verified ·
1 Parent(s): 6466132

Upload 4 files

Browse files
examples/basic_inference.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic inference example for JaneGPT v2 Intent Classifier.
3
+ """
4
+ import sys
5
+ import os
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+ from model.classifier import JaneGPTClassifier
8
+
9
+ def main():
10
+ # Load model
11
+ classifier = JaneGPTClassifier()
12
+ print(f"Model loaded: {classifier}")
13
+ print(f"Supported intents: {len(classifier.get_supported_intents())}\n")
14
+
15
+ # Test commands
16
+ test_inputs = [
17
+ "turn up the volume",
18
+ "make it louder",
19
+ "set volume to 50",
20
+ "mute",
21
+ "turn down the brightness",
22
+ "open chrome",
23
+ "play shape of you on youtube",
24
+ "search for python tutorials",
25
+ "set a reminder for 10 minutes",
26
+ "take a screenshot",
27
+ "read this for me",
28
+ "explain what's on my screen",
29
+ "undo that",
30
+ "shut down",
31
+ "hello",
32
+ "what time is it",
33
+ ]
34
+
35
+ print(f"{'Input':<45} {'Intent':<20} {'Confidence':<10}")
36
+ print("-" * 75)
37
+
38
+ for text in test_inputs:
39
+ intent, confidence = classifier.predict(text)
40
+ print(f"{text:<45} {intent:<20} {confidence:.1%}")
41
+
42
+ # Context-aware classification
43
+ print("\n--- Context-Aware ---")
44
+
45
+ # After volume up, user says "not enough"
46
+ intent, conf = classifier.predict(
47
+ "not enough",
48
+ context={"last_intent": "volume_up"}
49
+ )
50
+ print(f"{'not enough [after volume_up]':<45} {intent:<20} {conf:.1%}")
51
+
52
+ # Top-k predictions
53
+ print("\n--- Top-3 Predictions ---")
54
+ results = classifier.predict_top_k("play something nice", k=3)
55
+ for intent, conf in results:
56
+ print(f" {intent}: {conf:.1%}")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
examples/batch_inference.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch inference example for JaneGPT v2 Intent Classifier.
3
+
4
+ Classifies multiple inputs efficiently.
5
+ """
6
+ import sys
7
+ import os
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+ import time
10
+ import json
11
+ from pathlib import Path
12
+ from typing import List, Dict
13
+
14
+ import torch
15
+
16
+ from model.classifier import JaneGPTClassifier
17
+
18
+
19
+ def classify_batch(
20
+ classifier: JaneGPTClassifier,
21
+ texts: List[str],
22
+ context: dict = None
23
+ ) -> List[Dict]:
24
+ """
25
+ Classify a batch of texts.
26
+
27
+ Note: Current implementation processes sequentially.
28
+ For true batch processing with padding, see classify_batch_parallel().
29
+
30
+ Args:
31
+ classifier: Loaded JaneGPTClassifier
32
+ texts: List of user utterances
33
+ context: Optional shared context
34
+
35
+ Returns:
36
+ List of result dictionaries
37
+ """
38
+ results = []
39
+
40
+ for text in texts:
41
+ intent, confidence = classifier.predict(text, context)
42
+ results.append({
43
+ "text": text,
44
+ "intent": intent,
45
+ "confidence": round(confidence, 4),
46
+ })
47
+
48
+ return results
49
+
50
+
51
+ def classify_batch_parallel(
52
+ classifier: JaneGPTClassifier,
53
+ texts: List[str],
54
+ context: dict = None
55
+ ) -> List[Dict]:
56
+ """
57
+ Classify a batch of texts in parallel (single forward pass).
58
+
59
+ More efficient for large batches on GPU.
60
+
61
+ Args:
62
+ classifier: Loaded JaneGPTClassifier
63
+ texts: List of user utterances
64
+ context: Optional shared context
65
+
66
+ Returns:
67
+ List of result dictionaries
68
+ """
69
+ if not classifier.is_ready:
70
+ raise RuntimeError("Model not loaded")
71
+
72
+ # Format and tokenize all inputs
73
+ all_ids = []
74
+ for text in texts:
75
+ formatted = classifier._format_input(text, context)
76
+ ids = classifier.tokenizer.encode(formatted).ids
77
+
78
+ if len(ids) > classifier.MAX_LEN:
79
+ ids = ids[:classifier.MAX_LEN]
80
+ else:
81
+ ids = ids + [classifier.PAD_ID] * (classifier.MAX_LEN - len(ids))
82
+
83
+ all_ids.append(ids)
84
+
85
+ # Create batch tensor
86
+ batch_tensor = torch.tensor(all_ids, dtype=torch.long, device=classifier.device)
87
+
88
+ # Single forward pass
89
+ with torch.no_grad():
90
+ logits, _ = classifier.model(batch_tensor)
91
+ probs = torch.softmax(logits, dim=-1)
92
+ confidences, predicted = torch.max(probs, dim=-1)
93
+
94
+ # Build results
95
+ results = []
96
+ for i, text in enumerate(texts):
97
+ idx = predicted[i].item()
98
+ conf = confidences[i].item()
99
+ intent = classifier.id_to_intent.get(idx, 'chat')
100
+
101
+ results.append({
102
+ "text": text,
103
+ "intent": intent,
104
+ "confidence": round(conf, 4),
105
+ })
106
+
107
+ return results
108
+
109
+
110
+ def main():
111
+ # Load model
112
+ classifier = JaneGPTClassifier()
113
+ print(f"Model loaded: {classifier}\n")
114
+
115
+ # Example batch
116
+ commands = [
117
+ "turn up the volume",
118
+ "make it louder",
119
+ "open chrome",
120
+ "play shape of you",
121
+ "search for python tutorials on google",
122
+ "set brightness to 50",
123
+ "take a screenshot",
124
+ "set a reminder for 10 minutes",
125
+ "mute",
126
+ "read this for me",
127
+ "explain what's on my screen",
128
+ "undo that",
129
+ "shut down",
130
+ "hello",
131
+ "what can you do",
132
+ "close notepad",
133
+ "skip to the next song",
134
+ "dim the screen",
135
+ "pause the music",
136
+ "what time is it",
137
+ ]
138
+
139
+ # --- Sequential processing ---
140
+ print("=" * 65)
141
+ print(" Sequential Batch Processing")
142
+ print("=" * 65)
143
+
144
+ start = time.perf_counter()
145
+ results = classify_batch(classifier, commands)
146
+ elapsed = time.perf_counter() - start
147
+
148
+ print(f"\n {'Text':<42} {'Intent':<20} {'Conf':>6}")
149
+ print(f" {'-'*68}")
150
+
151
+ for r in results:
152
+ print(f" {r['text']:<42} {r['intent']:<20} {r['confidence']:>5.1%}")
153
+
154
+ print(f"\n Processed {len(commands)} commands in {elapsed*1000:.1f}ms")
155
+ print(f" Average: {elapsed/len(commands)*1000:.1f}ms per command")
156
+
157
+ # --- Parallel processing ---
158
+ print(f"\n{'=' * 65}")
159
+ print(" Parallel Batch Processing (single forward pass)")
160
+ print("=" * 65)
161
+
162
+ start = time.perf_counter()
163
+ results_parallel = classify_batch_parallel(classifier, commands)
164
+ elapsed_parallel = time.perf_counter() - start
165
+
166
+ print(f"\n Processed {len(commands)} commands in {elapsed_parallel*1000:.1f}ms")
167
+ print(f" Average: {elapsed_parallel/len(commands)*1000:.1f}ms per command")
168
+ print(f" Speedup: {elapsed/elapsed_parallel:.1f}x faster than sequential")
169
+
170
+ # Verify both methods give same results
171
+ match = all(
172
+ r1['intent'] == r2['intent']
173
+ for r1, r2 in zip(results, results_parallel)
174
+ )
175
+ print(f" Results match: {'YES' if match else 'NO'}")
176
+
177
+ # --- Save results to JSON ---
178
+ output_file = Path("examples/batch_results.json")
179
+ with open(output_file, 'w') as f:
180
+ json.dump(results, f, indent=2)
181
+ print(f"\n Results saved to: {output_file}")
182
+
183
+ # --- Batch with context ---
184
+ print(f"\n{'=' * 65}")
185
+ print(" Context-Aware Batch")
186
+ print("=" * 65)
187
+
188
+ # Simulate: user just adjusted volume, now giving follow-up commands
189
+ context = {"last_intent": "volume_up"}
190
+ follow_ups = [
191
+ "not enough",
192
+ "too much",
193
+ "a bit more",
194
+ "the other one",
195
+ "perfect",
196
+ ]
197
+
198
+ print(f"\n Context: last_intent = volume_up\n")
199
+
200
+ ctx_results = classify_batch(classifier, follow_ups, context)
201
+ for r in ctx_results:
202
+ print(f" {r['text']:<42} {r['intent']:<20} {r['confidence']:>5.1%}")
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
examples/benchmark.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speed benchmark for JaneGPT v2.
3
+ """
4
+ import sys
5
+ import os
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+ import time
8
+ from model.classifier import JaneGPTClassifier
9
+
10
+ def main():
11
+ classifier = JaneGPTClassifier()
12
+
13
+ test_inputs = [
14
+ "turn up the volume",
15
+ "open chrome",
16
+ "play some music",
17
+ "set brightness to 50",
18
+ "search for cats",
19
+ "take a screenshot",
20
+ "hello",
21
+ "undo that",
22
+ ]
23
+
24
+ # Warmup
25
+ for text in test_inputs:
26
+ classifier.predict(text)
27
+
28
+ # Benchmark
29
+ iterations = 100
30
+ start = time.perf_counter()
31
+
32
+ for _ in range(iterations):
33
+ for text in test_inputs:
34
+ classifier.predict(text)
35
+
36
+ elapsed = time.perf_counter() - start
37
+ total_predictions = iterations * len(test_inputs)
38
+
39
+ print(f"Device: {classifier.device}")
40
+ print(f"Total predictions: {total_predictions}")
41
+ print(f"Total time: {elapsed:.2f}s")
42
+ print(f"Average per prediction: {elapsed/total_predictions*1000:.2f}ms")
43
+ print(f"Predictions per second: {total_predictions/elapsed:.0f}")
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
examples/model_info.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Display detailed information about JaneGPT v2 model.
3
+
4
+ Shows architecture, parameters, training info, and size comparisons.
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ from model.architecture import JaneGPTv2Classifier, INTENT_LABELS
10
+
11
+
12
+ def main():
13
+ # Load checkpoint
14
+ checkpoint_path = "weights/janegpt_v2_classifier.pt"
15
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
16
+ config = checkpoint.get('config', {})
17
+
18
+ # Create model
19
+ model = JaneGPTv2Classifier(
20
+ vocab_size=config.get('vocab_size', 8192),
21
+ embed_dim=config.get('embed_dim', 256),
22
+ num_heads=config.get('num_heads', 8),
23
+ num_kv_heads=config.get('num_kv_heads', 4),
24
+ num_layers=config.get('num_layers', 8),
25
+ ff_hidden=config.get('ff_hidden', 672),
26
+ max_seq_len=config.get('max_seq_len', 256),
27
+ dropout=config.get('dropout', 0.1),
28
+ rope_theta=config.get('rope_theta', 10000.0),
29
+ )
30
+ model.load_state_dict(checkpoint['model_state_dict'])
31
+
32
+ # Calculate parameters
33
+ total_params = sum(p.numel() for p in model.parameters())
34
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
35
+ buffers = sum(b.numel() for b in model.buffers())
36
+
37
+ print("=" * 60)
38
+ print(" JANEGPT v2 - MODEL INFORMATION")
39
+ print("=" * 60)
40
+
41
+ # Architecture
42
+ print("\n ARCHITECTURE")
43
+ print(f" Type: Decoder-only Transformer (Classifier)")
44
+ print(f" Vocab Size: {config.get('vocab_size', 8192):,}")
45
+ print(f" Embedding Dim: {config.get('embed_dim', 256)}")
46
+ print(f" Attention Heads: {config.get('num_heads', 8)}")
47
+ print(f" KV Heads (GQA): {config.get('num_kv_heads', 4)}")
48
+ print(f" Head Dim: {config.get('embed_dim', 256) // config.get('num_heads', 8)}")
49
+ print(f" Layers: {config.get('num_layers', 8)}")
50
+ print(f" FF Hidden: {config.get('ff_hidden', 672)}")
51
+ print(f" Max Seq Length: {config.get('max_seq_len', 256)}")
52
+ print(f" Dropout: {config.get('dropout', 0.1)}")
53
+ print(f" RoPE Theta: {config.get('rope_theta', 10000.0)}")
54
+
55
+ # Features
56
+ print("\n FEATURES")
57
+ print(f" Position Encoding: RoPE (Rotary Position Embedding)")
58
+ print(f" Normalization: RMSNorm")
59
+ print(f" Attention: Grouped Query Attention (GQA)")
60
+ print(f" Feed-Forward: SwiGLU")
61
+ print(f" Classifier Head: Linear -> GELU -> Dropout -> Linear")
62
+ print(f" Output Classes: {len(INTENT_LABELS)}")
63
+
64
+ # Parameters
65
+ print("\n PARAMETERS")
66
+ print(f" Total Parameters: {total_params:>12,}")
67
+ print(f" Trainable Parameters: {trainable_params:>12,}")
68
+ print(f" Non-trainable Buffers: {buffers:>12,}")
69
+ print(f" Model Size (float32): {total_params * 4 / 1024 / 1024:.2f} MB")
70
+ print(f" Model Size (float16): {total_params * 2 / 1024 / 1024:.2f} MB")
71
+
72
+ # Breakdown
73
+ print("\n PARAMETER BREAKDOWN")
74
+ print(f" {'Component':<35} {'Params':>12} {'%':>8}")
75
+ print(f" {'-' * 55}")
76
+
77
+ emb_params = sum(p.numel() for p in model.token_embedding.parameters())
78
+ print(f" {'Token Embedding':<35} {emb_params:>12,} {emb_params/total_params*100:>7.1f}%")
79
+
80
+ all_layers_params = sum(p.numel() for p in model.layers.parameters())
81
+ print(f" {'Transformer Layers (total)':<35} {all_layers_params:>12,} {all_layers_params/total_params*100:>7.1f}%")
82
+
83
+ # Single layer breakdown
84
+ layer0_params = sum(p.numel() for p in model.layers[0].parameters())
85
+ attn_params = sum(p.numel() for p in model.layers[0].attn.parameters()) - sum(
86
+ b.numel() for b in model.layers[0].attn.buffers()
87
+ )
88
+ ff_params = sum(p.numel() for p in model.layers[0].ff.parameters())
89
+ norm_params = model.layers[0].norm1.weight.numel() + model.layers[0].norm2.weight.numel()
90
+
91
+ print(f" {' Per layer (x8):':<33} {layer0_params:>12,}")
92
+ print(f" {' Attention (Q/K/V/Out)':<33} {attn_params:>12,}")
93
+ print(f" {' Feed-Forward (SwiGLU)':<33} {ff_params:>12,}")
94
+ print(f" {' Norms (RMSNorm x2)':<33} {norm_params:>12,}")
95
+
96
+ final_norm_params = model.norm.weight.numel()
97
+ print(f" {'Final RMSNorm':<35} {final_norm_params:>12,} {final_norm_params/total_params*100:>7.1f}%")
98
+
99
+ head_params = sum(p.numel() for p in model.intent_head.parameters())
100
+ print(f" {'Classification Head':<35} {head_params:>12,} {head_params/total_params*100:>7.1f}%")
101
+ print(f" {' Linear(256, 256) + bias':<33} {256 * 256 + 256:>12,}")
102
+ print(f" {' Linear(256, 22) + bias':<33} {256 * 22 + 22:>12,}")
103
+
104
+ # Training
105
+ print("\n TRAINING")
106
+ print(f" Best Val Accuracy: {checkpoint.get('val_acc', 0):.2f}%")
107
+ print(f" Best Val Loss: {checkpoint.get('val_loss', 0):.4f}")
108
+ print(f" Best Epoch: {checkpoint.get('epoch', 'N/A')}")
109
+
110
+ # Intent classes
111
+ print(f"\n INTENT CLASSES ({len(INTENT_LABELS)})")
112
+ for i, label in enumerate(INTENT_LABELS):
113
+ print(f" {i:>2}: {label}")
114
+
115
+ # File info
116
+ print(f"\n FILES")
117
+ if os.path.exists(checkpoint_path):
118
+ model_size = os.path.getsize(checkpoint_path)
119
+ print(f" Checkpoint: {model_size / 1024 / 1024:.2f} MB")
120
+
121
+ tokenizer_path = "weights/tokenizer.json"
122
+ if os.path.exists(tokenizer_path):
123
+ tok_size = os.path.getsize(tokenizer_path)
124
+ print(f" Tokenizer: {tok_size / 1024:.1f} KB")
125
+
126
+ # Size comparison
127
+ print(f"\n SIZE COMPARISON")
128
+ print(f" {'Model':<30} {'Parameters':>15} {'Size':>10}")
129
+ print(f" {'-' * 55}")
130
+ print(f" {'JaneGPT v2 (this model)':<30} {total_params:>12,} {total_params * 4 / 1024 / 1024:>5.1f} MB")
131
+ print(f" {'DistilBERT':<30} {'66,000,000':>15} {'260.0 MB':>10}")
132
+ print(f" {'BERT Base':<30} {'110,000,000':>15} {'440.0 MB':>10}")
133
+ print(f" {'GPT-2 Small':<30} {'124,000,000':>15} {'500.0 MB':>10}")
134
+ print(f" {'Llama 3 8B':<30} {'8,000,000,000':>15} {' 16.0 GB':>10}")
135
+ print(f" {'GPT-4':<30} {'~1,800,000,000,000':>15} {'~ 3.6 TB':>10}")
136
+
137
+ print(f"\n Created by: Ravindu Senanayake")
138
+ print("=" * 60)
139
+
140
+
141
+ if __name__ == "__main__":
142
+ main()