Merge branch 'master' of github.com:debanjum/khoj

This commit is contained in:
sabaimran 2023-07-14 22:28:05 -07:00
commit ba47f2ab39
12 changed files with 361 additions and 196 deletions

View file

@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from khoj.utils.config import (
ContentIndex,
SearchType,
SearchModels,
ProcessorConfigModel,
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
from khoj.utils.rawconfig import FullConfig, ProcessorConfig from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
@ -49,11 +55,27 @@ def configure_server(args, required=False):
# Initialize Processor from Config # Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor) state.processor_config = configure_processor(args.config.processor)
# Initialize the search type and model from Config # Initialize Search Models from Config
state.search_index_lock.acquire() try:
state.SearchType = configure_search_types(state.config) state.search_index_lock.acquire()
state.model = configure_search(state.model, state.config, args.regenerate) state.SearchType = configure_search_types(state.config)
state.search_index_lock.release() state.search_models = configure_search(state.search_models, state.config.search_type)
except Exception as e:
logger.error(f"🚨 Error configuring search models on app load: {e}")
finally:
state.search_index_lock.release()
# Initialize Content from Config
if state.search_models:
try:
state.search_index_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, args.regenerate
)
except Exception as e:
logger.error(f"🚨 Error configuring content index on app load: {e}")
finally:
state.search_index_lock.release()
def configure_routes(app): def configure_routes(app):
@ -72,10 +94,16 @@ if not state.demo:
@schedule.repeat(schedule.every(61).minutes) @schedule.repeat(schedule.every(61).minutes)
def update_search_index(): def update_search_index():
state.search_index_lock.acquire() try:
state.model = configure_search(state.model, state.config, regenerate=False) state.search_index_lock.acquire()
state.search_index_lock.release() state.content_index = configure_content(
logger.info("📬 Search index updated via Scheduler") state.content_index, state.config.content_type, state.search_models, regenerate=False
)
logger.info("📬 Content index updated via Scheduler")
except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
finally:
state.search_index_lock.release()
def configure_search_types(config: FullConfig): def configure_search_types(config: FullConfig):
@ -90,94 +118,116 @@ def configure_search_types(config: FullConfig):
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None): def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
if config is None or config.content_type is None or config.search_type is None: # Run Validation Checks
logger.warning("🚨 No Content or Search type is configured.") if search_config is None:
return logger.warning("🚨 No Search type is configured.")
return None
if search_models is None:
search_models = SearchModels()
if model is None: # Initialize Search Models
model = SearchModels() if search_config.asymmetric:
logger.info("🔍 📜 Setting up text search model")
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
if search_config.image:
logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
def configure_content(
content_index: Optional[ContentIndex],
content_config: Optional[ContentConfig],
search_models: SearchModels,
regenerate: bool,
t: Optional[state.SearchType] = None,
) -> Optional[ContentIndex]:
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content type is configured.")
return None
if content_index is None:
content_index = ContentIndex()
try: try:
# Initialize Org Notes Search # Initialize Org Notes Search
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric: if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
logger.info("🦄 Setting up search for orgmode notes") logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
model.org_search = text_search.setup( content_index.org = text_search.setup(
OrgToJsonl, OrgToJsonl,
config.content_type.org, content_config.org,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize Markdown Search # Initialize Markdown Search
if ( if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
(t == state.SearchType.Markdown or t == None)
and config.content_type.markdown
and config.search_type.asymmetric
):
logger.info("💎 Setting up search for markdown notes") logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup( content_index.markdown = text_search.setup(
MarkdownToJsonl, MarkdownToJsonl,
config.content_type.markdown, content_config.markdown,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize PDF Search # Initialize PDF Search
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric: if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
logger.info("🖨️ Setting up search for pdf") logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings # Extract Entries, Generate PDF Embeddings
model.pdf_search = text_search.setup( content_index.pdf = text_search.setup(
PdfToJsonl, PdfToJsonl,
config.content_type.pdf, content_config.pdf,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize Image Search # Initialize Image Search
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image: if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
logger.info("🌄 Setting up search for images") logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup( content_index.image = image_search.setup(
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
) )
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric: if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
logger.info("🐙 Setting up search for github") logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings # Extract Entries, Generate Github Embeddings
model.github_search = text_search.setup( content_index.github = text_search.setup(
GithubToJsonl, GithubToJsonl,
config.content_type.github, content_config.github,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize External Plugin Search # Initialize External Plugin Search
if (t == None or t in state.SearchType) and config.content_type.plugins: if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins") logger.info("🔌 Setting up search for plugins")
model.plugin_search = {} content_index.plugins = {}
for plugin_type, plugin_config in config.content_type.plugins.items(): for plugin_type, plugin_config in content_config.plugins.items():
model.plugin_search[plugin_type] = text_search.setup( content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl, JsonlToJsonl,
plugin_config, plugin_config,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize Notion Search # Initialize Notion Search
if (t == None or t in state.SearchType) and config.content_type.notion: if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
logger.info("🔌 Setting up search for notion") logger.info("🔌 Setting up search for notion")
model.notion_search = text_search.setup( content_index.notion = text_search.setup(
NotionToJsonl, NotionToJsonl,
config.content_type.notion, content_config.notion,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
@ -189,7 +239,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
# Invalidate Query Cache # Invalidate Query Cache
state.query_cache = LRU() state.query_cache = LRU()
return model return content_index
def configure_processor(processor_config: ProcessorConfig): def configure_processor(processor_config: ProcessorConfig):

View file

@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
from sentence_transformers import util from sentence_transformers import util
# Internal Packages # Internal Packages
from khoj.configure import configure_processor, configure_search from khoj.configure import configure_content, configure_processor, configure_search
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
@ -163,17 +163,17 @@ if not state.demo:
state.config.content_type[content_type] = None state.config.content_type[content_type] = None
if content_type == "github": if content_type == "github":
state.model.github_search = None state.content_index.github = None
elif content_type == "notion": elif content_type == "notion":
state.model.notion_search = None state.content_index.notion = None
elif content_type == "plugins": elif content_type == "plugins":
state.model.plugin_search = None state.content_index.plugins = None
elif content_type == "pdf": elif content_type == "pdf":
state.model.pdf_search = None state.content_index.pdf = None
elif content_type == "markdown": elif content_type == "markdown":
state.model.markdown_search = None state.content_index.markdown = None
elif content_type == "org": elif content_type == "org":
state.model.org_search = None state.content_index.org = None
try: try:
save_config_to_file_updated_state() save_config_to_file_updated_state()
@ -280,7 +280,7 @@ def get_config_types():
for search_type in SearchType for search_type in SearchType
if ( if (
search_type.value in configured_content_types search_type.value in configured_content_types
and getattr(state.model, f"{search_type.value}_search") is not None and getattr(state.content_index, search_type.value) is not None
) )
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
or search_type == SearchType.All or search_type == SearchType.All
@ -308,7 +308,7 @@ async def search(
if q is None or q == "": if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search") logger.warning(f"No query param (q) passed in API call to initiate search")
return results return results
if not state.model or not any(state.model.__dict__.values()): if not state.search_models or not any(state.search_models.__dict__.values()):
logger.warning(f"No search models loaded. Configure a search model before initiating search") logger.warning(f"No search models loaded. Configure a search model before initiating search")
return results return results
@ -332,7 +332,7 @@ async def search(
encoded_asymmetric_query = None encoded_asymmetric_query = None
if t == SearchType.All or t != SearchType.Image: if t == SearchType.All or t != SearchType.Image:
text_search_models: List[TextSearchModel] = [ text_search_models: List[TextSearchModel] = [
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel) model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
] ]
if text_search_models: if text_search_models:
with timer("Encoding query took", logger=logger): with timer("Encoding query took", logger=logger):
@ -345,13 +345,14 @@ async def search(
) )
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search: if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
# query org-mode notes # query org-mode notes
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.org_search, state.search_models.text_search,
state.content_index.org,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@ -359,13 +360,18 @@ async def search(
) )
] ]
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search: if (
(t == SearchType.Markdown or t == SearchType.All)
and state.content_index.markdown
and state.search_models.text_search
):
# query markdown notes # query markdown notes
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.markdown_search, state.search_models.text_search,
state.content_index.markdown,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@ -373,13 +379,18 @@ async def search(
) )
] ]
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search: if (
(t == SearchType.Github or t == SearchType.All)
and state.content_index.github
and state.search_models.text_search
):
# query github issues # query github issues
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.github_search, state.search_models.text_search,
state.content_index.github,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@ -387,13 +398,14 @@ async def search(
) )
] ]
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search: if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
# query pdf files # query pdf files
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.pdf_search, state.search_models.text_search,
state.content_index.pdf,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@ -401,26 +413,38 @@ async def search(
) )
] ]
if (t == SearchType.Image) and state.model.image_search: if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images # query images
search_futures += [ search_futures += [
executor.submit( executor.submit(
image_search.query, image_search.query,
user_query, user_query,
results_count, results_count,
state.model.image_search, state.search_models.image_search,
state.content_index.image,
score_threshold=score_threshold, score_threshold=score_threshold,
) )
] ]
if (t == SearchType.All or t in SearchType) and state.model.plugin_search: if (
(t == SearchType.All or t in SearchType)
and state.content_index.plugins
and state.search_models.plugin_search
):
# query specified plugin type # query specified plugin type
# Get plugin content, search model for specified search type, or the first one if none specified
plugin_search = state.search_models.plugin_search.get(t.value) or next(
iter(state.search_models.plugin_search.values())
)
plugin_content = state.content_index.plugins.get(t.value) or next(
iter(state.content_index.plugins.values())
)
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
# Get plugin search model for specified search type, or the first one if none specified plugin_search,
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), plugin_content,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@ -428,13 +452,18 @@ async def search(
) )
] ]
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search: if (
(t == SearchType.Notion or t == SearchType.All)
and state.content_index.notion
and state.search_models.text_search
):
# query notion pages # query notion pages
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.notion_search, state.search_models.text_search,
state.content_index.notion,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@ -445,13 +474,13 @@ async def search(
# Query across each requested content types in parallel # Query across each requested content types in parallel
with timer("Query took", logger): with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures): for search_future in concurrent.futures.as_completed(search_futures):
if t == SearchType.Image: if t == SearchType.Image and state.content_index.image:
hits = await search_future.result() hits = await search_future.result()
output_directory = constants.web_directory / "images" output_directory = constants.web_directory / "images"
# Collate results # Collate results
results += image_search.collate_results( results += image_search.collate_results(
hits, hits,
image_names=state.model.image_search.image_names, image_names=state.content_index.image.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url="/static/images", image_files_url="/static/images",
count=results_count, count=results_count,
@ -498,7 +527,12 @@ def update(
try: try:
state.search_index_lock.acquire() state.search_index_lock.acquire()
try: try:
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t) if state.config and state.config.search_type:
state.search_models = configure_search(state.search_models, state.config.search_type)
if state.search_models:
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View file

@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
from PIL import Image from PIL import Image
from tqdm import trange from tqdm import trange
import torch import torch
from khoj.utils import state
# Internal Packages # Internal Packages
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
from khoj.utils.config import ImageSearchModel from khoj.utils.config import ImageContent, ImageSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
model_type=search_config.encoder_type or SentenceTransformer, model_type=search_config.encoder_type or SentenceTransformer,
) )
return encoder return ImageSearchModel(encoder)
def extract_entries(image_directories): def extract_entries(image_directories):
@ -143,7 +145,9 @@ def extract_metadata(image_name):
return image_processed_metadata return image_processed_metadata
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf): async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
):
# Set query to image content if query is of form file:/path/to/file.png # Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
# Now we encode the query (which can either be an image or a text string) # Now we encode the query (which can either be an image or a text string)
with timer("Query Encode Time", logger): with timer("Query Encode Time", logger):
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger): with timer("Search Time", logger):
image_hits = { image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]} result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0] for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
} }
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings: if content.image_metadata_embeddings:
with timer("Metadata Search Time", logger): with timer("Metadata Search Time", logger):
metadata_hits = { metadata_hits = {
result["corpus_id"]: result["score"] result["corpus_id"]: result["score"]
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0] for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
} }
# Sum metadata, image scores of the highest ranked images # Sum metadata, image scores of the highest ranked images
@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
return results return results
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
# Initialize Model
encoder = initialize_model(search_config)
# Extract Entries # Extract Entries
absolute_image_files, filtered_image_files = set(), set() absolute_image_files, filtered_image_files = set(), set()
if config.input_directories: if config.input_directories:
@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
use_xmp_metadata=config.use_xmp_metadata, use_xmp_metadata=config.use_xmp_metadata,
) )
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder) return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)

