Address Notion, Image tech debt in indexing code path (#687)

* Add support for using OAuth2.0 in the Notion integration
* Add notion to the admin page
* Remove unnecessary content_index and image search/setup references
* Trigger background job to start indexing Notion after user configures it
* Add a log line when a new Notion integration is setup
* Fix references to the configure_content methods
This commit is contained in:
sabaimran 2024-04-04 23:40:03 -07:00 committed by GitHub
parent 69dee75c34
commit f57f9f672d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 145 additions and 599 deletions

View file

@ -34,7 +34,7 @@ from khoj.database.adapters import (
) )
from khoj.database.models import ClientApplication, KhojUser, Subscription from khoj.database.models import ClientApplication, KhojUser, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, configure_search, load_content from khoj.routers.indexer import configure_content, configure_search
from khoj.routers.twilio import is_twilio_enabled from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
@ -245,16 +245,12 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
if state.search_models: if state.search_models:
try: try:
if init: if init:
logger.info("📬 Initializing content index...") logger.info("📬 No-op...")
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
else: else:
logger.info("📬 Updating content index...") logger.info("📬 Updating content index...")
all_files = collect_files(user=user) all_files = collect_files(user=user)
state.content_index, status = configure_content( status = configure_content(
state.content_index,
state.config.content_type,
all_files, all_files,
state.search_models,
regenerate, regenerate,
search_type, search_type,
user=user, user=user,
@ -272,6 +268,7 @@ def configure_routes(app):
from khoj.routers.api_chat import api_chat from khoj.routers.api_chat import api_chat
from khoj.routers.api_config import api_config from khoj.routers.api_config import api_config
from khoj.routers.indexer import indexer from khoj.routers.indexer import indexer
from khoj.routers.notion import notion_router
from khoj.routers.web_client import web_client from khoj.routers.web_client import web_client
app.include_router(api, prefix="/api") app.include_router(api, prefix="/api")
@ -279,6 +276,7 @@ def configure_routes(app):
app.include_router(api_agents, prefix="/api/agents") app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_config, prefix="/api/config") app.include_router(api_config, prefix="/api/config")
app.include_router(indexer, prefix="/api/v1/index") app.include_router(indexer, prefix="/api/v1/index")
app.include_router(notion_router, prefix="/api/notion")
app.include_router(web_client) app.include_router(web_client)
if not state.anonymous_mode: if not state.anonymous_mode:
@ -311,13 +309,9 @@ def update_search_index():
logger.info("📬 Updating content index via Scheduler") logger.info("📬 Updating content index via Scheduler")
for user in get_all_users(): for user in get_all_users():
all_files = collect_files(user=user) all_files = collect_files(user=user)
state.content_index, success = configure_content( success = configure_content(all_files, user=user)
state.content_index, state.config.content_type, all_files, state.search_models, user=user
)
all_files = collect_files(user=None) all_files = collect_files(user=None)
state.content_index, success = configure_content( success = configure_content(all_files, user=None)
state.content_index, state.config.content_type, all_files, state.search_models, user=None
)
if not success: if not success:
raise RuntimeError("Failed to update content index") raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler") logger.info("📪 Content index updated via Scheduler")

View file

@ -259,6 +259,10 @@ async def get_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst() return await KhojUser.objects.filter(email=email).afirst()
async def aget_user_by_uuid(uuid: str) -> KhojUser:
return await KhojUser.objects.filter(uuid=uuid).afirst()
async def get_user_by_token(token: dict) -> KhojUser: async def get_user_by_token(token: dict) -> KhojUser:
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst() google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
if not google_user: if not google_user:

View file

@ -11,7 +11,9 @@ from khoj.database.models import (
ClientApplication, ClientApplication,
Conversation, Conversation,
Entry, Entry,
GithubConfig,
KhojUser, KhojUser,
NotionConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ReflectiveQuestion, ReflectiveQuestion,
@ -52,6 +54,8 @@ admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig) admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication) admin.site.register(ClientApplication)
admin.site.register(Agent) admin.site.register(Agent)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
@admin.register(Entry) @admin.register(Entry)

