#!/usr/bin/env /home/jovyan/step37_work/.venv/bin/python import argparse, gc, json, os, pickle, tempfile from pathlib import Path import torch from safetensors import safe_open from safetensors.torch import save_file MODEL_PATH='/home/jovyan/step37_bf16_split_model' def mem_eff(): try: total=int(open('/sys/fs/cgroup/memory.current').read()) maxv=int(open('/sys/fs/cgroup/memory.max').read()) anon=0 for l in open('/sys/fs/cgroup/memory.stat'): if l.startswith('anon '): anon=int(l.split()[1]); break return total/1e9, anon/1e9, maxv/1e9 except Exception: return 0,0,0 def drop_file_cache(path): try: fd=os.open(path, os.O_RDONLY) os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_DONTNEED) os.close(fd) except Exception: pass def project(W,r,lam,dev='cuda:0'): W=W.to(dev); r=r.to(dev) rn=r/(torch.norm(r)+1e-12) Wn=W-lam*torch.outer(rn, rn@W) Wn=Wn/(torch.norm(Wn,dim=1,keepdim=True)+1e-12)*torch.norm(W,dim=1,keepdim=True) return Wn.cpu() def main(): ap=argparse.ArgumentParser() ap.add_argument('--gradients', required=True) ap.add_argument('--layers', default='all') ap.add_argument('--lambda', dest='lam', type=float, default=0.1) args=ap.parse_args() grads=pickle.load(open(args.gradients,'rb'))['gradients'] idx=json.load(open(f'{MODEL_PATH}/model.safetensors.index.json'))['weight_map'] layers=range(len(grads)) if args.layers=='all' else [int(x) for x in args.layers.split(',')] by_shard={} for i in layers: k=f'model.layers.{i}.self_attn.o_proj.weight' if k in idx: by_shard.setdefault(idx[k],[]).append((i,k)) for shard,ops in sorted(by_shard.items()): path=os.path.realpath(f'{MODEL_PATH}/{shard}') if not os.path.exists(path): raise FileNotFoundError(path) total,anon,maxv=mem_eff(); print(f'BEGIN {shard} ops={len(ops)} mem total={total:.1f} anon={anon:.1f}/{maxv:.1f}GB', flush=True) mods={} with safe_open(path, framework='pt', device='cpu') as f: for i,k in ops: r=-grads[i] if r.ndim==3: r=r.mean(dim=(0,1)) W=f.get_tensor(k).to(torch.bfloat16) mods[k]=project(W, r.to(torch.bfloat16), args.lam, 'cuda:0') del W; torch.cuda.empty_cache(); gc.collect() out={} with safe_open(path, framework='pt', device='cpu') as f: for k in f.keys(): out[k]=mods.pop(k) if k in mods else f.get_tensor(k) tmp=path+'.tmp' save_file(out,tmp) os.replace(tmp,path) drop_file_cache(path) del out; gc.collect(); torch.cuda.empty_cache() total,anon,maxv=mem_eff(); print(f'DONE {shard} mem total={total:.1f} anon={anon:.1f}/{maxv:.1f}GB', flush=True) if __name__=='__main__': main()