View file

@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
# Internal Packages # Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextSearchModel from khoj.utils.config import TextContent, TextSearchModel
from khoj.utils.models import BaseEncoder from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
from khoj.utils.jsonl import load_jsonl from khoj.utils.jsonl import load_jsonl
@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text" "Initialize model for semantic search on text"
torch.set_num_threads(4) torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder
top_k = 15
# If model directory is configured # If model directory is configured
if search_config.model_directory: if search_config.model_directory:
# Convert model directory to absolute path # Convert model directory to absolute path
@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
device=f"{state.device}", device=f"{state.device}",
) )
return bi_encoder, cross_encoder, top_k return TextSearchModel(bi_encoder, cross_encoder)
def extract_entries(jsonl_file) -> List[Entry]: def extract_entries(jsonl_file) -> List[Entry]:
@ -67,7 +64,7 @@ def compute_embeddings(
new_entries = [] new_entries = []
# Load pre-computed embeddings from file if exists and update them if required # Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate: if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device) corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}") logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings # Encode any new entries in the corpus and update corpus embeddings
@ -104,17 +101,18 @@ def compute_embeddings(
async def query( async def query(
raw_query: str, raw_query: str,
model: TextSearchModel, search_model: TextSearchModel,
content: TextContent,
question_embedding: Union[torch.Tensor, None] = None, question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False, rank_results: bool = False,
score_threshold: float = -math.inf, score_threshold: float = -math.inf,
dedupe: bool = True, dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]: ) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query" "Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters) query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
# If no entries left after filtering, return empty results # If no entries left after filtering, return empty results
if entries is None or len(entries) == 0: if entries is None or len(entries) == 0:
@ -127,18 +125,17 @@ async def query(
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
if question_embedding is None: if question_embedding is None:
with timer("Query Encode Time", logger, state.device): with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query # Find relevant entries for the query
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
with timer("Search Time", logger, state.device): with timer("Search Time", logger, state.device):
hits = util.semantic_search( hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
)[0]
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results: if rank_results and search_model.cross_encoder:
hits = cross_encoder_score(model.cross_encoder, query, entries, hits) hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
# Filter results by score threshold # Filter results by score threshold
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold] hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
def setup( def setup(
text_to_jsonl: Type[TextToJsonl], text_to_jsonl: Type[TextToJsonl],
config: TextConfigBase, config: TextConfigBase,
search_config: TextSearchConfig, bi_encoder: BaseEncoder,
regenerate: bool, regenerate: bool,
filters: List[BaseFilter] = [], filters: List[BaseFilter] = [],
) -> TextSearchModel: ) -> TextContent:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file # Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = ( previous_entries = (
@ -192,7 +186,6 @@ def setup(
if is_none_or_empty(entries): if is_none_or_empty(entries):
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()]) config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
raise ValueError(f"No valid entries found in specified files: {config_params}") raise ValueError(f"No valid entries found in specified files: {config_params}")
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings # Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file) config.embeddings_file = resolve_absolute_path(config.embeddings_file)
@ -203,7 +196,7 @@ def setup(
for filter in filters: for filter in filters:
filter.load(entries, regenerate=regenerate) filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) return TextContent(entries, corpus_embeddings, filters)
def apply_filters( def apply_filters(

View file

@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
# External Packages # External Packages
import torch import torch
@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
Conversation = "conversation" Conversation = "conversation"
@dataclass
class TextContent:
entries: List[Entry]
corpus_embeddings: torch.Tensor
filters: List[BaseFilter]
@dataclass
class ImageContent:
image_names: List[str]
image_embeddings: torch.Tensor
image_metadata_embeddings: torch.Tensor
@dataclass
class TextSearchModel: class TextSearchModel:
def __init__( bi_encoder: BaseEncoder
self, cross_encoder: Optional[CrossEncoder] = None
entries: List[Entry], top_k: Optional[int] = 15
corpus_embeddings: torch.Tensor,
bi_encoder: BaseEncoder,
cross_encoder: CrossEncoder,
filters: List[BaseFilter],
top_k,
):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder
self.filters = filters
self.top_k = top_k
@dataclass
class ImageSearchModel: class ImageSearchModel:
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder): image_encoder: BaseEncoder
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings @dataclass
self.image_metadata_embeddings = image_metadata_embeddings class ContentIndex:
self.image_encoder = image_encoder org: Optional[TextContent] = None
markdown: Optional[TextContent] = None
pdf: Optional[TextContent] = None
github: Optional[TextContent] = None
notion: Optional[TextContent] = None
image: Optional[ImageContent] = None
plugins: Optional[Dict[str, TextContent]] = None
@dataclass @dataclass
class SearchModels: class SearchModels:
org_search: Union[TextSearchModel, None] = None text_search: Optional[TextSearchModel] = None
markdown_search: Union[TextSearchModel, None] = None image_search: Optional[ImageSearchModel] = None
pdf_search: Union[TextSearchModel, None] = None plugin_search: Optional[Dict[str, TextSearchModel]] = None
image_search: Union[ImageSearchModel, None] = None
github_search: Union[TextSearchModel, None] = None
notion_search: Union[TextSearchModel, None] = None
plugin_search: Union[Dict[str, TextSearchModel], None] = None
class ConversationProcessorConfigModel: class ConversationProcessorConfigModel:

View file

@ -20,7 +20,7 @@ from khoj.utils import constants
if TYPE_CHECKING: if TYPE_CHECKING:
# External Packages # External Packages
from sentence_transformers import CrossEncoder from sentence_transformers import SentenceTransformer, CrossEncoder
# Internal Packages # Internal Packages
from khoj.utils.models import BaseEncoder from khoj.utils.models import BaseEncoder
@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]: def load_model(
model_name: str, model_type, model_dir=None, device: str = None
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
"Load model from disk or huggingface" "Load model from disk or huggingface"
# Construct model path # Construct model path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
class FullConfig(ConfigBase): class FullConfig(ConfigBase):
content_type: Optional[ContentConfig] content_type: Optional[ContentConfig] = None
search_type: Optional[SearchConfig] search_type: Optional[SearchConfig] = None
processor: Optional[ProcessorConfig] processor: Optional[ProcessorConfig] = None
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True) app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)

