mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge branch 'master' of github.com:debanjum/khoj
This commit is contained in:
commit
ba47f2ab39
12 changed files with 361 additions and 196 deletions
|
@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
|||
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||
from khoj.utils.config import (
|
||||
ContentIndex,
|
||||
SearchType,
|
||||
SearchModels,
|
||||
ProcessorConfigModel,
|
||||
ConversationProcessorConfigModel,
|
||||
)
|
||||
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
|
@ -49,11 +55,27 @@ def configure_server(args, required=False):
|
|||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(args.config.processor)
|
||||
|
||||
# Initialize the search type and model from Config
|
||||
state.search_index_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.model = configure_search(state.model, state.config, args.regenerate)
|
||||
state.search_index_lock.release()
|
||||
# Initialize Search Models from Config
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error configuring search models on app load: {e}")
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, args.regenerate
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error configuring content index on app load: {e}")
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
|
||||
|
||||
def configure_routes(app):
|
||||
|
@ -72,10 +94,16 @@ if not state.demo:
|
|||
|
||||
@schedule.repeat(schedule.every(61).minutes)
|
||||
def update_search_index():
|
||||
state.search_index_lock.acquire()
|
||||
state.model = configure_search(state.model, state.config, regenerate=False)
|
||||
state.search_index_lock.release()
|
||||
logger.info("📬 Search index updated via Scheduler")
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=False
|
||||
)
|
||||
logger.info("📬 Content index updated via Scheduler")
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
|
||||
|
||||
def configure_search_types(config: FullConfig):
|
||||
|
@ -90,94 +118,116 @@ def configure_search_types(config: FullConfig):
|
|||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||
|
||||
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
|
||||
if config is None or config.content_type is None or config.search_type is None:
|
||||
logger.warning("🚨 No Content or Search type is configured.")
|
||||
return
|
||||
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
|
||||
# Run Validation Checks
|
||||
if search_config is None:
|
||||
logger.warning("🚨 No Search type is configured.")
|
||||
return None
|
||||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
||||
if model is None:
|
||||
model = SearchModels()
|
||||
# Initialize Search Models
|
||||
if search_config.asymmetric:
|
||||
logger.info("🔍 📜 Setting up text search model")
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
|
||||
if search_config.image:
|
||||
logger.info("🔍 🌄 Setting up image search model")
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
def configure_content(
|
||||
content_index: Optional[ContentIndex],
|
||||
content_config: Optional[ContentConfig],
|
||||
search_models: SearchModels,
|
||||
regenerate: bool,
|
||||
t: Optional[state.SearchType] = None,
|
||||
) -> Optional[ContentIndex]:
|
||||
# Run Validation Checks
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content type is configured.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.org_search = text_search.setup(
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
config.content_type.org,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (
|
||||
(t == state.SearchType.Markdown or t == None)
|
||||
and config.content_type.markdown
|
||||
and config.search_type.asymmetric
|
||||
):
|
||||
if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(
|
||||
content_index.markdown = text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
config.content_type.markdown,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.markdown,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize PDF Search
|
||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
model.pdf_search = text_search.setup(
|
||||
content_index.pdf = text_search.setup(
|
||||
PdfToJsonl,
|
||||
config.content_type.pdf,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.pdf,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
|
||||
if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
model.image_search = image_search.setup(
|
||||
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
|
||||
)
|
||||
|
||||
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
model.github_search = text_search.setup(
|
||||
content_index.github = text_search.setup(
|
||||
GithubToJsonl,
|
||||
config.content_type.github,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.github,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize External Plugin Search
|
||||
if (t == None or t in state.SearchType) and config.content_type.plugins:
|
||||
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for plugins")
|
||||
model.plugin_search = {}
|
||||
for plugin_type, plugin_config in config.content_type.plugins.items():
|
||||
model.plugin_search[plugin_type] = text_search.setup(
|
||||
content_index.plugins = {}
|
||||
for plugin_type, plugin_config in content_config.plugins.items():
|
||||
content_index.plugins[plugin_type] = text_search.setup(
|
||||
JsonlToJsonl,
|
||||
plugin_config,
|
||||
search_config=config.search_type.asymmetric,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Notion Search
|
||||
if (t == None or t in state.SearchType) and config.content_type.notion:
|
||||
if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for notion")
|
||||
model.notion_search = text_search.setup(
|
||||
content_index.notion = text_search.setup(
|
||||
NotionToJsonl,
|
||||
config.content_type.notion,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.notion,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
@ -189,7 +239,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
# Invalidate Query Cache
|
||||
state.query_cache = LRU()
|
||||
|
||||
return model
|
||||
return content_index
|
||||
|
||||
|
||||
def configure_processor(processor_config: ProcessorConfig):
|
||||
|
|
|
@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
|
|||
from sentence_transformers import util
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.configure import configure_content, configure_processor, configure_search
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
|
@ -163,17 +163,17 @@ if not state.demo:
|
|||
state.config.content_type[content_type] = None
|
||||
|
||||
if content_type == "github":
|
||||
state.model.github_search = None
|
||||
state.content_index.github = None
|
||||
elif content_type == "notion":
|
||||
state.model.notion_search = None
|
||||
state.content_index.notion = None
|
||||
elif content_type == "plugins":
|
||||
state.model.plugin_search = None
|
||||
state.content_index.plugins = None
|
||||
elif content_type == "pdf":
|
||||
state.model.pdf_search = None
|
||||
state.content_index.pdf = None
|
||||
elif content_type == "markdown":
|
||||
state.model.markdown_search = None
|
||||
state.content_index.markdown = None
|
||||
elif content_type == "org":
|
||||
state.model.org_search = None
|
||||
state.content_index.org = None
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
|
@ -280,7 +280,7 @@ def get_config_types():
|
|||
for search_type in SearchType
|
||||
if (
|
||||
search_type.value in configured_content_types
|
||||
and getattr(state.model, f"{search_type.value}_search") is not None
|
||||
and getattr(state.content_index, search_type.value) is not None
|
||||
)
|
||||
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
||||
or search_type == SearchType.All
|
||||
|
@ -308,7 +308,7 @@ async def search(
|
|||
if q is None or q == "":
|
||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
if not state.model or not any(state.model.__dict__.values()):
|
||||
if not state.search_models or not any(state.search_models.__dict__.values()):
|
||||
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
||||
return results
|
||||
|
||||
|
@ -332,7 +332,7 @@ async def search(
|
|||
encoded_asymmetric_query = None
|
||||
if t == SearchType.All or t != SearchType.Image:
|
||||
text_search_models: List[TextSearchModel] = [
|
||||
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel)
|
||||
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
|
||||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
|
@ -345,13 +345,14 @@ async def search(
|
|||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
|
||||
# query org-mode notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.org_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.org,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -359,13 +360,18 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
|
||||
if (
|
||||
(t == SearchType.Markdown or t == SearchType.All)
|
||||
and state.content_index.markdown
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query markdown notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.markdown_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.markdown,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -373,13 +379,18 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search:
|
||||
if (
|
||||
(t == SearchType.Github or t == SearchType.All)
|
||||
and state.content_index.github
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query github issues
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.github_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.github,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -387,13 +398,14 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
|
||||
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
|
||||
# query pdf files
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.pdf_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.pdf,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -401,26 +413,38 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Image) and state.model.image_search:
|
||||
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
||||
# query images
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
image_search.query,
|
||||
user_query,
|
||||
results_count,
|
||||
state.model.image_search,
|
||||
state.search_models.image_search,
|
||||
state.content_index.image,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
|
||||
if (
|
||||
(t == SearchType.All or t in SearchType)
|
||||
and state.content_index.plugins
|
||||
and state.search_models.plugin_search
|
||||
):
|
||||
# query specified plugin type
|
||||
# Get plugin content, search model for specified search type, or the first one if none specified
|
||||
plugin_search = state.search_models.plugin_search.get(t.value) or next(
|
||||
iter(state.search_models.plugin_search.values())
|
||||
)
|
||||
plugin_content = state.content_index.plugins.get(t.value) or next(
|
||||
iter(state.content_index.plugins.values())
|
||||
)
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
plugin_search,
|
||||
plugin_content,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -428,13 +452,18 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search:
|
||||
if (
|
||||
(t == SearchType.Notion or t == SearchType.All)
|
||||
and state.content_index.notion
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query notion pages
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.notion_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.notion,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -445,13 +474,13 @@ async def search(
|
|||
# Query across each requested content types in parallel
|
||||
with timer("Query took", logger):
|
||||
for search_future in concurrent.futures.as_completed(search_futures):
|
||||
if t == SearchType.Image:
|
||||
if t == SearchType.Image and state.content_index.image:
|
||||
hits = await search_future.result()
|
||||
output_directory = constants.web_directory / "images"
|
||||
# Collate results
|
||||
results += image_search.collate_results(
|
||||
hits,
|
||||
image_names=state.model.image_search.image_names,
|
||||
image_names=state.content_index.image.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=results_count,
|
||||
|
@ -498,7 +527,12 @@ def update(
|
|||
try:
|
||||
state.search_index_lock.acquire()
|
||||
try:
|
||||
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
|
||||
if state.config and state.config.search_type:
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
if state.search_models:
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
|
|||
from PIL import Image
|
||||
from tqdm import trange
|
||||
import torch
|
||||
from khoj.utils import state
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import ImageSearchModel
|
||||
from khoj.utils.config import ImageContent, ImageSearchModel
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||
|
||||
|
||||
|
@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
|
|||
model_type=search_config.encoder_type or SentenceTransformer,
|
||||
)
|
||||
|
||||
return encoder
|
||||
return ImageSearchModel(encoder)
|
||||
|
||||
|
||||
def extract_entries(image_directories):
|
||||
|
@ -143,7 +145,9 @@ def extract_metadata(image_name):
|
|||
return image_processed_metadata
|
||||
|
||||
|
||||
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
async def query(
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
||||
):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||
|
@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
|
|||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
with timer("Query Encode Time", logger):
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {
|
||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
||||
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
|
||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||
if model.image_metadata_embeddings:
|
||||
if content.image_metadata_embeddings:
|
||||
with timer("Metadata Search Time", logger):
|
||||
metadata_hits = {
|
||||
result["corpus_id"]: result["score"]
|
||||
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
|
||||
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
# Sum metadata, image scores of the highest ranked images
|
||||
|
@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
|||
return results
|
||||
|
||||
|
||||
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
||||
# Initialize Model
|
||||
encoder = initialize_model(search_config)
|
||||
|
||||
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
|
||||
# Extract Entries
|
||||
absolute_image_files, filtered_image_files = set(), set()
|
||||
if config.input_directories:
|
||||
|
@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
|||
use_xmp_metadata=config.use_xmp_metadata,
|
||||
)
|
||||
|
||||
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
|
||||
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)
|
||||
|
|
|
@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
|
|||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.config import TextContent, TextSearchModel
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
|
||||
from khoj.utils.jsonl import load_jsonl
|
||||
|
@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
|
|||
"Initialize model for semantic search on text"
|
||||
torch.set_num_threads(4)
|
||||
|
||||
# Number of entries we want to retrieve with the bi-encoder
|
||||
top_k = 15
|
||||
|
||||
# If model directory is configured
|
||||
if search_config.model_directory:
|
||||
# Convert model directory to absolute path
|
||||
|
@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
|
|||
device=f"{state.device}",
|
||||
)
|
||||
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
return TextSearchModel(bi_encoder, cross_encoder)
|
||||
|
||||
|
||||
def extract_entries(jsonl_file) -> List[Entry]:
|
||||
|
@ -67,7 +64,7 @@ def compute_embeddings(
|
|||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
||||
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
|
@ -104,17 +101,18 @@ def compute_embeddings(
|
|||
|
||||
async def query(
|
||||
raw_query: str,
|
||||
model: TextSearchModel,
|
||||
search_model: TextSearchModel,
|
||||
content: TextContent,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
dedupe: bool = True,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
|
||||
|
||||
# If no entries left after filtering, return empty results
|
||||
if entries is None or len(entries) == 0:
|
||||
|
@ -127,18 +125,17 @@ async def query(
|
|||
# Encode the query using the bi-encoder
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
|
||||
# Find relevant entries for the query
|
||||
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
|
||||
with timer("Search Time", logger, state.device):
|
||||
hits = util.semantic_search(
|
||||
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
|
||||
)[0]
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
||||
if rank_results and search_model.cross_encoder:
|
||||
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
||||
|
@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
|||
def setup(
|
||||
text_to_jsonl: Type[TextToJsonl],
|
||||
config: TextConfigBase,
|
||||
search_config: TextSearchConfig,
|
||||
bi_encoder: BaseEncoder,
|
||||
regenerate: bool,
|
||||
filters: List[BaseFilter] = [],
|
||||
) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
) -> TextContent:
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
previous_entries = (
|
||||
|
@ -192,7 +186,6 @@ def setup(
|
|||
if is_none_or_empty(entries):
|
||||
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
|
||||
raise ValueError(f"No valid entries found in specified files: {config_params}")
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
|
@ -203,7 +196,7 @@ def setup(
|
|||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
return TextContent(entries, corpus_embeddings, filters)
|
||||
|
||||
|
||||
def apply_filters(
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
|
|||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
|
|||
Conversation = "conversation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
entries: List[Entry]
|
||||
corpus_embeddings: torch.Tensor
|
||||
filters: List[BaseFilter]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
image_names: List[str]
|
||||
image_embeddings: torch.Tensor
|
||||
image_metadata_embeddings: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextSearchModel:
|
||||
def __init__(
|
||||
self,
|
||||
entries: List[Entry],
|
||||
corpus_embeddings: torch.Tensor,
|
||||
bi_encoder: BaseEncoder,
|
||||
cross_encoder: CrossEncoder,
|
||||
filters: List[BaseFilter],
|
||||
top_k,
|
||||
):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
self.cross_encoder = cross_encoder
|
||||
self.filters = filters
|
||||
self.top_k = top_k
|
||||
bi_encoder: BaseEncoder
|
||||
cross_encoder: Optional[CrossEncoder] = None
|
||||
top_k: Optional[int] = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageSearchModel:
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
|
||||
self.image_encoder = image_encoder
|
||||
self.image_names = image_names
|
||||
self.image_embeddings = image_embeddings
|
||||
self.image_metadata_embeddings = image_metadata_embeddings
|
||||
self.image_encoder = image_encoder
|
||||
image_encoder: BaseEncoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentIndex:
|
||||
org: Optional[TextContent] = None
|
||||
markdown: Optional[TextContent] = None
|
||||
pdf: Optional[TextContent] = None
|
||||
github: Optional[TextContent] = None
|
||||
notion: Optional[TextContent] = None
|
||||
image: Optional[ImageContent] = None
|
||||
plugins: Optional[Dict[str, TextContent]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchModels:
|
||||
org_search: Union[TextSearchModel, None] = None
|
||||
markdown_search: Union[TextSearchModel, None] = None
|
||||
pdf_search: Union[TextSearchModel, None] = None
|
||||
image_search: Union[ImageSearchModel, None] = None
|
||||
github_search: Union[TextSearchModel, None] = None
|
||||
notion_search: Union[TextSearchModel, None] = None
|
||||
plugin_search: Union[Dict[str, TextSearchModel], None] = None
|
||||
text_search: Optional[TextSearchModel] = None
|
||||
image_search: Optional[ImageSearchModel] = None
|
||||
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
||||
|
||||
|
||||
class ConversationProcessorConfigModel:
|
||||
|
|
|
@ -20,7 +20,7 @@ from khoj.utils import constants
|
|||
|
||||
if TYPE_CHECKING:
|
||||
# External Packages
|
||||
from sentence_transformers import CrossEncoder
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.models import BaseEncoder
|
||||
|
@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
|||
return merged_dict
|
||||
|
||||
|
||||
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
|
||||
def load_model(
|
||||
model_name: str, model_type, model_dir=None, device: str = None
|
||||
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
|
||||
"Load model from disk or huggingface"
|
||||
# Construct model path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
|
|||
|
||||
|
||||
class FullConfig(ConfigBase):
|
||||
content_type: Optional[ContentConfig]
|
||||
search_type: Optional[SearchConfig]
|
||||
processor: Optional[ProcessorConfig]
|
||||
content_type: Optional[ContentConfig] = None
|
||||
search_type: Optional[SearchConfig] = None
|
||||
processor: Optional[ProcessorConfig] = None
|
||||
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)
|
||||
|
||||
|
||||
|
|
|
@ -9,13 +9,14 @@ from pathlib import Path
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import SearchModels, ProcessorConfigModel
|
||||
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
||||
from khoj.utils.helpers import LRU
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
model = SearchModels()
|
||||
search_models = SearchModels()
|
||||
content_index = ContentIndex()
|
||||
processor_config = ProcessorConfigModel()
|
||||
config_file: Path = None
|
||||
verbose: int = 0
|
||||
|
|
|
@ -10,6 +10,7 @@ from khoj.main import app
|
|||
from khoj.configure import configure_processor, configure_routes, configure_search_types
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.config import ImageContent, SearchModels, TextContent
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
|
@ -41,35 +42,49 @@ def search_config() -> SearchConfig:
|
|||
encoder="sentence-transformers/all-MiniLM-L6-v2",
|
||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory=model_dir / "symmetric/",
|
||||
encoder_type=None,
|
||||
)
|
||||
|
||||
search_config.asymmetric = TextSearchConfig(
|
||||
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory=model_dir / "asymmetric/",
|
||||
encoder_type=None,
|
||||
)
|
||||
|
||||
search_config.image = ImageSearchConfig(
|
||||
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/"
|
||||
encoder="sentence-transformers/clip-ViT-B-32",
|
||||
model_directory=model_dir / "image/",
|
||||
encoder_type=None,
|
||||
)
|
||||
|
||||
return search_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def content_config(tmp_path_factory, search_config: SearchConfig):
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
|
||||
content_dir = tmp_path_factory.mktemp("content")
|
||||
|
||||
# Generate Image Embeddings from Test Images
|
||||
content_config = ContentConfig()
|
||||
content_config.image = ImageContentConfig(
|
||||
input_filter=None,
|
||||
input_directories=["tests/data/images"],
|
||||
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
|
||||
batch_size=1,
|
||||
use_xmp_metadata=False,
|
||||
)
|
||||
|
||||
image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
|
||||
|
||||
# Generate Notes Embeddings from Test Notes
|
||||
content_config.org = TextContentConfig(
|
||||
|
@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
|
|||
)
|
||||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||
text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||
)
|
||||
|
||||
content_config.plugins = {
|
||||
"plugin1": TextContentConfig(
|
||||
|
@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
|
|||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters
|
||||
JsonlToJsonl,
|
||||
content_config.plugins["plugin1"],
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
return content_config
|
||||
|
@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
|
|||
|
||||
# Index Markdown Content for Search
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
state.model.markdown_search = text_search.setup(
|
||||
MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.content_index.markdown = text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
md_content_config.markdown,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
# Initialize Processor from Config
|
||||
|
@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
|
|||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.model.org_search = {}
|
||||
state.model.image_search = {}
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
state.content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
state.content_index.image = image_search.setup(
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
|
||||
configure_routes(app)
|
||||
return TestClient(app)
|
||||
|
|
|
@ -11,7 +11,8 @@ from fastapi.testclient import TestClient
|
|||
from khoj.main import app
|
||||
from khoj.configure import configure_routes, configure_search_types
|
||||
from khoj.utils import state
|
||||
from khoj.utils.state import model, config
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.state import search_models, content_index, config
|
||||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
|
@ -143,7 +144,10 @@ def test_get_configured_types_with_no_content_config():
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
query_expected_image_pairs = [
|
||||
("kitten", "kitten_park.jpg"),
|
||||
("a horse and dog on a leash", "horse_dog.jpg"),
|
||||
|
@ -166,7 +170,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
|
@ -183,8 +190,9 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
|
|||
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
filters = [WordFilter(), FileFilter()]
|
||||
model.org_search = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||
)
|
||||
user_query = quote('+"Emacs" file:"*.org"')
|
||||
|
||||
|
@ -202,8 +210,9 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s
|
|||
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
model.org_search = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters
|
||||
)
|
||||
user_query = quote('How to git install application? +"Emacs"')
|
||||
|
||||
|
@ -221,8 +230,9 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig,
|
|||
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
model.org_search = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
|
||||
)
|
||||
user_query = quote('How to git install application? -"clone"')
|
||||
|
||||
|
|
|
@ -5,9 +5,10 @@ from PIL import Image
|
|||
|
||||
# External Packages
|
||||
import pytest
|
||||
from khoj.utils.config import SearchModels
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.state import model
|
||||
from khoj.utils.state import content_index, search_models
|
||||
from khoj.utils.constants import web_directory
|
||||
from khoj.search_type import image_search
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
|
@ -16,10 +17,12 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
|||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig):
|
||||
def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
# Act
|
||||
# Regenerate image search embeddings during image setup
|
||||
image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True)
|
||||
image_search_model = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(image_search_model.image_names) == 3
|
||||
|
@ -54,8 +57,11 @@ def test_image_metadata(content_config: ContentConfig):
|
|||
@pytest.mark.anyio
|
||||
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
output_directory = resolve_absolute_path(web_directory)
|
||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
query_expected_image_pairs = [
|
||||
("kitten", "kitten_park.jpg"),
|
||||
("horse and dog in a farm", "horse_dog.jpg"),
|
||||
|
@ -64,11 +70,13 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
|
|||
|
||||
# Act
|
||||
for query, expected_image_name in query_expected_image_pairs:
|
||||
hits = await image_search.query(query, count=1, model=model.image_search)
|
||||
hits = await image_search.query(
|
||||
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||
)
|
||||
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
model.image_search.image_names,
|
||||
content_index.image.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=1,
|
||||
|
@ -90,7 +98,10 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
|
|||
@pytest.mark.anyio
|
||||
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
# Arrange
|
||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
max_words_supported = 10
|
||||
query = " ".join(["hello"] * 100)
|
||||
truncated_query = " ".join(["hello"] * max_words_supported)
|
||||
|
@ -98,7 +109,9 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
|
|||
# Act
|
||||
try:
|
||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||
await image_search.query(query, count=1, model=model.image_search)
|
||||
await image_search.query(
|
||||
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||
)
|
||||
# Assert
|
||||
except RuntimeError as e:
|
||||
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
||||
|
@ -110,8 +123,11 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
|
|||
@pytest.mark.anyio
|
||||
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
# Arrange
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
output_directory = resolve_absolute_path(web_directory)
|
||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
image_directory = content_config.image.input_directories[0]
|
||||
|
||||
query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
|
||||
|
@ -119,11 +135,13 @@ async def test_image_search_by_filepath(content_config: ContentConfig, search_co
|
|||
|
||||
# Act
|
||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||
hits = await image_search.query(query, count=1, model=model.image_search)
|
||||
hits = await image_search.query(
|
||||
query, count=1, search_model=search_models.image_search, content=content_index.image
|
||||
)
|
||||
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
model.image_search.image_names,
|
||||
content_index.image.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=1,
|
||||
|
|
|
@ -5,9 +5,10 @@ import os
|
|||
|
||||
# External Packages
|
||||
import pytest
|
||||
from khoj.utils.config import SearchModels
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.state import model
|
||||
from khoj.utils.state import content_index, search_models
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
|
@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error(
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
|
||||
def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(notes_model.entries) == 10
|
||||
|
@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
|
||||
# Arrange
|
||||
caplog.set_level(logging.INFO, logger="khoj")
|
||||
|
||||
# Act
|
||||
# Generate initial notes embeddings during asymmetric setup
|
||||
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
|
@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
|
|||
@pytest.mark.anyio
|
||||
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True)
|
||||
hits, entries = await text_search.query(
|
||||
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
|
||||
)
|
||||
|
||||
results = text_search.collate_results(hits, entries, count=1)
|
||||
|
||||
|
@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
max_tokens = 256
|
||||
|
@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
|||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
|
||||
def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
# Arrange
|
||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
|
@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
|||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(regenerated_notes_model.entries) == 11
|
||||
|
@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
|
||||
def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
# Arrange
|
||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
|
@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
|
|||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
# verify new entry added in updated embeddings, entries
|
||||
|
@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||
def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig):
|
||||
def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels):
|
||||
# Act
|
||||
# Regenerate github embeddings to test asymmetric setup without caching
|
||||
github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True)
|
||||
github_model = text_search.setup(
|
||||
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(github_model.entries) > 1
|
||||
|
|
Loading…
Reference in a new issue