Spaces:
Paused
Paused
| # ============================================================================ | |
| # CONTENTFORGE AI - FINAL WORKING VERSION | |
| # Multi-modal AI platform with fine-tuned models | |
| # ============================================================================ | |
| import gradio as gr | |
| import torch | |
| import os | |
| from huggingface_hub import login | |
| # ============================================================================ | |
| # AUTHENTICATION | |
| # ============================================================================ | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| print("🔐 Authenticating with HuggingFace...") | |
| login(token=HF_TOKEN) | |
| print("✅ Authenticated!\n") | |
| else: | |
| print("⚠️ No HF_TOKEN found - some models may fail to load\n") | |
| from transformers import ( | |
| T5Tokenizer, T5ForConditionalGeneration, | |
| Qwen2VLForConditionalGeneration, Qwen2VLProcessor, | |
| AutoProcessor, MusicgenForConditionalGeneration | |
| ) | |
| from peft import PeftModel | |
| from qwen_vl_utils import process_vision_info | |
| from diffusers import StableDiffusionPipeline | |
| from PIL import Image | |
| import numpy as np | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🖥️ Using device: {device}") | |
| print("📦 Loading models... This may take 2-3 minutes on first run.\n") | |
| # ============================================================================ | |
| # MODEL LOADING | |
| # ============================================================================ | |
| # 1. T5 Summarization Model | |
| print("📝 Loading T5 model...") | |
| t5_tokenizer = T5Tokenizer.from_pretrained("Bashaarat1/t5-small-arxiv-summarizer") | |
| t5_model = T5ForConditionalGeneration.from_pretrained( | |
| "Bashaarat1/t5-small-arxiv-summarizer" | |
| ).to(device) | |
| t5_model.eval() | |
| print("✅ T5 loaded!") | |
| # 2. Qwen VLM Q&A Model with YOUR LoRA adapter | |
| print("🤖 Loading Qwen2-VL base model...") | |
| qwen_base = Qwen2VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen2-VL-2B-Instruct", | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| print("🔧 Loading YOUR fine-tuned LoRA adapter...") | |
| qwen_model = PeftModel.from_pretrained( | |
| qwen_base, | |
| "Bashaarat1/qwen-finetuned-scienceqa" | |
| ) | |
| qwen_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
| qwen_model.eval() | |
| print("✅ Qwen loaded!") | |
| # 3. MusicGen Model | |
| print("🎵 Loading MusicGen model...") | |
| music_processor = AutoProcessor.from_pretrained("Bashaarat1/fine-tuned-musicgen-small") | |
| music_model = MusicgenForConditionalGeneration.from_pretrained( | |
| "Bashaarat1/fine-tuned-musicgen-small" | |
| ).to(device) | |
| music_model.eval() | |
| print("✅ MusicGen loaded!") | |
| # 4. Stable Diffusion Model | |
| print("🎨 Loading Stable Diffusion model...") | |
| sd_pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| safety_checker=None | |
| ).to(device) | |
| print("✅ Stable Diffusion loaded!") | |
| print("\n🎉 All 4 models loaded successfully!\n") | |
| # ============================================================================ | |
| # INFERENCE FUNCTIONS | |
| # ============================================================================ | |
| def summarize_text(text, max_length=128): | |
| """Summarize text using fine-tuned T5""" | |
| if not text.strip(): | |
| return "⚠️ Please enter some text to summarize." | |
| try: | |
| inputs = t5_tokenizer( | |
| f"summarize: {text}", | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = t5_model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| min_length=30, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return f"📝 **Summary:**\n\n{summary}\n\n---\n*Original: {len(text.split())} words → Summary: {len(summary.split())} words*" | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |
| def answer_question(question, image=None): | |
| """Answer question with optional image using Qwen VLM""" | |
| if not question.strip(): | |
| return "⚠️ Please enter a question." | |
| try: | |
| if image is not None: | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": question} | |
| ] | |
| }] | |
| else: | |
| messages = [{ | |
| "role": "user", | |
| "content": [{"type": "text", "text": question}] | |
| }] | |
| text_prompt = qwen_processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| if image is not None: | |
| img_inputs, _ = process_vision_info(messages) | |
| inputs = qwen_processor( | |
| text=[text_prompt], | |
| images=img_inputs, | |
| return_tensors="pt" | |
| ).to(device) | |
| else: | |
| inputs = qwen_processor( | |
| text=[text_prompt], | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = qwen_model.generate(**inputs, max_new_tokens=200) | |
| answer = qwen_processor.batch_decode( | |
| outputs[:, inputs.input_ids.size(1):], | |
| skip_special_tokens=True | |
| )[0].strip() | |
| return f"💡 **Answer:**\n\n{answer}" | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |
| def generate_image(prompt, negative_prompt="", num_steps=25): | |
| """Generate image using Stable Diffusion""" | |
| if not prompt.strip(): | |
| return None, "⚠️ Please enter an image description." | |
| try: | |
| with torch.no_grad(): | |
| image = sd_pipe( | |
| prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=num_steps, | |
| guidance_scale=7.5 | |
| ).images[0] | |
| return image, f"✅ **Image generated!**\n\n*Prompt: {prompt}*" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| def generate_music(prompt, duration=10): | |
| """Generate music using MusicGen""" | |
| if not prompt.strip(): | |
| return None, "⚠️ Please enter a music description." | |
| try: | |
| inputs = music_processor( | |
| text=[prompt], | |
| padding=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| max_tokens = int(duration * 50) | |
| with torch.no_grad(): | |
| audio_values = music_model.generate(**inputs, max_new_tokens=max_tokens, do_sample=True) | |
| sampling_rate = music_model.config.audio_encoder.sampling_rate | |
| audio_data = audio_values[0, 0].cpu().numpy() | |
| return (sampling_rate, audio_data), f"✅ **Music generated!**\n\n*Prompt: {prompt}*\n*Duration: ~{duration} seconds*" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| # ============================================================================ | |
| # GRADIO UI | |
| # ============================================================================ | |
| with gr.Blocks(title="ContentForge AI") as demo: | |
| gr.Markdown(""" | |
| # 🎨 ContentForge AI | |
| **Multi-modal AI platform for education and social media content generation** | |
| Powered by state-of-the-art fine-tuned models: | |
| - 📝 Fine-tuned T5 (+46% improvement) | |
| - 🤖 Qwen2-VL with LoRA for science Q&A | |
| - 🎨 Stable Diffusion v1.5 | |
| - 🎵 Fine-tuned MusicGen | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("📚 Education Tools"): | |
| gr.Markdown("## AI-powered tools for learning and research") | |
| with gr.Tab("📝 Text Summarizer"): | |
| gr.Markdown("### Summarize academic papers, articles, and long texts") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sum_input = gr.Textbox( | |
| label="Text to Summarize", | |
| placeholder="Paste your academic paper, article, or long text here...", | |
| lines=10 | |
| ) | |
| sum_length = gr.Slider( | |
| minimum=50, | |
| maximum=200, | |
| value=128, | |
| step=10, | |
| label="Summary Length (words)" | |
| ) | |
| sum_button = gr.Button("🪄 Generate Summary", variant="primary", size="lg") | |
| with gr.Column(): | |
| sum_output = gr.Markdown(label="Summary") | |
| gr.Examples( | |
| examples=[ | |
| ["We present a novel approach to neural network optimization using adaptive learning rates. Our method dynamically adjusts the learning rate based on gradient statistics during training. Experiments on ImageNet show 15% improvement over standard SGD with minimal computational overhead."] | |
| ], | |
| inputs=sum_input | |
| ) | |
| sum_button.click( | |
| fn=summarize_text, | |
| inputs=[sum_input, sum_length], | |
| outputs=sum_output | |
| ) | |
| with gr.Tab("🤖 Q&A Assistant"): | |
| gr.Markdown("### Ask questions with optional image support") | |
| with gr.Row(): | |
| with gr.Column(): | |
| qa_question = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask anything...", | |
| lines=3 | |
| ) | |
| qa_image = gr.Image( | |
| label="Upload Image (Optional)", | |
| type="pil" | |
| ) | |
| qa_button = gr.Button("💬 Get Answer", variant="primary", size="lg") | |
| with gr.Column(): | |
| qa_output = gr.Markdown(label="Answer") | |
| gr.Examples( | |
| examples=[ | |
| ["What is machine learning?", None], | |
| ["Explain photosynthesis in simple terms.", None] | |
| ], | |
| inputs=[qa_question, qa_image] | |
| ) | |
| qa_button.click( | |
| fn=answer_question, | |
| inputs=[qa_question, qa_image], | |
| outputs=qa_output | |
| ) | |
| with gr.Tab("🎨 Social Media Tools"): | |
| gr.Markdown("## Create stunning content for your audience") | |
| with gr.Tab("🖼️ Image Generator"): | |
| gr.Markdown("### Generate professional images from text descriptions") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_prompt = gr.Textbox( | |
| label="Image Description", | |
| placeholder="Describe the image you want to generate...", | |
| lines=3 | |
| ) | |
| img_negative = gr.Textbox( | |
| label="Negative Prompt (Optional)", | |
| placeholder="What to avoid (e.g., blur, low quality, distorted)", | |
| lines=2 | |
| ) | |
| img_steps = gr.Slider( | |
| minimum=10, | |
| maximum=50, | |
| value=25, | |
| step=5, | |
| label="Quality (inference steps)" | |
| ) | |
| img_button = gr.Button("🎨 Generate Image", variant="primary", size="lg") | |
| with gr.Column(): | |
| img_output = gr.Image(label="Generated Image") | |
| img_status = gr.Markdown() | |
| gr.Examples( | |
| examples=[ | |
| ["A serene mountain landscape at sunset, photorealistic, 4k"] | |
| ], | |
| inputs=img_prompt | |
| ) | |
| img_button.click( | |
| fn=generate_image, | |
| inputs=[img_prompt, img_negative, img_steps], | |
| outputs=[img_output, img_status] | |
| ) | |
| with gr.Tab("🎵 Music Generator"): | |
| gr.Markdown("### Generate royalty-free music from text descriptions") | |
| with gr.Row(): | |
| with gr.Column(): | |
| music_prompt = gr.Textbox( | |
| label="Music Description", | |
| placeholder="Describe the music you want (mood, genre, instruments)...", | |
| lines=3 | |
| ) | |
| music_duration = gr.Slider( | |
| minimum=5, | |
| maximum=20, | |
| value=10, | |
| step=5, | |
| label="Duration (seconds)" | |
| ) | |
| music_button = gr.Button("🎼 Generate Music", variant="primary", size="lg") | |
| with gr.Column(): | |
| music_output = gr.Audio(label="Generated Music") | |
| music_status = gr.Markdown() | |
| gr.Examples( | |
| examples=[ | |
| ["upbeat electronic dance music with energetic drums"] | |
| ], | |
| inputs=music_prompt | |
| ) | |
| music_button.click( | |
| fn=generate_music, | |
| inputs=[music_prompt, music_duration], | |
| outputs=[music_output, music_status] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **About ContentForge AI** | |
| Multi-modal AI platform demonstrating fine-tuned models for education and social media. | |
| *Built with ❤️ using Gradio and Transformers* | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |