File size: 5,113 Bytes
66b9f14
 
368003e
66b9f14
 
 
 
 
 
368003e
66b9f14
368003e
 
 
66b9f14
 
 
 
368003e
66b9f14
 
 
 
368003e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66b9f14
368003e
 
66b9f14
 
368003e
66b9f14
368003e
 
 
 
 
 
 
 
 
66b9f14
368003e
 
 
 
 
 
 
 
 
 
 
 
 
 
66b9f14
 
 
368003e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66b9f14
 
 
368003e
 
 
 
66b9f14
 
368003e
66b9f14
368003e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66b9f14
 
 
368003e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
BharatGen AyurParam β€” Swastik.fit AI Vaidya
Hosted on HuggingFace Spaces (ZeroGPU)

Model: bharatgenai/AyurParam (2.9B params, trained on 1,000+ Ayurvedic texts)
License: CC-BY-4.0 (commercial OK)
Prompt format: <user> {question} <assistant>

This Space is called by the Swastik Cloud Function (ayurParamProxy).
The /gradio_api/call/predict endpoint receives: { data: ["<user> ... <assistant>"] }
Returns: { data: ["response text"] }

ZeroGPU: GPU is allocated on-demand per request (no cold-start, shared GPU pool).
Model loads into GPU memory on first call, cached for duration of GPU slot.
"""

import gradio as gr
import torch
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "bharatgenai/AyurParam"

# Module-level cache β€” persists across ZeroGPU calls within the same session
_tokenizer = None
_model = None


def _ensure_model():
        """Load model if not already loaded. Called inside @spaces.GPU context."""
        global _tokenizer, _model
        if _model is not None:
                    return
                print("[AyurParam] Loading tokenizer...")
    _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=False)
    print("[AyurParam] Loading model to GPU...")
    _model = AutoModelForCausalLM.from_pretrained(
                MODEL_ID,
                trust_remote_code=True,
                torch_dtype=torch.float16,
                device_map="auto",
    )
    _model.eval()
    print("[AyurParam] Model ready on GPU.")


@spaces.GPU(duration=120)
def generate(prompt: str) -> str:
        """
            Main inference function β€” runs on ZeroGPU (T4/A100).
                Accepts either:
                      - Raw prompt already formatted: "<user> ... <assistant>"
                            - Plain text question (will be wrapped automatically)
                                Returns: assistant response only (no prompt echo)
                                    """
    _ensure_model()

    if not prompt or not prompt.strip():
                return "Please provide a question."
        
            # Ensure correct prompt format
            if "<user>" not in prompt:
                        formatted = f"<user> {prompt.strip()} <assistant>"
            else:
                        formatted = prompt.strip()
                        if not formatted.endswith("<assistant>"):
                                        formatted = formatted + " <assistant>"
                            
                    inputs = _tokenizer(formatted, return_tensors="pt")
    # Move inputs to same device as model
    device = next(_model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    input_len = inputs["input_ids"].shape[1]

    with torch.no_grad():
                output = _model.generate(
                                **inputs,
                                max_new_tokens=256,
                                do_sample=True,
                                top_k=50,
                                top_p=0.95,
                                temperature=0.6,
                                eos_token_id=_tokenizer.eos_token_id,
                                pad_token_id=_tokenizer.eos_token_id,
                                use_cache=True,
                )
        
            # Decode only the new tokens (not the prompt)
            new_tokens = output[0][input_len:]
    response = _tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

    # Clean up any trailing special tokens
    for stop in ["<user>", "<context>", "</s>"]:
                if stop in response:
                                response = response[: response.index(stop)].strip()
                    
            return response


# Gradio interface β€” Swastik Cloud Function calls /gradio_api/call/predict
demo = gr.Interface(
        fn=generate,
        inputs=gr.Textbox(
                    label="Prompt",
                    placeholder="<user> What foods should I eat for better digestion? <assistant>",
                    lines=3,
        ),
        outputs=gr.Textbox(label="AyurParam Response", lines=8),
        title="BharatGen AyurParam β€” Ayurveda AI",
        description=(
                    "**AyurParam** is India's first AI trained on 1,000+ Ayurvedic texts (54.5M words). "
                    "2.9B parameter model fine-tuned on classical Ayurveda knowledge.\n\n"
                    "Prompt format: `<user> your question <assistant>`\n\n"
                    "This Space powers the AI Vaidya at [swastik.fit](https://swastik.fit)."
        ),
        examples=[
                    ["<user> What foods should I eat to improve digestion according to Ayurveda? <assistant>"],
                    ["<user> I have vata imbalance β€” what daily routine do you recommend? <assistant>"],
                    ["<user> What are the benefits of turmeric in Ayurvedic medicine? <assistant>"],
                    ["<user> namaste <assistant>"],
        ],
        cache_examples=False,
        api_name="predict",
)

if __name__ == "__main__":
        demo.launch()