View file

@ -109,14 +109,23 @@
<p class="card-description">Sync your Notion pages</p> <p class="card-description">Sync your Notion pages</p>
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
{% if current_model_state.notion %}
<a class="card-button" href="/config/content-source/notion"> <a class="card-button" href="/config/content-source/notion">
{% if current_model_state.notion %}
Update Update
{% else %}
Setup
{% endif %}
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
{% elif notion_oauth_url %}
<a class="card-button" href="{{ notion_oauth_url }}">
Connect
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
{% else %}
<a class="card-button" href="/config/content-source/notion">
Setup
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
{% endif %}
<div id="clear-notion" <div id="clear-notion"
class="card-action-row" class="card-action-row"
style="display: {% if not current_model_state.notion %}none{% endif %}"> style="display: {% if not current_model_state.notion %}none{% endif %}">

View file

@ -5,11 +5,6 @@
<h2 class="section-title"> <h2 class="section-title">
<img class="card-icon" src="/static/assets/icons/notion.svg?v={{ khoj_version }}" alt="Notion"> <img class="card-icon" src="/static/assets/icons/notion.svg?v={{ khoj_version }}" alt="Notion">
<span class="card-title-text">Notion</span> <span class="card-title-text">Notion</span>
<div class="instructions">
<a href="https://docs.khoj.dev/#/notion_integration">ⓘ Help</a>
</div>
</h2>
<form>
<table> <table>
<tr> <tr>
<td> <td>
@ -22,7 +17,7 @@
</table> </table>
<div class="section"> <div class="section">
<div id="success" style="display: none;"></div> <div id="success" style="display: none;"></div>
<button id="submit" type="submit">Save</button> <button id="submit" type="submit">Sync to Update</button>
</div> </div>
</form> </form>
</div> </div>
@ -43,7 +38,7 @@
const submitButton = document.getElementById("submit"); const submitButton = document.getElementById("submit");
submitButton.disabled = true; submitButton.disabled = true;
submitButton.innerHTML = "Saving..."; submitButton.innerHTML = "Syncing...";
// Save Notion config on server // Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];

View file

