Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[POC] image based search (GSOC) #28009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ doc/gallery
doc/modules
doc/plot_types
doc/pyplots/tex_demo.png
doc/image_search/
doc/tutorials
doc/users/explain
lib/dateutil
Expand Down
3 changes: 3 additions & 0 deletions doc/_static/image_search.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.sphx-glr-imgsearch-resultelement{
display: none;
}
55 changes: 55 additions & 0 deletions doc/_static/image_search.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
function cosineSimilarity(vec1, vec2) {
const dotProduct = vec1.map((val, i) => val * vec2[i]).reduce((accum, curr) => accum + curr, 0);
const vec1Size = calcVectorSize(vec1);
const vec2Size = calcVectorSize(vec2);

return dotProduct / (vec1Size * vec2Size);
};

function calcVectorSize(vec) {
return Math.sqrt(vec.reduce((accum, curr) => accum + Math.pow(curr, 2), 0));
};


data = []
fetch('/_static/data.json')
.then( r => r.json() )
.then( d => { data = d } )


function handle_search() {
if( data.length == 0 ){
return;
}

const container = document.getElementById('sphx-glr-imgsearchresult-container')
container.innerHTML = ""

result = {}
for (const [key, value] of data ) {
// just find the similar images to the image at the beginning of data
cos = cosineSimilarity( data[0][1], value)
result[cos] = key
}

result = Object.keys(result).sort().reduce(
(obj, key) => {
obj[key] = result[key];
return obj;
},
{}
);


Object.entries(result).map( ([key, value], index) => {
if( index > 5 ) return
const id = value;
const elem = document.getElementById( id );
container.innerHTML += elem.innerHTML
} )

}

window.addEventListener( 'load', () => {
document.getElementById('sphx-glr-imgsearchbutton').addEventListener( 'click', handle_search )
} )
7 changes: 6 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _parse_skip_subdirs_file():
'sphinxext.mock_gui_toolkits',
'sphinxext.skip_deprecated',
'sphinxext.redirect_from',
'sphinxext.image_search',
'sphinx_copybutton',
'sphinx_design',
'sphinx_tags',
Expand Down Expand Up @@ -245,7 +246,7 @@ def matplotlib_reduced_latex_scraper(block, block_vars, gallery_conf,
return matplotlib_scraper(block, block_vars, gallery_conf, **kwargs)

gallery_dirs = [f'{ed}' for ed in
['gallery', 'tutorials', 'plot_types', 'users/explain']
['gallery', 'tutorials', 'plot_types', 'users/explain', 'image_search']
if f'{ed}/*' not in skip_subdirs]

example_dirs = []
Expand Down Expand Up @@ -477,6 +478,10 @@ def js_tag_with_cache_busting(js):
"mpl.css",
]

html_js_files = [
"image_search.js"
]

html_theme = "mpl_sphinx_theme"

# The name for this set of Sphinx documents. If None, it defaults to
Expand Down
257 changes: 257 additions & 0 deletions doc/sphinxext/image_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import os
import json
import pandas as pd
import numpy as np
import torch
import timm

from xml.sax.saxutils import escape
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torch.autograd import Variable

from sphinx.util import logging as sphinx_logging
from sphinx.errors import ExtensionError
from sphinx_gallery import gen_gallery
from sphinx_gallery.py_source_parser import split_code_and_text_blocks
from sphinx_gallery.gen_rst import extract_intro_and_title
from sphinx_gallery.backreferences import BACKREF_THUMBNAIL_TEMPLATE, _thumbnail_div, THUMBNAIL_PARENT_DIV, THUMBNAIL_PARENT_DIV_CLOSE
from sphinx_gallery.scrapers import _find_image_ext


logger = sphinx_logging.getLogger(__name__)



class SearchSetup:
""" A class for setting up and generating image vectors."""
def __init__(self, model_name='vgg19', pretrained=True):
"""
Parameters:
-----------
image_list : list
A list of images to be indexed and searched.
model_name : str, optional (default='vgg19')
The name of the pre-trained model to use for feature extraction.
pretrained : bool, optional (default=True)
Whether to use the pre-trained weights for the chosen model.
image_count : int, optional (default=None)
The number of images to be indexed and searched. If None, all images in the image_list will be used.
"""
self.model_name = model_name
self.pretrained = pretrained
self.image_data = pd.DataFrame()
self.d = None
self.queue = []

base_model = timm.create_model(self.model_name, pretrained=self.pretrained)
self.model = torch.nn.Sequential(*list(base_model.children())[:-1])
self.model.eval() # disables gradient computation


def _extract(self, img):
# Resize and convert the image
img = img.resize((224, 224))
img = img.convert('RGB')

# Preprocess the image
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224, 0.225]),
])
x = preprocess(img)
x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False)

# Extract features
feature = self.model(x)
feature = feature.data.numpy().flatten()
return feature / np.linalg.norm(feature)

def _get_feature(self, image_data: list):
self.image_data = image_data
features = []
for img_path in tqdm(self.image_data): # Iterate through images
# Extract features from the image
try:
feature = self._extract(img=Image.open(img_path))
print(feature)
features.append(feature)
except:
# If there is an error, append None to the feature list
features.append(None)
continue
return features

def add_image( self, thumbnail_id, image_path ):

self.queue.append( (thumbnail_id, image_path) )

def start_feature_extraction(self):
data_df = pd.DataFrame()

image_paths = list( map( lambda x:x[1], self.queue ) )
data_df['image_path'] = image_paths

features = self._get_feature(image_paths)
data_df['feature'] = features

