image-caption-generator/app.py
2025-01-14 13:48:54 +07:00

42 lines
1.4 KiB
Python

from fastapi import FastAPI, UploadFile, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image
import torch
import io
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def generate_caption(image: Image.Image) -> str:
"""Generate a caption for the uploaded image."""
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model.generate(**inputs)
caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
return caption
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/upload/")
async def upload_image(file: UploadFile):
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
caption = generate_caption(image)
return {"caption": caption}