rahul7star commited on
Commit
2216164
·
verified ·
1 Parent(s): dbcbf49

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +85 -0
README.md CHANGED
@@ -10,6 +10,91 @@ license: apache-2.0
10
  language:
11
  - en
12
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Uploaded model
15
 
 
10
  language:
11
  - en
12
  ---
13
+ ## Run in kaggle
14
+ ```
15
+ # =========================================================
16
+ # Install dependencies (Kaggle usually already has some)
17
+ # =========================================================
18
+ !pip install -q transformers peft accelerate bitsandbytes
19
+
20
+ # =========================================================
21
+ # Imports
22
+ # =========================================================
23
+ import torch
24
+ from transformers import AutoProcessor, AutoModelForCausalLM
25
+ from peft import PeftModel
26
+
27
+ # =========================================================
28
+ # Config
29
+ # =========================================================
30
+ BASE_MODEL = "google/gemma-4-E2B-it"
31
+ LORA_MODEL = "rahul7star/gemma_4_lora"
32
+
33
+ # =========================================================
34
+ # Load processor
35
+ # =========================================================
36
+ processor = AutoProcessor.from_pretrained(BASE_MODEL)
37
+
38
+ # =========================================================
39
+ # Load base model
40
+ # =========================================================
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ BASE_MODEL,
43
+ torch_dtype=torch.float16, # safer for Kaggle GPU
44
+ device_map="auto"
45
+ )
46
+
47
+ # =========================================================
48
+ # Load LoRA adapter on top of base model
49
+ # =========================================================
50
+ model = PeftModel.from_pretrained(model, LORA_MODEL)
51
+
52
+ # optional: merge LoRA for faster inference
53
+ model = model.merge_and_unload()
54
+
55
+ print("Model + LoRA loaded successfully 🚀")
56
+
57
+ # =========================================================
58
+ # Inference function
59
+ # =========================================================
60
+ def generate_response(user_input):
61
+ messages = [
62
+ {"role": "system", "content": "You are a helpful assistant."},
63
+ {"role": "user", "content": user_input},
64
+ ]
65
+
66
+ text = processor.apply_chat_template(
67
+ messages,
68
+ tokenize=False,
69
+ add_generation_prompt=True,
70
+ enable_thinking=False
71
+ )
72
+
73
+ inputs = processor(text=text, return_tensors="pt").to(model.device)
74
+ input_len = inputs["input_ids"].shape[-1]
75
+
76
+ with torch.no_grad():
77
+ outputs = model.generate(
78
+ **inputs,
79
+ max_new_tokens=512,
80
+ temperature=0.7,
81
+ top_p=0.9
82
+ )
83
+
84
+ response = processor.decode(
85
+ outputs[0][input_len:],
86
+ skip_special_tokens=True
87
+ )
88
+
89
+ return response
90
+
91
+
92
+ # =========================================================
93
+ # Test
94
+ # =========================================================
95
+ print(generate_response("Write a short joke about saving RAM."))
96
+ ```
97
+
98
 
99
  # Uploaded model
100