Spaces:
Runtime error
Runtime error
add sidebar
Browse files
app.py
CHANGED
|
@@ -21,6 +21,13 @@ model, model_params, tokenizer = load_model(model_name)
|
|
| 21 |
# neuron_dim = col3.text_input("Dim: ", value='0')
|
| 22 |
# neurons = model_params.K_heads[int(neuron_layer), int(neuron_dim)]
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
prompt = st.text_area("Prompt: ")
|
| 25 |
submitted = st.button("Send!")
|
| 26 |
|
|
@@ -28,9 +35,9 @@ if submitted:
|
|
| 28 |
with st.spinner('Wait for it..'):
|
| 29 |
model, model_params, tokenizer = map(deepcopy, (model, model_params, tokenizer))
|
| 30 |
decoded = speaking_probe(model, model_params, tokenizer, prompt,
|
| 31 |
-
repetition_penalty=
|
| 32 |
min_length=1, do_sample=True,
|
| 33 |
-
max_new_tokens=
|
| 34 |
|
| 35 |
for text in decoded:
|
| 36 |
st.code('\n'.join(textwrap.wrap(text, width=70)), language=None)
|
|
|
|
| 21 |
# neuron_dim = col3.text_input("Dim: ", value='0')
|
| 22 |
# neurons = model_params.K_heads[int(neuron_layer), int(neuron_dim)]
|
| 23 |
|
| 24 |
+
with st.sidebar:
|
| 25 |
+
temperature = st.slider("Temperature", min_value=0., max_value=2., value=0.5, step=0.05)
|
| 26 |
+
repetition_penalty = st.slider("Repetition Penalty", min_value=0., max_value=4., value=2., step=0.1)
|
| 27 |
+
sidebar_cols = st.columns(2)
|
| 28 |
+
num_generations = sidebar_cols[0].number_input("Number of Answers", min_value=1, value=3, format='%d')
|
| 29 |
+
max_new_tokens = sidebar_cols[1].number_input("Max Answer Length", min_value=1, value=50, format='%d')
|
| 30 |
+
|
| 31 |
prompt = st.text_area("Prompt: ")
|
| 32 |
submitted = st.button("Send!")
|
| 33 |
|
|
|
|
| 35 |
with st.spinner('Wait for it..'):
|
| 36 |
model, model_params, tokenizer = map(deepcopy, (model, model_params, tokenizer))
|
| 37 |
decoded = speaking_probe(model, model_params, tokenizer, prompt,
|
| 38 |
+
repetition_penalty=repetition_penalty, num_generations=num_generations,
|
| 39 |
min_length=1, do_sample=True,
|
| 40 |
+
max_new_tokens=max_new_tokens, temperature=temperature)
|
| 41 |
|
| 42 |
for text in decoded:
|
| 43 |
st.code('\n'.join(textwrap.wrap(text, width=70)), language=None)
|