@ -33,7 +33,7 @@ from khoj.routers.helpers import (
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import image_search, text_search from khoj.search_type import text_search
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import ConversationCommand, timer from khoj.utils.helpers import ConversationCommand, timer
@ -145,41 +145,17 @@ async def execute_search(
) )
] ]
elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images
search_futures += [
executor.submit(
image_search.query,
user_query,
results_count,
state.search_models.image_search,
state.content_index.image,
)
]
# Query across each requested content types in parallel # Query across each requested content types in parallel
with timer("Query took", logger): with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures): for search_future in concurrent.futures.as_completed(search_futures):
if t == SearchType.Image and state.content_index.image: hits = await search_future.result()
hits = await search_future.result() # Collate results
output_directory = constants.web_directory / "images" results += text_search.collate_results(hits, dedupe=dedupe)
# Collate results
results += image_search.collate_results(
hits,
image_names=state.content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=results_count,
)
else:
hits = await search_future.result()
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)
# Sort results across all content types and take top results # Sort results across all content types and take top results
results = text_search.rerank_and_sort_results( results = text_search.rerank_and_sort_results(
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
)[:results_count] )[:results_count]
# Cache results # Cache results
if user: if user:
@ -214,8 +190,6 @@ def update(
components = [] components = []
if state.search_models: if state.search_models:
components.append("Search models") components.append("Search models")
if state.content_index:
components.append("Content index")
components_msg = ", ".join(components) components_msg = ", ".join(components)
logger.info(f"📪 {components_msg} updated via API") logger.info(f"📪 {components_msg} updated via API")

View file

@ -38,7 +38,7 @@ logger = logging.getLogger(__name__)
def map_config_to_object(content_source: str): def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB: if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig return GithubConfig
if content_source == DbEntry.EntrySource.GITHUB: if content_source == DbEntry.EntrySource.NOTION:
return NotionConfig return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER: if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer" return "Computer"

View file

@ -14,9 +14,9 @@ from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.routers.helpers import ApiIndexedDataLimiter, update_telemetry_state from khoj.routers.helpers import ApiIndexedDataLimiter, update_telemetry_state
from khoj.search_type import image_search, text_search from khoj.search_type import text_search
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import ContentIndex, SearchModels from khoj.utils.config import SearchModels
from khoj.utils.helpers import LRU, get_file_type from khoj.utils.helpers import LRU, get_file_type
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
from khoj.utils.yaml import save_config_to_file_updated_state from khoj.utils.yaml import save_config_to_file_updated_state
@ -105,13 +105,10 @@ async def update(
# Extract required fields from config # Extract required fields from config
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
state.content_index, success = await loop.run_in_executor( success = await loop.run_in_executor(
None, None,
configure_content, configure_content,
state.content_index,
state.config.content_type,
indexer_input.model_dump(), indexer_input.model_dump(),
state.search_models,
force, force,
t, t,
False, False,
@ -159,23 +156,17 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
if search_config and search_config.image: if search_config and search_config.image:
logger.info("🔍 🌄 Setting up image search model") logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models return search_models
def configure_content( def configure_content(
content_index: Optional[ContentIndex],
content_config: Optional[ContentConfig],
files: Optional[dict[str, dict[str, str]]], files: Optional[dict[str, dict[str, str]]],
search_models: SearchModels,
regenerate: bool = False, regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All, t: Optional[state.SearchType] = state.SearchType.All,
full_corpus: bool = True, full_corpus: bool = True,
user: KhojUser = None, user: KhojUser = None,
) -> tuple[Optional[ContentIndex], bool]: ) -> bool:
content_index = ContentIndex()
success = True success = True
if t == None: if t == None:
t = state.SearchType.All t = state.SearchType.All
@ -185,7 +176,7 @@ def configure_content(
if t is not None and not t.value in [type.value for type in state.SearchType]: if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}") logger.warning(f"🚨 Invalid search type: {t}")
return None, False return False
search_type = t.value if t else None search_type = t.value if t else None
@ -193,7 +184,7 @@ def configure_content(
if files is None: if files is None:
logger.warning(f"🚨 No files to process for {search_type} search.") logger.warning(f"🚨 No files to process for {search_type} search.")
return None, True return True
try: try:
# Initialize Org Notes Search # Initialize Org Notes Search
@ -266,24 +257,6 @@ def configure_content(
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True) logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
success = False success = False
try:
# Initialize Image Search
if (
(search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value)
and content_config
and content_config.image
and search_models.image_search
):
logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try: try:
if no_documents: if no_documents:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
@ -330,23 +303,4 @@ def configure_content(
if user: if user:
state.query_cache[user.uuid] = LRU() state.query_cache[user.uuid] = LRU()
return content_index, success return success
def load_content(
content_config: Optional[ContentConfig],
content_index: Optional[ContentIndex],
search_models: SearchModels,
):
if content_config is None:
logger.debug("🚨 No Content configuration available.")
return None
if content_index is None:
content_index = ContentIndex()
if content_config.image:
logger.info("🌄 Loading images")
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
return content_index

View file

@ -0,0 +1,89 @@
import asyncio
import base64
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor
import requests
from fastapi import APIRouter, BackgroundTasks, Request, Response
from starlette.responses import RedirectResponse
from khoj.database.adapters import aget_user_by_uuid
from khoj.database.models import KhojUser, NotionConfig
from khoj.routers.indexer import configure_content
from khoj.utils.state import SearchType
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI")
notion_router = APIRouter()
executor = ThreadPoolExecutor()
logger = logging.getLogger(__name__)
def get_notion_auth_url(user: KhojUser):
if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI:
return None
return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}"
async def run_in_executor(func, *args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, func, *args)
@notion_router.get("/auth/callback")
async def notion_auth_callback(request: Request, background_tasks: BackgroundTasks):
code = request.query_params.get("code")
state = request.query_params.get("state")
if not code or not state:
return Response("Missing code or state", status_code=400)
user: KhojUser = await aget_user_by_uuid(state)
NotionConfig.objects.filter(user=user).adelete()
if not user:
raise Exception("User not found")
bearer_token = f"{NOTION_OAUTH_CLIENT_ID}:{NOTION_OAUTH_CLIENT_SECRET}"
base64_encoded_token = base64.b64encode(bearer_token.encode()).decode()
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Basic {base64_encoded_token}",
}
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": NOTION_REDIRECT_URI,
}
response = requests.post("https://api.notion.com/v1/oauth/token", data=json.dumps(data), headers=headers)
final_response = response.json()
access_token = final_response.get("access_token")
NotionConfig.objects.acreate(token=access_token, user=user)
owner = final_response.get("owner")
workspace_id = final_response.get("workspace_id")
workspace_name = final_response.get("workspace_name")
bot_id = final_response.get("bot_id")
logger.info(
f"Notion integration. Owner: {owner}, Workspace ID: {workspace_id}, Workspace Name: {workspace_name}, Bot ID: {bot_id}"
)
notion_redirect = str(request.app.url_path_for("notion_config_page"))
# Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user)
return RedirectResponse(notion_redirect)

View file

@ -19,6 +19,7 @@ from khoj.database.adapters import (
get_user_subscription_state, get_user_subscription_state,
) )
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.routers.notion import get_notion_auth_url
from khoj.routers.twilio import is_twilio_enabled from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
@ -244,6 +245,8 @@ def config_page(request: Request):
current_search_model_option = adapters.get_user_search_model_or_default(user) current_search_model_option = adapters.get_user_search_model_or_default(user)
notion_oauth_url = get_notion_auth_url(user)
return templates.TemplateResponse( return templates.TemplateResponse(
"config.html", "config.html",
context={ context={
@ -267,6 +270,7 @@ def config_page(request: Request):
"phone_number": user.phone_number, "phone_number": user.phone_number,
"is_phone_number_verified": user.verified_phone_number, "is_phone_number_verified": user.verified_phone_number,
"khoj_version": state.khoj_version, "khoj_version": state.khoj_version,
"notion_oauth_url": notion_oauth_url,
}, },
) )
@ -324,7 +328,7 @@ def notion_config_page(request: Request):
token=current_notion_config.token if current_notion_config else "", token=current_notion_config.token if current_notion_config else "",
) )
current_config = json.loads(current_config.json()) current_config = json.loads(current_config.model_dump_json())
return templates.TemplateResponse( return templates.TemplateResponse(
"content_source_notion_input.html", "content_source_notion_input.html",

View file

@ -1,272 +0,0 @@
import copy
import glob
import logging
import math
import pathlib
import shutil
from typing import List
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer, util
from tqdm import trange
from khoj.utils.config import ImageContent, ImageSearchModel
from khoj.utils.helpers import (
get_absolute_path,
get_from_dict,
load_model,
resolve_absolute_path,
timer,
)
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
# Create Logger
logger = logging.getLogger(__name__)
def initialize_model(search_config: ImageSearchConfig):
# Convert model directory to absolute path
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
# Create model directory if it doesn't exist
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
# Load the CLIP model
encoder = load_model(
model_dir=search_config.model_directory,
model_name=search_config.encoder,
model_type=search_config.encoder_type or SentenceTransformer,
)
return ImageSearchModel(encoder)
def extract_entries(image_directories):
image_names = []
for image_directory in image_directories:
image_directory = resolve_absolute_path(image_directory, strict=True)
image_names.extend(list(image_directory.glob("*.jpg")))
image_names.extend(list(image_directory.glob("*.jpeg")))
if logger.level >= logging.DEBUG:
image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories])
logger.debug(f"Found {len(image_names)} images in {image_directory_names}")
return sorted(image_names)
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(
image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate
)
return image_embeddings, image_metadata_embeddings
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False):
# Load pre-computed image embeddings from file if exists
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
image_embeddings = torch.load(embeddings_file)
logger.debug(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}")
# Else compute the image embeddings from scratch, which can take a while
else:
image_embeddings = []
for index in trange(0, len(image_names), batch_size):
images = []
for image_name in image_names[index : index + batch_size]:
image = Image.open(image_name)
# Resize images to max width of 640px for faster processing
image.thumbnail((640, image.height))
images += [image]
image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size))
# Create directory for embeddings file, if it doesn't exist
embeddings_file.parent.mkdir(parents=True, exist_ok=True)
# Save computed image embeddings to file
torch.save(image_embeddings, embeddings_file)
logger.info(f"📩 Saved computed image embeddings to {embeddings_file}")
return image_embeddings
def compute_metadata_embeddings(
image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0
):
image_metadata_embeddings = None
# Load pre-computed image metadata embedding file if exists
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
logger.debug(f"Loaded image metadata embeddings from {embeddings_file}_metadata")
# Else compute the image metadata embeddings from scratch, which can take a while
if use_xmp_metadata and image_metadata_embeddings is None:
image_metadata_embeddings = []
for index in trange(0, len(image_names), batch_size):
image_metadata = [
extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size]
]
try:
image_metadata_embeddings += encoder.encode(
image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size)
)
except RuntimeError as e:
logger.error(
f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}"
)
continue
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
logger.info(f"📩 Saved computed image metadata embeddings to {embeddings_file}_metadata")
return image_metadata_embeddings
def extract_metadata(image_name):
image_xmp_metadata = Image.open(image_name).getxmp()
image_description = get_from_dict(
image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text"
)
image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li")
image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject])
image_processed_metadata = image_description
if len(image_metadata_subjects) > 0:
image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
logger.debug(f"{image_name}:\t{image_processed_metadata}")
return image_processed_metadata
async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
query = copy.deepcopy(Image.open(query_imagepath))
query.thumbnail((640, query.height)) # scale down image for faster processing
logger.info(f"🔎 Find Images by Image: {query_imagepath}")
else:
# Truncate words in query to stay below max_tokens supported by ML model
max_words = 20
query = " ".join(raw_query.split()[:max_words])
logger.info(f"🔎 Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string)
with timer("Query Encode Time", logger):
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {
# Map scores to distance metric by multiplying by -1
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if content.image_metadata_embeddings:
with timer("Metadata Search Time", logger):
metadata_hits = {
result["corpus_id"]: result["score"]
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
}
# Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items():
scaling_factor = 0.33
if "corpus_id" in image_hits:
image_hits[corpus_id].update(
{
"metadata_score": score,
"score": image_hits[corpus_id].get("score", 0) + scaling_factor * score,
}
)
else:
image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score}
# Reformat results in original form from sentence transformer semantic_search()
hits = [
{
"corpus_id": corpus_id,
"score": scores["score"],
"image_score": scores.get("image_score", 0),
"metadata_score": scores.get("metadata_score", 0),
}
for corpus_id, scores in image_hits.items()
]
# Filter results by score threshold
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
# Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> List[SearchResponse]:
results: List[SearchResponse] = []
for index, hit in enumerate(hits[:count]):
source_path = image_names[hit["corpus_id"]]
target_image_name = f"{index}{source_path.suffix}"
target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}")
# Create output directory, if it doesn't exist
if not target_path.parent.exists():
target_path.parent.mkdir(exist_ok=True)
# Copy the image to the output directory
shutil.copy(source_path, target_path)
# Add the image metadata to the results
results += [
SearchResponse.model_validate(
{
"entry": f"{image_files_url}/{target_image_name}",
"score": f"{hit['score']:.9f}",
"additional": {
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
},
"corpus_id": str(hit["corpus_id"]),
}
)
]
return results
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
# Extract Entries
absolute_image_files, filtered_image_files = set(), set()
if config.input_directories:
image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories]
absolute_image_files = set(extract_entries(image_directories))
if config.input_filter:
filtered_image_files = {
filtered_file
for input_filter in config.input_filter
for filtered_file in glob.glob(get_absolute_path(input_filter))
}
all_image_files = sorted(list(absolute_image_files | filtered_image_files))
# Compute or Load Embeddings
embeddings_file = resolve_absolute_path(config.embeddings_file)
image_embeddings, image_metadata_embeddings = compute_embeddings(
all_image_files,
encoder,
embeddings_file,
batch_size=config.batch_size,
regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata,
)
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)

