audreyt commited on
Commit
861518d
·
verified ·
1 Parent(s): 6cfdef4

Upload make_attn_repeat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. make_attn_repeat.py +212 -0
make_attn_repeat.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Duplicate layer 5 (GQA Attention) in nemotron-cascade-2 GGUF.
4
+ Produces a 53-layer model where layers 5 and 6 are identical attention blocks.
5
+ All layers >= 6 in the original are shifted up by 1.
6
+
7
+ Per-layer metadata arrays (head_count_kv, feed_forward_length) are also updated.
8
+ """
9
+
10
+ import re
11
+ import numpy as np
12
+ from pathlib import Path
13
+ from gguf import GGUFReader, GGUFWriter, GGUFValueType
14
+
15
+ SRC = Path("/usr/share/ollama/.ollama/models/blobs/sha256-9e0c827cfd6a6d000032be3da3d0914668b0c1112977e927186d29c4487466c4")
16
+ DST = Path("/home/j/nemotron-cascade-2-attn-repeat-L5.gguf")
17
+ DUP_LAYER = 5
18
+ ARCH = "nemotron_h_moe"
19
+
20
+ # Per-layer array fields that need an extra element inserted
21
+ PER_LAYER_ARRAY_FIELDS = {
22
+ f"{ARCH}.attention.head_count_kv",
23
+ f"{ARCH}.feed_forward_length",
24
+ }
25
+
26
+
27
+ def read_per_layer_array(field):
28
+ """Read a per-layer uint32 array from GGUF field parts.
29
+
30
+ In gguf-py, array fields have parts:
31
+ [0]: key name length
32
+ [1]: key name bytes
33
+ [2]: array type marker (9 = ARRAY)
34
+ [3]: element type (4 = UINT32)
35
+ [4]: array length
36
+ [5..]: individual element values
37
+ """
38
+ parts = field.parts
39
+ arr_len = int(parts[4][0]) if hasattr(parts[4], '__getitem__') else int(parts[4])
40
+ values = []
41
+ for i in range(arr_len):
42
+ p = parts[5 + i]
43
+ val = int(p[0]) if hasattr(p, '__getitem__') else int(p)
44
+ values.append(val)
45
+ return values
46
+
47
+
48
+ def main():
49
+ print(f"Reading {SRC} ...")
50
+ reader = GGUFReader(str(SRC))
51
+
52
+ # --- Build tensor list with duplication ---
53
+ tensors_to_write = []
54
+ for t in reader.tensors:
55
+ m = re.match(r'^blk\.(\d+)\.(.*)', t.name)
56
+ if m:
57
+ blk = int(m.group(1))
58
+ suffix = m.group(2)
59
+ if blk <= DUP_LAYER:
60
+ tensors_to_write.append((t.name, t))
61
+ if blk == DUP_LAYER:
62
+ tensors_to_write.append((f"blk.{blk + 1}.{suffix}", t))
63
+ else:
64
+ tensors_to_write.append((f"blk.{blk + 1}.{suffix}", t))
65
+ else:
66
+ tensors_to_write.append((t.name, t))
67
+
68
+ print(f"Original: {len(reader.tensors)} tensors, 52 layers")
69
+ print(f"New: {len(tensors_to_write)} tensors, 53 layers")
70
+
71
+ # --- Read per-layer arrays and insert duplicate ---
72
+ per_layer_arrays = {}
73
+ for field_name in PER_LAYER_ARRAY_FIELDS:
74
+ field = reader.fields.get(field_name)
75
+ if field and len(field.types) > 1 and field.types[0] == GGUFValueType.ARRAY:
76
+ orig = read_per_layer_array(field)
77
+ # Insert duplicate of DUP_LAYER at DUP_LAYER+1
78
+ new_arr = orig[:DUP_LAYER + 1] + [orig[DUP_LAYER]] + orig[DUP_LAYER + 1:]
79
+ per_layer_arrays[field_name] = new_arr
80
+ print(f" {field_name}: {len(orig)} -> {len(new_arr)} elements")
81
+ # Show attention/moe layer positions for verification
82
+ nonzero = [(i, v) for i, v in enumerate(new_arr) if v != 0]
83
+ print(f" non-zero positions: {nonzero}")
84
+
85
+ # --- Write new GGUF ---
86
+ print(f"Writing {DST} ...")
87
+ writer = GGUFWriter(str(DST), ARCH)
88
+
89
+ for field_name in reader.fields:
90
+ if field_name.startswith("GGUF."):
91
+ continue
92
+ field = reader.fields[field_name]
93
+ parts = field.parts
94
+
95
+ # Skip architecture (writer adds automatically)
96
+ if field_name == "general.architecture":
97
+ continue
98
+
99
+ # Update block_count
100
+ if field_name == f"{ARCH}.block_count":
101
+ print(f" block_count: 52 -> 53")
102
+ writer.add_uint32(f"{ARCH}.block_count", 53)
103
+ continue
104
+
105
+ # Handle per-layer arrays with inserted element
106
+ if field_name in per_layer_arrays:
107
+ writer.add_array(field_name, per_layer_arrays[field_name])
108
+ continue
109
+
110
+ # Handle other array types
111
+ if len(field.types) > 1 and field.types[0] == GGUFValueType.ARRAY:
112
+ arr_type = field.types[1]
113
+
114
+ if arr_type == GGUFValueType.STRING:
115
+ # String arrays: parts[5+] are alternating (length, bytes) pairs
116
+ # Use a simpler approach: collect from data_offset
117
+ strings = []
118
+ i = 5 # skip header parts: name_len, name, array_type, elem_type, count
119
+ while i < len(parts):
120
+ # Each string: length part then bytes part
121
+ if i + 1 < len(parts):
122
+ try:
123
+ s = parts[i + 1].tobytes().decode('utf-8')
124
+ strings.append(s)
125
+ i += 2
126
+ except:
127
+ i += 1
128
+ else:
129
+ break
130
+ if strings:
131
+ writer.add_array(field_name, strings)
132
+ continue
133
+
134
+ elif arr_type == GGUFValueType.UINT32:
135
+ arr_len = int(parts[4][0]) if hasattr(parts[4], '__getitem__') else int(parts[4])
136
+ values = []
137
+ for idx in range(arr_len):
138
+ p = parts[5 + idx]
139
+ values.append(int(p[0]) if hasattr(p, '__getitem__') else int(p))
140
+ writer.add_array(field_name, values)
141
+ continue
142
+
143
+ elif arr_type == GGUFValueType.INT32:
144
+ arr_len = int(parts[4][0]) if hasattr(parts[4], '__getitem__') else int(parts[4])
145
+ values = []
146
+ for idx in range(arr_len):
147
+ p = parts[5 + idx]
148
+ values.append(int(p[0]) if hasattr(p, '__getitem__') else int(p))
149
+ writer.add_array(field_name, values)
150
+ continue
151
+
152
+ elif arr_type == GGUFValueType.FLOAT32:
153
+ arr_len = int(parts[4][0]) if hasattr(parts[4], '__getitem__') else int(parts[4])
154
+ values = []
155
+ for idx in range(arr_len):
156
+ p = parts[5 + idx]
157
+ values.append(float(p[0]) if hasattr(p, '__getitem__') else float(p))
158
+ writer.add_array(field_name, values)
159
+ continue
160
+
161
+ else:
162
+ print(f" SKIP array: {field_name} (elem type {arr_type})")
163
+ continue
164
+
165
+ # Scalar types
166
+ field_type = field.types[-1] if field.types else None
167
+ try:
168
+ if field_type == GGUFValueType.STRING:
169
+ val = parts[-1].tobytes().decode('utf-8')
170
+ writer.add_string(field_name, val)
171
+ elif field_type == GGUFValueType.UINT32:
172
+ val = int(parts[-1][0]) if hasattr(parts[-1], '__getitem__') else int(parts[-1])
173
+ writer.add_uint32(field_name, val)
174
+ elif field_type == GGUFValueType.INT32:
175
+ val = int(parts[-1][0]) if hasattr(parts[-1], '__getitem__') else int(parts[-1])
176
+ writer.add_int32(field_name, val)
177
+ elif field_type == GGUFValueType.FLOAT32:
178
+ val = float(parts[-1][0]) if hasattr(parts[-1], '__getitem__') else float(parts[-1])
179
+ writer.add_float32(field_name, val)
180
+ elif field_type == GGUFValueType.BOOL:
181
+ val = bool(parts[-1][0]) if hasattr(parts[-1], '__getitem__') else bool(parts[-1])
182
+ writer.add_bool(field_name, val)
183
+ elif field_type == GGUFValueType.UINT64:
184
+ val = int(parts[-1][0]) if hasattr(parts[-1], '__getitem__') else int(parts[-1])
185
+ writer.add_uint64(field_name, val)
186
+ elif field_type == GGUFValueType.UINT16:
187
+ val = int(parts[-1][0]) if hasattr(parts[-1], '__getitem__') else int(parts[-1])
188
+ writer.add_uint16(field_name, val)
189
+ elif field_type == GGUFValueType.UINT8:
190
+ val = int(parts[-1][0]) if hasattr(parts[-1], '__getitem__') else int(parts[-1])
191
+ writer.add_uint8(field_name, val)
192
+ else:
193
+ print(f" SKIP: {field_name} (type {field_type})")
194
+ except Exception as e:
195
+ print(f" ERROR on {field_name}: {e}")
196
+
197
+ # --- Add tensors ---
198
+ for new_name, tensor in tensors_to_write:
199
+ writer.add_tensor(new_name, tensor.data, raw_dtype=tensor.tensor_type)
200
+
201
+ print("Finalizing ...")
202
+ writer.write_header_to_file()
203
+ writer.write_kv_data_to_file()
204
+ writer.write_tensors_to_file()
205
+ writer.close()
206
+
207
+ size_gb = DST.stat().st_size / (1024**3)
208
+ print(f"Done! {DST} ({size_gb:.1f} GB)")
209
+
210
+
211
+ if __name__ == "__main__":
212
+ main()