e2evideo.feature_extractor
This module contains the code for extracting features from images using DINOv2 and img2vec.
1""" 2This module contains the code for extracting features from images using 3DINOv2 and img2vec. 4""" 5# pylint: disable=no-member 6# pylint: disable=protected-access 7import os 8import argparse 9from dataclasses import dataclass 10import logging 11import pandas as pd 12import torch 13from torchvision import models 14from torchvision import transforms 15from PIL import Image 16from tqdm import tqdm 17import fastdup 18from sklearn.manifold import TSNE 19import plotly.express as px 20 21 22logging.basicConfig( 23 level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 24) 25logger = logging.getLogger(__name__) 26logger.debug("Loaded dependencies successful!") 27 28 29@dataclass 30class FeatureExtractorConfig: 31 """Class to hold the configuration of the feature extractor.""" 32 33 input_path: str 34 output_path: str 35 36 37class FeatureExtractor: 38 """feature extractor class""" 39 40 def __init__(self, config: FeatureExtractorConfig): 41 self.config = config 42 self.scaler = transforms.Resize((224, 224)) 43 self.normalize = transforms.Normalize( 44 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 45 ) 46 self.to_tensor = transforms.ToTensor() 47 48 def get_vector(self, image_name, layer, model): 49 """get vector from image""" 50 img = Image.open(image_name) 51 t_img = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0) 52 my_embedding = torch.zeros(512) 53 54 def copy_data(m__, i__, output): 55 my_embedding.copy_(output.data.squeeze()) 56 57 temp_out = layer.register_forward_hook(copy_data) 58 model(t_img) 59 temp_out.remove() 60 return my_embedding 61 62 def extract_img_vector(self): 63 """extract image vector using resnet18""" 64 model = models.resnet18(pretrained=True) 65 layer = model._modules.get("avgpool") 66 model.eval() 67 logger.debug("%s loaded successful!", model) 68 69 vec_list = [] 70 71 for img_name in tqdm(os.listdir(self.config.input_path)): 72 img_path = os.path.join(self.config.input_path, img_name) 73 # check if the image has 3 channels 74 img = Image.open(img_path) 75 if len(img.split()) != 3: 76 continue 77 vec = self.get_vector(img_path, layer, model) 78 vec = vec.numpy().tolist() 79 vec_dict = {"Image_Name": img_name} 80 for i, vector_ in enumerate(vec): 81 vec_dict[f"Vector_{i}"] = vector_ 82 vec_list.append(vec_dict) 83 84 vec_df = pd.DataFrame(vec_list) 85 vec_df.to_csv(f"{self.config.output_path}/vec_df.csv", index=False) 86 return vec_df 87 88 def extract_dinov2_features(self): 89 """extract features using DINOv2""" 90 fd_model = fastdup.create( 91 input_dir=self.config.input_path, work_dir=self.config.output_path 92 ) 93 fd_model.run(model_path="dinov2s", cc_threshold=0.8) 94 95 filenames, feature_vec = fastdup.load_binary_feature( 96 f"{self.config.output_path}/atrain_features.dat", d=384 97 ) 98 logger.info("Embedding dimensions %s", feature_vec.shape) 99 return filenames, feature_vec 100 101 102def plot_tsne_3d(feature_vec, connected_components_df, filenames, output_path): 103 """ 104 Function to plot 3D t-SNE scatter plot and save it html file. 105 """ 106 tsne = TSNE(n_components=3, verbose=1, perplexity=40, n_iter=300) 107 tsne_result = tsne.fit_transform(feature_vec) 108 component_id = connected_components_df["component_id"].to_numpy() 109 results = pd.DataFrame( 110 { 111 "tsne_1": tsne_result[:, 0], 112 "tsne_2": tsne_result[:, 1], 113 "tsne_3": tsne_result[:, 2], 114 "component": component_id, 115 "filename": filenames, 116 } 117 ) 118 119 fig = px.scatter_3d( 120 results, 121 x="tsne_1", 122 y="tsne_2", 123 z="tsne_3", 124 color="component", 125 opacity=0.5, 126 hover_data=["component", "filename"], 127 ) 128 129 fig.write_html(output_path) 130 131 132def main(): 133 parser_ = argparse.ArgumentParser() 134 parser_.add_argument("--input_path", type=str) 135 parser_.add_argument("--output_path", type=str, default="./work_dir") 136 parser_.add_argument("--feature_extractor", type=str, default="dinov2") 137 args = parser_.parse_args() 138 139 feature_config = FeatureExtractorConfig(args.input_path, args.output_path) 140 fe = FeatureExtractor(feature_config) 141 142 if args.feature_extractor == "dinov2": 143 filenames_, feature_vec_ = fe.extract_dinov2_features() 144 145 connected_components_df_ = pd.read_csv( 146 os.path.join(args.output_path, "connected_components.csv") 147 ) 148 plot_tsne_3d( 149 feature_vec_, 150 connected_components_df_, 151 filenames_, 152 f"{args.output_path}/embeddings_dinvo2.html", 153 ) 154 elif args.feature_extractor == "img2vec": 155 fe.extract_img_vector() 156 157 158if __name__ == "__main__": 159 main()
logger =
<Logger e2evideo.feature_extractor (DEBUG)>
@dataclass
class
FeatureExtractorConfig:
30@dataclass 31class FeatureExtractorConfig: 32 """Class to hold the configuration of the feature extractor.""" 33 34 input_path: str 35 output_path: str
Class to hold the configuration of the feature extractor.
class
FeatureExtractor:
38class FeatureExtractor: 39 """feature extractor class""" 40 41 def __init__(self, config: FeatureExtractorConfig): 42 self.config = config 43 self.scaler = transforms.Resize((224, 224)) 44 self.normalize = transforms.Normalize( 45 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 46 ) 47 self.to_tensor = transforms.ToTensor() 48 49 def get_vector(self, image_name, layer, model): 50 """get vector from image""" 51 img = Image.open(image_name) 52 t_img = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0) 53 my_embedding = torch.zeros(512) 54 55 def copy_data(m__, i__, output): 56 my_embedding.copy_(output.data.squeeze()) 57 58 temp_out = layer.register_forward_hook(copy_data) 59 model(t_img) 60 temp_out.remove() 61 return my_embedding 62 63 def extract_img_vector(self): 64 """extract image vector using resnet18""" 65 model = models.resnet18(pretrained=True) 66 layer = model._modules.get("avgpool") 67 model.eval() 68 logger.debug("%s loaded successful!", model) 69 70 vec_list = [] 71 72 for img_name in tqdm(os.listdir(self.config.input_path)): 73 img_path = os.path.join(self.config.input_path, img_name) 74 # check if the image has 3 channels 75 img = Image.open(img_path) 76 if len(img.split()) != 3: 77 continue 78 vec = self.get_vector(img_path, layer, model) 79 vec = vec.numpy().tolist() 80 vec_dict = {"Image_Name": img_name} 81 for i, vector_ in enumerate(vec): 82 vec_dict[f"Vector_{i}"] = vector_ 83 vec_list.append(vec_dict) 84 85 vec_df = pd.DataFrame(vec_list) 86 vec_df.to_csv(f"{self.config.output_path}/vec_df.csv", index=False) 87 return vec_df 88 89 def extract_dinov2_features(self): 90 """extract features using DINOv2""" 91 fd_model = fastdup.create( 92 input_dir=self.config.input_path, work_dir=self.config.output_path 93 ) 94 fd_model.run(model_path="dinov2s", cc_threshold=0.8) 95 96 filenames, feature_vec = fastdup.load_binary_feature( 97 f"{self.config.output_path}/atrain_features.dat", d=384 98 ) 99 logger.info("Embedding dimensions %s", feature_vec.shape) 100 return filenames, feature_vec
feature extractor class
FeatureExtractor(config: FeatureExtractorConfig)
def
get_vector(self, image_name, layer, model):
49 def get_vector(self, image_name, layer, model): 50 """get vector from image""" 51 img = Image.open(image_name) 52 t_img = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0) 53 my_embedding = torch.zeros(512) 54 55 def copy_data(m__, i__, output): 56 my_embedding.copy_(output.data.squeeze()) 57 58 temp_out = layer.register_forward_hook(copy_data) 59 model(t_img) 60 temp_out.remove() 61 return my_embedding
get vector from image
def
extract_img_vector(self):
63 def extract_img_vector(self): 64 """extract image vector using resnet18""" 65 model = models.resnet18(pretrained=True) 66 layer = model._modules.get("avgpool") 67 model.eval() 68 logger.debug("%s loaded successful!", model) 69 70 vec_list = [] 71 72 for img_name in tqdm(os.listdir(self.config.input_path)): 73 img_path = os.path.join(self.config.input_path, img_name) 74 # check if the image has 3 channels 75 img = Image.open(img_path) 76 if len(img.split()) != 3: 77 continue 78 vec = self.get_vector(img_path, layer, model) 79 vec = vec.numpy().tolist() 80 vec_dict = {"Image_Name": img_name} 81 for i, vector_ in enumerate(vec): 82 vec_dict[f"Vector_{i}"] = vector_ 83 vec_list.append(vec_dict) 84 85 vec_df = pd.DataFrame(vec_list) 86 vec_df.to_csv(f"{self.config.output_path}/vec_df.csv", index=False) 87 return vec_df
extract image vector using resnet18
def
extract_dinov2_features(self):
89 def extract_dinov2_features(self): 90 """extract features using DINOv2""" 91 fd_model = fastdup.create( 92 input_dir=self.config.input_path, work_dir=self.config.output_path 93 ) 94 fd_model.run(model_path="dinov2s", cc_threshold=0.8) 95 96 filenames, feature_vec = fastdup.load_binary_feature( 97 f"{self.config.output_path}/atrain_features.dat", d=384 98 ) 99 logger.info("Embedding dimensions %s", feature_vec.shape) 100 return filenames, feature_vec
extract features using DINOv2
def
plot_tsne_3d(feature_vec, connected_components_df, filenames, output_path):
103def plot_tsne_3d(feature_vec, connected_components_df, filenames, output_path): 104 """ 105 Function to plot 3D t-SNE scatter plot and save it html file. 106 """ 107 tsne = TSNE(n_components=3, verbose=1, perplexity=40, n_iter=300) 108 tsne_result = tsne.fit_transform(feature_vec) 109 component_id = connected_components_df["component_id"].to_numpy() 110 results = pd.DataFrame( 111 { 112 "tsne_1": tsne_result[:, 0], 113 "tsne_2": tsne_result[:, 1], 114 "tsne_3": tsne_result[:, 2], 115 "component": component_id, 116 "filename": filenames, 117 } 118 ) 119 120 fig = px.scatter_3d( 121 results, 122 x="tsne_1", 123 y="tsne_2", 124 z="tsne_3", 125 color="component", 126 opacity=0.5, 127 hover_data=["component", "filename"], 128 ) 129 130 fig.write_html(output_path)
Function to plot 3D t-SNE scatter plot and save it html file.
def
main():
133def main(): 134 parser_ = argparse.ArgumentParser() 135 parser_.add_argument("--input_path", type=str) 136 parser_.add_argument("--output_path", type=str, default="./work_dir") 137 parser_.add_argument("--feature_extractor", type=str, default="dinov2") 138 args = parser_.parse_args() 139 140 feature_config = FeatureExtractorConfig(args.input_path, args.output_path) 141 fe = FeatureExtractor(feature_config) 142 143 if args.feature_extractor == "dinov2": 144 filenames_, feature_vec_ = fe.extract_dinov2_features() 145 146 connected_components_df_ = pd.read_csv( 147 os.path.join(args.output_path, "connected_components.csv") 148 ) 149 plot_tsne_3d( 150 feature_vec_, 151 connected_components_df_, 152 filenames_, 153 f"{args.output_path}/embeddings_dinvo2.html", 154 ) 155 elif args.feature_extractor == "img2vec": 156 fe.extract_img_vector()