e2evideo.feature_extractor

This module contains the code for extracting features from images using DINOv2 and img2vec.

  1"""
  2This module contains the code for extracting features from images using
  3DINOv2 and img2vec.
  4"""
  5# pylint: disable=no-member
  6# pylint: disable=protected-access
  7import os
  8import argparse
  9from dataclasses import dataclass
 10import logging
 11import pandas as pd
 12import torch
 13from torchvision import models
 14from torchvision import transforms
 15from PIL import Image
 16from tqdm import tqdm
 17import fastdup
 18from sklearn.manifold import TSNE
 19import plotly.express as px
 20
 21
 22logging.basicConfig(
 23    level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
 24)
 25logger = logging.getLogger(__name__)
 26logger.debug("Loaded dependencies successful!")
 27
 28
 29@dataclass
 30class FeatureExtractorConfig:
 31    """Class to hold the configuration of the feature extractor."""
 32
 33    input_path: str
 34    output_path: str
 35
 36
 37class FeatureExtractor:
 38    """feature extractor class"""
 39
 40    def __init__(self, config: FeatureExtractorConfig):
 41        self.config = config
 42        self.scaler = transforms.Resize((224, 224))
 43        self.normalize = transforms.Normalize(
 44            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 45        )
 46        self.to_tensor = transforms.ToTensor()
 47
 48    def get_vector(self, image_name, layer, model):
 49        """get vector from image"""
 50        img = Image.open(image_name)
 51        t_img = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0)
 52        my_embedding = torch.zeros(512)
 53
 54        def copy_data(m__, i__, output):
 55            my_embedding.copy_(output.data.squeeze())
 56
 57        temp_out = layer.register_forward_hook(copy_data)
 58        model(t_img)
 59        temp_out.remove()
 60        return my_embedding
 61
 62    def extract_img_vector(self):
 63        """extract image vector using resnet18"""
 64        model = models.resnet18(pretrained=True)
 65        layer = model._modules.get("avgpool")
 66        model.eval()
 67        logger.debug("%s loaded successful!", model)
 68
 69        vec_list = []
 70
 71        for img_name in tqdm(os.listdir(self.config.input_path)):
 72            img_path = os.path.join(self.config.input_path, img_name)
 73            # check if the image has 3 channels
 74            img = Image.open(img_path)
 75            if len(img.split()) != 3:
 76                continue
 77            vec = self.get_vector(img_path, layer, model)
 78            vec = vec.numpy().tolist()
 79            vec_dict = {"Image_Name": img_name}
 80            for i, vector_ in enumerate(vec):
 81                vec_dict[f"Vector_{i}"] = vector_
 82            vec_list.append(vec_dict)
 83
 84        vec_df = pd.DataFrame(vec_list)
 85        vec_df.to_csv(f"{self.config.output_path}/vec_df.csv", index=False)
 86        return vec_df
 87
 88    def extract_dinov2_features(self):
 89        """extract features using DINOv2"""
 90        fd_model = fastdup.create(
 91            input_dir=self.config.input_path, work_dir=self.config.output_path
 92        )
 93        fd_model.run(model_path="dinov2s", cc_threshold=0.8)
 94
 95        filenames, feature_vec = fastdup.load_binary_feature(
 96            f"{self.config.output_path}/atrain_features.dat", d=384
 97        )
 98        logger.info("Embedding dimensions %s", feature_vec.shape)
 99        return filenames, feature_vec
100
101
102def plot_tsne_3d(feature_vec, connected_components_df, filenames, output_path):
103    """
104    Function to plot 3D t-SNE scatter plot and save it html file.
105    """
106    tsne = TSNE(n_components=3, verbose=1, perplexity=40, n_iter=300)
107    tsne_result = tsne.fit_transform(feature_vec)
108    component_id = connected_components_df["component_id"].to_numpy()
109    results = pd.DataFrame(
110        {
111            "tsne_1": tsne_result[:, 0],
112            "tsne_2": tsne_result[:, 1],
113            "tsne_3": tsne_result[:, 2],
114            "component": component_id,
115            "filename": filenames,
116        }
117    )
118
119    fig = px.scatter_3d(
120        results,
121        x="tsne_1",
122        y="tsne_2",
123        z="tsne_3",
124        color="component",
125        opacity=0.5,
126        hover_data=["component", "filename"],
127    )
128
129    fig.write_html(output_path)
130
131
132def main():
133    parser_ = argparse.ArgumentParser()
134    parser_.add_argument("--input_path", type=str)
135    parser_.add_argument("--output_path", type=str, default="./work_dir")
136    parser_.add_argument("--feature_extractor", type=str, default="dinov2")
137    args = parser_.parse_args()
138
139    feature_config = FeatureExtractorConfig(args.input_path, args.output_path)
140    fe = FeatureExtractor(feature_config)
141
142    if args.feature_extractor == "dinov2":
143        filenames_, feature_vec_ = fe.extract_dinov2_features()
144
145        connected_components_df_ = pd.read_csv(
146            os.path.join(args.output_path, "connected_components.csv")
147        )
148        plot_tsne_3d(
149            feature_vec_,
150            connected_components_df_,
151            filenames_,
152            f"{args.output_path}/embeddings_dinvo2.html",
153        )
154    elif args.feature_extractor == "img2vec":
155        fe.extract_img_vector()
156
157
158if __name__ == "__main__":
159    main()
logger = <Logger e2evideo.feature_extractor (DEBUG)>
@dataclass
class FeatureExtractorConfig:
30@dataclass
31class FeatureExtractorConfig:
32    """Class to hold the configuration of the feature extractor."""
33
34    input_path: str
35    output_path: str

Class to hold the configuration of the feature extractor.

FeatureExtractorConfig(input_path: str, output_path: str)
input_path: str
output_path: str
class FeatureExtractor:
 38class FeatureExtractor:
 39    """feature extractor class"""
 40
 41    def __init__(self, config: FeatureExtractorConfig):
 42        self.config = config
 43        self.scaler = transforms.Resize((224, 224))
 44        self.normalize = transforms.Normalize(
 45            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 46        )
 47        self.to_tensor = transforms.ToTensor()
 48
 49    def get_vector(self, image_name, layer, model):
 50        """get vector from image"""
 51        img = Image.open(image_name)
 52        t_img = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0)
 53        my_embedding = torch.zeros(512)
 54
 55        def copy_data(m__, i__, output):
 56            my_embedding.copy_(output.data.squeeze())
 57
 58        temp_out = layer.register_forward_hook(copy_data)
 59        model(t_img)
 60        temp_out.remove()
 61        return my_embedding
 62
 63    def extract_img_vector(self):
 64        """extract image vector using resnet18"""
 65        model = models.resnet18(pretrained=True)
 66        layer = model._modules.get("avgpool")
 67        model.eval()
 68        logger.debug("%s loaded successful!", model)
 69
 70        vec_list = []
 71
 72        for img_name in tqdm(os.listdir(self.config.input_path)):
 73            img_path = os.path.join(self.config.input_path, img_name)
 74            # check if the image has 3 channels
 75            img = Image.open(img_path)
 76            if len(img.split()) != 3:
 77                continue
 78            vec = self.get_vector(img_path, layer, model)
 79            vec = vec.numpy().tolist()
 80            vec_dict = {"Image_Name": img_name}
 81            for i, vector_ in enumerate(vec):
 82                vec_dict[f"Vector_{i}"] = vector_
 83            vec_list.append(vec_dict)
 84
 85        vec_df = pd.DataFrame(vec_list)
 86        vec_df.to_csv(f"{self.config.output_path}/vec_df.csv", index=False)
 87        return vec_df
 88
 89    def extract_dinov2_features(self):
 90        """extract features using DINOv2"""
 91        fd_model = fastdup.create(
 92            input_dir=self.config.input_path, work_dir=self.config.output_path
 93        )
 94        fd_model.run(model_path="dinov2s", cc_threshold=0.8)
 95
 96        filenames, feature_vec = fastdup.load_binary_feature(
 97            f"{self.config.output_path}/atrain_features.dat", d=384
 98        )
 99        logger.info("Embedding dimensions %s", feature_vec.shape)
100        return filenames, feature_vec

feature extractor class

