| import gradio as gr |
| import spaces |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
|
|
| |
| model_id = "Equall/SaulLM-54B-Instruct" |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto" |
| ) |
|
|
| @spaces.GPU() |
| def generate_response(message, history, system_prompt, max_tokens, temperature): |
| """Generate legal analysis using Saul-54B""" |
| |
| |
| messages = [] |
| |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| |
| for human, assistant in history: |
| messages.append({"role": "user", "content": human}) |
| messages.append({"role": "assistant", "content": assistant}) |
| |
| messages.append({"role": "user", "content": message}) |
| |
| |
| input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
| |
| |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=temperature > 0, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| |
| response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
| return response |
|
|
| |
| DEFAULT_SYSTEM = """You are SaulLM-54B, a specialized legal language model. You provide accurate legal analysis based on U.S. and European legal systems. |
| |
| IMPORTANT DISCLAIMERS: |
| - This is for informational purposes only, not legal advice |
| - Information may not reflect recent legal developments |
| - Users should consult qualified legal professionals for actual legal advice |
| - Do not use this for decisions that could affect legal rights""" |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# SaulLM-54B Legal Assistant") |
| gr.Markdown("*Specialized AI for legal reasoning and analysis. Private queries, powered by Zero GPU (25 min/day free).*") |
| |
| with gr.Row(): |
| with gr.Column(scale=3): |
| chatbot = gr.Chatbot(label="Legal Analysis", height=500) |
| msg = gr.Textbox( |
| label="Your Legal Question", |
| placeholder="Ask about statutes, case law, legal concepts, or compliance...", |
| lines=3 |
| ) |
| with gr.Row(): |
| submit = gr.Button("Submit", variant="primary") |
| clear = gr.Button("Clear Chat") |
| |
| with gr.Column(scale=1): |
| system_prompt = gr.Textbox( |
| label="System Prompt", |
| value=DEFAULT_SYSTEM, |
| lines=12, |
| max_lines=12 |
| ) |
| max_tokens = gr.Slider( |
| label="Max Response Tokens", |
| minimum=100, |
| maximum=2000, |
| value=1000, |
| step=100 |
| ) |
| temperature = gr.Slider( |
| label="Temperature", |
| minimum=0.0, |
| maximum=1.0, |
| value=0.7, |
| step=0.1 |
| ) |
| gr.Markdown("### Usage Tips") |
| gr.Markdown(""" |
| - Be specific about jurisdiction |
| - Cite relevant statutes/cases if known |
| - Zero GPU resets after 60s idle |
| - 25 min/day free compute limit |
| """) |
| |
| def user_submit(message, history): |
| return "", history + [[message, None]] |
| |
| def bot_respond(history, system_prompt, max_tokens, temperature): |
| message = history[-1][0] |
| history_context = history[:-1] |
| |
| response = generate_response(message, history_context, system_prompt, max_tokens, temperature) |
| history[-1][1] = response |
| return history |
| |
| msg.submit(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot_respond, [chatbot, system_prompt, max_tokens, temperature], chatbot |
| ) |
| submit.click(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot_respond, [chatbot, system_prompt, max_tokens, temperature], chatbot |
| ) |
| clear.click(lambda: None, None, chatbot, queue=False) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |
|
|