mirror of
https://github.com/Sosokker/image-caption-generator.git
synced 2025-12-18 01:54:05 +01:00
42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
from fastapi import FastAPI, UploadFile, Request
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
|
|
from PIL import Image
|
|
import torch
|
|
import io
|
|
|
|
app = FastAPI()
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
templates = Jinja2Templates(directory="templates")
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
|
processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
|
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model.to(device)
|
|
|
|
|
|
def generate_caption(image: Image.Image) -> str:
|
|
"""Generate a caption for the uploaded image."""
|
|
inputs = processor(images=image, return_tensors="pt").to(device)
|
|
outputs = model.generate(**inputs)
|
|
caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return caption
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def home(request: Request):
|
|
return templates.TemplateResponse("index.html", {"request": request})
|
|
|
|
|
|
@app.post("/upload/")
|
|
async def upload_image(file: UploadFile):
|
|
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
|
|
caption = generate_caption(image)
|
|
return {"caption": caption}
|