From f57f9f672da4c94ae5df8102936298aa2d0bf697 Mon Sep 17 00:00:00 2001
From: sabaimran <65192171+sabaimran@users.noreply.github.com>
Date: Thu, 4 Apr 2024 23:40:03 -0700
Subject: [PATCH] Address Notion, Image tech debt in indexing code path (#687)
* Add support for using OAuth2.0 in the Notion integration
* Add notion to the admin page
* Remove unnecessary content_index and image search/setup references
* Trigger background job to start indexing Notion after user configures it
* Add a log line when a new Notion integration is setup
* Fix references to the configure_content methods
---
src/khoj/configure.py | 20 +-
src/khoj/database/adapters/__init__.py | 4 +
src/khoj/database/admin.py | 4 +
src/khoj/interface/web/config.html | 17 +-
.../web/content_source_notion_input.html | 9 +-
src/khoj/routers/api.py | 42 +--
src/khoj/routers/api_config.py | 2 +-
src/khoj/routers/indexer.py | 60 +---
src/khoj/routers/notion.py | 89 ++++++
src/khoj/routers/web_client.py | 6 +-
src/khoj/search_type/image_search.py | 272 ------------------
src/khoj/utils/config.py | 6 -
src/khoj/utils/state.py | 3 +-
tests/conftest.py | 17 +-
tests/test_client.py | 31 +-
tests/test_image_search.py | 162 -----------
16 files changed, 145 insertions(+), 599 deletions(-)
create mode 100644 src/khoj/routers/notion.py
delete mode 100644 src/khoj/search_type/image_search.py
delete mode 100644 tests/test_image_search.py
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index 0adbe889..7768a014 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -34,7 +34,7 @@ from khoj.database.adapters import (
)
from khoj.database.models import ClientApplication, KhojUser, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
-from khoj.routers.indexer import configure_content, configure_search, load_content
+from khoj.routers.indexer import configure_content, configure_search
from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
from khoj.utils.config import SearchType
@@ -245,16 +245,12 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
if state.search_models:
try:
if init:
- logger.info("📬 Initializing content index...")
- state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
+ logger.info("📬 No-op...")
else:
logger.info("📬 Updating content index...")
all_files = collect_files(user=user)
- state.content_index, status = configure_content(
- state.content_index,
- state.config.content_type,
+ status = configure_content(
all_files,
- state.search_models,
regenerate,
search_type,
user=user,
@@ -272,6 +268,7 @@ def configure_routes(app):
from khoj.routers.api_chat import api_chat
from khoj.routers.api_config import api_config
from khoj.routers.indexer import indexer
+ from khoj.routers.notion import notion_router
from khoj.routers.web_client import web_client
app.include_router(api, prefix="/api")
@@ -279,6 +276,7 @@ def configure_routes(app):
app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_config, prefix="/api/config")
app.include_router(indexer, prefix="/api/v1/index")
+ app.include_router(notion_router, prefix="/api/notion")
app.include_router(web_client)
if not state.anonymous_mode:
@@ -311,13 +309,9 @@ def update_search_index():
logger.info("📬 Updating content index via Scheduler")
for user in get_all_users():
all_files = collect_files(user=user)
- state.content_index, success = configure_content(
- state.content_index, state.config.content_type, all_files, state.search_models, user=user
- )
+ success = configure_content(all_files, user=user)
all_files = collect_files(user=None)
- state.content_index, success = configure_content(
- state.content_index, state.config.content_type, all_files, state.search_models, user=None
- )
+ success = configure_content(all_files, user=None)
if not success:
raise RuntimeError("Failed to update content index")
logger.info("📪 Content index updated via Scheduler")
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 95b06a9e..72e2a360 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -259,6 +259,10 @@ async def get_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst()
+async def aget_user_by_uuid(uuid: str) -> KhojUser:
+ return await KhojUser.objects.filter(uuid=uuid).afirst()
+
+
async def get_user_by_token(token: dict) -> KhojUser:
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
if not google_user:
diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py
index cc1be7e4..0e440872 100644
--- a/src/khoj/database/admin.py
+++ b/src/khoj/database/admin.py
@@ -11,7 +11,9 @@ from khoj.database.models import (
ClientApplication,
Conversation,
Entry,
+ GithubConfig,
KhojUser,
+ NotionConfig,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
ReflectiveQuestion,
@@ -52,6 +54,8 @@ admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
admin.site.register(Agent)
+admin.site.register(GithubConfig)
+admin.site.register(NotionConfig)
@admin.register(Entry)
diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html
index 1f1d1f1e..2e310ce6 100644
--- a/src/khoj/interface/web/config.html
+++ b/src/khoj/interface/web/config.html
@@ -109,14 +109,23 @@
+ {% if current_model_state.notion %}
- {% if current_model_state.notion %}
Update
- {% else %}
- Setup
- {% endif %}
+ {% elif notion_oauth_url %}
+
+ Connect
+
+
+ {% else %}
+
+ Setup
+
+
+ {% endif %}
+
diff --git a/src/khoj/interface/web/content_source_notion_input.html b/src/khoj/interface/web/content_source_notion_input.html
index 220e8eb7..1303730c 100644
--- a/src/khoj/interface/web/content_source_notion_input.html
+++ b/src/khoj/interface/web/content_source_notion_input.html
@@ -5,11 +5,6 @@
Notion
-
-
-
@@ -43,7 +38,7 @@
const submitButton = document.getElementById("submit");
submitButton.disabled = true;
- submitButton.innerHTML = "Saving...";
+ submitButton.innerHTML = "Syncing...";
// Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 328f2548..357a04e5 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -33,7 +33,7 @@ from khoj.routers.helpers import (
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
-from khoj.search_type import image_search, text_search
+from khoj.search_type import text_search
from khoj.utils import constants, state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import ConversationCommand, timer
@@ -145,41 +145,17 @@ async def execute_search(
)
]
- elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
- # query images
- search_futures += [
- executor.submit(
- image_search.query,
- user_query,
- results_count,
- state.search_models.image_search,
- state.content_index.image,
- )
- ]
-
# Query across each requested content types in parallel
with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures):
- if t == SearchType.Image and state.content_index.image:
- hits = await search_future.result()
- output_directory = constants.web_directory / "images"
- # Collate results
- results += image_search.collate_results(
- hits,
- image_names=state.content_index.image.image_names,
- output_directory=output_directory,
- image_files_url="/static/images",
- count=results_count,
- )
- else:
- hits = await search_future.result()
- # Collate results
- results += text_search.collate_results(hits, dedupe=dedupe)
+ hits = await search_future.result()
+ # Collate results
+ results += text_search.collate_results(hits, dedupe=dedupe)
- # Sort results across all content types and take top results
- results = text_search.rerank_and_sort_results(
- results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
- )[:results_count]
+ # Sort results across all content types and take top results
+ results = text_search.rerank_and_sort_results(
+ results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
+ )[:results_count]
# Cache results
if user:
@@ -214,8 +190,6 @@ def update(
components = []
if state.search_models:
components.append("Search models")
- if state.content_index:
- components.append("Content index")
components_msg = ", ".join(components)
logger.info(f"📪 {components_msg} updated via API")
diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py
index d23f131b..e72ba3b8 100644
--- a/src/khoj/routers/api_config.py
+++ b/src/khoj/routers/api_config.py
@@ -38,7 +38,7 @@ logger = logging.getLogger(__name__)
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
- if content_source == DbEntry.EntrySource.GITHUB:
+ if content_source == DbEntry.EntrySource.NOTION:
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py
index cf8bc015..1bca6a25 100644
--- a/src/khoj/routers/indexer.py
+++ b/src/khoj/routers/indexer.py
@@ -14,9 +14,9 @@ from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.routers.helpers import ApiIndexedDataLimiter, update_telemetry_state
-from khoj.search_type import image_search, text_search
+from khoj.search_type import text_search
from khoj.utils import constants, state
-from khoj.utils.config import ContentIndex, SearchModels
+from khoj.utils.config import SearchModels
from khoj.utils.helpers import LRU, get_file_type
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
from khoj.utils.yaml import save_config_to_file_updated_state
@@ -105,13 +105,10 @@ async def update(
# Extract required fields from config
loop = asyncio.get_event_loop()
- state.content_index, success = await loop.run_in_executor(
+ success = await loop.run_in_executor(
None,
configure_content,
- state.content_index,
- state.config.content_type,
indexer_input.model_dump(),
- state.search_models,
force,
t,
False,
@@ -159,23 +156,17 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
if search_config and search_config.image:
logger.info("🔍 🌄 Setting up image search model")
- search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
def configure_content(
- content_index: Optional[ContentIndex],
- content_config: Optional[ContentConfig],
files: Optional[dict[str, dict[str, str]]],
- search_models: SearchModels,
regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All,
full_corpus: bool = True,
user: KhojUser = None,
-) -> tuple[Optional[ContentIndex], bool]:
- content_index = ContentIndex()
-
+) -> bool:
success = True
if t == None:
t = state.SearchType.All
@@ -185,7 +176,7 @@ def configure_content(
if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
- return None, False
+ return False
search_type = t.value if t else None
@@ -193,7 +184,7 @@ def configure_content(
if files is None:
logger.warning(f"🚨 No files to process for {search_type} search.")
- return None, True
+ return True
try:
# Initialize Org Notes Search
@@ -266,24 +257,6 @@ def configure_content(
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
success = False
- try:
- # Initialize Image Search
- if (
- (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value)
- and content_config
- and content_config.image
- and search_models.image_search
- ):
- logger.info("🌄 Setting up search for images")
- # Extract Entries, Generate Image Embeddings
- content_index.image = image_search.setup(
- content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
- )
-
- except Exception as e:
- logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
- success = False
-
try:
if no_documents:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
@@ -330,23 +303,4 @@ def configure_content(
if user:
state.query_cache[user.uuid] = LRU()
- return content_index, success
-
-
-def load_content(
- content_config: Optional[ContentConfig],
- content_index: Optional[ContentIndex],
- search_models: SearchModels,
-):
- if content_config is None:
- logger.debug("🚨 No Content configuration available.")
- return None
- if content_index is None:
- content_index = ContentIndex()
-
- if content_config.image:
- logger.info("🌄 Loading images")
- content_index.image = image_search.setup(
- content_config.image, search_models.image_search.image_encoder, regenerate=False
- )
- return content_index
+ return success
diff --git a/src/khoj/routers/notion.py b/src/khoj/routers/notion.py
new file mode 100644
index 00000000..d8fd0c26
--- /dev/null
+++ b/src/khoj/routers/notion.py
@@ -0,0 +1,89 @@
+import asyncio
+import base64
+import json
+import logging
+import os
+from concurrent.futures import ThreadPoolExecutor
+
+import requests
+from fastapi import APIRouter, BackgroundTasks, Request, Response
+from starlette.responses import RedirectResponse
+
+from khoj.database.adapters import aget_user_by_uuid
+from khoj.database.models import KhojUser, NotionConfig
+from khoj.routers.indexer import configure_content
+from khoj.utils.state import SearchType
+
+NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
+NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
+NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI")
+
+notion_router = APIRouter()
+
+executor = ThreadPoolExecutor()
+
+logger = logging.getLogger(__name__)
+
+
+def get_notion_auth_url(user: KhojUser):
+ if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI:
+ return None
+ return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}"
+
+
+async def run_in_executor(func, *args):
+ loop = asyncio.get_event_loop()
+ return await loop.run_in_executor(executor, func, *args)
+
+
+@notion_router.get("/auth/callback")
+async def notion_auth_callback(request: Request, background_tasks: BackgroundTasks):
+ code = request.query_params.get("code")
+ state = request.query_params.get("state")
+ if not code or not state:
+ return Response("Missing code or state", status_code=400)
+
+ user: KhojUser = await aget_user_by_uuid(state)
+
+ NotionConfig.objects.filter(user=user).adelete()
+
+ if not user:
+ raise Exception("User not found")
+
+ bearer_token = f"{NOTION_OAUTH_CLIENT_ID}:{NOTION_OAUTH_CLIENT_SECRET}"
+ base64_encoded_token = base64.b64encode(bearer_token.encode()).decode()
+
+ headers = {
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "Authorization": f"Basic {base64_encoded_token}",
+ }
+
+ data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": NOTION_REDIRECT_URI,
+ }
+
+ response = requests.post("https://api.notion.com/v1/oauth/token", data=json.dumps(data), headers=headers)
+
+ final_response = response.json()
+
+ access_token = final_response.get("access_token")
+ NotionConfig.objects.acreate(token=access_token, user=user)
+
+ owner = final_response.get("owner")
+ workspace_id = final_response.get("workspace_id")
+ workspace_name = final_response.get("workspace_name")
+ bot_id = final_response.get("bot_id")
+
+ logger.info(
+ f"Notion integration. Owner: {owner}, Workspace ID: {workspace_id}, Workspace Name: {workspace_name}, Bot ID: {bot_id}"
+ )
+
+ notion_redirect = str(request.app.url_path_for("notion_config_page"))
+
+ # Trigger an async job to configure_content. Let it run without blocking the response.
+ background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user)
+
+ return RedirectResponse(notion_redirect)
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index cb03cb89..4df0e3e1 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -19,6 +19,7 @@ from khoj.database.adapters import (
get_user_subscription_state,
)
from khoj.database.models import KhojUser
+from khoj.routers.notion import get_notion_auth_url
from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
from khoj.utils.rawconfig import (
@@ -244,6 +245,8 @@ def config_page(request: Request):
current_search_model_option = adapters.get_user_search_model_or_default(user)
+ notion_oauth_url = get_notion_auth_url(user)
+
return templates.TemplateResponse(
"config.html",
context={
@@ -267,6 +270,7 @@ def config_page(request: Request):
"phone_number": user.phone_number,
"is_phone_number_verified": user.verified_phone_number,
"khoj_version": state.khoj_version,
+ "notion_oauth_url": notion_oauth_url,
},
)
@@ -324,7 +328,7 @@ def notion_config_page(request: Request):
token=current_notion_config.token if current_notion_config else "",
)
- current_config = json.loads(current_config.json())
+ current_config = json.loads(current_config.model_dump_json())
return templates.TemplateResponse(
"content_source_notion_input.html",
diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py
deleted file mode 100644
index 76a72538..00000000
--- a/src/khoj/search_type/image_search.py
+++ /dev/null
@@ -1,272 +0,0 @@
-import copy
-import glob
-import logging
-import math
-import pathlib
-import shutil
-from typing import List
-
-import torch
-from PIL import Image
-from sentence_transformers import SentenceTransformer, util
-from tqdm import trange
-
-from khoj.utils.config import ImageContent, ImageSearchModel
-from khoj.utils.helpers import (
- get_absolute_path,
- get_from_dict,
- load_model,
- resolve_absolute_path,
- timer,
-)
-from khoj.utils.models import BaseEncoder
-from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
-
-# Create Logger
-logger = logging.getLogger(__name__)
-
-
-def initialize_model(search_config: ImageSearchConfig):
- # Convert model directory to absolute path
- search_config.model_directory = resolve_absolute_path(search_config.model_directory)
-
- # Create model directory if it doesn't exist
- search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
-
- # Load the CLIP model
- encoder = load_model(
- model_dir=search_config.model_directory,
- model_name=search_config.encoder,
- model_type=search_config.encoder_type or SentenceTransformer,
- )
-
- return ImageSearchModel(encoder)
-
-
-def extract_entries(image_directories):
- image_names = []
- for image_directory in image_directories:
- image_directory = resolve_absolute_path(image_directory, strict=True)
- image_names.extend(list(image_directory.glob("*.jpg")))
- image_names.extend(list(image_directory.glob("*.jpeg")))
-
- if logger.level >= logging.DEBUG:
- image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories])
- logger.debug(f"Found {len(image_names)} images in {image_directory_names}")
- return sorted(image_names)
-
-
-def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False):
- "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
-
- image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
- image_metadata_embeddings = compute_metadata_embeddings(
- image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate
- )
-
- return image_embeddings, image_metadata_embeddings
-
-
-def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False):
- # Load pre-computed image embeddings from file if exists
- if resolve_absolute_path(embeddings_file).exists() and not regenerate:
- image_embeddings = torch.load(embeddings_file)
- logger.debug(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}")
- # Else compute the image embeddings from scratch, which can take a while
- else:
- image_embeddings = []
- for index in trange(0, len(image_names), batch_size):
- images = []
- for image_name in image_names[index : index + batch_size]:
- image = Image.open(image_name)
- # Resize images to max width of 640px for faster processing
- image.thumbnail((640, image.height))
- images += [image]
- image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size))
-
- # Create directory for embeddings file, if it doesn't exist
- embeddings_file.parent.mkdir(parents=True, exist_ok=True)
-
- # Save computed image embeddings to file
- torch.save(image_embeddings, embeddings_file)
- logger.info(f"📩 Saved computed image embeddings to {embeddings_file}")
-
- return image_embeddings
-
-
-def compute_metadata_embeddings(
- image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0
-):
- image_metadata_embeddings = None
-
- # Load pre-computed image metadata embedding file if exists
- if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
- image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
- logger.debug(f"Loaded image metadata embeddings from {embeddings_file}_metadata")
-
- # Else compute the image metadata embeddings from scratch, which can take a while
- if use_xmp_metadata and image_metadata_embeddings is None:
- image_metadata_embeddings = []
- for index in trange(0, len(image_names), batch_size):
- image_metadata = [
- extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size]
- ]
- try:
- image_metadata_embeddings += encoder.encode(
- image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size)
- )
- except RuntimeError as e:
- logger.error(
- f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}"
- )
- continue
- torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
- logger.info(f"📩 Saved computed image metadata embeddings to {embeddings_file}_metadata")
-
- return image_metadata_embeddings
-
-
-def extract_metadata(image_name):
- image_xmp_metadata = Image.open(image_name).getxmp()
- image_description = get_from_dict(
- image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text"
- )
- image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li")
- image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject])
-
- image_processed_metadata = image_description
- if len(image_metadata_subjects) > 0:
- image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
-
- logger.debug(f"{image_name}:\t{image_processed_metadata}")
-
- return image_processed_metadata
-
-
-async def query(
- raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
-):
- # Set query to image content if query is of form file:/path/to/file.png
- if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
- query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
- query = copy.deepcopy(Image.open(query_imagepath))
- query.thumbnail((640, query.height)) # scale down image for faster processing
- logger.info(f"🔎 Find Images by Image: {query_imagepath}")
- else:
- # Truncate words in query to stay below max_tokens supported by ML model
- max_words = 20
- query = " ".join(raw_query.split()[:max_words])
- logger.info(f"🔎 Find Images by Text: {query}")
-
- # Now we encode the query (which can either be an image or a text string)
- with timer("Query Encode Time", logger):
- query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
-
- # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
- with timer("Search Time", logger):
- image_hits = {
- # Map scores to distance metric by multiplying by -1
- result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
- for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
- }
-
- # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
- if content.image_metadata_embeddings:
- with timer("Metadata Search Time", logger):
- metadata_hits = {
- result["corpus_id"]: result["score"]
- for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
- }
-
- # Sum metadata, image scores of the highest ranked images
- for corpus_id, score in metadata_hits.items():
- scaling_factor = 0.33
- if "corpus_id" in image_hits:
- image_hits[corpus_id].update(
- {
- "metadata_score": score,
- "score": image_hits[corpus_id].get("score", 0) + scaling_factor * score,
- }
- )
- else:
- image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score}
-
- # Reformat results in original form from sentence transformer semantic_search()
- hits = [
- {
- "corpus_id": corpus_id,
- "score": scores["score"],
- "image_score": scores.get("image_score", 0),
- "metadata_score": scores.get("metadata_score", 0),
- }
- for corpus_id, scores in image_hits.items()
- ]
-
- # Filter results by score threshold
- hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
-
- # Sort the images based on their combined metadata, image scores
- return sorted(hits, key=lambda hit: hit["score"], reverse=True)
-
-
-def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> List[SearchResponse]:
- results: List[SearchResponse] = []
-
- for index, hit in enumerate(hits[:count]):
- source_path = image_names[hit["corpus_id"]]
-
- target_image_name = f"{index}{source_path.suffix}"
- target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}")
-
- # Create output directory, if it doesn't exist
- if not target_path.parent.exists():
- target_path.parent.mkdir(exist_ok=True)
-
- # Copy the image to the output directory
- shutil.copy(source_path, target_path)
-
- # Add the image metadata to the results
- results += [
- SearchResponse.model_validate(
- {
- "entry": f"{image_files_url}/{target_image_name}",
- "score": f"{hit['score']:.9f}",
- "additional": {
- "image_score": f"{hit['image_score']:.9f}",
- "metadata_score": f"{hit['metadata_score']:.9f}",
- },
- "corpus_id": str(hit["corpus_id"]),
- }
- )
- ]
-
- return results
-
-
-def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
- # Extract Entries
- absolute_image_files, filtered_image_files = set(), set()
- if config.input_directories:
- image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories]
- absolute_image_files = set(extract_entries(image_directories))
- if config.input_filter:
- filtered_image_files = {
- filtered_file
- for input_filter in config.input_filter
- for filtered_file in glob.glob(get_absolute_path(input_filter))
- }
-
- all_image_files = sorted(list(absolute_image_files | filtered_image_files))
-
- # Compute or Load Embeddings
- embeddings_file = resolve_absolute_path(config.embeddings_file)
- image_embeddings, image_metadata_embeddings = compute_embeddings(
- all_image_files,
- encoder,
- embeddings_file,
- batch_size=config.batch_size,
- regenerate=regenerate,
- use_xmp_metadata=config.use_xmp_metadata,
- )
-
- return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index d04e6698..3f95030f 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -58,15 +58,9 @@ class ImageSearchModel:
image_encoder: BaseEncoder
-@dataclass
-class ContentIndex:
- image: Optional[ImageContent] = None
-
-
@dataclass
class SearchModels:
text_search: Optional[TextSearchModel] = None
- image_search: Optional[ImageSearchModel] = None
@dataclass
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index 7edc3483..b431225d 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -9,7 +9,7 @@ from whisper import Whisper
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config
-from khoj.utils.config import ContentIndex, OfflineChatProcessorModel, SearchModels
+from khoj.utils.config import OfflineChatProcessorModel, SearchModels
from khoj.utils.helpers import LRU, get_device
from khoj.utils.rawconfig import FullConfig
@@ -18,7 +18,6 @@ config = FullConfig()
search_models = SearchModels()
embeddings_model: Dict[str, EmbeddingsModel] = None
cross_encoder_model: Dict[str, CrossEncoderModel] = None
-content_index = ContentIndex()
openai_client: OpenAI = None
offline_chat_processor_config: OfflineChatProcessorModel = None
whisper_model: Whisper = None
diff --git a/tests/conftest.py b/tests/conftest.py
index 8533ba4f..1c84c7d1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -25,7 +25,7 @@ from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content
-from khoj.search_type import image_search, text_search
+from khoj.search_type import text_search
from khoj.utils import fs_syncer, state
from khoj.utils.config import SearchModels
from khoj.utils.constants import web_directory
@@ -207,7 +207,6 @@ def openai_agent():
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()
- search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@@ -232,8 +231,6 @@ def content_config(tmp_path_factory, search_models: SearchModels, default_user:
use_xmp_metadata=False,
)
- image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
-
LocalOrgConfig.objects.create(
input_files=None,
input_filter=["tests/data/org/*.org"],
@@ -305,9 +302,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Index Markdown Content for Search
all_files = fs_syncer.collect_files(user=user)
- state.content_index, _ = configure_content(
- state.content_index, state.config.content_type, all_files, state.search_models, user=user
- )
+ success = configure_content(all_files, user=user)
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
@@ -349,16 +344,12 @@ def client(
state.cross_encoder_model["default"] = CrossEncoderModel()
# These lines help us Mock the Search models for these search types
- state.search_models.image_search = image_search.initialize_model(search_config.image)
text_search.setup(
OrgToEntries,
get_sample_data("org"),
regenerate=False,
user=api_user.user,
)
- state.content_index.image = image_search.setup(
- content_config.image, state.search_models.image_search, regenerate=False
- )
text_search.setup(
PlaintextToEntries,
get_sample_data("plaintext"),
@@ -388,9 +379,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
)
all_files = fs_syncer.collect_files(user=default_user2)
- configure_content(
- state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
- )
+ configure_content(all_files, user=default_user2)
# Initialize Processor from Config
OfflineChatProcessorConversationConfigFactory(enabled=True)
diff --git a/tests/test_client.py b/tests/test_client.py
index abbd1fec..d1f07a4b 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -12,10 +12,9 @@ from khoj.configure import configure_routes, configure_search_types
from khoj.database.adapters import EntryAdapters
from khoj.database.models import KhojApiUser, KhojUser
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
-from khoj.search_type import image_search, text_search
+from khoj.search_type import text_search
from khoj.utils import state
from khoj.utils.rawconfig import ContentConfig, SearchConfig
-from khoj.utils.state import config, content_index, search_models
# Test
@@ -298,34 +297,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
assert response.json() == ["all"]
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.django_db(transaction=True)
-def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
- # Arrange
- headers = {"Authorization": "Bearer kk-secret"}
- search_models.image_search = image_search.initialize_model(search_config.image)
- content_index.image = image_search.setup(
- content_config.image, search_models.image_search.image_encoder, regenerate=False
- )
- query_expected_image_pairs = [
- ("kitten", "kitten_park.jpg"),
- ("a horse and dog on a leash", "horse_dog.jpg"),
- ("A guinea pig eating grass", "guineapig_grass.jpg"),
- ]
-
- for query, expected_image_name in query_expected_image_pairs:
- # Act
- response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers)
-
- # Assert
- assert response.status_code == 200
- actual_image = Image.open(BytesIO(client.get(response.json()[0]["entry"]).content))
- expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
-
- # Assert
- assert expected_image == actual_image
-
-
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
diff --git a/tests/test_image_search.py b/tests/test_image_search.py
deleted file mode 100644
index 5fe9ac7a..00000000
--- a/tests/test_image_search.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# Standard Modules
-import logging
-from pathlib import Path
-
-import pytest
-from PIL import Image
-
-from khoj.search_type import image_search
-from khoj.utils.config import SearchModels
-from khoj.utils.constants import web_directory
-from khoj.utils.helpers import resolve_absolute_path
-from khoj.utils.rawconfig import ContentConfig, SearchConfig
-from khoj.utils.state import content_index, search_models
-
-
-# Test
-# ----------------------------------------------------------------------------------------------------
-def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
- # Act
- # Regenerate image search embeddings during image setup
- image_search_model = image_search.setup(
- content_config.image, search_models.image_search.image_encoder, regenerate=True
- )
-
- # Assert
- assert len(image_search_model.image_names) == 3
- assert len(image_search_model.image_embeddings) == 3
-
-
-# ----------------------------------------------------------------------------------------------------
-def test_image_metadata(content_config: ContentConfig):
- "Verify XMP Description and Subjects Extracted from Image"
- # Arrange
- expected_metadata_image_name_pairs = [
- (["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"),
- (["Pasture.", "Horse", "Dog"], "horse_dog.jpg"),
- (["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg"),
- ]
-
- test_image_paths = [
- Path(content_config.image.input_directories[0] / image_name[1])
- for image_name in expected_metadata_image_name_pairs
- ]
-
- for expected_metadata, test_image_path in zip(expected_metadata_image_name_pairs, test_image_paths):
- # Act
- actual_metadata = image_search.extract_metadata(test_image_path)
-
- # Assert
- for expected_snippet in expected_metadata[0]:
- assert expected_snippet in actual_metadata
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.anyio
-async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
- # Arrange
- search_models.image_search = image_search.initialize_model(search_config.image)
- content_index.image = image_search.setup(
- content_config.image, search_models.image_search.image_encoder, regenerate=False
- )
- output_directory = resolve_absolute_path(web_directory)
- query_expected_image_pairs = [
- ("kitten", "kitten_park.jpg"),
- ("horse and dog in a farm", "horse_dog.jpg"),
- ("A guinea pig eating grass", "guineapig_grass.jpg"),
- ]
-
- # Act
- for query, expected_image_name in query_expected_image_pairs:
- hits = await image_search.query(
- query, count=1, search_model=search_models.image_search, content=content_index.image
- )
-
- results = image_search.collate_results(
- hits,
- content_index.image.image_names,
- output_directory=output_directory,
- image_files_url="/static/images",
- count=1,
- )
-
- actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
- actual_image = Image.open(actual_image_path)
- expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
-
- # Assert
- assert expected_image == actual_image
-
- # Cleanup
- # Delete the image files copied to results directory
- actual_image_path.unlink()
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.anyio
-async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
- # Arrange
- search_models.image_search = image_search.initialize_model(search_config.image)
- content_index.image = image_search.setup(
- content_config.image, search_models.image_search.image_encoder, regenerate=False
- )
- max_words_supported = 10
- query = " ".join(["hello"] * 100)
- truncated_query = " ".join(["hello"] * max_words_supported)
-
- # Act
- try:
- with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
- await image_search.query(
- query, count=1, search_model=search_models.image_search, content=content_index.image
- )
- # Assert
- except RuntimeError as e:
- if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
- assert False, f"Query length exceeds max tokens supported by model\n"
- assert f"Find Images by Text: {truncated_query}" in caplog.text, "Query not truncated"
-
-
-# ----------------------------------------------------------------------------------------------------
-@pytest.mark.anyio
-async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
- # Arrange
- search_models.image_search = image_search.initialize_model(search_config.image)
- content_index.image = image_search.setup(
- content_config.image, search_models.image_search.image_encoder, regenerate=False
- )
- output_directory = resolve_absolute_path(web_directory)
- image_directory = content_config.image.input_directories[0]
-
- query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
- expected_image_path = f"{image_directory.joinpath('kitten_park.jpg')}"
-
- # Act
- with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
- hits = await image_search.query(
- query, count=1, search_model=search_models.image_search, content=content_index.image
- )
-
- results = image_search.collate_results(
- hits,
- content_index.image.image_names,
- output_directory=output_directory,
- image_files_url="/static/images",
- count=1,
- )
-
- actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
- actual_image = Image.open(actual_image_path)
- expected_image = Image.open(expected_image_path)
-
- # Assert
- # Ensure file search triggered instead of query with file path as string
- assert (
- f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text
- ), "File search not triggered"
- # Ensure the correct image is returned
- assert expected_image == actual_image, "Incorrect image returned by file search"
-
- # Cleanup
- # Delete the image files copied to results directory
- actual_image_path.unlink()