mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Address Notion, Image tech debt in indexing code path (#687)
* Add support for using OAuth2.0 in the Notion integration * Add notion to the admin page * Remove unnecessary content_index and image search/setup references * Trigger background job to start indexing Notion after user configures it * Add a log line when a new Notion integration is setup * Fix references to the configure_content methods
This commit is contained in:
parent
69dee75c34
commit
f57f9f672d
16 changed files with 145 additions and 599 deletions
|
@ -34,7 +34,7 @@ from khoj.database.adapters import (
|
|||
)
|
||||
from khoj.database.models import ClientApplication, KhojUser, Subscription
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content, configure_search, load_content
|
||||
from khoj.routers.indexer import configure_content, configure_search
|
||||
from khoj.routers.twilio import is_twilio_enabled
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType
|
||||
|
@ -245,16 +245,12 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
|
|||
if state.search_models:
|
||||
try:
|
||||
if init:
|
||||
logger.info("📬 Initializing content index...")
|
||||
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
|
||||
logger.info("📬 No-op...")
|
||||
else:
|
||||
logger.info("📬 Updating content index...")
|
||||
all_files = collect_files(user=user)
|
||||
state.content_index, status = configure_content(
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
status = configure_content(
|
||||
all_files,
|
||||
state.search_models,
|
||||
regenerate,
|
||||
search_type,
|
||||
user=user,
|
||||
|
@ -272,6 +268,7 @@ def configure_routes(app):
|
|||
from khoj.routers.api_chat import api_chat
|
||||
from khoj.routers.api_config import api_config
|
||||
from khoj.routers.indexer import indexer
|
||||
from khoj.routers.notion import notion_router
|
||||
from khoj.routers.web_client import web_client
|
||||
|
||||
app.include_router(api, prefix="/api")
|
||||
|
@ -279,6 +276,7 @@ def configure_routes(app):
|
|||
app.include_router(api_agents, prefix="/api/agents")
|
||||
app.include_router(api_config, prefix="/api/config")
|
||||
app.include_router(indexer, prefix="/api/v1/index")
|
||||
app.include_router(notion_router, prefix="/api/notion")
|
||||
app.include_router(web_client)
|
||||
|
||||
if not state.anonymous_mode:
|
||||
|
@ -311,13 +309,9 @@ def update_search_index():
|
|||
logger.info("📬 Updating content index via Scheduler")
|
||||
for user in get_all_users():
|
||||
all_files = collect_files(user=user)
|
||||
state.content_index, success = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=user
|
||||
)
|
||||
success = configure_content(all_files, user=user)
|
||||
all_files = collect_files(user=None)
|
||||
state.content_index, success = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=None
|
||||
)
|
||||
success = configure_content(all_files, user=None)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
logger.info("📪 Content index updated via Scheduler")
|
||||
|
|
|
@ -259,6 +259,10 @@ async def get_user_by_email(email: str) -> KhojUser:
|
|||
return await KhojUser.objects.filter(email=email).afirst()
|
||||
|
||||
|
||||
async def aget_user_by_uuid(uuid: str) -> KhojUser:
|
||||
return await KhojUser.objects.filter(uuid=uuid).afirst()
|
||||
|
||||
|
||||
async def get_user_by_token(token: dict) -> KhojUser:
|
||||
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
|
||||
if not google_user:
|
||||
|
|
|
@ -11,7 +11,9 @@ from khoj.database.models import (
|
|||
ClientApplication,
|
||||
Conversation,
|
||||
Entry,
|
||||
GithubConfig,
|
||||
KhojUser,
|
||||
NotionConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
ReflectiveQuestion,
|
||||
|
@ -52,6 +54,8 @@ admin.site.register(UserSearchModelConfig)
|
|||
admin.site.register(TextToImageModelConfig)
|
||||
admin.site.register(ClientApplication)
|
||||
admin.site.register(Agent)
|
||||
admin.site.register(GithubConfig)
|
||||
admin.site.register(NotionConfig)
|
||||
|
||||
|
||||
@admin.register(Entry)
|
||||
|
|
|
@ -109,14 +109,23 @@
|
|||
<p class="card-description">Sync your Notion pages</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
{% if current_model_state.notion %}
|
||||
<a class="card-button" href="/config/content-source/notion">
|
||||
{% if current_model_state.notion %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
{% endif %}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
{% elif notion_oauth_url %}
|
||||
<a class="card-button" href="{{ notion_oauth_url }}">
|
||||
Connect
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
{% else %}
|
||||
<a class="card-button" href="/config/content-source/notion">
|
||||
Setup
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
<div id="clear-notion"
|
||||
class="card-action-row"
|
||||
style="display: {% if not current_model_state.notion %}none{% endif %}">
|
||||
|
|
|
@ -5,11 +5,6 @@
|
|||
<h2 class="section-title">
|
||||
<img class="card-icon" src="/static/assets/icons/notion.svg?v={{ khoj_version }}" alt="Notion">
|
||||
<span class="card-title-text">Notion</span>
|
||||
<div class="instructions">
|
||||
<a href="https://docs.khoj.dev/#/notion_integration">ⓘ Help</a>
|
||||
</div>
|
||||
</h2>
|
||||
<form>
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
|
@ -22,7 +17,7 @@
|
|||
</table>
|
||||
<div class="section">
|
||||
<div id="success" style="display: none;"></div>
|
||||
<button id="submit" type="submit">Save</button>
|
||||
<button id="submit" type="submit">Sync to Update</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
|
@ -43,7 +38,7 @@
|
|||
|
||||
const submitButton = document.getElementById("submit");
|
||||
submitButton.disabled = true;
|
||||
submitButton.innerHTML = "Saving...";
|
||||
submitButton.innerHTML = "Syncing...";
|
||||
|
||||
// Save Notion config on server
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
|
|
|
@ -33,7 +33,7 @@ from khoj.routers.helpers import (
|
|||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import OfflineChatProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, timer
|
||||
|
@ -145,41 +145,17 @@ async def execute_search(
|
|||
)
|
||||
]
|
||||
|
||||
elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
||||
# query images
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
image_search.query,
|
||||
user_query,
|
||||
results_count,
|
||||
state.search_models.image_search,
|
||||
state.content_index.image,
|
||||
)
|
||||
]
|
||||
|
||||
# Query across each requested content types in parallel
|
||||
with timer("Query took", logger):
|
||||
for search_future in concurrent.futures.as_completed(search_futures):
|
||||
if t == SearchType.Image and state.content_index.image:
|
||||
hits = await search_future.result()
|
||||
output_directory = constants.web_directory / "images"
|
||||
# Collate results
|
||||
results += image_search.collate_results(
|
||||
hits,
|
||||
image_names=state.content_index.image.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=results_count,
|
||||
)
|
||||
else:
|
||||
hits = await search_future.result()
|
||||
# Collate results
|
||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||
hits = await search_future.result()
|
||||
# Collate results
|
||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||
|
||||
# Sort results across all content types and take top results
|
||||
results = text_search.rerank_and_sort_results(
|
||||
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
|
||||
)[:results_count]
|
||||
# Sort results across all content types and take top results
|
||||
results = text_search.rerank_and_sort_results(
|
||||
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
|
||||
)[:results_count]
|
||||
|
||||
# Cache results
|
||||
if user:
|
||||
|
@ -214,8 +190,6 @@ def update(
|
|||
components = []
|
||||
if state.search_models:
|
||||
components.append("Search models")
|
||||
if state.content_index:
|
||||
components.append("Content index")
|
||||
components_msg = ", ".join(components)
|
||||
logger.info(f"📪 {components_msg} updated via API")
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ logger = logging.getLogger(__name__)
|
|||
def map_config_to_object(content_source: str):
|
||||
if content_source == DbEntry.EntrySource.GITHUB:
|
||||
return GithubConfig
|
||||
if content_source == DbEntry.EntrySource.GITHUB:
|
||||
if content_source == DbEntry.EntrySource.NOTION:
|
||||
return NotionConfig
|
||||
if content_source == DbEntry.EntrySource.COMPUTER:
|
||||
return "Computer"
|
||||
|
|
|
@ -14,9 +14,9 @@ from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
|||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from khoj.routers.helpers import ApiIndexedDataLimiter, update_telemetry_state
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import ContentIndex, SearchModels
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.helpers import LRU, get_file_type
|
||||
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
|
@ -105,13 +105,10 @@ async def update(
|
|||
|
||||
# Extract required fields from config
|
||||
loop = asyncio.get_event_loop()
|
||||
state.content_index, success = await loop.run_in_executor(
|
||||
success = await loop.run_in_executor(
|
||||
None,
|
||||
configure_content,
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
indexer_input.model_dump(),
|
||||
state.search_models,
|
||||
force,
|
||||
t,
|
||||
False,
|
||||
|
@ -159,23 +156,17 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
|
|||
|
||||
if search_config and search_config.image:
|
||||
logger.info("🔍 🌄 Setting up image search model")
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
def configure_content(
|
||||
content_index: Optional[ContentIndex],
|
||||
content_config: Optional[ContentConfig],
|
||||
files: Optional[dict[str, dict[str, str]]],
|
||||
search_models: SearchModels,
|
||||
regenerate: bool = False,
|
||||
t: Optional[state.SearchType] = state.SearchType.All,
|
||||
full_corpus: bool = True,
|
||||
user: KhojUser = None,
|
||||
) -> tuple[Optional[ContentIndex], bool]:
|
||||
content_index = ContentIndex()
|
||||
|
||||
) -> bool:
|
||||
success = True
|
||||
if t == None:
|
||||
t = state.SearchType.All
|
||||
|
@ -185,7 +176,7 @@ def configure_content(
|
|||
|
||||
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
||||
logger.warning(f"🚨 Invalid search type: {t}")
|
||||
return None, False
|
||||
return False
|
||||
|
||||
search_type = t.value if t else None
|
||||
|
||||
|
@ -193,7 +184,7 @@ def configure_content(
|
|||
|
||||
if files is None:
|
||||
logger.warning(f"🚨 No files to process for {search_type} search.")
|
||||
return None, True
|
||||
return True
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
|
@ -266,24 +257,6 @@ def configure_content(
|
|||
logger.error(f"🚨 Failed to setup plaintext: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
try:
|
||||
# Initialize Image Search
|
||||
if (
|
||||
(search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value)
|
||||
and content_config
|
||||
and content_config.image
|
||||
and search_models.image_search
|
||||
):
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
|
||||
success = False
|
||||
|
||||
try:
|
||||
if no_documents:
|
||||
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
|
@ -330,23 +303,4 @@ def configure_content(
|
|||
if user:
|
||||
state.query_cache[user.uuid] = LRU()
|
||||
|
||||
return content_index, success
|
||||
|
||||
|
||||
def load_content(
|
||||
content_config: Optional[ContentConfig],
|
||||
content_index: Optional[ContentIndex],
|
||||
search_models: SearchModels,
|
||||
):
|
||||
if content_config is None:
|
||||
logger.debug("🚨 No Content configuration available.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
||||
if content_config.image:
|
||||
logger.info("🌄 Loading images")
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
return content_index
|
||||
return success
|
||||
|
|
89
src/khoj/routers/notion.py
Normal file
89
src/khoj/routers/notion.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, BackgroundTasks, Request, Response
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from khoj.database.adapters import aget_user_by_uuid
|
||||
from khoj.database.models import KhojUser, NotionConfig
|
||||
from khoj.routers.indexer import configure_content
|
||||
from khoj.utils.state import SearchType
|
||||
|
||||
NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID")
|
||||
NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET")
|
||||
NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI")
|
||||
|
||||
notion_router = APIRouter()
|
||||
|
||||
executor = ThreadPoolExecutor()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_notion_auth_url(user: KhojUser):
|
||||
if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI:
|
||||
return None
|
||||
return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}"
|
||||
|
||||
|
||||
async def run_in_executor(func, *args):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, func, *args)
|
||||
|
||||
|
||||
@notion_router.get("/auth/callback")
|
||||
async def notion_auth_callback(request: Request, background_tasks: BackgroundTasks):
|
||||
code = request.query_params.get("code")
|
||||
state = request.query_params.get("state")
|
||||
if not code or not state:
|
||||
return Response("Missing code or state", status_code=400)
|
||||
|
||||
user: KhojUser = await aget_user_by_uuid(state)
|
||||
|
||||
NotionConfig.objects.filter(user=user).adelete()
|
||||
|
||||
if not user:
|
||||
raise Exception("User not found")
|
||||
|
||||
bearer_token = f"{NOTION_OAUTH_CLIENT_ID}:{NOTION_OAUTH_CLIENT_SECRET}"
|
||||
base64_encoded_token = base64.b64encode(bearer_token.encode()).decode()
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Basic {base64_encoded_token}",
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": NOTION_REDIRECT_URI,
|
||||
}
|
||||
|
||||
response = requests.post("https://api.notion.com/v1/oauth/token", data=json.dumps(data), headers=headers)
|
||||
|
||||
final_response = response.json()
|
||||
|
||||
access_token = final_response.get("access_token")
|
||||
NotionConfig.objects.acreate(token=access_token, user=user)
|
||||
|
||||
owner = final_response.get("owner")
|
||||
workspace_id = final_response.get("workspace_id")
|
||||
workspace_name = final_response.get("workspace_name")
|
||||
bot_id = final_response.get("bot_id")
|
||||
|
||||
logger.info(
|
||||
f"Notion integration. Owner: {owner}, Workspace ID: {workspace_id}, Workspace Name: {workspace_name}, Bot ID: {bot_id}"
|
||||
)
|
||||
|
||||
notion_redirect = str(request.app.url_path_for("notion_config_page"))
|
||||
|
||||
# Trigger an async job to configure_content. Let it run without blocking the response.
|
||||
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user)
|
||||
|
||||
return RedirectResponse(notion_redirect)
|
|
@ -19,6 +19,7 @@ from khoj.database.adapters import (
|
|||
get_user_subscription_state,
|
||||
)
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.routers.notion import get_notion_auth_url
|
||||
from khoj.routers.twilio import is_twilio_enabled
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.rawconfig import (
|
||||
|
@ -244,6 +245,8 @@ def config_page(request: Request):
|
|||
|
||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
||||
|
||||
notion_oauth_url = get_notion_auth_url(user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"config.html",
|
||||
context={
|
||||
|
@ -267,6 +270,7 @@ def config_page(request: Request):
|
|||
"phone_number": user.phone_number,
|
||||
"is_phone_number_verified": user.verified_phone_number,
|
||||
"khoj_version": state.khoj_version,
|
||||
"notion_oauth_url": notion_oauth_url,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -324,7 +328,7 @@ def notion_config_page(request: Request):
|
|||
token=current_notion_config.token if current_notion_config else "",
|
||||
)
|
||||
|
||||
current_config = json.loads(current_config.json())
|
||||
current_config = json.loads(current_config.model_dump_json())
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"content_source_notion_input.html",
|
||||
|
|
|
@ -1,272 +0,0 @@
|
|||
import copy
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import pathlib
|
||||
import shutil
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
from tqdm import trange
|
||||
|
||||
from khoj.utils.config import ImageContent, ImageSearchModel
|
||||
from khoj.utils.helpers import (
|
||||
get_absolute_path,
|
||||
get_from_dict,
|
||||
load_model,
|
||||
resolve_absolute_path,
|
||||
timer,
|
||||
)
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||
|
||||
# Create Logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_model(search_config: ImageSearchConfig):
|
||||
# Convert model directory to absolute path
|
||||
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
|
||||
|
||||
# Create model directory if it doesn't exist
|
||||
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load the CLIP model
|
||||
encoder = load_model(
|
||||
model_dir=search_config.model_directory,
|
||||
model_name=search_config.encoder,
|
||||
model_type=search_config.encoder_type or SentenceTransformer,
|
||||
)
|
||||
|
||||
return ImageSearchModel(encoder)
|
||||
|
||||
|
||||
def extract_entries(image_directories):
|
||||
image_names = []
|
||||
for image_directory in image_directories:
|
||||
image_directory = resolve_absolute_path(image_directory, strict=True)
|
||||
image_names.extend(list(image_directory.glob("*.jpg")))
|
||||
image_names.extend(list(image_directory.glob("*.jpeg")))
|
||||
|
||||
if logger.level >= logging.DEBUG:
|
||||
image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories])
|
||||
logger.debug(f"Found {len(image_names)} images in {image_directory_names}")
|
||||
return sorted(image_names)
|
||||
|
||||
|
||||
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
|
||||
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
|
||||
image_metadata_embeddings = compute_metadata_embeddings(
|
||||
image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate
|
||||
)
|
||||
|
||||
return image_embeddings, image_metadata_embeddings
|
||||
|
||||
|
||||
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False):
|
||||
# Load pre-computed image embeddings from file if exists
|
||||
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
||||
image_embeddings = torch.load(embeddings_file)
|
||||
logger.debug(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}")
|
||||
# Else compute the image embeddings from scratch, which can take a while
|
||||
else:
|
||||
image_embeddings = []
|
||||
for index in trange(0, len(image_names), batch_size):
|
||||
images = []
|
||||
for image_name in image_names[index : index + batch_size]:
|
||||
image = Image.open(image_name)
|
||||
# Resize images to max width of 640px for faster processing
|
||||
image.thumbnail((640, image.height))
|
||||
images += [image]
|
||||
image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size))
|
||||
|
||||
# Create directory for embeddings file, if it doesn't exist
|
||||
embeddings_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save computed image embeddings to file
|
||||
torch.save(image_embeddings, embeddings_file)
|
||||
logger.info(f"📩 Saved computed image embeddings to {embeddings_file}")
|
||||
|
||||
return image_embeddings
|
||||
|
||||
|
||||
def compute_metadata_embeddings(
|
||||
image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0
|
||||
):
|
||||
image_metadata_embeddings = None
|
||||
|
||||
# Load pre-computed image metadata embedding file if exists
|
||||
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
|
||||
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
|
||||
logger.debug(f"Loaded image metadata embeddings from {embeddings_file}_metadata")
|
||||
|
||||
# Else compute the image metadata embeddings from scratch, which can take a while
|
||||
if use_xmp_metadata and image_metadata_embeddings is None:
|
||||
image_metadata_embeddings = []
|
||||
for index in trange(0, len(image_names), batch_size):
|
||||
image_metadata = [
|
||||
extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size]
|
||||
]
|
||||
try:
|
||||
image_metadata_embeddings += encoder.encode(
|
||||
image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size)
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}"
|
||||
)
|
||||
continue
|
||||
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
||||
logger.info(f"📩 Saved computed image metadata embeddings to {embeddings_file}_metadata")
|
||||
|
||||
return image_metadata_embeddings
|
||||
|
||||
|
||||
def extract_metadata(image_name):
|
||||
image_xmp_metadata = Image.open(image_name).getxmp()
|
||||
image_description = get_from_dict(
|
||||
image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text"
|
||||
)
|
||||
image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li")
|
||||
image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject])
|
||||
|
||||
image_processed_metadata = image_description
|
||||
if len(image_metadata_subjects) > 0:
|
||||
image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
|
||||
|
||||
logger.debug(f"{image_name}:\t{image_processed_metadata}")
|
||||
|
||||
return image_processed_metadata
|
||||
|
||||
|
||||
async def query(
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
|
||||
):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||
query = copy.deepcopy(Image.open(query_imagepath))
|
||||
query.thumbnail((640, query.height)) # scale down image for faster processing
|
||||
logger.info(f"🔎 Find Images by Image: {query_imagepath}")
|
||||
else:
|
||||
# Truncate words in query to stay below max_tokens supported by ML model
|
||||
max_words = 20
|
||||
query = " ".join(raw_query.split()[:max_words])
|
||||
logger.info(f"🔎 Find Images by Text: {query}")
|
||||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
with timer("Query Encode Time", logger):
|
||||
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {
|
||||
# Map scores to distance metric by multiplying by -1
|
||||
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
|
||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||
if content.image_metadata_embeddings:
|
||||
with timer("Metadata Search Time", logger):
|
||||
metadata_hits = {
|
||||
result["corpus_id"]: result["score"]
|
||||
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
# Sum metadata, image scores of the highest ranked images
|
||||
for corpus_id, score in metadata_hits.items():
|
||||
scaling_factor = 0.33
|
||||
if "corpus_id" in image_hits:
|
||||
image_hits[corpus_id].update(
|
||||
{
|
||||
"metadata_score": score,
|
||||
"score": image_hits[corpus_id].get("score", 0) + scaling_factor * score,
|
||||
}
|
||||
)
|
||||
else:
|
||||
image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score}
|
||||
|
||||
# Reformat results in original form from sentence transformer semantic_search()
|
||||
hits = [
|
||||
{
|
||||
"corpus_id": corpus_id,
|
||||
"score": scores["score"],
|
||||
"image_score": scores.get("image_score", 0),
|
||||
"metadata_score": scores.get("metadata_score", 0),
|
||||
}
|
||||
for corpus_id, scores in image_hits.items()
|
||||
]
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
|
||||
|
||||
# Sort the images based on their combined metadata, image scores
|
||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||
|
||||
|
||||
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> List[SearchResponse]:
|
||||
results: List[SearchResponse] = []
|
||||
|
||||
for index, hit in enumerate(hits[:count]):
|
||||
source_path = image_names[hit["corpus_id"]]
|
||||
|
||||
target_image_name = f"{index}{source_path.suffix}"
|
||||
target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}")
|
||||
|
||||
# Create output directory, if it doesn't exist
|
||||
if not target_path.parent.exists():
|
||||
target_path.parent.mkdir(exist_ok=True)
|
||||
|
||||
# Copy the image to the output directory
|
||||
shutil.copy(source_path, target_path)
|
||||
|
||||
# Add the image metadata to the results
|
||||
results += [
|
||||
SearchResponse.model_validate(
|
||||
{
|
||||
"entry": f"{image_files_url}/{target_image_name}",
|
||||
"score": f"{hit['score']:.9f}",
|
||||
"additional": {
|
||||
"image_score": f"{hit['image_score']:.9f}",
|
||||
"metadata_score": f"{hit['metadata_score']:.9f}",
|
||||
},
|
||||
"corpus_id": str(hit["corpus_id"]),
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
|
||||
# Extract Entries
|
||||
absolute_image_files, filtered_image_files = set(), set()
|
||||
if config.input_directories:
|
||||
image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories]
|
||||
absolute_image_files = set(extract_entries(image_directories))
|
||||
if config.input_filter:
|
||||
filtered_image_files = {
|
||||
filtered_file
|
||||
for input_filter in config.input_filter
|
||||
for filtered_file in glob.glob(get_absolute_path(input_filter))
|
||||
}
|
||||
|
||||
all_image_files = sorted(list(absolute_image_files | filtered_image_files))
|
||||
|
||||
# Compute or Load Embeddings
|
||||
embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
image_embeddings, image_metadata_embeddings = compute_embeddings(
|
||||
all_image_files,
|
||||
encoder,
|
||||
embeddings_file,
|
||||
batch_size=config.batch_size,
|
||||
regenerate=regenerate,
|
||||
use_xmp_metadata=config.use_xmp_metadata,
|
||||
)
|
||||
|
||||
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)
|
|
@ -58,15 +58,9 @@ class ImageSearchModel:
|
|||
image_encoder: BaseEncoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentIndex:
|
||||
image: Optional[ImageContent] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchModels:
|
||||
text_search: Optional[TextSearchModel] = None
|
||||
image_search: Optional[ImageSearchModel] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -9,7 +9,7 @@ from whisper import Whisper
|
|||
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import ContentIndex, OfflineChatProcessorModel, SearchModels
|
||||
from khoj.utils.config import OfflineChatProcessorModel, SearchModels
|
||||
from khoj.utils.helpers import LRU, get_device
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
|
@ -18,7 +18,6 @@ config = FullConfig()
|
|||
search_models = SearchModels()
|
||||
embeddings_model: Dict[str, EmbeddingsModel] = None
|
||||
cross_encoder_model: Dict[str, CrossEncoderModel] = None
|
||||
content_index = ContentIndex()
|
||||
openai_client: OpenAI = None
|
||||
offline_chat_processor_config: OfflineChatProcessorModel = None
|
||||
whisper_model: Whisper = None
|
||||
|
|
|
@ -25,7 +25,7 @@ from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
|||
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils import fs_syncer, state
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.constants import web_directory
|
||||
|
@ -207,7 +207,6 @@ def openai_agent():
|
|||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
@ -232,8 +231,6 @@ def content_config(tmp_path_factory, search_models: SearchModels, default_user:
|
|||
use_xmp_metadata=False,
|
||||
)
|
||||
|
||||
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
|
||||
|
||||
LocalOrgConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
|
@ -305,9 +302,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
|
|||
|
||||
# Index Markdown Content for Search
|
||||
all_files = fs_syncer.collect_files(user=user)
|
||||
state.content_index, _ = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=user
|
||||
)
|
||||
success = configure_content(all_files, user=user)
|
||||
|
||||
# Initialize Processor from Config
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
|
@ -349,16 +344,12 @@ def client(
|
|||
state.cross_encoder_model["default"] = CrossEncoderModel()
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
text_search.setup(
|
||||
OrgToEntries,
|
||||
get_sample_data("org"),
|
||||
regenerate=False,
|
||||
user=api_user.user,
|
||||
)
|
||||
state.content_index.image = image_search.setup(
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
text_search.setup(
|
||||
PlaintextToEntries,
|
||||
get_sample_data("plaintext"),
|
||||
|
@ -388,9 +379,7 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
|||
)
|
||||
|
||||
all_files = fs_syncer.collect_files(user=default_user2)
|
||||
configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
||||
)
|
||||
configure_content(all_files, user=default_user2)
|
||||
|
||||
# Initialize Processor from Config
|
||||
OfflineChatProcessorConversationConfigFactory(enabled=True)
|
||||
|
|
|
@ -12,10 +12,9 @@ from khoj.configure import configure_routes, configure_search_types
|
|||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.models import KhojApiUser, KhojUser
|
||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils import state
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.utils.state import config, content_index, search_models
|
||||
|
||||
|
||||
# Test
|
||||
|
@ -298,34 +297,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
|||
assert response.json() == ["all"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
query_expected_image_pairs = [
|
||||
("kitten", "kitten_park.jpg"),
|
||||
("a horse and dog on a leash", "horse_dog.jpg"),
|
||||
("A guinea pig eating grass", "guineapig_grass.jpg"),
|
||||
]
|
||||
|
||||
for query, expected_image_name in query_expected_image_pairs:
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
actual_image = Image.open(BytesIO(client.get(response.json()[0]["entry"]).content))
|
||||
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
|
||||
|
||||
# Assert
|
||||
assert expected_image == actual_image
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
|
||||
|
|
|
@ -1,162 +0,0 @@
|
|||
# Standard Modules
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from khoj.search_type import image_search
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.constants import web_directory
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.utils.state import content_index, search_models
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
# Act
|
||||
# Regenerate image search embeddings during image setup
|
||||
image_search_model = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(image_search_model.image_names) == 3
|
||||
assert len(image_search_model.image_embeddings) == 3
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_metadata(content_config: ContentConfig):
|
||||
"Verify XMP Description and Subjects Extracted from Image"
|
||||
# Arrange
|
||||
expected_metadata_image_name_pairs = [
|
||||
(["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"),
|
||||
(["Pasture.", "Horse", "Dog"], "horse_dog.jpg"),
|
||||
(["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg"),
|
||||
]
|
||||
|
||||
test_image_paths = [
|
||||
Path(content_config.image.input_directories[0] / image_name[1])
|
||||
for image_name in expected_metadata_image_name_pairs
|
||||
]
|
||||
|
||||
for expected_metadata, test_image_path in zip(expected_metadata_image_name_pairs, test_image_paths):
|
||||
# Act
|
||||
actual_metadata = image_search.extract_metadata(test_image_path)
|
||||
|
||||
# Assert
|
||||
for expected_snippet in expected_metadata[0]:
|
||||
assert expected_snippet in actual_metadata
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
output_directory = resolve_absolute_path(web_directory)
|
||||
query_expected_image_pairs = [
|
||||
("kitten", "kitten_park.jpg"),
|
||||
("horse and dog in a farm", "horse_dog.jpg"),
|
||||
("A guinea pig eating grass", "guineapig_grass.jpg"),
|
||||
]
|
||||
|
||||
# Act
|
||||
for query, expected_image_name in query_expected_image_pairs:
|
||||
hits = await image_search.query(
|
||||
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||
)
|
||||
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
content_index.image.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=1,
|
||||
)
|
||||
|
||||
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
|
||||
actual_image = Image.open(actual_image_path)
|
||||
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
|
||||
|
||||
# Assert
|
||||
assert expected_image == actual_image
|
||||
|
||||
# Cleanup
|
||||
# Delete the image files copied to results directory
|
||||
actual_image_path.unlink()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
# Arrange
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
max_words_supported = 10
|
||||
query = " ".join(["hello"] * 100)
|
||||
truncated_query = " ".join(["hello"] * max_words_supported)
|
||||
|
||||
# Act
|
||||
try:
|
||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||
await image_search.query(
|
||||
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||
)
|
||||
# Assert
|
||||
except RuntimeError as e:
|
||||
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
||||
assert False, f"Query length exceeds max tokens supported by model\n"
|
||||
assert f"Find Images by Text: {truncated_query}" in caplog.text, "Query not truncated"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
# Arrange
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
output_directory = resolve_absolute_path(web_directory)
|
||||
image_directory = content_config.image.input_directories[0]
|
||||
|
||||
query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
|
||||
expected_image_path = f"{image_directory.joinpath('kitten_park.jpg')}"
|
||||
|
||||
# Act
|
||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||
hits = await image_search.query(
|
||||
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||
)
|
||||
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
content_index.image.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=1,
|
||||
)
|
||||
|
||||
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
|
||||
actual_image = Image.open(actual_image_path)
|
||||
expected_image = Image.open(expected_image_path)
|
||||
|
||||
# Assert
|
||||
# Ensure file search triggered instead of query with file path as string
|
||||
assert (
|
||||
f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text
|
||||
), "File search not triggered"
|
||||
# Ensure the correct image is returned
|
||||
assert expected_image == actual_image, "Incorrect image returned by file search"
|
||||
|
||||
# Cleanup
|
||||
# Delete the image files copied to results directory
|
||||
actual_image_path.unlink()
|
Loading…
Reference in a new issue