FeatureExtractor(config: FeatureExtractorConfig)
41    def __init__(self, config: FeatureExtractorConfig):
42        self.config = config
43        self.scaler = transforms.Resize((224, 224))
44        self.normalize = transforms.Normalize(
45            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
46        )
47        self.to_tensor = transforms.ToTensor()
config
scaler
normalize
to_tensor
def get_vector(self, image_name, layer, model):
49    def get_vector(self, image_name, layer, model):
50        """get vector from image"""
51        img = Image.open(image_name)
52        t_img = self.normalize(self.to_tensor(self.scaler(img))).unsqueeze(0)
53        my_embedding = torch.zeros(512)
54
55        def copy_data(m__, i__, output):
56            my_embedding.copy_(output.data.squeeze())
57
58        temp_out = layer.register_forward_hook(copy_data)
59        model(t_img)
60        temp_out.remove()
61        return my_embedding

get vector from image

def extract_img_vector(self):
63    def extract_img_vector(self):
64        """extract image vector using resnet18"""
65        model = models.resnet18(pretrained=True)
66        layer = model._modules.get("avgpool")
67        model.eval()
68        logger.debug("%s loaded successful!", model)
69
70        vec_list = []
71
72        for img_name in tqdm(os.listdir(self.config.input_path)):
73            img_path = os.path.join(self.config.input_path, img_name)
74            # check if the image has 3 channels
75            img = Image.open(img_path)
76            if len(img.split()) != 3:
77                continue
78            vec = self.get_vector(img_path, layer, model)
79            vec = vec.numpy().tolist()
80            vec_dict = {"Image_Name": img_name}
81            for i, vector_ in enumerate(vec):
82                vec_dict[f"Vector_{i}"] = vector_
83            vec_list.append(vec_dict)
84
85        vec_df = pd.DataFrame(vec_list)
86        vec_df.to_csv(f"{self.config.output_path}/vec_df.csv", index=False)
87        return vec_df

extract image vector using resnet18

def extract_dinov2_features(self):
 89    def extract_dinov2_features(self):
 90        """extract features using DINOv2"""
 91        fd_model = fastdup.create(
 92            input_dir=self.config.input_path, work_dir=self.config.output_path
 93        )
 94        fd_model.run(model_path="dinov2s", cc_threshold=0.8)
 95
 96        filenames, feature_vec = fastdup.load_binary_feature(
 97            f"{self.config.output_path}/atrain_features.dat", d=384
 98        )
 99        logger.info("Embedding dimensions %s", feature_vec.shape)
100        return filenames, feature_vec

extract features using DINOv2

def plot_tsne_3d(feature_vec, connected_components_df, filenames, output_path):
103def plot_tsne_3d(feature_vec, connected_components_df, filenames, output_path):
104    """
105    Function to plot 3D t-SNE scatter plot and save it html file.
106    """
107    tsne = TSNE(n_components=3, verbose=1, perplexity=40, n_iter=300)
108    tsne_result = tsne.fit_transform(feature_vec)
109    component_id = connected_components_df["component_id"].to_numpy()
110    results = pd.DataFrame(
111        {
112            "tsne_1": tsne_result[:, 0],
113            "tsne_2": tsne_result[:, 1],
114            "tsne_3": tsne_result[:, 2],
115            "component": component_id,
116            "filename": filenames,
117        }
118    )
119
120    fig = px.scatter_3d(
121        results,
122        x="tsne_1",
123        y="tsne_2",
124        z="tsne_3",
125        color="component",
126        opacity=0.5,
127        hover_data=["component", "filename"],
128    )
129
130    fig.write_html(output_path)

Function to plot 3D t-SNE scatter plot and save it html file.

def main():
133def main():
134    parser_ = argparse.ArgumentParser()
135    parser_.add_argument("--input_path", type=str)
136    parser_.add_argument("--output_path", type=str, default="./work_dir")
137    parser_.add_argument("--feature_extractor", type=str, default="dinov2")
138    args = parser_.parse_args()
139
140    feature_config = FeatureExtractorConfig(args.input_path, args.output_path)
141    fe = FeatureExtractor(feature_config)
142
143    if args.feature_extractor == "dinov2":
144        filenames_, feature_vec_ = fe.extract_dinov2_features()
145
146        connected_components_df_ = pd.read_csv(
147            os.path.join(args.output_path, "connected_components.csv")
148        )
149        plot_tsne_3d(
150            feature_vec_,
151            connected_components_df_,
152            filenames_,
153            f"{args.output_path}/embeddings_dinvo2.html",
154        )
155    elif args.feature_extractor == "img2vec":
156        fe.extract_img_vector()