728x90
22분 짜리 영상을 1분으로 요약하는 ViT모델
데이터
데이터 분야 - AI 데이터찾기 - AI-Hub (aihub.or.kr)
AI-Hub
샘플 데이터 ? ※샘플데이터는 데이터의 이해를 돕기 위해 별도로 가공하여 제공하는 정보로써 원본 데이터와 차이가 있을 수 있으며, 데이터에 따라서 민감한 정보는 일부 마스킹(*) 처리가 되
aihub.or.kr
학습모델
영상요약모델 - Google Drive
이 폴더에 파일이 없습니다.이 폴더에 파일을 추가하려면 로그인하세요.
drive.google.com
run
import torch
from training.summary.datamodule import SummaryDataset
from transformers import ViTImageProcessor
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
import seaborn as sns
import numpy as np
from moviepy.editor import VideoFileClip, concatenate_videoclips
from v2021 import SummaryModel
preprocessor = ViTImageProcessor.from_pretrained(
"google/vit-base-patch16-224", size=224, device='cuda'
)
SAMPLE_EVERY_SEC = 2
video_path = 'videos/news2.mp4'
cap = cv2.VideoCapture(video_path)
n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fps = cap.get(cv2.CAP_PROP_FPS)
video_len = n_frames / fps
print(f'Video length {video_len:.2f} seconds!')
frames = []
last_collected = -1
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
timestamp = cap.get(cv2.CAP_PROP_POS_MSEC)
second = timestamp // 1000
if second % SAMPLE_EVERY_SEC == 0 and second != last_collected:
last_collected = second
frames.append(frame)
features = preprocessor(images=frames, return_tensors="pt")["pixel_values"]
print(features.shape)
model = SummaryModel.load_from_checkpoint('summary.ckpt')
model.to('cuda')
model.eval()
features = features.to('cuda')
y_pred = []
for frame in tqdm(features):
y_p = model(frame.unsqueeze(0))
y_p = torch.sigmoid(y_p)
y_pred.append(y_p.cpu().detach().numpy().squeeze())
y_pred = np.array(y_pred)
sns.displot(y_pred)
THRESHOLD = 0.67
total_secs = 0
for i, y_p in enumerate(y_pred):
if y_p >= THRESHOLD:
print(i * SAMPLE_EVERY_SEC)
total_secs += SAMPLE_EVERY_SEC
total_secs
clip = VideoFileClip(video_path)
subclips = []
for i, y_p in enumerate(y_pred):
sec = i * SAMPLE_EVERY_SEC
if y_p >= THRESHOLD:
subclip = clip.subclip(sec, sec + SAMPLE_EVERY_SEC)
subclips.append(subclip)
result = concatenate_videoclips(subclips)
result.write_videofile("videos/result.mp4")
result.ipython_display(width=640, maxduration=240)
model
import argparse
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from pytorch_lightning.core.lightning import LightningModule
from torch import optim
# from torchmetrics import F1
from transformers import ViTModel
class SummaryModel(LightningModule):
def __init__(self, hidden_dim=768, individual_logs=None):
super().__init__()
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
self.scorer = nn.Linear(hidden_dim, 1)
# self.sigmoid = nn.Sigmoid()
self.loss = nn.BCEWithLogitsLoss()
# self.train_f1 = F1()
# self.val_f1 = F1()
# self.test_f1 = F1()
self.individual_logs = individual_logs
self.tta_logs = defaultdict(list)
def forward(self, x):
x = self.vit(x).pooler_output
x = self.scorer(x)
# x = self.sigmoid(x)
return x
def run_batch(self, batch, batch_idx, metric, training=False):
video_name, image_features, labels = batch
video_name = video_name[0]
image_features = image_features.squeeze(0)
labels = labels.squeeze(0)
# Score - aggregated labels.
score = torch.sum(labels, dim=0)
score = torch.min(
score,
torch.ones(
score.shape[0],
).to(score.device),
)
out = self(image_features).squeeze(1)
try:
loss = self.loss(out.double(), score)
preds = (torch.sigmoid(out) > 0.7).int()
metric.update(preds, score.int())
f1 = metric.compute()
tp, fp, tn, fn = metric._get_final_stats()
self.tta_logs[video_name].append((tp.item(), fp.item(), fn.item()))
except Exception as e:
print(e)
loss = 0
return loss
def training_step(self, batch, batch_idx):
loss = self.run_batch(batch, batch_idx, self.train_f1, training=True)
self.log("train_loss", loss)
return loss
def training_epoch_end(self, training_step_outputs):
self.log("train_f1", self.train_f1.compute())
self.train_f1.reset()
def validation_step(self, batch, batch_idx):
loss = self.run_batch(batch, batch_idx, self.val_f1)
self.log("val_loss", loss)
return loss
def validation_epoch_end(self, validation_step_outputs):
self.log("val_f1", self.val_f1.compute())
self.val_f1.reset()
def test_step(self, batch, batch_idx):
loss = self.run_batch(batch, batch_idx, self.test_f1)
self.log("test_loss", loss)
return loss
def test_epoch_end(self, outputs):
f1 = self.test_f1.compute()
self.log("test_f1", f1)
tp, fp, tn, fn = self.test_f1._get_final_stats()
print(f"\nTest f1: {f1}, TP: {tp}, FP: {fp}, TN: {tn}, fn: {fn}")
self.test_f1.reset()
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
return optimizer
출처
'영상처리 > 기초' 카테고리의 다른 글
Image Processing 기초 (141) | 2023.05.03 |
---|---|
YOLOv8 imagesegmentation (50) | 2023.02.26 |
Vision Transformer(ViT) 리뷰 (16) | 2023.02.26 |
GAN(Generative Adversarial Networks) (92) | 2023.02.15 |
VGG-Net 리뷰 (60) | 2023.01.25 |