from fastapi import FastAPI, UploadFile, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer from PIL import Image import torch import io app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def generate_caption(image: Image.Image) -> str: """Generate a caption for the uploaded image.""" inputs = processor(images=image, return_tensors="pt").to(device) outputs = model.generate(**inputs) caption = tokenizer.decode(outputs[0], skip_special_tokens=True) return caption @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/upload/") async def upload_image(file: UploadFile): image = Image.open(io.BytesIO(await file.read())).convert("RGB") caption = generate_caption(image) return {"caption": caption}