data_df['thumbnail_id'] = list( map( lambda x:x[0], self.queue ) )

f = open('./_static/data.json', "w")
data_json = []
for i in range(len(data_df)):
data_json.append( [ data_df.loc[i, "thumbnail_id"], data_df.loc[i, "feature"].tolist() ] )

f.write(json.dumps(data_json))





# id="imgsearchref-{ref_name}" attribute is used by the js file later
# to programmatically hide or unhide thumbnails depending on search result
THUMBNAIL_TEMPLATE = """
.. raw:: html

<div class="sphx-glr-imgsearch-resultelement" id="imgsearchref-({ref_name})">
<div class="sphx-glr-thumbcontainer" tooltip="{snippet}">

.. only:: html

.. image:: /{thumbnail}
:alt:

:ref:`sphx_glr_{ref_name}`

.. raw:: html

<div class="sphx-glr-thumbnail-title">{title}</div>
</div>
</div>

"""

def _thumbnail_div(target_dir, src_dir, fname, snippet, title,
is_backref=False, check=True):
"""Generate RST to place a thumbnail in a gallery."""
thumb, _ = _find_image_ext(
os.path.join(target_dir, 'images', 'thumb',
'sphx_glr_%s_thumb.png' % fname[:-3]))
if check and not os.path.isfile(thumb):
# This means we have done something wrong in creating our thumbnail!
raise ExtensionError('Could not find internal Sphinx-Gallery thumbnail'
' file:\n%s' % (thumb,))
thumb = os.path.relpath(thumb, src_dir)
full_dir = os.path.relpath(target_dir, src_dir)

# Inside rst files forward slash defines paths
thumb = thumb.replace(os.sep, "/")

ref_name = os.path.join(full_dir, fname).replace(os.path.sep, '_')

template = BACKREF_THUMBNAIL_TEMPLATE if is_backref else THUMBNAIL_TEMPLATE
return ( ref_name, template.format(snippet=escape(snippet),
thumbnail=thumb, title=title, ref_name=ref_name) )


def generate_search_page(app):
"""
fetches all generated example images and adds links to them
in image_search/index.recommendations file
"""

gallery_conf = app.config.sphinx_gallery_conf

workdirs = gen_gallery._prepare_sphx_glr_dirs(gallery_conf,
app.builder.srcdir)

# imageSearch = ImageSearch()

src_dir = app.builder.srcdir
heading = "Image Search page"

image_search_path = os.path.join(src_dir, "image_search")

try:
os.mkdir(image_search_path)
except FileExistsError:
pass

search_setup = SearchSetup(model_name='vgg19', pretrained=True)
f = open(os.path.join(image_search_path, "index.recommendations"), "w")
f.write("\n\n" + heading + "\n")
f.write("^" * len(heading) + "\n")

# THUMBNAIL_PARENT_DIV can be modified to include search page speecific classnames
# for applying custom css
f.write(THUMBNAIL_PARENT_DIV)

for examples_dir, gallery_dir in workdirs:

examples_dir_abs_path = os.path.join(app.builder.srcdir, examples_dir)
gallery_dir_abs_path = os.path.join(app.builder.srcdir, gallery_dir)

# list all paths to subsection index files in this array
subsecs = gen_gallery.get_subsections(app.builder.srcdir,
examples_dir_abs_path, gallery_conf,
check_for_index=True)

directory_explore = [gallery_dir_abs_path] + subsecs
# logger.info("directory_explore")
# logger.info(directory_explore)

rst_content = ""

# loop through every subfolder
for subsection in directory_explore:
src_dir = os.path.join(gallery_dir_abs_path, subsection)

# get filenames of files with .py extension
listdir = [fname for fname in os.listdir(src_dir) if fname.endswith('.py')]
fullpath = [os.path.join(src_dir, file_name) for file_name in listdir]


for example in fullpath:

# the name of example is determined by the name of the .py file
example_name = example.split('/')[-1].replace('.py','')

# name of the example image is generated in the following format
# sphx_glr_{example_name}_{i where i = index of image in example}.png
# for now i just tested with the first example
example_image_path = os.path.join(os.path.join(src_dir, "images"), f"sphx_glr_{example_name}_001.png")

if os.path.isfile(example_image_path):

_, script = split_code_and_text_blocks(example)
intro, title = extract_intro_and_title(example, script[0][1])

# generates rst text with thumbnail and link to the example
# _thumbnail_div needs to be modified to keep the thumbnail and link
# hidden by default so that it can be made visible by the js script later
( ref_name, thumbnail_rst ) = _thumbnail_div(
src_dir,
app.builder.srcdir,
f"{example_name}.py",
intro,
title
)

search_setup.add_image( f"imgsearchref-({ref_name})", example_image_path )

# add the thumbnail
rst_content += thumbnail_rst

logger.info("STARTING FEATURE EXTRACTION")
search_setup.start_feature_extraction()
f.write(rst_content)
f.write(THUMBNAIL_PARENT_DIV_CLOSE)
# f.close()



def setup(app):

# need to decide on the priority
app.connect("builder-inited", generate_search_page, priority=100)
20 changes: 20 additions & 0 deletions galleries/image_search/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
..
This page is required by sphinx-gallery,
but is not rendered in the final doc build.

=================
Image Search Demo
=================

.. raw:: html

<button id="sphx-glr-imgsearchbutton" >show search results demo</button>
<div class="sphx-glr-thumbcontainer" id="sphx-glr-imgsearchresult-container" ></div>

.. include:: index.recommendations

.. raw:: html

<script src="../_static/image_search.js"></script>
<link rel="stylesheet" href="../_static/image_search.css" />