contentforge-ai / app.py
Bashaarat1's picture
Update app.py
db75ce3 verified
Raw
History Blame Contribute Delete
14.7 kB
# ============================================================================
# CONTENTFORGE AI - FINAL WORKING VERSION
# Multi-modal AI platform with fine-tuned models
# ============================================================================
import gradio as gr
import torch
import os
from huggingface_hub import login
# ============================================================================
# AUTHENTICATION
# ============================================================================
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
print("🔐 Authenticating with HuggingFace...")
login(token=HF_TOKEN)
print("✅ Authenticated!\n")
else:
print("⚠️ No HF_TOKEN found - some models may fail to load\n")
from transformers import (
T5Tokenizer, T5ForConditionalGeneration,
Qwen2VLForConditionalGeneration, Qwen2VLProcessor,
AutoProcessor, MusicgenForConditionalGeneration
)
from peft import PeftModel
from qwen_vl_utils import process_vision_info
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Using device: {device}")
print("📦 Loading models... This may take 2-3 minutes on first run.\n")
# ============================================================================
# MODEL LOADING
# ============================================================================
# 1. T5 Summarization Model
print("📝 Loading T5 model...")
t5_tokenizer = T5Tokenizer.from_pretrained("Bashaarat1/t5-small-arxiv-summarizer")
t5_model = T5ForConditionalGeneration.from_pretrained(
"Bashaarat1/t5-small-arxiv-summarizer"
).to(device)
t5_model.eval()
print("✅ T5 loaded!")
# 2. Qwen VLM Q&A Model with YOUR LoRA adapter
print("🤖 Loading Qwen2-VL base model...")
qwen_base = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
device_map="auto",
torch_dtype=torch.bfloat16
)
print("🔧 Loading YOUR fine-tuned LoRA adapter...")
qwen_model = PeftModel.from_pretrained(
qwen_base,
"Bashaarat1/qwen-finetuned-scienceqa"
)
qwen_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
qwen_model.eval()
print("✅ Qwen loaded!")
# 3. MusicGen Model
print("🎵 Loading MusicGen model...")
music_processor = AutoProcessor.from_pretrained("Bashaarat1/fine-tuned-musicgen-small")
music_model = MusicgenForConditionalGeneration.from_pretrained(
"Bashaarat1/fine-tuned-musicgen-small"
).to(device)
music_model.eval()
print("✅ MusicGen loaded!")
# 4. Stable Diffusion Model
print("🎨 Loading Stable Diffusion model...")
sd_pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None
).to(device)
print("✅ Stable Diffusion loaded!")
print("\n🎉 All 4 models loaded successfully!\n")
# ============================================================================
# INFERENCE FUNCTIONS
# ============================================================================
def summarize_text(text, max_length=128):
"""Summarize text using fine-tuned T5"""
if not text.strip():
return "⚠️ Please enter some text to summarize."
try:
inputs = t5_tokenizer(
f"summarize: {text}",
return_tensors="pt",
max_length=512,
truncation=True
).to(device)
with torch.no_grad():
outputs = t5_model.generate(
**inputs,
max_length=max_length,
min_length=30,
num_beams=4,
early_stopping=True
)
summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
return f"📝 **Summary:**\n\n{summary}\n\n---\n*Original: {len(text.split())} words → Summary: {len(summary.split())} words*"
except Exception as e:
return f"❌ Error: {str(e)}"
def answer_question(question, image=None):
"""Answer question with optional image using Qwen VLM"""
if not question.strip():
return "⚠️ Please enter a question."
try:
if image is not None:
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": question}
]
}]
else:
messages = [{
"role": "user",
"content": [{"type": "text", "text": question}]
}]
text_prompt = qwen_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
if image is not None:
img_inputs, _ = process_vision_info(messages)
inputs = qwen_processor(
text=[text_prompt],
images=img_inputs,
return_tensors="pt"
).to(device)
else:
inputs = qwen_processor(
text=[text_prompt],
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = qwen_model.generate(**inputs, max_new_tokens=200)
answer = qwen_processor.batch_decode(
outputs[:, inputs.input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return f"💡 **Answer:**\n\n{answer}"
except Exception as e:
return f"❌ Error: {str(e)}"
def generate_image(prompt, negative_prompt="", num_steps=25):
"""Generate image using Stable Diffusion"""
if not prompt.strip():
return None, "⚠️ Please enter an image description."
try:
with torch.no_grad():
image = sd_pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_steps,
guidance_scale=7.5
).images[0]
return image, f"✅ **Image generated!**\n\n*Prompt: {prompt}*"
except Exception as e:
return None, f"❌ Error: {str(e)}"
def generate_music(prompt, duration=10):
"""Generate music using MusicGen"""
if not prompt.strip():
return None, "⚠️ Please enter a music description."
try:
inputs = music_processor(
text=[prompt],
padding=True,
return_tensors="pt"
).to(device)
max_tokens = int(duration * 50)
with torch.no_grad():
audio_values = music_model.generate(**inputs, max_new_tokens=max_tokens, do_sample=True)
sampling_rate = music_model.config.audio_encoder.sampling_rate
audio_data = audio_values[0, 0].cpu().numpy()
return (sampling_rate, audio_data), f"✅ **Music generated!**\n\n*Prompt: {prompt}*\n*Duration: ~{duration} seconds*"
except Exception as e:
return None, f"❌ Error: {str(e)}"
# ============================================================================
# GRADIO UI
# ============================================================================
with gr.Blocks(title="ContentForge AI") as demo:
gr.Markdown("""
# 🎨 ContentForge AI
**Multi-modal AI platform for education and social media content generation**
Powered by state-of-the-art fine-tuned models:
- 📝 Fine-tuned T5 (+46% improvement)
- 🤖 Qwen2-VL with LoRA for science Q&A
- 🎨 Stable Diffusion v1.5
- 🎵 Fine-tuned MusicGen
""")
with gr.Tabs():
with gr.Tab("📚 Education Tools"):
gr.Markdown("## AI-powered tools for learning and research")
with gr.Tab("📝 Text Summarizer"):
gr.Markdown("### Summarize academic papers, articles, and long texts")
with gr.Row():
with gr.Column():
sum_input = gr.Textbox(
label="Text to Summarize",
placeholder="Paste your academic paper, article, or long text here...",
lines=10
)
sum_length = gr.Slider(
minimum=50,
maximum=200,
value=128,
step=10,
label="Summary Length (words)"
)
sum_button = gr.Button("🪄 Generate Summary", variant="primary", size="lg")
with gr.Column():
sum_output = gr.Markdown(label="Summary")
gr.Examples(
examples=[
["We present a novel approach to neural network optimization using adaptive learning rates. Our method dynamically adjusts the learning rate based on gradient statistics during training. Experiments on ImageNet show 15% improvement over standard SGD with minimal computational overhead."]
],
inputs=sum_input
)
sum_button.click(
fn=summarize_text,
inputs=[sum_input, sum_length],
outputs=sum_output
)
with gr.Tab("🤖 Q&A Assistant"):
gr.Markdown("### Ask questions with optional image support")
with gr.Row():
with gr.Column():
qa_question = gr.Textbox(
label="Your Question",
placeholder="Ask anything...",
lines=3
)
qa_image = gr.Image(
label="Upload Image (Optional)",
type="pil"
)
qa_button = gr.Button("💬 Get Answer", variant="primary", size="lg")
with gr.Column():
qa_output = gr.Markdown(label="Answer")
gr.Examples(
examples=[
["What is machine learning?", None],
["Explain photosynthesis in simple terms.", None]
],
inputs=[qa_question, qa_image]
)
qa_button.click(
fn=answer_question,
inputs=[qa_question, qa_image],
outputs=qa_output
)
with gr.Tab("🎨 Social Media Tools"):
gr.Markdown("## Create stunning content for your audience")
with gr.Tab("🖼️ Image Generator"):
gr.Markdown("### Generate professional images from text descriptions")
with gr.Row():
with gr.Column():
img_prompt = gr.Textbox(
label="Image Description",
placeholder="Describe the image you want to generate...",
lines=3
)
img_negative = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="What to avoid (e.g., blur, low quality, distorted)",
lines=2
)
img_steps = gr.Slider(
minimum=10,
maximum=50,
value=25,
step=5,
label="Quality (inference steps)"
)
img_button = gr.Button("🎨 Generate Image", variant="primary", size="lg")
with gr.Column():
img_output = gr.Image(label="Generated Image")
img_status = gr.Markdown()
gr.Examples(
examples=[
["A serene mountain landscape at sunset, photorealistic, 4k"]
],
inputs=img_prompt
)
img_button.click(
fn=generate_image,
inputs=[img_prompt, img_negative, img_steps],
outputs=[img_output, img_status]
)
with gr.Tab("🎵 Music Generator"):
gr.Markdown("### Generate royalty-free music from text descriptions")
with gr.Row():
with gr.Column():
music_prompt = gr.Textbox(
label="Music Description",
placeholder="Describe the music you want (mood, genre, instruments)...",
lines=3
)
music_duration = gr.Slider(
minimum=5,
maximum=20,
value=10,
step=5,
label="Duration (seconds)"
)
music_button = gr.Button("🎼 Generate Music", variant="primary", size="lg")
with gr.Column():
music_output = gr.Audio(label="Generated Music")
music_status = gr.Markdown()
gr.Examples(
examples=[
["upbeat electronic dance music with energetic drums"]
],
inputs=music_prompt
)
music_button.click(
fn=generate_music,
inputs=[music_prompt, music_duration],
outputs=[music_output, music_status]
)
gr.Markdown("""
---
**About ContentForge AI**
Multi-modal AI platform demonstrating fine-tuned models for education and social media.
*Built with ❤️ using Gradio and Transformers*
""")
if __name__ == "__main__":
demo.launch()