import gradio as gr import torch from PIL import Image import io import contextlib import os # To ensure deepseek_ocr custom code is available, # it's best to specify the model path as 'hf_repo_id' directly # and use trust_remote_code. # The Unsloth library will handle patching and loading correctly. from unsloth import FastVisionModel from transformers import AutoTokenizer # AutoTokenizer is not directly used for loading DeepSeek-OCR's custom tokenizer through FastVisionModel, but is a common import # Replace with your actual Hugging Face repo ID HF_REPO_ID = "obay123/finetuned-deepseek-ocr-model" # Load model and tokenizer # trust_remote_code is crucial for DeepSeek-OCR to load its custom modeling_deepseekocr.py # load_in_4bit can be True if you want to reduce VRAM for inference on the merged model # but False is generally better for full 16-bit inference quality. try: model, tokenizer = FastVisionModel.from_pretrained( HF_REPO_ID, load_in_4bit=False, # Use False for full precision inference auto_model=None, # AutoModel is not used directly, FastVisionModel handles it trust_remote_code=True, unsloth_force_compile=False, # Not needed for inference after finetuning ) except Exception as e: print(f"Error loading model from Hugging Face Hub: {e}") print("Attempting to load from local deepseek_ocr if available for debugging/local testing.") # Fallback for local testing if HF_REPO_ID is not fully setup or has issues model, tokenizer = FastVisionModel.from_pretrained( "./deepseek_ocr", # This path needs to be available in the Space or downloaded. load_in_4bit=False, auto_model=None, trust_remote_code=True, unsloth_force_compile=False, ) print("Loaded model from local deepseek_ocr as a fallback.") # Set model to evaluation mode model.eval() # Define the OCR function for Gradio def ocr_image(image: Image.Image): if image is None: return "Please upload an image." # Save the image temporarily to a file, as model.infer expects a path image_path = "temp_input_image.png" image.save(image_path) prompt = "\nFree OCR. " # Capture stdout from model.infer f = io.StringIO() with contextlib.redirect_stdout(f): model.infer( tokenizer, prompt=prompt, image_file=image_path, output_path=".", # Output path for generated files, not directly used for text result image_size=640, base_size=1024, crop_mode=True, save_results=False, # We only need the text output test_compress=False, max_new_tokens=512, # Adjust as needed for longer texts ) s = f.getvalue() lines = s.strip().split('\n') filtered_lines = [line for line in lines if line.strip()] # Extract the full OCR result by finding the start and end markers # This logic assumes the output format of model.infer ocr_start_marker_count = 0 ocr_start_index = -1 ocr_end_index = -1 for i, line in enumerate(filtered_lines): if line == '=====================': ocr_start_marker_count += 1 if ocr_start_marker_count == 2: ocr_start_index = i + 1 elif line == '===============save results:===============': ocr_end_index = i break res = '' if ocr_start_index != -1 and ocr_end_index != -1 and ocr_start_index < ocr_end_index: ocr_result_lines = filtered_lines[ocr_start_index : ocr_end_index] res = '\n'.join(ocr_result_lines).strip() else: # Fallback if markers are not found, or if there's only one marker before output # Heuristic: skip initial '=====', BASE/PATCHES lines. # This part might need adjustment based on the exact output of model.infer if len(filtered_lines) > 4: # At least enough lines to contain some output temp_output = [] capture_output = False for line in filtered_lines: if line == '=====================': if capture_output: # Second '=====' means end of model output break capture_output = True # First '=====' means start of model output elif capture_output and not (line.startswith("BASE:") or line.startswith("PATCHES:")): temp_output.append(line) res = '\n'.join(temp_output).strip() # Clean up temporary image file os.remove(image_path) return res # Create Gradio interface if __name__ == "__main__": gr.Interface( fn=ocr_image, inputs=gr.Image(type="pil", label="Upload Image"), outputs="text", title="DeepSeek-OCR Fine-tuned Model", description="Upload an image to perform OCR using a fine-tuned DeepSeek-OCR model. The model will extract text from the image.", live=False, # Set to False for better stability in deployed Spaces ).launch()