View file

@ -58,15 +58,9 @@ class ImageSearchModel:
image_encoder: BaseEncoder image_encoder: BaseEncoder
@dataclass
class ContentIndex:
image: Optional[ImageContent] = None
@dataclass @dataclass
class SearchModels: class SearchModels:
text_search: Optional[TextSearchModel] = None text_search: Optional[TextSearchModel] = None
image_search: Optional[ImageSearchModel] = None
@dataclass @dataclass

View file

@ -9,7 +9,7 @@ from whisper import Whisper
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, OfflineChatProcessorModel, SearchModels from khoj.utils.config import OfflineChatProcessorModel, SearchModels
from khoj.utils.helpers import LRU, get_device from khoj.utils.helpers import LRU, get_device
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
@ -18,7 +18,6 @@ config = FullConfig()
search_models = SearchModels() search_models = SearchModels()
embeddings_model: Dict[str, EmbeddingsModel] = None embeddings_model: Dict[str, EmbeddingsModel] = None
cross_encoder_model: Dict[str, CrossEncoderModel] = None cross_encoder_model: Dict[str, CrossEncoderModel] = None
content_index = ContentIndex()
openai_client: OpenAI = None openai_client: OpenAI = None
offline_chat_processor_config: OfflineChatProcessorModel = None offline_chat_processor_config: OfflineChatProcessorModel = None
whisper_model: Whisper = None whisper_model: Whisper = None

