mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 10:53:02 +01:00
Accept file deletion requests by clients during sync
- Remove unused full_corpus boolean. The full_corpus=False code path wasn't being used (accept for in a test) - The full_corpus=True code path used was ignoring file deletion requests sent by clients during sync. Unclear why this was done - Added unit test to prevent regression and show file deletion by clients during sync not ignored now
This commit is contained in:
parent
5923b6d89e
commit
bba4e0b529
14 changed files with 84 additions and 80 deletions
|
@ -19,16 +19,11 @@ class DocxToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
files_to_process = set(files) - deletion_file_names
|
||||||
files_to_process = set(files) - deletion_file_names
|
files = {file: files[file] for file in files_to_process}
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified Docx files
|
# Extract Entries from specified Docx files
|
||||||
with timer("Extract entries from specified DOCX files", logger):
|
with timer("Extract entries from specified DOCX files", logger):
|
||||||
|
|
|
@ -48,9 +48,7 @@ class GithubToEntries(TextToEntries):
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
if self.config.pat_token is None or self.config.pat_token == "":
|
if self.config.pat_token is None or self.config.pat_token == "":
|
||||||
logger.error(f"Github PAT token is not set. Skipping github content")
|
logger.error(f"Github PAT token is not set. Skipping github content")
|
||||||
raise ValueError("Github PAT token is not set. Skipping github content")
|
raise ValueError("Github PAT token is not set. Skipping github content")
|
||||||
|
|
|
@ -20,16 +20,11 @@ class ImageToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
files_to_process = set(files) - deletion_file_names
|
||||||
files_to_process = set(files) - deletion_file_names
|
files = {file: files[file] for file in files_to_process}
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified image files
|
# Extract Entries from specified image files
|
||||||
with timer("Extract entries from specified Image files", logger):
|
with timer("Extract entries from specified Image files", logger):
|
||||||
|
|
|
@ -19,16 +19,11 @@ class MarkdownToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
files_to_process = set(files) - deletion_file_names
|
||||||
files_to_process = set(files) - deletion_file_names
|
files = {file: files[file] for file in files_to_process}
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
|
|
|
@ -78,9 +78,7 @@ class NotionToEntries(TextToEntries):
|
||||||
|
|
||||||
self.body_params = {"page_size": 100}
|
self.body_params = {"page_size": 100}
|
||||||
|
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
current_entries = []
|
current_entries = []
|
||||||
|
|
||||||
# Get all pages
|
# Get all pages
|
||||||
|
|
|
@ -20,15 +20,10 @@ class OrgToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
) -> Tuple[int, int]:
|
files_to_process = set(files) - deletion_file_names
|
||||||
if not full_corpus:
|
files = {file: files[file] for file in files_to_process}
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
|
||||||
files_to_process = set(files) - deletion_file_names
|
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
|
|
|
@ -22,16 +22,11 @@ class PdfToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
files_to_process = set(files) - deletion_file_names
|
||||||
files_to_process = set(files) - deletion_file_names
|
files = {file: files[file] for file in files_to_process}
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified Pdf files
|
# Extract Entries from specified Pdf files
|
||||||
with timer("Extract entries from specified PDF files", logger):
|
with timer("Extract entries from specified PDF files", logger):
|
||||||
|
|
|
@ -20,15 +20,10 @@ class PlaintextToEntries(TextToEntries):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
) -> Tuple[int, int]:
|
files_to_process = set(files) - deletion_file_names
|
||||||
if not full_corpus:
|
files = {file: files[file] for file in files_to_process}
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
|
||||||
files_to_process = set(files) - deletion_file_names
|
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified plaintext files
|
# Extract Entries from specified plaintext files
|
||||||
with timer("Extract entries from specified Plaintext files", logger):
|
with timer("Extract entries from specified Plaintext files", logger):
|
||||||
|
|
|
@ -31,9 +31,7 @@ class TextToEntries(ABC):
|
||||||
self.date_filter = DateFilter()
|
self.date_filter = DateFilter()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
...
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -1313,7 +1313,6 @@ def configure_content(
|
||||||
files: Optional[dict[str, dict[str, str]]],
|
files: Optional[dict[str, dict[str, str]]],
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
t: Optional[state.SearchType] = state.SearchType.All,
|
t: Optional[state.SearchType] = state.SearchType.All,
|
||||||
full_corpus: bool = True,
|
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
success = True
|
success = True
|
||||||
|
@ -1344,7 +1343,6 @@ def configure_content(
|
||||||
OrgToEntries,
|
OrgToEntries,
|
||||||
files.get("org"),
|
files.get("org"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1362,7 +1360,6 @@ def configure_content(
|
||||||
MarkdownToEntries,
|
MarkdownToEntries,
|
||||||
files.get("markdown"),
|
files.get("markdown"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1379,7 +1376,6 @@ def configure_content(
|
||||||
PdfToEntries,
|
PdfToEntries,
|
||||||
files.get("pdf"),
|
files.get("pdf"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1398,7 +1394,6 @@ def configure_content(
|
||||||
PlaintextToEntries,
|
PlaintextToEntries,
|
||||||
files.get("plaintext"),
|
files.get("plaintext"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1418,7 +1413,6 @@ def configure_content(
|
||||||
GithubToEntries,
|
GithubToEntries,
|
||||||
None,
|
None,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
config=github_config,
|
config=github_config,
|
||||||
)
|
)
|
||||||
|
@ -1439,7 +1433,6 @@ def configure_content(
|
||||||
NotionToEntries,
|
NotionToEntries,
|
||||||
None,
|
None,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
config=notion_config,
|
config=notion_config,
|
||||||
)
|
)
|
||||||
|
@ -1459,7 +1452,6 @@ def configure_content(
|
||||||
ImageToEntries,
|
ImageToEntries,
|
||||||
files.get("image"),
|
files.get("image"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1472,7 +1464,6 @@ def configure_content(
|
||||||
DocxToEntries,
|
DocxToEntries,
|
||||||
files.get("docx"),
|
files.get("docx"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
|
||||||
notion_redirect = str(request.app.url_path_for("notion_config_page"))
|
notion_redirect = str(request.app.url_path_for("notion_config_page"))
|
||||||
|
|
||||||
# Trigger an async job to configure_content. Let it run without blocking the response.
|
# Trigger an async job to configure_content. Let it run without blocking the response.
|
||||||
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user)
|
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
|
||||||
|
|
||||||
return RedirectResponse(notion_redirect)
|
return RedirectResponse(notion_redirect)
|
||||||
|
|
|
@ -199,17 +199,16 @@ def setup(
|
||||||
text_to_entries: Type[TextToEntries],
|
text_to_entries: Type[TextToEntries],
|
||||||
files: dict[str, str],
|
files: dict[str, str],
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
full_corpus: bool = True,
|
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
config=None,
|
config=None,
|
||||||
) -> None:
|
) -> Tuple[int, int]:
|
||||||
if config:
|
if config:
|
||||||
num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process(
|
num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process(
|
||||||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
files=files, user=user, regenerate=regenerate
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_new_embeddings, num_deleted_embeddings = text_to_entries().process(
|
num_new_embeddings, num_deleted_embeddings = text_to_entries().process(
|
||||||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
files=files, user=user, regenerate=regenerate
|
||||||
)
|
)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
|
@ -219,6 +218,8 @@ def setup(
|
||||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..."
|
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
||||||
|
|
||||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||||
"""Score all retrieved entries using the cross-encoder"""
|
"""Score all retrieved entries using the cross-encoder"""
|
||||||
|
|
|
@ -122,7 +122,7 @@ def get_org_files(config: TextContentConfig):
|
||||||
logger.debug("At least one of org-files or org-file-filter is required to be specified")
|
logger.debug("At least one of org-files or org-file-filter is required to be specified")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
"Get Org files to process"
|
# Get Org files to process
|
||||||
absolute_org_files, filtered_org_files = set(), set()
|
absolute_org_files, filtered_org_files = set(), set()
|
||||||
if org_files:
|
if org_files:
|
||||||
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
||||||
|
|
|
@ -6,9 +6,16 @@ from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from khoj.database.adapters import EntryAdapters
|
||||||
from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
|
from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
|
||||||
|
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
|
||||||
from khoj.processor.content.github.github_to_entries import GithubToEntries
|
from khoj.processor.content.github.github_to_entries import GithubToEntries
|
||||||
|
from khoj.processor.content.images.image_to_entries import ImageToEntries
|
||||||
|
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
|
||||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||||
|
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||||
|
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||||
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
from khoj.search_type import text_search
|
from khoj.search_type import text_search
|
||||||
from khoj.utils.fs_syncer import collect_files, get_org_files
|
from khoj.utils.fs_syncer import collect_files, get_org_files
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
|
@ -151,7 +158,6 @@ async def test_text_search(search_config: SearchConfig):
|
||||||
OrgToEntries,
|
OrgToEntries,
|
||||||
data,
|
data,
|
||||||
True,
|
True,
|
||||||
True,
|
|
||||||
default_user,
|
default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -240,7 +246,6 @@ conda activate khoj
|
||||||
OrgToEntries,
|
OrgToEntries,
|
||||||
data,
|
data,
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
full_corpus=False,
|
|
||||||
user=default_user,
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -396,6 +401,49 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
||||||
verify_embeddings(3, default_user)
|
verify_embeddings(3, default_user)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text_to_entries",
|
||||||
|
[
|
||||||
|
(OrgToEntries),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_update_index_with_deleted_file(
|
||||||
|
org_config_with_only_new_file: LocalOrgConfig, text_to_entries: TextToEntries, default_user: KhojUser
|
||||||
|
):
|
||||||
|
"Delete entries associated with new file when file path with empty content passed."
|
||||||
|
# Arrange
|
||||||
|
file_to_index = "test"
|
||||||
|
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||||
|
initial_data = {file_to_index: new_entry}
|
||||||
|
final_data = {file_to_index: ""}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# load entries after adding file
|
||||||
|
initial_added_entries, _ = text_search.setup(text_to_entries, initial_data, regenerate=True, user=default_user)
|
||||||
|
initial_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count()
|
||||||
|
|
||||||
|
# load entries after deleting file
|
||||||
|
final_added_entries, final_deleted_entries = text_search.setup(
|
||||||
|
text_to_entries, final_data, regenerate=False, user=default_user
|
||||||
|
)
|
||||||
|
final_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert initial_total_entries > 0, "File entries not indexed"
|
||||||
|
assert initial_added_entries > 0, "No entries got added"
|
||||||
|
|
||||||
|
assert final_total_entries == 0, "File did not get deleted"
|
||||||
|
assert final_added_entries == 0, "Entries were unexpectedly added in delete entries pass"
|
||||||
|
assert final_deleted_entries == initial_added_entries, "All added entries were not deleted"
|
||||||
|
|
||||||
|
verify_embeddings(0, default_user), "Embeddings still exist for user"
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
EntryAdapters.delete_all_entries(default_user)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||||
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||||
|
|
Loading…
Reference in a new issue