View file

@ -9,13 +9,14 @@ from pathlib import Path
# Internal Packages # Internal Packages
from khoj.utils import config as utils_config from khoj.utils import config as utils_config
from khoj.utils.config import SearchModels, ProcessorConfigModel from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
model = SearchModels() search_models = SearchModels()
content_index = ContentIndex()
processor_config = ProcessorConfigModel() processor_config = ProcessorConfigModel()
config_file: Path = None config_file: Path = None
verbose: int = 0 verbose: int = 0

View file

@ -10,6 +10,7 @@ from khoj.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types from khoj.configure import configure_processor, configure_routes, configure_search_types
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils.config import ImageContent, SearchModels, TextContent
from khoj.utils.helpers import resolve_absolute_path from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ContentConfig, ContentConfig,
@ -41,35 +42,49 @@ def search_config() -> SearchConfig:
encoder="sentence-transformers/all-MiniLM-L6-v2", encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "symmetric/", model_directory=model_dir / "symmetric/",
encoder_type=None,
) )
search_config.asymmetric = TextSearchConfig( search_config.asymmetric = TextSearchConfig(
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1", encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "asymmetric/", model_directory=model_dir / "asymmetric/",
encoder_type=None,
) )
search_config.image = ImageSearchConfig( search_config.image = ImageSearchConfig(
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/" encoder="sentence-transformers/clip-ViT-B-32",
model_directory=model_dir / "image/",
encoder_type=None,
) )
return search_config return search_config
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_config: SearchConfig): def search_models(search_config: SearchConfig):
search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
content_dir = tmp_path_factory.mktemp("content") content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images # Generate Image Embeddings from Test Images
content_config = ContentConfig() content_config = ContentConfig()
content_config.image = ImageContentConfig( content_config.image = ImageContentConfig(
input_filter=None,
input_directories=["tests/data/images"], input_directories=["tests/data/images"],
embeddings_file=content_dir.joinpath("image_embeddings.pt"), embeddings_file=content_dir.joinpath("image_embeddings.pt"),
batch_size=1, batch_size=1,
use_xmp_metadata=False, use_xmp_metadata=False,
) )
image_search.setup(content_config.image, search_config.image, regenerate=False) image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes # Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig( content_config.org = TextContentConfig(
@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
) )
filters = [DateFilter(), WordFilter(), FileFilter()] filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
content_config.plugins = { content_config.plugins = {
"plugin1": TextContentConfig( "plugin1": TextContentConfig(
@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
filters = [DateFilter(), WordFilter(), FileFilter()] filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup( text_search.setup(
JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters JsonlToJsonl,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
) )
return content_config return content_config
@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
# Index Markdown Content for Search # Index Markdown Content for Search
filters = [DateFilter(), WordFilter(), FileFilter()] filters = [DateFilter(), WordFilter(), FileFilter()]
state.model.markdown_search = text_search.setup( state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters state.content_index.markdown = text_search.setup(
MarkdownToJsonl,
md_content_config.markdown,
state.search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
) )
# Initialize Processor from Config # Initialize Processor from Config
@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types # These lines help us Mock the Search models for these search types
state.model.org_search = {} state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.model.image_search = {} state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
configure_routes(app) configure_routes(app)
return TestClient(app) return TestClient(app)

View file

@ -11,7 +11,8 @@ from fastapi.testclient import TestClient
from khoj.main import app from khoj.main import app
from khoj.configure import configure_routes, configure_search_types from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state from khoj.utils import state
from khoj.utils.state import model, config from khoj.utils.config import SearchModels
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@ -143,7 +144,10 @@ def test_get_configured_types_with_no_content_config():
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
query_expected_image_pairs = [ query_expected_image_pairs = [
("kitten", "kitten_park.jpg"), ("kitten", "kitten_park.jpg"),
("a horse and dog on a leash", "horse_dog.jpg"), ("a horse and dog on a leash", "horse_dog.jpg"),
@ -166,7 +170,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig): def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
user_query = quote("How to git install application?") user_query = quote("How to git install application?")
# Act # Act
@ -183,8 +190,9 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter(), FileFilter()] filters = [WordFilter(), FileFilter()]
model.org_search = text_search.setup( search_models.text_search = text_search.initialize_model(search_config.asymmetric)
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
) )
user_query = quote('+"Emacs" file:"*.org"') user_query = quote('+"Emacs" file:"*.org"')
@ -202,8 +210,9 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter()] filters = [WordFilter()]
model.org_search = text_search.setup( search_models.text_search = text_search.initialize_model(search_config.asymmetric)
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters
) )
user_query = quote('How to git install application? +"Emacs"') user_query = quote('How to git install application? +"Emacs"')
@ -221,8 +230,9 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig,
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter()] filters = [WordFilter()]
model.org_search = text_search.setup( search_models.text_search = text_search.initialize_model(search_config.asymmetric)
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
) )
user_query = quote('How to git install application? -"clone"') user_query = quote('How to git install application? -"clone"')

