DataScience
728x90

1. 22분 짜리 영상을 1분으로 요약하는 ViT모델

 

2. 데이터

데이터 분야 - AI 데이터찾기 - AI-Hub (aihub.or.kr)

 

AI-Hub

샘플 데이터 ? ※샘플데이터는 데이터의 이해를 돕기 위해 별도로 가공하여 제공하는 정보로써 원본 데이터와 차이가 있을 수 있으며, 데이터에 따라서 민감한 정보는 일부 마스킹(*) 처리가 되

aihub.or.kr

3. 학습모델

영상요약모델 - Google Drive

 

영상요약모델 - Google Drive

이 폴더에 파일이 없습니다.이 폴더에 파일을 추가하려면 로그인하세요.

drive.google.com

4.  

5. run

<python />
import torch from training.summary.datamodule import SummaryDataset from transformers import ViTImageProcessor from tqdm import tqdm import matplotlib.pyplot as plt import cv2 import seaborn as sns import numpy as np from moviepy.editor import VideoFileClip, concatenate_videoclips from v2021 import SummaryModel preprocessor = ViTImageProcessor.from_pretrained( "google/vit-base-patch16-224", size=224, device='cuda' ) SAMPLE_EVERY_SEC = 2 video_path = 'videos/news2.mp4' cap = cv2.VideoCapture(video_path) n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) fps = cap.get(cv2.CAP_PROP_FPS) video_len = n_frames / fps print(f'Video length {video_len:.2f} seconds!') frames = [] last_collected = -1 while cap.isOpened(): ret, frame = cap.read() if not ret: break timestamp = cap.get(cv2.CAP_PROP_POS_MSEC) second = timestamp // 1000 if second % SAMPLE_EVERY_SEC == 0 and second != last_collected: last_collected = second frames.append(frame) features = preprocessor(images=frames, return_tensors="pt")["pixel_values"] print(features.shape) model = SummaryModel.load_from_checkpoint('summary.ckpt') model.to('cuda') model.eval() features = features.to('cuda') y_pred = [] for frame in tqdm(features): y_p = model(frame.unsqueeze(0)) y_p = torch.sigmoid(y_p) y_pred.append(y_p.cpu().detach().numpy().squeeze()) y_pred = np.array(y_pred) sns.displot(y_pred) THRESHOLD = 0.67 total_secs = 0 for i, y_p in enumerate(y_pred): if y_p >= THRESHOLD: print(i * SAMPLE_EVERY_SEC) total_secs += SAMPLE_EVERY_SEC total_secs clip = VideoFileClip(video_path) subclips = [] for i, y_p in enumerate(y_pred): sec = i * SAMPLE_EVERY_SEC if y_p >= THRESHOLD: subclip = clip.subclip(sec, sec + SAMPLE_EVERY_SEC) subclips.append(subclip) result = concatenate_videoclips(subclips) result.write_videofile("videos/result.mp4") result.ipython_display(width=640, maxduration=240)

 

6. model

<python />
import argparse from collections import defaultdict import numpy as np import torch import torch.nn as nn from pytorch_lightning.core.lightning import LightningModule from torch import optim # from torchmetrics import F1 from transformers import ViTModel class SummaryModel(LightningModule): def __init__(self, hidden_dim=768, individual_logs=None): super().__init__() self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") self.scorer = nn.Linear(hidden_dim, 1) # self.sigmoid = nn.Sigmoid() self.loss = nn.BCEWithLogitsLoss() # self.train_f1 = F1() # self.val_f1 = F1() # self.test_f1 = F1() self.individual_logs = individual_logs self.tta_logs = defaultdict(list) def forward(self, x): x = self.vit(x).pooler_output x = self.scorer(x) # x = self.sigmoid(x) return x def run_batch(self, batch, batch_idx, metric, training=False): video_name, image_features, labels = batch video_name = video_name[0] image_features = image_features.squeeze(0) labels = labels.squeeze(0) # Score - aggregated labels. score = torch.sum(labels, dim=0) score = torch.min( score, torch.ones( score.shape[0], ).to(score.device), ) out = self(image_features).squeeze(1) try: loss = self.loss(out.double(), score) preds = (torch.sigmoid(out) > 0.7).int() metric.update(preds, score.int()) f1 = metric.compute() tp, fp, tn, fn = metric._get_final_stats() self.tta_logs[video_name].append((tp.item(), fp.item(), fn.item())) except Exception as e: print(e) loss = 0 return loss def training_step(self, batch, batch_idx): loss = self.run_batch(batch, batch_idx, self.train_f1, training=True) self.log("train_loss", loss) return loss def training_epoch_end(self, training_step_outputs): self.log("train_f1", self.train_f1.compute()) self.train_f1.reset() def validation_step(self, batch, batch_idx): loss = self.run_batch(batch, batch_idx, self.val_f1) self.log("val_loss", loss) return loss def validation_epoch_end(self, validation_step_outputs): self.log("val_f1", self.val_f1.compute()) self.val_f1.reset() def test_step(self, batch, batch_idx): loss = self.run_batch(batch, batch_idx, self.test_f1) self.log("test_loss", loss) return loss def test_epoch_end(self, outputs): f1 = self.test_f1.compute() self.log("test_f1", f1) tp, fp, tn, fn = self.test_f1._get_final_stats() print(f"\nTest f1: {f1}, TP: {tp}, FP: {fp}, TN: {tn}, fn: {fn}") self.test_f1.reset() def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4) return optimizer

 

 

출처

빵형의 개발도상국 - YouTube

'영상처리 > 기초' 카테고리의 다른 글

Image Processing 기초  (141) 2023.05.03
YOLOv8 imagesegmentation  (50) 2023.02.26
Vision Transformer(ViT) 리뷰  (16) 2023.02.26
GAN(Generative Adversarial Networks)  (92) 2023.02.15
VGG-Net 리뷰  (60) 2023.01.25
profile

DataScience

@Ninestar

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!