From f57f9f672da4c94ae5df8102936298aa2d0bf697 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Thu, 4 Apr 2024 23:40:03 -0700 Subject: [PATCH] Address Notion, Image tech debt in indexing code path (#687) * Add support for using OAuth2.0 in the Notion integration * Add notion to the admin page * Remove unnecessary content_index and image search/setup references * Trigger background job to start indexing Notion after user configures it * Add a log line when a new Notion integration is setup * Fix references to the configure_content methods --- src/khoj/configure.py | 20 +- src/khoj/database/adapters/__init__.py | 4 + src/khoj/database/admin.py | 4 + src/khoj/interface/web/config.html | 17 +- .../web/content_source_notion_input.html | 9 +- src/khoj/routers/api.py | 42 +-- src/khoj/routers/api_config.py | 2 +- src/khoj/routers/indexer.py | 60 +--- src/khoj/routers/notion.py | 89 ++++++ src/khoj/routers/web_client.py | 6 +- src/khoj/search_type/image_search.py | 272 ------------------ src/khoj/utils/config.py | 6 - src/khoj/utils/state.py | 3 +- tests/conftest.py | 17 +- tests/test_client.py | 31 +- tests/test_image_search.py | 162 ----------- 16 files changed, 145 insertions(+), 599 deletions(-) create mode 100644 src/khoj/routers/notion.py delete mode 100644 src/khoj/search_type/image_search.py delete mode 100644 tests/test_image_search.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 0adbe889..7768a014 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -34,7 +34,7 @@ from khoj.database.adapters import ( ) from khoj.database.models import ClientApplication, KhojUser, Subscription from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel -from khoj.routers.indexer import configure_content, configure_search, load_content +from khoj.routers.indexer import configure_content, configure_search from khoj.routers.twilio import is_twilio_enabled from khoj.utils import constants, state from khoj.utils.config import SearchType @@ -245,16 +245,12 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non if state.search_models: try: if init: - logger.info("📬 Initializing content index...") - state.content_index = load_content(state.config.content_type, state.content_index, state.search_models) + logger.info("📬 No-op...") else: logger.info("📬 Updating content index...") all_files = collect_files(user=user) - state.content_index, status = configure_content( - state.content_index, - state.config.content_type, + status = configure_content( all_files, - state.search_models, regenerate, search_type, user=user, @@ -272,6 +268,7 @@ def configure_routes(app): from khoj.routers.api_chat import api_chat from khoj.routers.api_config import api_config from khoj.routers.indexer import indexer + from khoj.routers.notion import notion_router from khoj.routers.web_client import web_client app.include_router(api, prefix="/api") @@ -279,6 +276,7 @@ def configure_routes(app): app.include_router(api_agents, prefix="/api/agents") app.include_router(api_config, prefix="/api/config") app.include_router(indexer, prefix="/api/v1/index") + app.include_router(notion_router, prefix="/api/notion") app.include_router(web_client) if not state.anonymous_mode: @@ -311,13 +309,9 @@ def update_search_index(): logger.info("📬 Updating content index via Scheduler") for user in get_all_users(): all_files = collect_files(user=user) - state.content_index, success = configure_content( - state.content_index, state.config.content_type, all_files, state.search_models, user=user - ) + success = configure_content(all_files, user=user) all_files = collect_files(user=None) - state.content_index, success = configure_content( - state.content_index, state.config.content_type, all_files, state.search_models, user=None - ) + success = configure_content(all_files, user=None) if not success: raise RuntimeError("Failed to update content index") logger.info("📪 Content index updated via Scheduler") diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 95b06a9e..72e2a360 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -259,6 +259,10 @@ async def get_user_by_email(email: str) -> KhojUser: return await KhojUser.objects.filter(email=email).afirst() +async def aget_user_by_uuid(uuid: str) -> KhojUser: + return await KhojUser.objects.filter(uuid=uuid).afirst() + + async def get_user_by_token(token: dict) -> KhojUser: google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst() if not google_user: diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index cc1be7e4..0e440872 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -11,7 +11,9 @@ from khoj.database.models import ( ClientApplication, Conversation, Entry, + GithubConfig, KhojUser, + NotionConfig, OfflineChatProcessorConversationConfig, OpenAIProcessorConversationConfig, ReflectiveQuestion, @@ -52,6 +54,8 @@ admin.site.register(UserSearchModelConfig) admin.site.register(TextToImageModelConfig) admin.site.register(ClientApplication) admin.site.register(Agent) +admin.site.register(GithubConfig) +admin.site.register(NotionConfig) @admin.register(Entry) diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 1f1d1f1e..2e310ce6 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -109,14 +109,23 @@

Sync your Notion pages

+ {% if current_model_state.notion %} - {% if current_model_state.notion %} Update - {% else %} - Setup - {% endif %} + {% elif notion_oauth_url %} + + Connect + + + {% else %} + + Setup + + + {% endif %} +
diff --git a/src/khoj/interface/web/content_source_notion_input.html b/src/khoj/interface/web/content_source_notion_input.html index 220e8eb7..1303730c 100644 --- a/src/khoj/interface/web/content_source_notion_input.html +++ b/src/khoj/interface/web/content_source_notion_input.html @@ -5,11 +5,6 @@

Notion Notion -
- ⓘ Help -
-

-
@@ -22,7 +17,7 @@
- +
@@ -43,7 +38,7 @@ const submitButton = document.getElementById("submit"); submitButton.disabled = true; - submitButton.innerHTML = "Saving..."; + submitButton.innerHTML = "Syncing..."; // Save Notion config on server const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 328f2548..357a04e5 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -33,7 +33,7 @@ from khoj.routers.helpers import ( from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter -from khoj.search_type import image_search, text_search +from khoj.search_type import text_search from khoj.utils import constants, state from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ConversationCommand, timer @@ -145,41 +145,17 @@ async def execute_search( ) ] - elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search: - # query images - search_futures += [ - executor.submit( - image_search.query, - user_query, - results_count, - state.search_models.image_search, - state.content_index.image, - ) - ] - # Query across each requested content types in parallel with timer("Query took", logger): for search_future in concurrent.futures.as_completed(search_futures): - if t == SearchType.Image and state.content_index.image: - hits = await search_future.result() - output_directory = constants.web_directory / "images" - # Collate results - results += image_search.collate_results( - hits, - image_names=state.content_index.image.image_names, - output_directory=output_directory, - image_files_url="/static/images", - count=results_count, - ) - else: - hits = await search_future.result() - # Collate results - results += text_search.collate_results(hits, dedupe=dedupe) + hits = await search_future.result() + # Collate results + results += text_search.collate_results(hits, dedupe=dedupe) - # Sort results across all content types and take top results - results = text_search.rerank_and_sort_results( - results, query=defiltered_query, rank_results=r, search_model_name=search_model.name - )[:results_count] + # Sort results across all content types and take top results + results = text_search.rerank_and_sort_results( + results, query=defiltered_query, rank_results=r, search_model_name=search_model.name + )[:results_count] # Cache results if user: @@ -214,8 +190,6 @@ def update( components = [] if state.search_models: components.append("Search models") - if state.content_index: - components.append("Content index") components_msg = ", ".join(components) logger.info(f"📪 {components_msg} updated via API") diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index d23f131b..e72ba3b8 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) def map_config_to_object(content_source: str): if content_source == DbEntry.EntrySource.GITHUB: return GithubConfig - if content_source == DbEntry.EntrySource.GITHUB: + if content_source == DbEntry.EntrySource.NOTION: return NotionConfig if content_source == DbEntry.EntrySource.COMPUTER: return "Computer" diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index cf8bc015..1bca6a25 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -14,9 +14,9 @@ from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.routers.helpers import ApiIndexedDataLimiter, update_telemetry_state -from khoj.search_type import image_search, text_search +from khoj.search_type import text_search from khoj.utils import constants, state -from khoj.utils.config import ContentIndex, SearchModels +from khoj.utils.config import SearchModels from khoj.utils.helpers import LRU, get_file_type from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig from khoj.utils.yaml import save_config_to_file_updated_state @@ -105,13 +105,10 @@ async def update( # Extract required fields from config loop = asyncio.get_event_loop() - state.content_index, success = await loop.run_in_executor( + success = await loop.run_in_executor( None, configure_content, - state.content_index, - state.config.content_type, indexer_input.model_dump(), - state.search_models, force, t, False, @@ -159,23 +156,17 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search if search_config and search_config.image: logger.info("🔍 🌄 Setting up image search model") - search_models.image_search = image_search.initialize_model(search_config.image) return search_models def configure_content( - content_index: Optional[ContentIndex], - content_config: Optional[ContentConfig], files: Optional[dict[str, dict[str, str]]], - search_models: SearchModels, regenerate: bool = False, t: Optional[state.SearchType] = state.SearchType.All, full_corpus: bool = True, user: KhojUser = None, -) -> tuple[Optional[ContentIndex], bool]: - content_index = ContentIndex() - +) -> bool: success = True if t == None: t = state.SearchType.All @@ -185,7 +176,7 @@ def configure_content( if t is not None and not t.value in [type.value for type in state.SearchType]: logger.warning(f"🚨 Invalid search type: {t}") - return None, False + return False search_type = t.value if t else None @@ -193,7 +184,7 @@ def configure_content( if files is None: logger.warning(f"🚨 No files to process for {search_type} search.") - return None, True + return True try: # Initialize Org Notes Search @@ -266,24 +257,6 @@ def configure_content( logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True) success = False - try: - # Initialize Image Search - if ( - (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) - and content_config - and content_config.image - and search_models.image_search - ): - logger.info("🌄 Setting up search for images") - # Extract Entries, Generate Image Embeddings - content_index.image = image_search.setup( - content_config.image, search_models.image_search.image_encoder, regenerate=regenerate - ) - - except Exception as e: - logger.error(f"🚨 Failed to setup images: {e}", exc_info=True) - success = False - try: if no_documents: github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() @@ -330,23 +303,4 @@ def configure_content( if user: state.query_cache[user.uuid] = LRU() - return content_index, success - - -def load_content( - content_config: Optional[ContentConfig], - content_index: Optional[ContentIndex], - search_models: SearchModels, -): - if content_config is None: - logger.debug("🚨 No Content configuration available.") - return None - if content_index is None: - content_index = ContentIndex() - - if content_config.image: - logger.info("🌄 Loading images") - content_index.image = image_search.setup( - content_config.image, search_models.image_search.image_encoder, regenerate=False - ) - return content_index + return success diff --git a/src/khoj/routers/notion.py b/src/khoj/routers/notion.py new file mode 100644 index 00000000..d8fd0c26 --- /dev/null +++ b/src/khoj/routers/notion.py @@ -0,0 +1,89 @@ +import asyncio +import base64 +import json +import logging +import os +from concurrent.futures import ThreadPoolExecutor + +import requests +from fastapi import APIRouter, BackgroundTasks, Request, Response +from starlette.responses import RedirectResponse + +from khoj.database.adapters import aget_user_by_uuid +from khoj.database.models import KhojUser, NotionConfig +from khoj.routers.indexer import configure_content +from khoj.utils.state import SearchType + +NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID") +NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET") +NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI") + +notion_router = APIRouter() + +executor = ThreadPoolExecutor() + +logger = logging.getLogger(__name__) + + +def get_notion_auth_url(user: KhojUser): + if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI: + return None + return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}" + + +async def run_in_executor(func, *args): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, func, *args) + + +@notion_router.get("/auth/callback") +async def notion_auth_callback(request: Request, background_tasks: BackgroundTasks): + code = request.query_params.get("code") + state = request.query_params.get("state") + if not code or not state: + return Response("Missing code or state", status_code=400) + + user: KhojUser = await aget_user_by_uuid(state) + + NotionConfig.objects.filter(user=user).adelete() + + if not user: + raise Exception("User not found") + + bearer_token = f"{NOTION_OAUTH_CLIENT_ID}:{NOTION_OAUTH_CLIENT_SECRET}" + base64_encoded_token = base64.b64encode(bearer_token.encode()).decode() + + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Basic {base64_encoded_token}", + } + + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": NOTION_REDIRECT_URI, + } + + response = requests.post("https://api.notion.com/v1/oauth/token", data=json.dumps(data), headers=headers) + + final_response = response.json() + + access_token = final_response.get("access_token") + NotionConfig.objects.acreate(token=access_token, user=user) + + owner = final_response.get("owner") + workspace_id = final_response.get("workspace_id") + workspace_name = final_response.get("workspace_name") + bot_id = final_response.get("bot_id") + + logger.info( + f"Notion integration. Owner: {owner}, Workspace ID: {workspace_id}, Workspace Name: {workspace_name}, Bot ID: {bot_id}" + ) + + notion_redirect = str(request.app.url_path_for("notion_config_page")) + + # Trigger an async job to configure_content. Let it run without blocking the response. + background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user) + + return RedirectResponse(notion_redirect) diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index cb03cb89..4df0e3e1 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -19,6 +19,7 @@ from khoj.database.adapters import ( get_user_subscription_state, ) from khoj.database.models import KhojUser +from khoj.routers.notion import get_notion_auth_url from khoj.routers.twilio import is_twilio_enabled from khoj.utils import constants, state from khoj.utils.rawconfig import ( @@ -244,6 +245,8 @@ def config_page(request: Request): current_search_model_option = adapters.get_user_search_model_or_default(user) + notion_oauth_url = get_notion_auth_url(user) + return templates.TemplateResponse( "config.html", context={ @@ -267,6 +270,7 @@ def config_page(request: Request): "phone_number": user.phone_number, "is_phone_number_verified": user.verified_phone_number, "khoj_version": state.khoj_version, + "notion_oauth_url": notion_oauth_url, }, ) @@ -324,7 +328,7 @@ def notion_config_page(request: Request): token=current_notion_config.token if current_notion_config else "", ) - current_config = json.loads(current_config.json()) + current_config = json.loads(current_config.model_dump_json()) return templates.TemplateResponse( "content_source_notion_input.html", diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py deleted file mode 100644 index 76a72538..00000000 --- a/src/khoj/search_type/image_search.py +++ /dev/null @@ -1,272 +0,0 @@ -import copy -import glob -import logging -import math -import pathlib -import shutil -from typing import List - -import torch -from PIL import Image -from sentence_transformers import SentenceTransformer, util -from tqdm import trange - -from khoj.utils.config import ImageContent, ImageSearchModel -from khoj.utils.helpers import ( - get_absolute_path, - get_from_dict, - load_model, - resolve_absolute_path, - timer, -) -from khoj.utils.models import BaseEncoder -from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse - -# Create Logger -logger = logging.getLogger(__name__) - - -def initialize_model(search_config: ImageSearchConfig): - # Convert model directory to absolute path - search_config.model_directory = resolve_absolute_path(search_config.model_directory) - - # Create model directory if it doesn't exist - search_config.model_directory.parent.mkdir(parents=True, exist_ok=True) - - # Load the CLIP model - encoder = load_model( - model_dir=search_config.model_directory, - model_name=search_config.encoder, - model_type=search_config.encoder_type or SentenceTransformer, - ) - - return ImageSearchModel(encoder) - - -def extract_entries(image_directories): - image_names = [] - for image_directory in image_directories: - image_directory = resolve_absolute_path(image_directory, strict=True) - image_names.extend(list(image_directory.glob("*.jpg"))) - image_names.extend(list(image_directory.glob("*.jpeg"))) - - if logger.level >= logging.DEBUG: - image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories]) - logger.debug(f"Found {len(image_names)} images in {image_directory_names}") - return sorted(image_names) - - -def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False): - "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" - - image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate) - image_metadata_embeddings = compute_metadata_embeddings( - image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate - ) - - return image_embeddings, image_metadata_embeddings - - -def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False): - # Load pre-computed image embeddings from file if exists - if resolve_absolute_path(embeddings_file).exists() and not regenerate: - image_embeddings = torch.load(embeddings_file) - logger.debug(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}") - # Else compute the image embeddings from scratch, which can take a while - else: - image_embeddings = [] - for index in trange(0, len(image_names), batch_size): - images = [] - for image_name in image_names[index : index + batch_size]: - image = Image.open(image_name) - # Resize images to max width of 640px for faster processing - image.thumbnail((640, image.height)) - images += [image] - image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size)) - - # Create directory for embeddings file, if it doesn't exist - embeddings_file.parent.mkdir(parents=True, exist_ok=True) - - # Save computed image embeddings to file - torch.save(image_embeddings, embeddings_file) - logger.info(f"📩 Saved computed image embeddings to {embeddings_file}") - - return image_embeddings - - -def compute_metadata_embeddings( - image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0 -): - image_metadata_embeddings = None - - # Load pre-computed image metadata embedding file if exists - if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate: - image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata") - logger.debug(f"Loaded image metadata embeddings from {embeddings_file}_metadata") - - # Else compute the image metadata embeddings from scratch, which can take a while - if use_xmp_metadata and image_metadata_embeddings is None: - image_metadata_embeddings = [] - for index in trange(0, len(image_names), batch_size): - image_metadata = [ - extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size] - ] - try: - image_metadata_embeddings += encoder.encode( - image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size) - ) - except RuntimeError as e: - logger.error( - f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}" - ) - continue - torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata") - logger.info(f"📩 Saved computed image metadata embeddings to {embeddings_file}_metadata") - - return image_metadata_embeddings - - -def extract_metadata(image_name): - image_xmp_metadata = Image.open(image_name).getxmp() - image_description = get_from_dict( - image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text" - ) - image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li") - image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject]) - - image_processed_metadata = image_description - if len(image_metadata_subjects) > 0: - image_processed_metadata += ". " + ", ".join(image_metadata_subjects) - - logger.debug(f"{image_name}:\t{image_processed_metadata}") - - return image_processed_metadata - - -async def query( - raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf -): - # Set query to image content if query is of form file:/path/to/file.png - if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): - query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) - query = copy.deepcopy(Image.open(query_imagepath)) - query.thumbnail((640, query.height)) # scale down image for faster processing - logger.info(f"🔎 Find Images by Image: {query_imagepath}") - else: - # Truncate words in query to stay below max_tokens supported by ML model - max_words = 20 - query = " ".join(raw_query.split()[:max_words]) - logger.info(f"🔎 Find Images by Text: {query}") - - # Now we encode the query (which can either be an image or a text string) - with timer("Query Encode Time", logger): - query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) - - # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. - with timer("Search Time", logger): - image_hits = { - # Map scores to distance metric by multiplying by -1 - result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]} - for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0] - } - - # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. - if content.image_metadata_embeddings: - with timer("Metadata Search Time", logger): - metadata_hits = { - result["corpus_id"]: result["score"] - for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0] - } - - # Sum metadata, image scores of the highest ranked images - for corpus_id, score in metadata_hits.items(): - scaling_factor = 0.33 - if "corpus_id" in image_hits: - image_hits[corpus_id].update( - { - "metadata_score": score, - "score": image_hits[corpus_id].get("score", 0) + scaling_factor * score, - } - ) - else: - image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score} - - # Reformat results in original form from sentence transformer semantic_search() - hits = [ - { - "corpus_id": corpus_id, - "score": scores["score"], - "image_score": scores.get("image_score", 0), - "metadata_score": scores.get("metadata_score", 0), - } - for corpus_id, scores in image_hits.items() - ] - - # Filter results by score threshold - hits = [hit for hit in hits if hit["image_score"] <= score_threshold] - - # Sort the images based on their combined metadata, image scores - return sorted(hits, key=lambda hit: hit["score"], reverse=True) - - -def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> List[SearchResponse]: - results: List[SearchResponse] = [] - - for index, hit in enumerate(hits[:count]): - source_path = image_names[hit["corpus_id"]] - - target_image_name = f"{index}{source_path.suffix}" - target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}") - - # Create output directory, if it doesn't exist - if not target_path.parent.exists(): - target_path.parent.mkdir(exist_ok=True) - - # Copy the image to the output directory - shutil.copy(source_path, target_path) - - # Add the image metadata to the results - results += [ - SearchResponse.model_validate( - { - "entry": f"{image_files_url}/{target_image_name}", - "score": f"{hit['score']:.9f}", - "additional": { - "image_score": f"{hit['image_score']:.9f}", - "metadata_score": f"{hit['metadata_score']:.9f}", - }, - "corpus_id": str(hit["corpus_id"]), - } - ) - ] - - return results - - -def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent: - # Extract Entries - absolute_image_files, filtered_image_files = set(), set() - if config.input_directories: - image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories] - absolute_image_files = set(extract_entries(image_directories)) - if config.input_filter: - filtered_image_files = { - filtered_file - for input_filter in config.input_filter - for filtered_file in glob.glob(get_absolute_path(input_filter)) - } - - all_image_files = sorted(list(absolute_image_files | filtered_image_files)) - - # Compute or Load Embeddings - embeddings_file = resolve_absolute_path(config.embeddings_file) - image_embeddings, image_metadata_embeddings = compute_embeddings( - all_image_files, - encoder, - embeddings_file, - batch_size=config.batch_size, - regenerate=regenerate, - use_xmp_metadata=config.use_xmp_metadata, - ) - - return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index d04e6698..3f95030f 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -58,15 +58,9 @@ class ImageSearchModel: image_encoder: BaseEncoder -@dataclass -class ContentIndex: - image: Optional[ImageContent] = None - - @dataclass class SearchModels: text_search: Optional[TextSearchModel] = None - image_search: Optional[ImageSearchModel] = None @dataclass diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 7edc3483..b431225d 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -9,7 +9,7 @@ from whisper import Whisper from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.utils import config as utils_config -from khoj.utils.config import ContentIndex, OfflineChatProcessorModel, SearchModels +from khoj.utils.config import OfflineChatProcessorModel, SearchModels from khoj.utils.helpers import LRU, get_device from khoj.utils.rawconfig import FullConfig @@ -18,7 +18,6 @@ config = FullConfig() search_models = SearchModels() embeddings_model: Dict[str, EmbeddingsModel] = None cross_encoder_model: Dict[str, CrossEncoderModel] = None -content_index = ContentIndex() openai_client: OpenAI = None offline_chat_processor_config: OfflineChatProcessorModel = None whisper_model: Whisper = None diff --git a/tests/conftest.py b/tests/conftest.py index 8533ba4f..1c84c7d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,7 @@ from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.routers.indexer import configure_content -from khoj.search_type import image_search, text_search +from khoj.search_type import text_search from khoj.utils import fs_syncer, state from khoj.utils.config import SearchModels from khoj.utils.constants import web_directory @@ -207,7 +207,6 @@ def openai_agent(): @pytest.fixture(scope="session") def search_models(search_config: SearchConfig): search_models = SearchModels() - search_models.image_search = image_search.initialize_model(search_config.image) return search_models @@ -232,8 +231,6 @@ def content_config(tmp_path_factory, search_models: SearchModels, default_user: use_xmp_metadata=False, ) - image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False) - LocalOrgConfig.objects.create( input_files=None, input_filter=["tests/data/org/*.org"], @@ -305,9 +302,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa # Index Markdown Content for Search all_files = fs_syncer.collect_files(user=user) - state.content_index, _ = configure_content( - state.content_index, state.config.content_type, all_files, state.search_models, user=user - ) + success = configure_content(all_files, user=user) # Initialize Processor from Config if os.getenv("OPENAI_API_KEY"): @@ -349,16 +344,12 @@ def client( state.cross_encoder_model["default"] = CrossEncoderModel() # These lines help us Mock the Search models for these search types - state.search_models.image_search = image_search.initialize_model(search_config.image) text_search.setup( OrgToEntries, get_sample_data("org"), regenerate=False, user=api_user.user, ) - state.content_index.image = image_search.setup( - content_config.image, state.search_models.image_search, regenerate=False - ) text_search.setup( PlaintextToEntries, get_sample_data("plaintext"), @@ -388,9 +379,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): ) all_files = fs_syncer.collect_files(user=default_user2) - configure_content( - state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2 - ) + configure_content(all_files, user=default_user2) # Initialize Processor from Config OfflineChatProcessorConversationConfigFactory(enabled=True) diff --git a/tests/test_client.py b/tests/test_client.py index abbd1fec..d1f07a4b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,10 +12,9 @@ from khoj.configure import configure_routes, configure_search_types from khoj.database.adapters import EntryAdapters from khoj.database.models import KhojApiUser, KhojUser from khoj.processor.content.org_mode.org_to_entries import OrgToEntries -from khoj.search_type import image_search, text_search +from khoj.search_type import text_search from khoj.utils import state from khoj.utils.rawconfig import ContentConfig, SearchConfig -from khoj.utils.state import config, content_index, search_models # Test @@ -298,34 +297,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): assert response.json() == ["all"] -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.django_db(transaction=True) -def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): - # Arrange - headers = {"Authorization": "Bearer kk-secret"} - search_models.image_search = image_search.initialize_model(search_config.image) - content_index.image = image_search.setup( - content_config.image, search_models.image_search.image_encoder, regenerate=False - ) - query_expected_image_pairs = [ - ("kitten", "kitten_park.jpg"), - ("a horse and dog on a leash", "horse_dog.jpg"), - ("A guinea pig eating grass", "guineapig_grass.jpg"), - ] - - for query, expected_image_name in query_expected_image_pairs: - # Act - response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers) - - # Assert - assert response.status_code == 200 - actual_image = Image.open(BytesIO(client.get(response.json()[0]["entry"]).content)) - expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name)) - - # Assert - assert expected_image == actual_image - - # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser): diff --git a/tests/test_image_search.py b/tests/test_image_search.py deleted file mode 100644 index 5fe9ac7a..00000000 --- a/tests/test_image_search.py +++ /dev/null @@ -1,162 +0,0 @@ -# Standard Modules -import logging -from pathlib import Path - -import pytest -from PIL import Image - -from khoj.search_type import image_search -from khoj.utils.config import SearchModels -from khoj.utils.constants import web_directory -from khoj.utils.helpers import resolve_absolute_path -from khoj.utils.rawconfig import ContentConfig, SearchConfig -from khoj.utils.state import content_index, search_models - - -# Test -# ---------------------------------------------------------------------------------------------------- -def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels): - # Act - # Regenerate image search embeddings during image setup - image_search_model = image_search.setup( - content_config.image, search_models.image_search.image_encoder, regenerate=True - ) - - # Assert - assert len(image_search_model.image_names) == 3 - assert len(image_search_model.image_embeddings) == 3 - - -# ---------------------------------------------------------------------------------------------------- -def test_image_metadata(content_config: ContentConfig): - "Verify XMP Description and Subjects Extracted from Image" - # Arrange - expected_metadata_image_name_pairs = [ - (["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"), - (["Pasture.", "Horse", "Dog"], "horse_dog.jpg"), - (["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg"), - ] - - test_image_paths = [ - Path(content_config.image.input_directories[0] / image_name[1]) - for image_name in expected_metadata_image_name_pairs - ] - - for expected_metadata, test_image_path in zip(expected_metadata_image_name_pairs, test_image_paths): - # Act - actual_metadata = image_search.extract_metadata(test_image_path) - - # Assert - for expected_snippet in expected_metadata[0]: - assert expected_snippet in actual_metadata - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -async def test_image_search(content_config: ContentConfig, search_config: SearchConfig): - # Arrange - search_models.image_search = image_search.initialize_model(search_config.image) - content_index.image = image_search.setup( - content_config.image, search_models.image_search.image_encoder, regenerate=False - ) - output_directory = resolve_absolute_path(web_directory) - query_expected_image_pairs = [ - ("kitten", "kitten_park.jpg"), - ("horse and dog in a farm", "horse_dog.jpg"), - ("A guinea pig eating grass", "guineapig_grass.jpg"), - ] - - # Act - for query, expected_image_name in query_expected_image_pairs: - hits = await image_search.query( - query, count=1, search_model=search_models.image_search, content=content_index.image - ) - - results = image_search.collate_results( - hits, - content_index.image.image_names, - output_directory=output_directory, - image_files_url="/static/images", - count=1, - ) - - actual_image_path = output_directory.joinpath(Path(results[0].entry).name) - actual_image = Image.open(actual_image_path) - expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name)) - - # Assert - assert expected_image == actual_image - - # Cleanup - # Delete the image files copied to results directory - actual_image_path.unlink() - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog): - # Arrange - search_models.image_search = image_search.initialize_model(search_config.image) - content_index.image = image_search.setup( - content_config.image, search_models.image_search.image_encoder, regenerate=False - ) - max_words_supported = 10 - query = " ".join(["hello"] * 100) - truncated_query = " ".join(["hello"] * max_words_supported) - - # Act - try: - with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): - await image_search.query( - query, count=1, search_model=search_models.image_search, content=content_index.image - ) - # Assert - except RuntimeError as e: - if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): - assert False, f"Query length exceeds max tokens supported by model\n" - assert f"Find Images by Text: {truncated_query}" in caplog.text, "Query not truncated" - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog): - # Arrange - search_models.image_search = image_search.initialize_model(search_config.image) - content_index.image = image_search.setup( - content_config.image, search_models.image_search.image_encoder, regenerate=False - ) - output_directory = resolve_absolute_path(web_directory) - image_directory = content_config.image.input_directories[0] - - query = f"file:{image_directory.joinpath('kitten_park.jpg')}" - expected_image_path = f"{image_directory.joinpath('kitten_park.jpg')}" - - # Act - with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): - hits = await image_search.query( - query, count=1, search_model=search_models.image_search, content=content_index.image - ) - - results = image_search.collate_results( - hits, - content_index.image.image_names, - output_directory=output_directory, - image_files_url="/static/images", - count=1, - ) - - actual_image_path = output_directory.joinpath(Path(results[0].entry).name) - actual_image = Image.open(actual_image_path) - expected_image = Image.open(expected_image_path) - - # Assert - # Ensure file search triggered instead of query with file path as string - assert ( - f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text - ), "File search not triggered" - # Ensure the correct image is returned - assert expected_image == actual_image, "Incorrect image returned by file search" - - # Cleanup - # Delete the image files copied to results directory - actual_image_path.unlink()