View file

@ -5,9 +5,10 @@ from PIL import Image
# External Packages # External Packages
import pytest import pytest
from khoj.utils.config import SearchModels
# Internal Packages # Internal Packages
from khoj.utils.state import model from khoj.utils.state import content_index, search_models
from khoj.utils.constants import web_directory from khoj.utils.constants import web_directory
from khoj.search_type import image_search from khoj.search_type import image_search
from khoj.utils.helpers import resolve_absolute_path from khoj.utils.helpers import resolve_absolute_path
@ -16,10 +17,12 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig): def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
# Act # Act
# Regenerate image search embeddings during image setup # Regenerate image search embeddings during image setup
image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True) image_search_model = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=True
)
# Assert # Assert
assert len(image_search_model.image_names) == 3 assert len(image_search_model.image_names) == 3
@ -54,8 +57,11 @@ def test_image_metadata(content_config: ContentConfig):
@pytest.mark.anyio @pytest.mark.anyio
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig): async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory) output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [ query_expected_image_pairs = [
("kitten", "kitten_park.jpg"), ("kitten", "kitten_park.jpg"),
("horse and dog in a farm", "horse_dog.jpg"), ("horse and dog in a farm", "horse_dog.jpg"),
@ -64,11 +70,13 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
# Act # Act
for query, expected_image_name in query_expected_image_pairs: for query, expected_image_name in query_expected_image_pairs:
hits = await image_search.query(query, count=1, model=model.image_search) hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
model.image_search.image_names, content_index.image.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url="/static/images", image_files_url="/static/images",
count=1, count=1,
@ -90,7 +98,10 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
@pytest.mark.anyio @pytest.mark.anyio
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog): async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange # Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
max_words_supported = 10 max_words_supported = 10
query = " ".join(["hello"] * 100) query = " ".join(["hello"] * 100)
truncated_query = " ".join(["hello"] * max_words_supported) truncated_query = " ".join(["hello"] * max_words_supported)
@ -98,7 +109,9 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
# Act # Act
try: try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
await image_search.query(query, count=1, model=model.image_search) await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
# Assert # Assert
except RuntimeError as e: except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
@ -110,8 +123,11 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
@pytest.mark.anyio @pytest.mark.anyio
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog): async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange # Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory) output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
image_directory = content_config.image.input_directories[0] image_directory = content_config.image.input_directories[0]
query = f"file:{image_directory.joinpath('kitten_park.jpg')}" query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
@ -119,11 +135,13 @@ async def test_image_search_by_filepath(content_config: ContentConfig, search_co
# Act # Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = await image_search.query(query, count=1, model=model.image_search) hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
model.image_search.image_names, content_index.image.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url="/static/images", image_files_url="/static/images",
count=1, count=1,

View file

@ -5,9 +5,10 @@ import os
# External Packages # External Packages
import pytest import pytest
from khoj.utils.config import SearchModels
# Internal Packages # Internal Packages
from khoj.utils.state import model from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error(
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
# Act # Act
# Regenerate notes embeddings during asymmetric setup # Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Assert # Assert
assert len(notes_model.entries) == 10 assert len(notes_model.entries) == 10
@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog): def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
# Arrange # Arrange
caplog.set_level(logging.INFO, logger="khoj") caplog.set_level(logging.INFO, logger="khoj")
# Act # Act
# Generate initial notes embeddings during asymmetric setup # Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
initial_logs = caplog.text initial_logs = caplog.text
caplog.clear() # Clear logs caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated # Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
final_logs = caplog.text final_logs = caplog.text
# Assert # Assert
@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
@pytest.mark.anyio @pytest.mark.anyio
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
query = "How to git install application?" query = "How to git install application?"
# Act # Act
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True) hits, entries = await text_search.query(
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1) results = text_search.collate_results(hits, entries, count=1)
@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig): def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
# Arrange # Arrange
# Insert org-mode entry with size exceeding max token limit to new org file # Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256 max_tokens = 256
@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act # Act
# reload embeddings, entries, notes model after adding new org-mode file # reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup( initial_notes_model = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
) )
# Assert # Assert
@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange # Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10
@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# regenerate notes jsonl, model embeddings and model to include entry from new file # regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup( regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
) )
# Act # Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert # Assert
assert len(regenerated_notes_model.entries) == 11 assert len(regenerated_notes_model.entries) == 11
@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange # Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10
@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act # Act
# update embeddings, entries with the newly added note # update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"] content_config.org.input_files = [f"{new_org_file}"]
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert # Assert
# verify new entry added in updated embeddings, entries # verify new entry added in updated embeddings, entries
@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels):
# Act # Act
# Regenerate github embeddings to test asymmetric setup without caching # Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True) github_model = text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
)
# Assert # Assert
assert len(github_model.entries) > 1 assert len(github_model.entries) > 1