DataScience
728x90

22분 짜리 영상을 1분으로 요약하는 ViT모델

 

데이터

데이터 분야 - AI 데이터찾기 - AI-Hub (aihub.or.kr)

 

AI-Hub

샘플 데이터 ? ※샘플데이터는 데이터의 이해를 돕기 위해 별도로 가공하여 제공하는 정보로써 원본 데이터와 차이가 있을 수 있으며, 데이터에 따라서 민감한 정보는 일부 마스킹(*) 처리가 되

aihub.or.kr

학습모델

영상요약모델 - Google Drive

 

영상요약모델 - Google Drive

이 폴더에 파일이 없습니다.이 폴더에 파일을 추가하려면 로그인하세요.

drive.google.com

 

run

import torch
from training.summary.datamodule import SummaryDataset
from transformers import ViTImageProcessor
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
import seaborn as sns
import numpy as np
from moviepy.editor import VideoFileClip, concatenate_videoclips
from v2021 import SummaryModel

preprocessor = ViTImageProcessor.from_pretrained(
    "google/vit-base-patch16-224", size=224, device='cuda'
)

SAMPLE_EVERY_SEC = 2
video_path = 'videos/news2.mp4'
cap = cv2.VideoCapture(video_path)
n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fps = cap.get(cv2.CAP_PROP_FPS)
video_len = n_frames / fps
print(f'Video length {video_len:.2f} seconds!')

frames = []
last_collected = -1
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    timestamp = cap.get(cv2.CAP_PROP_POS_MSEC)
    second = timestamp // 1000

    if second % SAMPLE_EVERY_SEC == 0 and second != last_collected:
        last_collected = second
        frames.append(frame)
features = preprocessor(images=frames, return_tensors="pt")["pixel_values"]
print(features.shape)


model = SummaryModel.load_from_checkpoint('summary.ckpt')
model.to('cuda')
model.eval()


features = features.to('cuda')
y_pred = []
for frame in tqdm(features):
    y_p = model(frame.unsqueeze(0))
    y_p = torch.sigmoid(y_p)

    y_pred.append(y_p.cpu().detach().numpy().squeeze())

y_pred = np.array(y_pred)
sns.displot(y_pred)

THRESHOLD = 0.67
total_secs = 0

for i, y_p in enumerate(y_pred):
    if y_p >= THRESHOLD:
        print(i * SAMPLE_EVERY_SEC)
        total_secs += SAMPLE_EVERY_SEC

total_secs


clip = VideoFileClip(video_path)
subclips = []
for i, y_p in enumerate(y_pred):
    sec = i * SAMPLE_EVERY_SEC

    if y_p >= THRESHOLD:
        subclip = clip.subclip(sec, sec + SAMPLE_EVERY_SEC)
        subclips.append(subclip)
result = concatenate_videoclips(subclips)
result.write_videofile("videos/result.mp4")
result.ipython_display(width=640, maxduration=240)

 

model

import argparse
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
from pytorch_lightning.core.lightning import LightningModule
from torch import optim
# from torchmetrics import F1
from transformers import ViTModel


class SummaryModel(LightningModule):
    def __init__(self, hidden_dim=768, individual_logs=None):
        super().__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.scorer = nn.Linear(hidden_dim, 1)
        # self.sigmoid = nn.Sigmoid()
        self.loss = nn.BCEWithLogitsLoss()
        # self.train_f1 = F1()
        # self.val_f1 = F1()
        # self.test_f1 = F1()
        self.individual_logs = individual_logs
        self.tta_logs = defaultdict(list)

    def forward(self, x):
        x = self.vit(x).pooler_output
        x = self.scorer(x)
        # x = self.sigmoid(x)
        return x

    def run_batch(self, batch, batch_idx, metric, training=False):
        video_name, image_features, labels = batch
        video_name = video_name[0]
        image_features = image_features.squeeze(0)
        labels = labels.squeeze(0)

        # Score - aggregated labels.
        score = torch.sum(labels, dim=0)
        score = torch.min(
            score,
            torch.ones(
                score.shape[0],
            ).to(score.device),
        )
        out = self(image_features).squeeze(1)
        try:
            loss = self.loss(out.double(), score)
            preds = (torch.sigmoid(out) > 0.7).int()
            metric.update(preds, score.int())
            f1 = metric.compute()
            tp, fp, tn, fn = metric._get_final_stats()
            self.tta_logs[video_name].append((tp.item(), fp.item(), fn.item()))
        except Exception as e:
            print(e)
            loss = 0
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.run_batch(batch, batch_idx, self.train_f1, training=True)
        self.log("train_loss", loss)
        return loss

    def training_epoch_end(self, training_step_outputs):
        self.log("train_f1", self.train_f1.compute())
        self.train_f1.reset()

    def validation_step(self, batch, batch_idx):
        loss = self.run_batch(batch, batch_idx, self.val_f1)
        self.log("val_loss", loss)
        return loss

    def validation_epoch_end(self, validation_step_outputs):
        self.log("val_f1", self.val_f1.compute())
        self.val_f1.reset()

    def test_step(self, batch, batch_idx):
        loss = self.run_batch(batch, batch_idx, self.test_f1)
        self.log("test_loss", loss)
        return loss

    def test_epoch_end(self, outputs):
        f1 = self.test_f1.compute()
        self.log("test_f1", f1)
        tp, fp, tn, fn = self.test_f1._get_final_stats()
        print(f"\nTest f1: {f1}, TP: {tp}, FP: {fp}, TN: {tn}, fn: {fn}")
        self.test_f1.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        return optimizer

 

 

출처

빵형의 개발도상국 - YouTube

'영상처리 > 기초' 카테고리의 다른 글

Image Processing 기초  (141) 2023.05.03
YOLOv8 imagesegmentation  (50) 2023.02.26
Vision Transformer(ViT) 리뷰  (16) 2023.02.26
GAN(Generative Adversarial Networks)  (92) 2023.02.15
VGG-Net 리뷰  (60) 2023.01.25
profile

DataScience

@Ninestar

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!