mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Merge branch 'master' of github.com:debanjum/khoj
This commit is contained in:
commit
ba47f2ab39
12 changed files with 361 additions and 196 deletions
|
@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||||
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
from khoj.utils.config import (
|
||||||
|
ContentIndex,
|
||||||
|
SearchType,
|
||||||
|
SearchModels,
|
||||||
|
ProcessorConfigModel,
|
||||||
|
ConversationProcessorConfigModel,
|
||||||
|
)
|
||||||
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
||||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig
|
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
@ -49,11 +55,27 @@ def configure_server(args, required=False):
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
state.processor_config = configure_processor(args.config.processor)
|
state.processor_config = configure_processor(args.config.processor)
|
||||||
|
|
||||||
# Initialize the search type and model from Config
|
# Initialize Search Models from Config
|
||||||
state.search_index_lock.acquire()
|
try:
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.search_index_lock.acquire()
|
||||||
state.model = configure_search(state.model, state.config, args.regenerate)
|
state.SearchType = configure_search_types(state.config)
|
||||||
state.search_index_lock.release()
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"🚨 Error configuring search models on app load: {e}")
|
||||||
|
finally:
|
||||||
|
state.search_index_lock.release()
|
||||||
|
|
||||||
|
# Initialize Content from Config
|
||||||
|
if state.search_models:
|
||||||
|
try:
|
||||||
|
state.search_index_lock.acquire()
|
||||||
|
state.content_index = configure_content(
|
||||||
|
state.content_index, state.config.content_type, state.search_models, args.regenerate
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"🚨 Error configuring content index on app load: {e}")
|
||||||
|
finally:
|
||||||
|
state.search_index_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def configure_routes(app):
|
def configure_routes(app):
|
||||||
|
@ -72,10 +94,16 @@ if not state.demo:
|
||||||
|
|
||||||
@schedule.repeat(schedule.every(61).minutes)
|
@schedule.repeat(schedule.every(61).minutes)
|
||||||
def update_search_index():
|
def update_search_index():
|
||||||
state.search_index_lock.acquire()
|
try:
|
||||||
state.model = configure_search(state.model, state.config, regenerate=False)
|
state.search_index_lock.acquire()
|
||||||
state.search_index_lock.release()
|
state.content_index = configure_content(
|
||||||
logger.info("📬 Search index updated via Scheduler")
|
state.content_index, state.config.content_type, state.search_models, regenerate=False
|
||||||
|
)
|
||||||
|
logger.info("📬 Content index updated via Scheduler")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
|
||||||
|
finally:
|
||||||
|
state.search_index_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def configure_search_types(config: FullConfig):
|
def configure_search_types(config: FullConfig):
|
||||||
|
@ -90,94 +118,116 @@ def configure_search_types(config: FullConfig):
|
||||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||||
|
|
||||||
|
|
||||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
|
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
|
||||||
if config is None or config.content_type is None or config.search_type is None:
|
# Run Validation Checks
|
||||||
logger.warning("🚨 No Content or Search type is configured.")
|
if search_config is None:
|
||||||
return
|
logger.warning("🚨 No Search type is configured.")
|
||||||
|
return None
|
||||||
|
if search_models is None:
|
||||||
|
search_models = SearchModels()
|
||||||
|
|
||||||
if model is None:
|
# Initialize Search Models
|
||||||
model = SearchModels()
|
if search_config.asymmetric:
|
||||||
|
logger.info("🔍 📜 Setting up text search model")
|
||||||
|
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
|
|
||||||
|
if search_config.image:
|
||||||
|
logger.info("🔍 🌄 Setting up image search model")
|
||||||
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
|
||||||
|
return search_models
|
||||||
|
|
||||||
|
|
||||||
|
def configure_content(
|
||||||
|
content_index: Optional[ContentIndex],
|
||||||
|
content_config: Optional[ContentConfig],
|
||||||
|
search_models: SearchModels,
|
||||||
|
regenerate: bool,
|
||||||
|
t: Optional[state.SearchType] = None,
|
||||||
|
) -> Optional[ContentIndex]:
|
||||||
|
# Run Validation Checks
|
||||||
|
if content_config is None:
|
||||||
|
logger.warning("🚨 No Content type is configured.")
|
||||||
|
return None
|
||||||
|
if content_index is None:
|
||||||
|
content_index = ContentIndex()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
|
if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
|
||||||
logger.info("🦄 Setting up search for orgmode notes")
|
logger.info("🦄 Setting up search for orgmode notes")
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
model.org_search = text_search.setup(
|
content_index.org = text_search.setup(
|
||||||
OrgToJsonl,
|
OrgToJsonl,
|
||||||
config.content_type.org,
|
content_config.org,
|
||||||
search_config=config.search_type.asymmetric,
|
search_models.text_search.bi_encoder,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (
|
if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
|
||||||
(t == state.SearchType.Markdown or t == None)
|
|
||||||
and config.content_type.markdown
|
|
||||||
and config.search_type.asymmetric
|
|
||||||
):
|
|
||||||
logger.info("💎 Setting up search for markdown notes")
|
logger.info("💎 Setting up search for markdown notes")
|
||||||
# Extract Entries, Generate Markdown Embeddings
|
# Extract Entries, Generate Markdown Embeddings
|
||||||
model.markdown_search = text_search.setup(
|
content_index.markdown = text_search.setup(
|
||||||
MarkdownToJsonl,
|
MarkdownToJsonl,
|
||||||
config.content_type.markdown,
|
content_config.markdown,
|
||||||
search_config=config.search_type.asymmetric,
|
search_models.text_search.bi_encoder,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize PDF Search
|
# Initialize PDF Search
|
||||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
|
if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
|
||||||
logger.info("🖨️ Setting up search for pdf")
|
logger.info("🖨️ Setting up search for pdf")
|
||||||
# Extract Entries, Generate PDF Embeddings
|
# Extract Entries, Generate PDF Embeddings
|
||||||
model.pdf_search = text_search.setup(
|
content_index.pdf = text_search.setup(
|
||||||
PdfToJsonl,
|
PdfToJsonl,
|
||||||
config.content_type.pdf,
|
content_config.pdf,
|
||||||
search_config=config.search_type.asymmetric,
|
search_models.text_search.bi_encoder,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
|
if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
|
||||||
logger.info("🌄 Setting up search for images")
|
logger.info("🌄 Setting up search for images")
|
||||||
# Extract Entries, Generate Image Embeddings
|
# Extract Entries, Generate Image Embeddings
|
||||||
model.image_search = image_search.setup(
|
content_index.image = image_search.setup(
|
||||||
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
|
||||||
)
|
)
|
||||||
|
|
||||||
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
|
if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
|
||||||
logger.info("🐙 Setting up search for github")
|
logger.info("🐙 Setting up search for github")
|
||||||
# Extract Entries, Generate Github Embeddings
|
# Extract Entries, Generate Github Embeddings
|
||||||
model.github_search = text_search.setup(
|
content_index.github = text_search.setup(
|
||||||
GithubToJsonl,
|
GithubToJsonl,
|
||||||
config.content_type.github,
|
content_config.github,
|
||||||
search_config=config.search_type.asymmetric,
|
search_models.text_search.bi_encoder,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize External Plugin Search
|
# Initialize External Plugin Search
|
||||||
if (t == None or t in state.SearchType) and config.content_type.plugins:
|
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
|
||||||
logger.info("🔌 Setting up search for plugins")
|
logger.info("🔌 Setting up search for plugins")
|
||||||
model.plugin_search = {}
|
content_index.plugins = {}
|
||||||
for plugin_type, plugin_config in config.content_type.plugins.items():
|
for plugin_type, plugin_config in content_config.plugins.items():
|
||||||
model.plugin_search[plugin_type] = text_search.setup(
|
content_index.plugins[plugin_type] = text_search.setup(
|
||||||
JsonlToJsonl,
|
JsonlToJsonl,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
search_config=config.search_type.asymmetric,
|
search_models.text_search.bi_encoder,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Notion Search
|
# Initialize Notion Search
|
||||||
if (t == None or t in state.SearchType) and config.content_type.notion:
|
if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
|
||||||
logger.info("🔌 Setting up search for notion")
|
logger.info("🔌 Setting up search for notion")
|
||||||
model.notion_search = text_search.setup(
|
content_index.notion = text_search.setup(
|
||||||
NotionToJsonl,
|
NotionToJsonl,
|
||||||
config.content_type.notion,
|
content_config.notion,
|
||||||
search_config=config.search_type.asymmetric,
|
search_models.text_search.bi_encoder,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
)
|
)
|
||||||
|
@ -189,7 +239,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
# Invalidate Query Cache
|
# Invalidate Query Cache
|
||||||
state.query_cache = LRU()
|
state.query_cache = LRU()
|
||||||
|
|
||||||
return model
|
return content_index
|
||||||
|
|
||||||
|
|
||||||
def configure_processor(processor_config: ProcessorConfig):
|
def configure_processor(processor_config: ProcessorConfig):
|
||||||
|
|
|
@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_processor, configure_search
|
from khoj.configure import configure_content, configure_processor, configure_search
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
@ -163,17 +163,17 @@ if not state.demo:
|
||||||
state.config.content_type[content_type] = None
|
state.config.content_type[content_type] = None
|
||||||
|
|
||||||
if content_type == "github":
|
if content_type == "github":
|
||||||
state.model.github_search = None
|
state.content_index.github = None
|
||||||
elif content_type == "notion":
|
elif content_type == "notion":
|
||||||
state.model.notion_search = None
|
state.content_index.notion = None
|
||||||
elif content_type == "plugins":
|
elif content_type == "plugins":
|
||||||
state.model.plugin_search = None
|
state.content_index.plugins = None
|
||||||
elif content_type == "pdf":
|
elif content_type == "pdf":
|
||||||
state.model.pdf_search = None
|
state.content_index.pdf = None
|
||||||
elif content_type == "markdown":
|
elif content_type == "markdown":
|
||||||
state.model.markdown_search = None
|
state.content_index.markdown = None
|
||||||
elif content_type == "org":
|
elif content_type == "org":
|
||||||
state.model.org_search = None
|
state.content_index.org = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
save_config_to_file_updated_state()
|
save_config_to_file_updated_state()
|
||||||
|
@ -280,7 +280,7 @@ def get_config_types():
|
||||||
for search_type in SearchType
|
for search_type in SearchType
|
||||||
if (
|
if (
|
||||||
search_type.value in configured_content_types
|
search_type.value in configured_content_types
|
||||||
and getattr(state.model, f"{search_type.value}_search") is not None
|
and getattr(state.content_index, search_type.value) is not None
|
||||||
)
|
)
|
||||||
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
||||||
or search_type == SearchType.All
|
or search_type == SearchType.All
|
||||||
|
@ -308,7 +308,7 @@ async def search(
|
||||||
if q is None or q == "":
|
if q is None or q == "":
|
||||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||||
return results
|
return results
|
||||||
if not state.model or not any(state.model.__dict__.values()):
|
if not state.search_models or not any(state.search_models.__dict__.values()):
|
||||||
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -332,7 +332,7 @@ async def search(
|
||||||
encoded_asymmetric_query = None
|
encoded_asymmetric_query = None
|
||||||
if t == SearchType.All or t != SearchType.Image:
|
if t == SearchType.All or t != SearchType.Image:
|
||||||
text_search_models: List[TextSearchModel] = [
|
text_search_models: List[TextSearchModel] = [
|
||||||
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel)
|
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
|
||||||
]
|
]
|
||||||
if text_search_models:
|
if text_search_models:
|
||||||
with timer("Encoding query took", logger=logger):
|
with timer("Encoding query took", logger=logger):
|
||||||
|
@ -345,13 +345,14 @@ async def search(
|
||||||
)
|
)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
state.model.org_search,
|
state.search_models.text_search,
|
||||||
|
state.content_index.org,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
@ -359,13 +360,18 @@ async def search(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
|
if (
|
||||||
|
(t == SearchType.Markdown or t == SearchType.All)
|
||||||
|
and state.content_index.markdown
|
||||||
|
and state.search_models.text_search
|
||||||
|
):
|
||||||
# query markdown notes
|
# query markdown notes
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
state.model.markdown_search,
|
state.search_models.text_search,
|
||||||
|
state.content_index.markdown,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
@ -373,13 +379,18 @@ async def search(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search:
|
if (
|
||||||
|
(t == SearchType.Github or t == SearchType.All)
|
||||||
|
and state.content_index.github
|
||||||
|
and state.search_models.text_search
|
||||||
|
):
|
||||||
# query github issues
|
# query github issues
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
state.model.github_search,
|
state.search_models.text_search,
|
||||||
|
state.content_index.github,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
@ -387,13 +398,14 @@ async def search(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
|
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
|
||||||
# query pdf files
|
# query pdf files
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
state.model.pdf_search,
|
state.search_models.text_search,
|
||||||
|
state.content_index.pdf,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
@ -401,26 +413,38 @@ async def search(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (t == SearchType.Image) and state.model.image_search:
|
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
||||||
# query images
|
# query images
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
image_search.query,
|
image_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
results_count,
|
results_count,
|
||||||
state.model.image_search,
|
state.search_models.image_search,
|
||||||
|
state.content_index.image,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
|
if (
|
||||||
|
(t == SearchType.All or t in SearchType)
|
||||||
|
and state.content_index.plugins
|
||||||
|
and state.search_models.plugin_search
|
||||||
|
):
|
||||||
# query specified plugin type
|
# query specified plugin type
|
||||||
|
# Get plugin content, search model for specified search type, or the first one if none specified
|
||||||
|
plugin_search = state.search_models.plugin_search.get(t.value) or next(
|
||||||
|
iter(state.search_models.plugin_search.values())
|
||||||
|
)
|
||||||
|
plugin_content = state.content_index.plugins.get(t.value) or next(
|
||||||
|
iter(state.content_index.plugins.values())
|
||||||
|
)
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
# Get plugin search model for specified search type, or the first one if none specified
|
plugin_search,
|
||||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
plugin_content,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
@ -428,13 +452,18 @@ async def search(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search:
|
if (
|
||||||
|
(t == SearchType.Notion or t == SearchType.All)
|
||||||
|
and state.content_index.notion
|
||||||
|
and state.search_models.text_search
|
||||||
|
):
|
||||||
# query notion pages
|
# query notion pages
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
state.model.notion_search,
|
state.search_models.text_search,
|
||||||
|
state.content_index.notion,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
@ -445,13 +474,13 @@ async def search(
|
||||||
# Query across each requested content types in parallel
|
# Query across each requested content types in parallel
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
for search_future in concurrent.futures.as_completed(search_futures):
|
for search_future in concurrent.futures.as_completed(search_futures):
|
||||||
if t == SearchType.Image:
|
if t == SearchType.Image and state.content_index.image:
|
||||||
hits = await search_future.result()
|
hits = await search_future.result()
|
||||||
output_directory = constants.web_directory / "images"
|
output_directory = constants.web_directory / "images"
|
||||||
# Collate results
|
# Collate results
|
||||||
results += image_search.collate_results(
|
results += image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
image_names=state.model.image_search.image_names,
|
image_names=state.content_index.image.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
image_files_url="/static/images",
|
image_files_url="/static/images",
|
||||||
count=results_count,
|
count=results_count,
|
||||||
|
@ -498,7 +527,12 @@ def update(
|
||||||
try:
|
try:
|
||||||
state.search_index_lock.acquire()
|
state.search_index_lock.acquire()
|
||||||
try:
|
try:
|
||||||
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
|
if state.config and state.config.search_type:
|
||||||
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
|
if state.search_models:
|
||||||
|
state.content_index = configure_content(
|
||||||
|
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
|
@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
import torch
|
import torch
|
||||||
|
from khoj.utils import state
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
|
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
|
||||||
from khoj.utils.config import ImageSearchModel
|
from khoj.utils.config import ImageContent, ImageSearchModel
|
||||||
|
from khoj.utils.models import BaseEncoder
|
||||||
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
|
||||||
model_type=search_config.encoder_type or SentenceTransformer,
|
model_type=search_config.encoder_type or SentenceTransformer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return encoder
|
return ImageSearchModel(encoder)
|
||||||
|
|
||||||
|
|
||||||
def extract_entries(image_directories):
|
def extract_entries(image_directories):
|
||||||
|
@ -143,7 +145,9 @@ def extract_metadata(image_name):
|
||||||
return image_processed_metadata
|
return image_processed_metadata
|
||||||
|
|
||||||
|
|
||||||
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
async def query(
|
||||||
|
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
||||||
|
):
|
||||||
# Set query to image content if query is of form file:/path/to/file.png
|
# Set query to image content if query is of form file:/path/to/file.png
|
||||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||||
|
@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
|
||||||
|
|
||||||
# Now we encode the query (which can either be an image or a text string)
|
# Now we encode the query (which can either be an image or a text string)
|
||||||
with timer("Query Encode Time", logger):
|
with timer("Query Encode Time", logger):
|
||||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||||
with timer("Search Time", logger):
|
with timer("Search Time", logger):
|
||||||
image_hits = {
|
image_hits = {
|
||||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
||||||
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
|
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||||
if model.image_metadata_embeddings:
|
if content.image_metadata_embeddings:
|
||||||
with timer("Metadata Search Time", logger):
|
with timer("Metadata Search Time", logger):
|
||||||
metadata_hits = {
|
metadata_hits = {
|
||||||
result["corpus_id"]: result["score"]
|
result["corpus_id"]: result["score"]
|
||||||
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
|
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sum metadata, image scores of the highest ranked images
|
# Sum metadata, image scores of the highest ranked images
|
||||||
|
@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
|
||||||
# Initialize Model
|
|
||||||
encoder = initialize_model(search_config)
|
|
||||||
|
|
||||||
# Extract Entries
|
# Extract Entries
|
||||||
absolute_image_files, filtered_image_files = set(), set()
|
absolute_image_files, filtered_image_files = set(), set()
|
||||||
if config.input_directories:
|
if config.input_directories:
|
||||||
|
@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
||||||
use_xmp_metadata=config.use_xmp_metadata,
|
use_xmp_metadata=config.use_xmp_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
|
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||||
from khoj.utils.config import TextSearchModel
|
from khoj.utils.config import TextContent, TextSearchModel
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
|
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
|
||||||
from khoj.utils.jsonl import load_jsonl
|
from khoj.utils.jsonl import load_jsonl
|
||||||
|
@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
|
||||||
"Initialize model for semantic search on text"
|
"Initialize model for semantic search on text"
|
||||||
torch.set_num_threads(4)
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
# Number of entries we want to retrieve with the bi-encoder
|
|
||||||
top_k = 15
|
|
||||||
|
|
||||||
# If model directory is configured
|
# If model directory is configured
|
||||||
if search_config.model_directory:
|
if search_config.model_directory:
|
||||||
# Convert model directory to absolute path
|
# Convert model directory to absolute path
|
||||||
|
@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
|
||||||
device=f"{state.device}",
|
device=f"{state.device}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return bi_encoder, cross_encoder, top_k
|
return TextSearchModel(bi_encoder, cross_encoder)
|
||||||
|
|
||||||
|
|
||||||
def extract_entries(jsonl_file) -> List[Entry]:
|
def extract_entries(jsonl_file) -> List[Entry]:
|
||||||
|
@ -67,7 +64,7 @@ def compute_embeddings(
|
||||||
new_entries = []
|
new_entries = []
|
||||||
# Load pre-computed embeddings from file if exists and update them if required
|
# Load pre-computed embeddings from file if exists and update them if required
|
||||||
if embeddings_file.exists() and not regenerate:
|
if embeddings_file.exists() and not regenerate:
|
||||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||||
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
||||||
|
|
||||||
# Encode any new entries in the corpus and update corpus embeddings
|
# Encode any new entries in the corpus and update corpus embeddings
|
||||||
|
@ -104,17 +101,18 @@ def compute_embeddings(
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
raw_query: str,
|
raw_query: str,
|
||||||
model: TextSearchModel,
|
search_model: TextSearchModel,
|
||||||
|
content: TextContent,
|
||||||
question_embedding: Union[torch.Tensor, None] = None,
|
question_embedding: Union[torch.Tensor, None] = None,
|
||||||
rank_results: bool = False,
|
rank_results: bool = False,
|
||||||
score_threshold: float = -math.inf,
|
score_threshold: float = -math.inf,
|
||||||
dedupe: bool = True,
|
dedupe: bool = True,
|
||||||
) -> Tuple[List[dict], List[Entry]]:
|
) -> Tuple[List[dict], List[Entry]]:
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
|
||||||
|
|
||||||
# Filter query, entries and embeddings before semantic search
|
# Filter query, entries and embeddings before semantic search
|
||||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
|
||||||
|
|
||||||
# If no entries left after filtering, return empty results
|
# If no entries left after filtering, return empty results
|
||||||
if entries is None or len(entries) == 0:
|
if entries is None or len(entries) == 0:
|
||||||
|
@ -127,18 +125,17 @@ async def query(
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
if question_embedding is None:
|
if question_embedding is None:
|
||||||
with timer("Query Encode Time", logger, state.device):
|
with timer("Query Encode Time", logger, state.device):
|
||||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||||
question_embedding = util.normalize_embeddings(question_embedding)
|
question_embedding = util.normalize_embeddings(question_embedding)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
|
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
|
||||||
with timer("Search Time", logger, state.device):
|
with timer("Search Time", logger, state.device):
|
||||||
hits = util.semantic_search(
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
|
||||||
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results:
|
if rank_results and search_model.cross_encoder:
|
||||||
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
|
||||||
|
|
||||||
# Filter results by score threshold
|
# Filter results by score threshold
|
||||||
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
||||||
|
@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
||||||
def setup(
|
def setup(
|
||||||
text_to_jsonl: Type[TextToJsonl],
|
text_to_jsonl: Type[TextToJsonl],
|
||||||
config: TextConfigBase,
|
config: TextConfigBase,
|
||||||
search_config: TextSearchConfig,
|
bi_encoder: BaseEncoder,
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
filters: List[BaseFilter] = [],
|
filters: List[BaseFilter] = [],
|
||||||
) -> TextSearchModel:
|
) -> TextContent:
|
||||||
# Initialize Model
|
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
|
||||||
|
|
||||||
# Map notes in text files to (compressed) JSONL formatted file
|
# Map notes in text files to (compressed) JSONL formatted file
|
||||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||||
previous_entries = (
|
previous_entries = (
|
||||||
|
@ -192,7 +186,6 @@ def setup(
|
||||||
if is_none_or_empty(entries):
|
if is_none_or_empty(entries):
|
||||||
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
|
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
|
||||||
raise ValueError(f"No valid entries found in specified files: {config_params}")
|
raise ValueError(f"No valid entries found in specified files: {config_params}")
|
||||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||||
|
@ -203,7 +196,7 @@ def setup(
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
filter.load(entries, regenerate=regenerate)
|
filter.load(entries, regenerate=regenerate)
|
||||||
|
|
||||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
return TextContent(entries, corpus_embeddings, filters)
|
||||||
|
|
||||||
|
|
||||||
def apply_filters(
|
def apply_filters(
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, List, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
|
||||||
Conversation = "conversation"
|
Conversation = "conversation"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextContent:
|
||||||
|
entries: List[Entry]
|
||||||
|
corpus_embeddings: torch.Tensor
|
||||||
|
filters: List[BaseFilter]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageContent:
|
||||||
|
image_names: List[str]
|
||||||
|
image_embeddings: torch.Tensor
|
||||||
|
image_metadata_embeddings: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class TextSearchModel:
|
class TextSearchModel:
|
||||||
def __init__(
|
bi_encoder: BaseEncoder
|
||||||
self,
|
cross_encoder: Optional[CrossEncoder] = None
|
||||||
entries: List[Entry],
|
top_k: Optional[int] = 15
|
||||||
corpus_embeddings: torch.Tensor,
|
|
||||||
bi_encoder: BaseEncoder,
|
|
||||||
cross_encoder: CrossEncoder,
|
|
||||||
filters: List[BaseFilter],
|
|
||||||
top_k,
|
|
||||||
):
|
|
||||||
self.entries = entries
|
|
||||||
self.corpus_embeddings = corpus_embeddings
|
|
||||||
self.bi_encoder = bi_encoder
|
|
||||||
self.cross_encoder = cross_encoder
|
|
||||||
self.filters = filters
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class ImageSearchModel:
|
class ImageSearchModel:
|
||||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
|
image_encoder: BaseEncoder
|
||||||
self.image_encoder = image_encoder
|
|
||||||
self.image_names = image_names
|
|
||||||
self.image_embeddings = image_embeddings
|
@dataclass
|
||||||
self.image_metadata_embeddings = image_metadata_embeddings
|
class ContentIndex:
|
||||||
self.image_encoder = image_encoder
|
org: Optional[TextContent] = None
|
||||||
|
markdown: Optional[TextContent] = None
|
||||||
|
pdf: Optional[TextContent] = None
|
||||||
|
github: Optional[TextContent] = None
|
||||||
|
notion: Optional[TextContent] = None
|
||||||
|
image: Optional[ImageContent] = None
|
||||||
|
plugins: Optional[Dict[str, TextContent]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchModels:
|
class SearchModels:
|
||||||
org_search: Union[TextSearchModel, None] = None
|
text_search: Optional[TextSearchModel] = None
|
||||||
markdown_search: Union[TextSearchModel, None] = None
|
image_search: Optional[ImageSearchModel] = None
|
||||||
pdf_search: Union[TextSearchModel, None] = None
|
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
||||||
image_search: Union[ImageSearchModel, None] = None
|
|
||||||
github_search: Union[TextSearchModel, None] = None
|
|
||||||
notion_search: Union[TextSearchModel, None] = None
|
|
||||||
plugin_search: Union[Dict[str, TextSearchModel], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfigModel:
|
class ConversationProcessorConfigModel:
|
||||||
|
|
|
@ -20,7 +20,7 @@ from khoj.utils import constants
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# External Packages
|
# External Packages
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
|
@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
||||||
return merged_dict
|
return merged_dict
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
|
def load_model(
|
||||||
|
model_name: str, model_type, model_dir=None, device: str = None
|
||||||
|
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
|
||||||
"Load model from disk or huggingface"
|
"Load model from disk or huggingface"
|
||||||
# Construct model path
|
# Construct model path
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
|
||||||
|
|
||||||
|
|
||||||
class FullConfig(ConfigBase):
|
class FullConfig(ConfigBase):
|
||||||
content_type: Optional[ContentConfig]
|
content_type: Optional[ContentConfig] = None
|
||||||
search_type: Optional[SearchConfig]
|
search_type: Optional[SearchConfig] = None
|
||||||
processor: Optional[ProcessorConfig]
|
processor: Optional[ProcessorConfig] = None
|
||||||
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)
|
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,13 +9,14 @@ from pathlib import Path
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import config as utils_config
|
from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import SearchModels, ProcessorConfigModel
|
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
||||||
from khoj.utils.helpers import LRU
|
from khoj.utils.helpers import LRU
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
|
|
||||||
# Application Global State
|
# Application Global State
|
||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
model = SearchModels()
|
search_models = SearchModels()
|
||||||
|
content_index = ContentIndex()
|
||||||
processor_config = ProcessorConfigModel()
|
processor_config = ProcessorConfigModel()
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
verbose: int = 0
|
verbose: int = 0
|
||||||
|
|
|
@ -10,6 +10,7 @@ from khoj.main import app
|
||||||
from khoj.configure import configure_processor, configure_routes, configure_search_types
|
from khoj.configure import configure_processor, configure_routes, configure_search_types
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
|
from khoj.utils.config import ImageContent, SearchModels, TextContent
|
||||||
from khoj.utils.helpers import resolve_absolute_path
|
from khoj.utils.helpers import resolve_absolute_path
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
|
@ -41,35 +42,49 @@ def search_config() -> SearchConfig:
|
||||||
encoder="sentence-transformers/all-MiniLM-L6-v2",
|
encoder="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
model_directory=model_dir / "symmetric/",
|
model_directory=model_dir / "symmetric/",
|
||||||
|
encoder_type=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
search_config.asymmetric = TextSearchConfig(
|
search_config.asymmetric = TextSearchConfig(
|
||||||
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
model_directory=model_dir / "asymmetric/",
|
model_directory=model_dir / "asymmetric/",
|
||||||
|
encoder_type=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
search_config.image = ImageSearchConfig(
|
search_config.image = ImageSearchConfig(
|
||||||
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/"
|
encoder="sentence-transformers/clip-ViT-B-32",
|
||||||
|
model_directory=model_dir / "image/",
|
||||||
|
encoder_type=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return search_config
|
return search_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def content_config(tmp_path_factory, search_config: SearchConfig):
|
def search_models(search_config: SearchConfig):
|
||||||
|
search_models = SearchModels()
|
||||||
|
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
|
||||||
|
return search_models
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
|
||||||
content_dir = tmp_path_factory.mktemp("content")
|
content_dir = tmp_path_factory.mktemp("content")
|
||||||
|
|
||||||
# Generate Image Embeddings from Test Images
|
# Generate Image Embeddings from Test Images
|
||||||
content_config = ContentConfig()
|
content_config = ContentConfig()
|
||||||
content_config.image = ImageContentConfig(
|
content_config.image = ImageContentConfig(
|
||||||
|
input_filter=None,
|
||||||
input_directories=["tests/data/images"],
|
input_directories=["tests/data/images"],
|
||||||
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
|
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
use_xmp_metadata=False,
|
use_xmp_metadata=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_search.setup(content_config.image, search_config.image, regenerate=False)
|
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
|
||||||
|
|
||||||
# Generate Notes Embeddings from Test Notes
|
# Generate Notes Embeddings from Test Notes
|
||||||
content_config.org = TextContentConfig(
|
content_config.org = TextContentConfig(
|
||||||
|
@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||||
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||||
|
)
|
||||||
|
|
||||||
content_config.plugins = {
|
content_config.plugins = {
|
||||||
"plugin1": TextContentConfig(
|
"plugin1": TextContentConfig(
|
||||||
|
@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
|
||||||
|
|
||||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters
|
JsonlToJsonl,
|
||||||
|
content_config.plugins["plugin1"],
|
||||||
|
search_models.text_search.bi_encoder,
|
||||||
|
regenerate=False,
|
||||||
|
filters=filters,
|
||||||
)
|
)
|
||||||
|
|
||||||
return content_config
|
return content_config
|
||||||
|
@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
|
||||||
|
|
||||||
# Index Markdown Content for Search
|
# Index Markdown Content for Search
|
||||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||||
state.model.markdown_search = text_search.setup(
|
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters
|
state.content_index.markdown = text_search.setup(
|
||||||
|
MarkdownToJsonl,
|
||||||
|
md_content_config.markdown,
|
||||||
|
state.search_models.text_search.bi_encoder,
|
||||||
|
regenerate=False,
|
||||||
|
filters=filters,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
|
@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
|
|
||||||
# These lines help us Mock the Search models for these search types
|
# These lines help us Mock the Search models for these search types
|
||||||
state.model.org_search = {}
|
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
state.model.image_search = {}
|
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
state.content_index.org = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
|
||||||
|
)
|
||||||
|
state.content_index.image = image_search.setup(
|
||||||
|
content_config.image, state.search_models.image_search, regenerate=False
|
||||||
|
)
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
|
@ -11,7 +11,8 @@ from fastapi.testclient import TestClient
|
||||||
from khoj.main import app
|
from khoj.main import app
|
||||||
from khoj.configure import configure_routes, configure_search_types
|
from khoj.configure import configure_routes, configure_search_types
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.state import model, config
|
from khoj.utils.config import SearchModels
|
||||||
|
from khoj.utils.state import search_models, content_index, config
|
||||||
from khoj.search_type import text_search, image_search
|
from khoj.search_type import text_search, image_search
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
|
@ -143,7 +144,10 @@ def test_get_configured_types_with_no_content_config():
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
content_index.image = image_search.setup(
|
||||||
|
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||||
|
)
|
||||||
query_expected_image_pairs = [
|
query_expected_image_pairs = [
|
||||||
("kitten", "kitten_park.jpg"),
|
("kitten", "kitten_park.jpg"),
|
||||||
("a horse and dog on a leash", "horse_dog.jpg"),
|
("a horse and dog on a leash", "horse_dog.jpg"),
|
||||||
|
@ -166,7 +170,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
|
content_index.org = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||||
|
)
|
||||||
user_query = quote("How to git install application?")
|
user_query = quote("How to git install application?")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -183,8 +190,9 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
|
||||||
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter(), FileFilter()]
|
filters = [WordFilter(), FileFilter()]
|
||||||
model.org_search = text_search.setup(
|
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
content_index.org = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||||
)
|
)
|
||||||
user_query = quote('+"Emacs" file:"*.org"')
|
user_query = quote('+"Emacs" file:"*.org"')
|
||||||
|
|
||||||
|
@ -202,8 +210,9 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s
|
||||||
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter()]
|
filters = [WordFilter()]
|
||||||
model.org_search = text_search.setup(
|
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
content_index.org = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters
|
||||||
)
|
)
|
||||||
user_query = quote('How to git install application? +"Emacs"')
|
user_query = quote('How to git install application? +"Emacs"')
|
||||||
|
|
||||||
|
@ -221,8 +230,9 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig,
|
||||||
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter()]
|
filters = [WordFilter()]
|
||||||
model.org_search = text_search.setup(
|
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
content_index.org = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||||
)
|
)
|
||||||
user_query = quote('How to git install application? -"clone"')
|
user_query = quote('How to git install application? -"clone"')
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,10 @@ from PIL import Image
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import pytest
|
import pytest
|
||||||
|
from khoj.utils.config import SearchModels
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.state import model
|
from khoj.utils.state import content_index, search_models
|
||||||
from khoj.utils.constants import web_directory
|
from khoj.utils.constants import web_directory
|
||||||
from khoj.search_type import image_search
|
from khoj.search_type import image_search
|
||||||
from khoj.utils.helpers import resolve_absolute_path
|
from khoj.utils.helpers import resolve_absolute_path
|
||||||
|
@ -16,10 +17,12 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig):
|
def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||||
# Act
|
# Act
|
||||||
# Regenerate image search embeddings during image setup
|
# Regenerate image search embeddings during image setup
|
||||||
image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True)
|
image_search_model = image_search.setup(
|
||||||
|
content_config.image, search_models.image_search.image_encoder, regenerate=True
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(image_search_model.image_names) == 3
|
assert len(image_search_model.image_names) == 3
|
||||||
|
@ -54,8 +57,11 @@ def test_image_metadata(content_config: ContentConfig):
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
content_index.image = image_search.setup(
|
||||||
|
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||||
|
)
|
||||||
output_directory = resolve_absolute_path(web_directory)
|
output_directory = resolve_absolute_path(web_directory)
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
|
||||||
query_expected_image_pairs = [
|
query_expected_image_pairs = [
|
||||||
("kitten", "kitten_park.jpg"),
|
("kitten", "kitten_park.jpg"),
|
||||||
("horse and dog in a farm", "horse_dog.jpg"),
|
("horse and dog in a farm", "horse_dog.jpg"),
|
||||||
|
@ -64,11 +70,13 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
for query, expected_image_name in query_expected_image_pairs:
|
for query, expected_image_name in query_expected_image_pairs:
|
||||||
hits = await image_search.query(query, count=1, model=model.image_search)
|
hits = await image_search.query(
|
||||||
|
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||||
|
)
|
||||||
|
|
||||||
results = image_search.collate_results(
|
results = image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
model.image_search.image_names,
|
content_index.image.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
image_files_url="/static/images",
|
image_files_url="/static/images",
|
||||||
count=1,
|
count=1,
|
||||||
|
@ -90,7 +98,10 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
content_index.image = image_search.setup(
|
||||||
|
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||||
|
)
|
||||||
max_words_supported = 10
|
max_words_supported = 10
|
||||||
query = " ".join(["hello"] * 100)
|
query = " ".join(["hello"] * 100)
|
||||||
truncated_query = " ".join(["hello"] * max_words_supported)
|
truncated_query = " ".join(["hello"] * max_words_supported)
|
||||||
|
@ -98,7 +109,9 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
|
||||||
# Act
|
# Act
|
||||||
try:
|
try:
|
||||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||||
await image_search.query(query, count=1, model=model.image_search)
|
await image_search.query(
|
||||||
|
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||||
|
)
|
||||||
# Assert
|
# Assert
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
||||||
|
@ -110,8 +123,11 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
content_index.image = image_search.setup(
|
||||||
|
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||||
|
)
|
||||||
output_directory = resolve_absolute_path(web_directory)
|
output_directory = resolve_absolute_path(web_directory)
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
|
||||||
image_directory = content_config.image.input_directories[0]
|
image_directory = content_config.image.input_directories[0]
|
||||||
|
|
||||||
query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
|
query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
|
||||||
|
@ -119,11 +135,13 @@ async def test_image_search_by_filepath(content_config: ContentConfig, search_co
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||||
hits = await image_search.query(query, count=1, model=model.image_search)
|
hits = await image_search.query(
|
||||||
|
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||||
|
)
|
||||||
|
|
||||||
results = image_search.collate_results(
|
results = image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
model.image_search.image_names,
|
content_index.image.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
image_files_url="/static/images",
|
image_files_url="/static/images",
|
||||||
count=1,
|
count=1,
|
||||||
|
|
|
@ -5,9 +5,10 @@ import os
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import pytest
|
import pytest
|
||||||
|
from khoj.utils.config import SearchModels
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.state import model
|
from khoj.utils.state import content_index, search_models
|
||||||
from khoj.search_type import text_search
|
from khoj.search_type import text_search
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
|
@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error(
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
|
def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||||
# Act
|
# Act
|
||||||
# Regenerate notes embeddings during asymmetric setup
|
# Regenerate notes embeddings during asymmetric setup
|
||||||
notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
notes_model = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(notes_model.entries) == 10
|
assert len(notes_model.entries) == 10
|
||||||
|
@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
caplog.set_level(logging.INFO, logger="khoj")
|
caplog.set_level(logging.INFO, logger="khoj")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Generate initial notes embeddings during asymmetric setup
|
# Generate initial notes embeddings during asymmetric setup
|
||||||
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
|
||||||
initial_logs = caplog.text
|
initial_logs = caplog.text
|
||||||
caplog.clear() # Clear logs
|
caplog.clear() # Clear logs
|
||||||
|
|
||||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||||
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
|
||||||
final_logs = caplog.text
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||||
|
content_index.org = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||||
|
)
|
||||||
query = "How to git install application?"
|
query = "How to git install application?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True)
|
hits, entries = await text_search.query(
|
||||||
|
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
|
||||||
|
)
|
||||||
|
|
||||||
results = text_search.collate_results(hits, entries, count=1)
|
results = text_search.collate_results(hits, entries, count=1)
|
||||||
|
|
||||||
|
@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
|
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||||
# Arrange
|
# Arrange
|
||||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
|
@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
||||||
# Act
|
# Act
|
||||||
# reload embeddings, entries, notes model after adding new org-mode file
|
# reload embeddings, entries, notes model after adding new org-mode file
|
||||||
initial_notes_model = text_search.setup(
|
initial_notes_model = text_search.setup(
|
||||||
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False
|
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
|
def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||||
# Arrange
|
# Arrange
|
||||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
initial_notes_model = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||||
|
)
|
||||||
|
|
||||||
assert len(initial_notes_model.entries) == 10
|
assert len(initial_notes_model.entries) == 10
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||||
|
@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
||||||
|
|
||||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||||
regenerated_notes_model = text_search.setup(
|
regenerated_notes_model = text_search.setup(
|
||||||
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
||||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
initial_notes_model = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(regenerated_notes_model.entries) == 11
|
assert len(regenerated_notes_model.entries) == 11
|
||||||
|
@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
|
def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||||
# Arrange
|
# Arrange
|
||||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
initial_notes_model = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||||
|
)
|
||||||
|
|
||||||
assert len(initial_notes_model.entries) == 10
|
assert len(initial_notes_model.entries) == 10
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||||
|
@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
|
||||||
# Act
|
# Act
|
||||||
# update embeddings, entries with the newly added note
|
# update embeddings, entries with the newly added note
|
||||||
content_config.org.input_files = [f"{new_org_file}"]
|
content_config.org.input_files = [f"{new_org_file}"]
|
||||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
initial_notes_model = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify new entry added in updated embeddings, entries
|
# verify new entry added in updated embeddings, entries
|
||||||
|
@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||||
def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig):
|
def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels):
|
||||||
# Act
|
# Act
|
||||||
# Regenerate github embeddings to test asymmetric setup without caching
|
# Regenerate github embeddings to test asymmetric setup without caching
|
||||||
github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True)
|
github_model = text_search.setup(
|
||||||
|
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(github_model.entries) > 1
|
assert len(github_model.entries) > 1
|
||||||
|
|
Loading…
Add table
Reference in a new issue