728x90
22분 짜리 영상을 1분으로 요약하는 ViT모델
데이터
데이터 분야 - AI 데이터찾기 - AI-Hub (aihub.or.kr)
학습모델
run
import torch
from training.summary.datamodule import SummaryDataset
from transformers import ViTImageProcessor
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
import seaborn as sns
import numpy as np
from moviepy.editor import VideoFileClip, concatenate_videoclips
from v2021 import SummaryModel
preprocessor = ViTImageProcessor.from_pretrained(
"google/vit-base-patch16-224", size=224, device='cuda'
)
SAMPLE_EVERY_SEC = 2
video_path = 'videos/news2.mp4'
cap = cv2.VideoCapture(video_path)
n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fps = cap.get(cv2.CAP_PROP_FPS)
video_len = n_frames / fps
print(f'Video length {video_len:.2f} seconds!')
frames = []
last_collected = -1
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
timestamp = cap.get(cv2.CAP_PROP_POS_MSEC)
second = timestamp // 1000
if second % SAMPLE_EVERY_SEC == 0 and second != last_collected:
last_collected = second
frames.append(frame)
features = preprocessor(images=frames, return_tensors="pt")["pixel_values"]
print(features.shape)
model = SummaryModel.load_from_checkpoint('summary.ckpt')
model.to('cuda')
model.eval()
features = features.to('cuda')
y_pred = []
for frame in tqdm(features):
y_p = model(frame.unsqueeze(0))
y_p = torch.sigmoid(y_p)
y_pred.append(y_p.cpu().detach().numpy().squeeze())
y_pred = np.array(y_pred)
sns.displot(y_pred)
THRESHOLD = 0.67
total_secs = 0
for i, y_p in enumerate(y_pred):
if y_p >= THRESHOLD:
print(i * SAMPLE_EVERY_SEC)
total_secs += SAMPLE_EVERY_SEC
total_secs
clip = VideoFileClip(video_path)
subclips = []
for i, y_p in enumerate(y_pred):
sec = i * SAMPLE_EVERY_SEC
if y_p >= THRESHOLD:
subclip = clip.subclip(sec, sec + SAMPLE_EVERY_SEC)
subclips.append(subclip)
result = concatenate_videoclips(subclips)
result.write_videofile("videos/result.mp4")
result.ipython_display(width=640, maxduration=240)
model
import argparse
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from pytorch_lightning.core.lightning import LightningModule
from torch import optim
# from torchmetrics import F1
from transformers import ViTModel
class SummaryModel(LightningModule):
def __init__(self, hidden_dim=768, individual_logs=None):
super().__init__()
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
self.scorer = nn.Linear(hidden_dim, 1)
# self.sigmoid = nn.Sigmoid()
self.loss = nn.BCEWithLogitsLoss()
# self.train_f1 = F1()
# self.val_f1 = F1()
# self.test_f1 = F1()
self.individual_logs = individual_logs
self.tta_logs = defaultdict(list)
def forward(self, x):
x = self.vit(x).pooler_output
x = self.scorer(x)
# x = self.sigmoid(x)
return x
def run_batch(self, batch, batch_idx, metric, training=False):
video_name, image_features, labels = batch
video_name = video_name[0]
image_features = image_features.squeeze(0)
labels = labels.squeeze(0)
# Score - aggregated labels.
score = torch.sum(labels, dim=0)
score = torch.min(
score,
torch.ones(
score.shape[0],
).to(score.device),
)
out = self(image_features).squeeze(1)
try:
loss = self.loss(out.double(), score)
preds = (torch.sigmoid(out) > 0.7).int()
metric.update(preds, score.int())
f1 = metric.compute()
tp, fp, tn, fn = metric._get_final_stats()
self.tta_logs[video_name].append((tp.item(), fp.item(), fn.item()))
except Exception as e:
print(e)
loss = 0
return loss
def training_step(self, batch, batch_idx):
loss = self.run_batch(batch, batch_idx, self.train_f1, training=True)
self.log("train_loss", loss)
return loss
def training_epoch_end(self, training_step_outputs):
self.log("train_f1", self.train_f1.compute())
self.train_f1.reset()
def validation_step(self, batch, batch_idx):
loss = self.run_batch(batch, batch_idx, self.val_f1)
self.log("val_loss", loss)
return loss
def validation_epoch_end(self, validation_step_outputs):
self.log("val_f1", self.val_f1.compute())
self.val_f1.reset()
def test_step(self, batch, batch_idx):
loss = self.run_batch(batch, batch_idx, self.test_f1)
self.log("test_loss", loss)
return loss
def test_epoch_end(self, outputs):
f1 = self.test_f1.compute()
self.log("test_f1", f1)
tp, fp, tn, fn = self.test_f1._get_final_stats()
print(f"\nTest f1: {f1}, TP: {tp}, FP: {fp}, TN: {tn}, fn: {fn}")
self.test_f1.reset()
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
return optimizer
출처
'영상처리 > 기초' 카테고리의 다른 글
Image Processing 기초 (141) | 2023.05.03 |
---|---|
YOLOv8 imagesegmentation (50) | 2023.02.26 |
Vision Transformer(ViT) 리뷰 (16) | 2023.02.26 |
GAN(Generative Adversarial Networks) (92) | 2023.02.15 |
VGG-Net 리뷰 (60) | 2023.01.25 |