Merge branch 'master' of github.com:debanjum/khoj

This commit is contained in:
sabaimran 2023-07-14 22:28:05 -07:00
commit ba47f2ab39
12 changed files with 361 additions and 196 deletions

View file

@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from khoj.utils.config import (
ContentIndex,
SearchType,
SearchModels,
ProcessorConfigModel,
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
from khoj.utils.rawconfig import FullConfig, ProcessorConfig
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
@ -49,11 +55,27 @@ def configure_server(args, required=False):
# Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor)
# Initialize the search type and model from Config
state.search_index_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.model = configure_search(state.model, state.config, args.regenerate)
state.search_index_lock.release()
# Initialize Search Models from Config
try:
state.search_index_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type)
except Exception as e:
logger.error(f"🚨 Error configuring search models on app load: {e}")
finally:
state.search_index_lock.release()
# Initialize Content from Config
if state.search_models:
try:
state.search_index_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, args.regenerate
)
except Exception as e:
logger.error(f"🚨 Error configuring content index on app load: {e}")
finally:
state.search_index_lock.release()
def configure_routes(app):
@ -72,10 +94,16 @@ if not state.demo:
@schedule.repeat(schedule.every(61).minutes)
def update_search_index():
state.search_index_lock.acquire()
state.model = configure_search(state.model, state.config, regenerate=False)
state.search_index_lock.release()
logger.info("📬 Search index updated via Scheduler")
try:
state.search_index_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=False
)
logger.info("📬 Content index updated via Scheduler")
except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
finally:
state.search_index_lock.release()
def configure_search_types(config: FullConfig):
@ -90,94 +118,116 @@ def configure_search_types(config: FullConfig):
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
if config is None or config.content_type is None or config.search_type is None:
logger.warning("🚨 No Content or Search type is configured.")
return
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
# Run Validation Checks
if search_config is None:
logger.warning("🚨 No Search type is configured.")
return None
if search_models is None:
search_models = SearchModels()
if model is None:
model = SearchModels()
# Initialize Search Models
if search_config.asymmetric:
logger.info("🔍 📜 Setting up text search model")
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
if search_config.image:
logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
def configure_content(
content_index: Optional[ContentIndex],
content_config: Optional[ContentConfig],
search_models: SearchModels,
regenerate: bool,
t: Optional[state.SearchType] = None,
) -> Optional[ContentIndex]:
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content type is configured.")
return None
if content_index is None:
content_index = ContentIndex()
try:
# Initialize Org Notes Search
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
model.org_search = text_search.setup(
content_index.org = text_search.setup(
OrgToJsonl,
config.content_type.org,
search_config=config.search_type.asymmetric,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Markdown Search
if (
(t == state.SearchType.Markdown or t == None)
and config.content_type.markdown
and config.search_type.asymmetric
):
if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(
content_index.markdown = text_search.setup(
MarkdownToJsonl,
config.content_type.markdown,
search_config=config.search_type.asymmetric,
content_config.markdown,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize PDF Search
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
model.pdf_search = text_search.setup(
content_index.pdf = text_search.setup(
PdfToJsonl,
config.content_type.pdf,
search_config=config.search_type.asymmetric,
content_config.pdf,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Image Search
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
)
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
model.github_search = text_search.setup(
content_index.github = text_search.setup(
GithubToJsonl,
config.content_type.github,
search_config=config.search_type.asymmetric,
content_config.github,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize External Plugin Search
if (t == None or t in state.SearchType) and config.content_type.plugins:
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins")
model.plugin_search = {}
for plugin_type, plugin_config in config.content_type.plugins.items():
model.plugin_search[plugin_type] = text_search.setup(
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl,
plugin_config,
search_config=config.search_type.asymmetric,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Notion Search
if (t == None or t in state.SearchType) and config.content_type.notion:
if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
logger.info("🔌 Setting up search for notion")
model.notion_search = text_search.setup(
content_index.notion = text_search.setup(
NotionToJsonl,
config.content_type.notion,
search_config=config.search_type.asymmetric,
content_config.notion,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
@ -189,7 +239,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
# Invalidate Query Cache
state.query_cache = LRU()
return model
return content_index
def configure_processor(processor_config: ProcessorConfig):

View file

@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
from sentence_transformers import util
# Internal Packages
from khoj.configure import configure_processor, configure_search
from khoj.configure import configure_content, configure_processor, configure_search
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@ -163,17 +163,17 @@ if not state.demo:
state.config.content_type[content_type] = None
if content_type == "github":
state.model.github_search = None
state.content_index.github = None
elif content_type == "notion":
state.model.notion_search = None
state.content_index.notion = None
elif content_type == "plugins":
state.model.plugin_search = None
state.content_index.plugins = None
elif content_type == "pdf":
state.model.pdf_search = None
state.content_index.pdf = None
elif content_type == "markdown":
state.model.markdown_search = None
state.content_index.markdown = None
elif content_type == "org":
state.model.org_search = None
state.content_index.org = None
try:
save_config_to_file_updated_state()
@ -280,7 +280,7 @@ def get_config_types():
for search_type in SearchType
if (
search_type.value in configured_content_types
and getattr(state.model, f"{search_type.value}_search") is not None
and getattr(state.content_index, search_type.value) is not None
)
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
or search_type == SearchType.All
@ -308,7 +308,7 @@ async def search(
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
if not state.model or not any(state.model.__dict__.values()):
if not state.search_models or not any(state.search_models.__dict__.values()):
logger.warning(f"No search models loaded. Configure a search model before initiating search")
return results
@ -332,7 +332,7 @@ async def search(
encoded_asymmetric_query = None
if t == SearchType.All or t != SearchType.Image:
text_search_models: List[TextSearchModel] = [
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel)
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
]
if text_search_models:
with timer("Encoding query took", logger=logger):
@ -345,13 +345,14 @@ async def search(
)
with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
# query org-mode notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.org_search,
state.search_models.text_search,
state.content_index.org,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -359,13 +360,18 @@ async def search(
)
]
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
if (
(t == SearchType.Markdown or t == SearchType.All)
and state.content_index.markdown
and state.search_models.text_search
):
# query markdown notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.markdown_search,
state.search_models.text_search,
state.content_index.markdown,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -373,13 +379,18 @@ async def search(
)
]
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search:
if (
(t == SearchType.Github or t == SearchType.All)
and state.content_index.github
and state.search_models.text_search
):
# query github issues
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.github_search,
state.search_models.text_search,
state.content_index.github,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -387,13 +398,14 @@ async def search(
)
]
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
# query pdf files
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.pdf_search,
state.search_models.text_search,
state.content_index.pdf,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -401,26 +413,38 @@ async def search(
)
]
if (t == SearchType.Image) and state.model.image_search:
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images
search_futures += [
executor.submit(
image_search.query,
user_query,
results_count,
state.model.image_search,
state.search_models.image_search,
state.content_index.image,
score_threshold=score_threshold,
)
]
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
if (
(t == SearchType.All or t in SearchType)
and state.content_index.plugins
and state.search_models.plugin_search
):
# query specified plugin type
# Get plugin content, search model for specified search type, or the first one if none specified
plugin_search = state.search_models.plugin_search.get(t.value) or next(
iter(state.search_models.plugin_search.values())
)
plugin_content = state.content_index.plugins.get(t.value) or next(
iter(state.content_index.plugins.values())
)
search_futures += [
executor.submit(
text_search.query,
user_query,
# Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
plugin_search,
plugin_content,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -428,13 +452,18 @@ async def search(
)
]
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search:
if (
(t == SearchType.Notion or t == SearchType.All)
and state.content_index.notion
and state.search_models.text_search
):
# query notion pages
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.notion_search,
state.search_models.text_search,
state.content_index.notion,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -445,13 +474,13 @@ async def search(
# Query across each requested content types in parallel
with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures):
if t == SearchType.Image:
if t == SearchType.Image and state.content_index.image:
hits = await search_future.result()
output_directory = constants.web_directory / "images"
# Collate results
results += image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
image_names=state.content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=results_count,
@ -498,7 +527,12 @@ def update(
try:
state.search_index_lock.acquire()
try:
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
if state.config and state.config.search_type:
state.search_models = configure_search(state.search_models, state.config.search_type)
if state.search_models:
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
)
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))

View file

@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
from PIL import Image
from tqdm import trange
import torch
from khoj.utils import state
# Internal Packages
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
from khoj.utils.config import ImageSearchModel
from khoj.utils.config import ImageContent, ImageSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
model_type=search_config.encoder_type or SentenceTransformer,
)
return encoder
return ImageSearchModel(encoder)
def extract_entries(image_directories):
@ -143,7 +145,9 @@ def extract_metadata(image_name):
return image_processed_metadata
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
# Now we encode the query (which can either be an image or a text string)
with timer("Query Encode Time", logger):
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
if content.image_metadata_embeddings:
with timer("Metadata Search Time", logger):
metadata_hits = {
result["corpus_id"]: result["score"]
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
}
# Sum metadata, image scores of the highest ranked images
@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
return results
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
# Initialize Model
encoder = initialize_model(search_config)
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
# Extract Entries
absolute_image_files, filtered_image_files = set(), set()
if config.input_directories:
@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
use_xmp_metadata=config.use_xmp_metadata,
)
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)

View file

@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
# Internal Packages
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextSearchModel
from khoj.utils.config import TextContent, TextSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
from khoj.utils.jsonl import load_jsonl
@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder
top_k = 15
# If model directory is configured
if search_config.model_directory:
# Convert model directory to absolute path
@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
device=f"{state.device}",
)
return bi_encoder, cross_encoder, top_k
return TextSearchModel(bi_encoder, cross_encoder)
def extract_entries(jsonl_file) -> List[Entry]:
@ -67,7 +64,7 @@ def compute_embeddings(
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings
@ -104,17 +101,18 @@ def compute_embeddings(
async def query(
raw_query: str,
model: TextSearchModel,
search_model: TextSearchModel,
content: TextContent,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
# Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
# If no entries left after filtering, return empty results
if entries is None or len(entries) == 0:
@ -127,18 +125,17 @@ async def query(
# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
with timer("Search Time", logger, state.device):
hits = util.semantic_search(
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
)[0]
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
# Score all retrieved entries using the cross-encoder
if rank_results:
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
if rank_results and search_model.cross_encoder:
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
# Filter results by score threshold
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
def setup(
text_to_jsonl: Type[TextToJsonl],
config: TextConfigBase,
search_config: TextSearchConfig,
bi_encoder: BaseEncoder,
regenerate: bool,
filters: List[BaseFilter] = [],
) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = (
@ -192,7 +186,6 @@ def setup(
if is_none_or_empty(entries):
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
raise ValueError(f"No valid entries found in specified files: {config_params}")
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
@ -203,7 +196,7 @@ def setup(
for filter in filters:
filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
return TextContent(entries, corpus_embeddings, filters)
def apply_filters(

View file

@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
# External Packages
import torch
@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
Conversation = "conversation"
@dataclass
class TextContent:
entries: List[Entry]
corpus_embeddings: torch.Tensor
filters: List[BaseFilter]
@dataclass
class ImageContent:
image_names: List[str]
image_embeddings: torch.Tensor
image_metadata_embeddings: torch.Tensor
@dataclass
class TextSearchModel:
def __init__(
self,
entries: List[Entry],
corpus_embeddings: torch.Tensor,
bi_encoder: BaseEncoder,
cross_encoder: CrossEncoder,
filters: List[BaseFilter],
top_k,
):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder
self.filters = filters
self.top_k = top_k
bi_encoder: BaseEncoder
cross_encoder: Optional[CrossEncoder] = None
top_k: Optional[int] = 15
@dataclass
class ImageSearchModel:
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings
self.image_metadata_embeddings = image_metadata_embeddings
self.image_encoder = image_encoder
image_encoder: BaseEncoder
@dataclass
class ContentIndex:
org: Optional[TextContent] = None
markdown: Optional[TextContent] = None
pdf: Optional[TextContent] = None
github: Optional[TextContent] = None
notion: Optional[TextContent] = None
image: Optional[ImageContent] = None
plugins: Optional[Dict[str, TextContent]] = None
@dataclass
class SearchModels:
org_search: Union[TextSearchModel, None] = None
markdown_search: Union[TextSearchModel, None] = None
pdf_search: Union[TextSearchModel, None] = None
image_search: Union[ImageSearchModel, None] = None
github_search: Union[TextSearchModel, None] = None
notion_search: Union[TextSearchModel, None] = None
plugin_search: Union[Dict[str, TextSearchModel], None] = None
text_search: Optional[TextSearchModel] = None
image_search: Optional[ImageSearchModel] = None
plugin_search: Optional[Dict[str, TextSearchModel]] = None
class ConversationProcessorConfigModel:

View file

@ -20,7 +20,7 @@ from khoj.utils import constants
if TYPE_CHECKING:
# External Packages
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer, CrossEncoder
# Internal Packages
from khoj.utils.models import BaseEncoder
@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
def load_model(
model_name: str, model_type, model_dir=None, device: str = None
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
"Load model from disk or huggingface"
# Construct model path
logger = logging.getLogger(__name__)

View file

@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
class FullConfig(ConfigBase):
content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig]
content_type: Optional[ContentConfig] = None
search_type: Optional[SearchConfig] = None
processor: Optional[ProcessorConfig] = None
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)

View file

@ -9,13 +9,14 @@ from pathlib import Path
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import SearchModels, ProcessorConfigModel
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig
# Application Global State
config = FullConfig()
model = SearchModels()
search_models = SearchModels()
content_index = ContentIndex()
processor_config = ProcessorConfigModel()
config_file: Path = None
verbose: int = 0

View file

@ -10,6 +10,7 @@ from khoj.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import ImageContent, SearchModels, TextContent
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import (
ContentConfig,
@ -41,35 +42,49 @@ def search_config() -> SearchConfig:
encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "symmetric/",
encoder_type=None,
)
search_config.asymmetric = TextSearchConfig(
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "asymmetric/",
encoder_type=None,
)
search_config.image = ImageSearchConfig(
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/"
encoder="sentence-transformers/clip-ViT-B-32",
model_directory=model_dir / "image/",
encoder_type=None,
)
return search_config
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_config: SearchConfig):
def search_models(search_config: SearchConfig):
search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images
content_config = ContentConfig()
content_config.image = ImageContentConfig(
input_filter=None,
input_directories=["tests/data/images"],
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
batch_size=1,
use_xmp_metadata=False,
)
image_search.setup(content_config.image, search_config.image, regenerate=False)
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
content_config.plugins = {
"plugin1": TextContentConfig(
@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters
JsonlToJsonl,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
return content_config
@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
# Index Markdown Content for Search
filters = [DateFilter(), WordFilter(), FileFilter()]
state.model.markdown_search = text_search.setup(
MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.content_index.markdown = text_search.setup(
MarkdownToJsonl,
md_content_config.markdown,
state.search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
# Initialize Processor from Config
@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types
state.model.org_search = {}
state.model.image_search = {}
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
configure_routes(app)
return TestClient(app)

View file

@ -11,7 +11,8 @@ from fastapi.testclient import TestClient
from khoj.main import app
from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.state import model, config
from khoj.utils.config import SearchModels
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@ -143,7 +144,10 @@ def test_get_configured_types_with_no_content_config():
# ----------------------------------------------------------------------------------------------------
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
query_expected_image_pairs = [
("kitten", "kitten_park.jpg"),
("a horse and dog on a leash", "horse_dog.jpg"),
@ -166,7 +170,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
user_query = quote("How to git install application?")
# Act
@ -183,8 +190,9 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter(), FileFilter()]
model.org_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
user_query = quote('+"Emacs" file:"*.org"')
@ -202,8 +210,9 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.org_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters
)
user_query = quote('How to git install application? +"Emacs"')
@ -221,8 +230,9 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig,
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.org_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
user_query = quote('How to git install application? -"clone"')

View file

@ -5,9 +5,10 @@ from PIL import Image
# External Packages
import pytest
from khoj.utils.config import SearchModels
# Internal Packages
from khoj.utils.state import model
from khoj.utils.state import content_index, search_models
from khoj.utils.constants import web_directory
from khoj.search_type import image_search
from khoj.utils.helpers import resolve_absolute_path
@ -16,10 +17,12 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
# Test
# ----------------------------------------------------------------------------------------------------
def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig):
def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate image search embeddings during image setup
image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True)
image_search_model = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=True
)
# Assert
assert len(image_search_model.image_names) == 3
@ -54,8 +57,11 @@ def test_image_metadata(content_config: ContentConfig):
@pytest.mark.anyio
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [
("kitten", "kitten_park.jpg"),
("horse and dog in a farm", "horse_dog.jpg"),
@ -64,11 +70,13 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
# Act
for query, expected_image_name in query_expected_image_pairs:
hits = await image_search.query(query, count=1, model=model.image_search)
hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results(
hits,
model.image_search.image_names,
content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=1,
@ -90,7 +98,10 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
@pytest.mark.anyio
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
max_words_supported = 10
query = " ".join(["hello"] * 100)
truncated_query = " ".join(["hello"] * max_words_supported)
@ -98,7 +109,9 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
# Act
try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
await image_search.query(query, count=1, model=model.image_search)
await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
# Assert
except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
@ -110,8 +123,11 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
@pytest.mark.anyio
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
image_directory = content_config.image.input_directories[0]
query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
@ -119,11 +135,13 @@ async def test_image_search_by_filepath(content_config: ContentConfig, search_co
# Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = await image_search.query(query, count=1, model=model.image_search)
hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results(
hits,
model.image_search.image_names,
content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=1,

View file

@ -5,9 +5,10 @@ import os
# External Packages
import pytest
from khoj.utils.config import SearchModels
# Internal Packages
from khoj.utils.state import model
from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error(
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Assert
assert len(notes_model.entries) == 10
@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ----------------------------------------------------------------------------------------------------
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog):
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
# Arrange
caplog.set_level(logging.INFO, logger="khoj")
# Act
# Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
final_logs = caplog.text
# Assert
@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
@pytest.mark.anyio
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
query = "How to git install application?"
# Act
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True)
hits, entries = await text_search.query(
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1)
@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S
# ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
assert len(regenerated_notes_model.entries) == 11
@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"]
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
# verify new entry added in updated embeddings, entries
@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig):
def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True)
github_model = text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
)
# Assert
assert len(github_model.entries) > 1