View file

@ -25,7 +25,7 @@ from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content from khoj.routers.indexer import configure_content
from khoj.search_type import image_search, text_search from khoj.search_type import text_search
from khoj.utils import fs_syncer, state from khoj.utils import fs_syncer, state
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels
from khoj.utils.constants import web_directory from khoj.utils.constants import web_directory
@ -207,7 +207,6 @@ def openai_agent():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def search_models(search_config: SearchConfig): def search_models(search_config: SearchConfig):
search_models = SearchModels() search_models = SearchModels()
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models return search_models
@ -232,8 +231,6 @@ def content_config(tmp_path_factory, search_models: SearchModels, default_user:
use_xmp_metadata=False, use_xmp_metadata=False,
) )
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
LocalOrgConfig.objects.create( LocalOrgConfig.objects.create(
input_files=None, input_files=None,
input_filter=["tests/data/org/*.org"], input_filter=["tests/data/org/*.org"],
@ -305,9 +302,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Index Markdown Content for Search # Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=user) all_files = fs_syncer.collect_files(user=user)
state.content_index, _ = configure_content( success = configure_content(all_files, user=user)
state.content_index, state.config.content_type, all_files, state.search_models, user=user
)
# Initialize Processor from Config # Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"): if os.getenv("OPENAI_API_KEY"):
@ -349,16 +344,12 @@ def client(
state.cross_encoder_model["default"] = CrossEncoderModel() state.cross_encoder_model["default"] = CrossEncoderModel()
# These lines help us Mock the Search models for these search types # These lines help us Mock the Search models for these search types
state.search_models.image_search = image_search.initialize_model(search_config.image)
text_search.setup( text_search.setup(
OrgToEntries, OrgToEntries,
get_sample_data("org"), get_sample_data("org"),
regenerate=False, regenerate=False,
user=api_user.user, user=api_user.user,
) )
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
text_search.setup( text_search.setup(
PlaintextToEntries, PlaintextToEntries,
get_sample_data("plaintext"), get_sample_data("plaintext"),
@ -388,9 +379,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
) )
all_files = fs_syncer.collect_files(user=default_user2) all_files = fs_syncer.collect_files(user=default_user2)
configure_content( configure_content(all_files, user=default_user2)
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
# Initialize Processor from Config # Initialize Processor from Config
OfflineChatProcessorConversationConfigFactory(enabled=True) OfflineChatProcessorConversationConfigFactory(enabled=True)

View file

@ -12,10 +12,9 @@ from khoj.configure import configure_routes, configure_search_types
from khoj.database.adapters import EntryAdapters from khoj.database.adapters import EntryAdapters
from khoj.database.models import KhojApiUser, KhojUser from khoj.database.models import KhojApiUser, KhojUser
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.search_type import image_search, text_search from khoj.search_type import text_search
from khoj.utils import state from khoj.utils import state
from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.utils.state import config, content_index, search_models
# Test # Test
@ -298,34 +297,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
assert response.json() == ["all"] assert response.json() == ["all"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
query_expected_image_pairs = [
("kitten", "kitten_park.jpg"),
("a horse and dog on a leash", "horse_dog.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg"),
]
for query, expected_image_name in query_expected_image_pairs:
# Act
response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers)
# Assert
assert response.status_code == 200
actual_image = Image.open(BytesIO(client.get(response.json()[0]["entry"]).content))
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
# Assert
assert expected_image == actual_image
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser): def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):

View file

@ -1,162 +0,0 @@
# Standard Modules
import logging
from pathlib import Path
import pytest
from PIL import Image
from khoj.search_type import image_search
from khoj.utils.config import SearchModels
from khoj.utils.constants import web_directory
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.utils.state import content_index, search_models
# Test
# ----------------------------------------------------------------------------------------------------
def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate image search embeddings during image setup
image_search_model = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=True
)
# Assert
assert len(image_search_model.image_names) == 3
assert len(image_search_model.image_embeddings) == 3
# ----------------------------------------------------------------------------------------------------
def test_image_metadata(content_config: ContentConfig):
"Verify XMP Description and Subjects Extracted from Image"
# Arrange
expected_metadata_image_name_pairs = [
(["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"),
(["Pasture.", "Horse", "Dog"], "horse_dog.jpg"),
(["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg"),
]
test_image_paths = [
Path(content_config.image.input_directories[0] / image_name[1])
for image_name in expected_metadata_image_name_pairs
]
for expected_metadata, test_image_path in zip(expected_metadata_image_name_pairs, test_image_paths):
# Act
actual_metadata = image_search.extract_metadata(test_image_path)
# Assert
for expected_snippet in expected_metadata[0]:
assert expected_snippet in actual_metadata
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory)
query_expected_image_pairs = [
("kitten", "kitten_park.jpg"),
("horse and dog in a farm", "horse_dog.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg"),
]
# Act
for query, expected_image_name in query_expected_image_pairs:
hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results(
hits,
content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=1,
)
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path)
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
# Assert
assert expected_image == actual_image
# Cleanup
# Delete the image files copied to results directory
actual_image_path.unlink()
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
max_words_supported = 10
query = " ".join(["hello"] * 100)
truncated_query = " ".join(["hello"] * max_words_supported)
# Act
try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
# Assert
except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
assert False, f"Query length exceeds max tokens supported by model\n"
assert f"Find Images by Text: {truncated_query}" in caplog.text, "Query not truncated"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory)
image_directory = content_config.image.input_directories[0]
query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
expected_image_path = f"{image_directory.joinpath('kitten_park.jpg')}"
# Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results(
hits,
content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=1,
)
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path)
expected_image = Image.open(expected_image_path)
# Assert
# Ensure file search triggered instead of query with file path as string
assert (
f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text
), "File search not triggered"
# Ensure the correct image is returned
assert expected_image == actual_image, "Incorrect image returned by file search"
# Cleanup
# Delete the image files copied to results directory
actual_image_path.unlink()