From 5e83baab2148e2f4d2f744726b86f8b67f96e7d2 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 17 Feb 2023 10:04:26 -0600 Subject: [PATCH] Use Black to format Khoj server code and tests --- pyproject.toml | 6 +- src/khoj/configure.py | 27 +- src/khoj/interface/desktop/file_browser.py | 41 +- .../interface/desktop/labelled_text_field.py | 2 +- src/khoj/interface/desktop/main_window.py | 79 ++-- src/khoj/interface/desktop/system_tray.py | 10 +- src/khoj/main.py | 20 +- src/khoj/processor/conversation/gpt.py | 72 ++-- .../processor/ledger/beancount_to_jsonl.py | 44 +- .../processor/markdown/markdown_to_jsonl.py | 41 +- src/khoj/processor/org_mode/org_to_jsonl.py | 39 +- src/khoj/processor/org_mode/orgnode.py | 376 +++++++++--------- src/khoj/processor/text_to_jsonl.py | 24 +- src/khoj/routers/api.py | 32 +- src/khoj/routers/api_beta.py | 80 ++-- src/khoj/routers/web_client.py | 8 +- src/khoj/search_filter/base_filter.py | 9 +- src/khoj/search_filter/date_filter.py | 56 ++- src/khoj/search_filter/file_filter.py | 17 +- src/khoj/search_filter/word_filter.py | 21 +- src/khoj/search_type/image_search.py | 120 +++--- src/khoj/search_type/text_search.py | 108 +++-- src/khoj/utils/cli.py | 35 +- src/khoj/utils/config.py | 22 +- src/khoj/utils/constants.py | 103 +++-- src/khoj/utils/helpers.py | 22 +- src/khoj/utils/jsonl.py | 14 +- src/khoj/utils/models.py | 24 +- src/khoj/utils/rawconfig.py | 38 +- src/khoj/utils/yaml.py | 4 +- tests/conftest.py | 58 +-- tests/test_beancount_to_jsonl.py | 25 +- tests/test_chatbot.py | 44 +- tests/test_cli.py | 18 +- tests/test_client.py | 21 +- tests/test_date_filter.py | 106 +++-- tests/test_file_filter.py | 22 +- tests/test_helpers.py | 31 +- tests/test_image_search.py | 44 +- tests/test_markdown_to_jsonl.py | 29 +- tests/test_org_to_jsonl.py | 56 +-- tests/test_orgnode.py | 80 ++-- tests/test_text_search.py | 36 +- tests/test_word_filter.py | 18 +- 44 files changed, 1167 insertions(+), 915 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 775b5b21..f55ff8b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,8 @@ khoj = "khoj.main:run" [project.optional-dependencies] test = [ - "pytest == 7.1.2", + "pytest >= 7.1.2", + "black >= 23.1.0", ] dev = ["khoj-assistant[test]"] @@ -88,3 +89,6 @@ exclude = [ "src/khoj/interface/desktop/file_browser.py", "src/khoj/interface/desktop/system_tray.py", ] + +[tool.black] +line-length = 120 \ No newline at end of file diff --git a/src/khoj/configure.py b/src/khoj/configure.py index c776a4a0..5755b9c2 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -26,10 +26,12 @@ logger = logging.getLogger(__name__) def configure_server(args, required=False): if args.config is None: if required: - logger.error(f'Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.') + logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.") sys.exit(1) else: - logger.warn(f'Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}.') + logger.warn( + f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}." + ) return else: state.config = args.config @@ -60,7 +62,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()]) + filters=[DateFilter(), WordFilter(), FileFilter()], + ) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: @@ -70,7 +73,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, - filters=[DateFilter(), WordFilter()]) + filters=[DateFilter(), WordFilter()], + ) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: @@ -80,7 +84,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()]) + filters=[DateFilter(), WordFilter(), FileFilter()], + ) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: @@ -90,15 +95,15 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()]) + filters=[DateFilter(), WordFilter(), FileFilter()], + ) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: # Extract Entries, Generate Image Embeddings model.image_search = image_search.setup( - config.content_type.image, - search_config=config.search_type.image, - regenerate=regenerate) + config.content_type.image, search_config=config.search_type.image, regenerate=regenerate + ) # Invalidate Query Cache state.query_cache = LRU() @@ -125,9 +130,9 @@ def configure_conversation_processor(conversation_processor_config): if conversation_logfile.is_file(): # Load Metadata Logs from Conversation Logfile - with conversation_logfile.open('r') as f: + with conversation_logfile.open("r") as f: conversation_processor.meta_log = json.load(f) - logger.info('Conversation logs loaded from disk.') + logger.info("Conversation logs loaded from disk.") else: # Initialize Conversation Logs conversation_processor.meta_log = {} diff --git a/src/khoj/interface/desktop/file_browser.py b/src/khoj/interface/desktop/file_browser.py index 70e6b8d7..97c2ddda 100644 --- a/src/khoj/interface/desktop/file_browser.py +++ b/src/khoj/interface/desktop/file_browser.py @@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty class FileBrowser(QtWidgets.QWidget): - def __init__(self, title, search_type: SearchType=None, default_files:list=[]): + def __init__(self, title, search_type: SearchType = None, default_files: list = []): QtWidgets.QWidget.__init__(self) layout = QtWidgets.QHBoxLayout() self.setLayout(layout) @@ -22,51 +22,54 @@ class FileBrowser(QtWidgets.QWidget): self.label.setFixedWidth(95) self.label.setWordWrap(True) layout.addWidget(self.label) - + self.lineEdit = QtWidgets.QPlainTextEdit(self) self.lineEdit.setFixedWidth(330) self.setFiles(default_files) - self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90)) + self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90)) self.lineEdit.textChanged.connect(self.updateFieldHeight) layout.addWidget(self.lineEdit) - - self.button = QtWidgets.QPushButton('Add') + + self.button = QtWidgets.QPushButton("Add") self.button.clicked.connect(self.storeFilesSelectedInFileDialog) layout.addWidget(self.button) layout.addStretch() def getFileFilter(self, search_type): if search_type == SearchType.Org: - return 'Org-Mode Files (*.org)' + return "Org-Mode Files (*.org)" elif search_type == SearchType.Ledger: - return 'Beancount Files (*.bean *.beancount)' + return "Beancount Files (*.bean *.beancount)" elif search_type == SearchType.Markdown: - return 'Markdown Files (*.md *.markdown)' + return "Markdown Files (*.md *.markdown)" elif search_type == SearchType.Music: - return 'Org-Music Files (*.org)' + return "Org-Music Files (*.org)" elif search_type == SearchType.Image: - return 'Images (*.jp[e]g)' + return "Images (*.jp[e]g)" def storeFilesSelectedInFileDialog(self): filepaths = self.getPaths() if self.search_type == SearchType.Image: - filepaths.append(QtWidgets.QFileDialog.getExistingDirectory(self, caption='Choose Folder', - directory=self.dirpath)) + filepaths.append( + QtWidgets.QFileDialog.getExistingDirectory(self, caption="Choose Folder", directory=self.dirpath) + ) else: - filepaths.extend(QtWidgets.QFileDialog.getOpenFileNames(self, caption='Choose Files', - directory=self.dirpath, - filter=self.filter_name)[0]) + filepaths.extend( + QtWidgets.QFileDialog.getOpenFileNames( + self, caption="Choose Files", directory=self.dirpath, filter=self.filter_name + )[0] + ) self.setFiles(filepaths) - def setFiles(self, paths:list): + def setFiles(self, paths: list): self.filepaths = [path for path in paths if not is_none_or_empty(path)] self.lineEdit.setPlainText("\n".join(self.filepaths)) def getPaths(self) -> list: - if self.lineEdit.toPlainText() == '': + if self.lineEdit.toPlainText() == "": return [] else: - return self.lineEdit.toPlainText().split('\n') + return self.lineEdit.toPlainText().split("\n") def updateFieldHeight(self): - self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90)) + self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90)) diff --git a/src/khoj/interface/desktop/labelled_text_field.py b/src/khoj/interface/desktop/labelled_text_field.py index 3248c21c..4032c2c0 100644 --- a/src/khoj/interface/desktop/labelled_text_field.py +++ b/src/khoj/interface/desktop/labelled_text_field.py @@ -6,7 +6,7 @@ from khoj.utils.config import ProcessorType class LabelledTextField(QtWidgets.QWidget): - def __init__(self, title, processor_type: ProcessorType=None, default_value: str=None): + def __init__(self, title, processor_type: ProcessorType = None, default_value: str = None): QtWidgets.QWidget.__init__(self) layout = QtWidgets.QHBoxLayout() self.setLayout(layout) diff --git a/src/khoj/interface/desktop/main_window.py b/src/khoj/interface/desktop/main_window.py index 6ee45fb6..d2150ed1 100644 --- a/src/khoj/interface/desktop/main_window.py +++ b/src/khoj/interface/desktop/main_window.py @@ -31,9 +31,9 @@ class MainWindow(QtWidgets.QMainWindow): self.config_file = config_file # Set regenerate flag to regenerate embeddings everytime user clicks configure if state.cli_args: - state.cli_args += ['--regenerate'] + state.cli_args += ["--regenerate"] else: - state.cli_args = ['--regenerate'] + state.cli_args = ["--regenerate"] # Load config from existing config, if exists, else load from default config if resolve_absolute_path(self.config_file).exists(): @@ -49,8 +49,8 @@ class MainWindow(QtWidgets.QMainWindow): self.setFixedWidth(600) # Set Window Icon - icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png' - self.setWindowIcon(QtGui.QIcon(f'{icon_path.absolute()}')) + icon_path = constants.web_directory / "assets/icons/favicon-144x144.png" + self.setWindowIcon(QtGui.QIcon(f"{icon_path.absolute()}")) # Initialize Configure Window Layout self.layout = QtWidgets.QVBoxLayout() @@ -58,13 +58,13 @@ class MainWindow(QtWidgets.QMainWindow): # Add Settings Panels for each Search Type to Configure Window Layout self.search_settings_panels = [] for search_type in SearchType: - current_content_config = self.current_config['content-type'].get(search_type, {}) + current_content_config = self.current_config["content-type"].get(search_type, {}) self.search_settings_panels += [self.add_settings_panel(current_content_config, search_type)] # Add Conversation Processor Panel to Configure Screen self.processor_settings_panels = [] conversation_type = ProcessorType.Conversation - current_conversation_config = self.current_config['processor'].get(conversation_type, {}) + current_conversation_config = self.current_config["processor"].get(conversation_type, {}) self.processor_settings_panels += [self.add_processor_panel(current_conversation_config, conversation_type)] # Add Action Buttons Panel @@ -81,11 +81,11 @@ class MainWindow(QtWidgets.QMainWindow): "Add Settings Panel for specified Search Type. Toggle Editable Search Types" # Get current files from config for given search type if search_type == SearchType.Image: - current_content_files = current_content_config.get('input-directories', []) - file_input_text = f'{search_type.name} Folders' + current_content_files = current_content_config.get("input-directories", []) + file_input_text = f"{search_type.name} Folders" else: - current_content_files = current_content_config.get('input-files', []) - file_input_text = f'{search_type.name} Files' + current_content_files = current_content_config.get("input-files", []) + file_input_text = f"{search_type.name} Files" # Create widgets to display settings for given search type search_type_settings = QtWidgets.QWidget() @@ -109,7 +109,7 @@ class MainWindow(QtWidgets.QMainWindow): def add_processor_panel(self, current_conversation_config: dict, processor_type: ProcessorType): "Add Conversation Processor Panel" # Get current settings from config for given processor type - current_openai_api_key = current_conversation_config.get('openai-api-key', None) + current_openai_api_key = current_conversation_config.get("openai-api-key", None) # Create widgets to display settings for given processor type processor_type_settings = QtWidgets.QWidget() @@ -137,20 +137,22 @@ class MainWindow(QtWidgets.QMainWindow): action_bar_layout = QtWidgets.QHBoxLayout(action_bar) self.configure_button = QtWidgets.QPushButton("Configure", clicked=self.configure_app) - self.search_button = QtWidgets.QPushButton("Search", clicked=lambda: webbrowser.open(f'http://{state.host}:{state.port}/')) + self.search_button = QtWidgets.QPushButton( + "Search", clicked=lambda: webbrowser.open(f"http://{state.host}:{state.port}/") + ) self.search_button.setEnabled(not self.first_run) action_bar_layout.addWidget(self.configure_button) action_bar_layout.addWidget(self.search_button) self.layout.addWidget(action_bar) - def get_default_config(self, search_type:SearchType=None, processor_type:ProcessorType=None): + def get_default_config(self, search_type: SearchType = None, processor_type: ProcessorType = None): "Get default config" config = constants.default_config if search_type: - return config['content-type'][search_type] + return config["content-type"][search_type] elif processor_type: - return config['processor'][processor_type] + return config["processor"][processor_type] else: return config @@ -160,7 +162,9 @@ class MainWindow(QtWidgets.QMainWindow): for message_prefix in ErrorType: for i in reversed(range(self.layout.count())): current_widget = self.layout.itemAt(i).widget() - if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(message_prefix.value): + if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith( + message_prefix.value + ): self.layout.removeWidget(current_widget) current_widget.deleteLater() @@ -180,18 +184,24 @@ class MainWindow(QtWidgets.QMainWindow): continue if isinstance(child, SearchCheckBox): # Search Type Disabled - if not child.isChecked() and child.search_type in self.new_config['content-type']: - del self.new_config['content-type'][child.search_type] + if not child.isChecked() and child.search_type in self.new_config["content-type"]: + del self.new_config["content-type"][child.search_type] # Search Type (re)-Enabled if child.isChecked(): - current_search_config = self.current_config['content-type'].get(child.search_type, {}) - default_search_config = self.get_default_config(search_type = child.search_type) - self.new_config['content-type'][child.search_type.value] = merge_dicts(current_search_config, default_search_config) - elif isinstance(child, FileBrowser) and child.search_type in self.new_config['content-type']: + current_search_config = self.current_config["content-type"].get(child.search_type, {}) + default_search_config = self.get_default_config(search_type=child.search_type) + self.new_config["content-type"][child.search_type.value] = merge_dicts( + current_search_config, default_search_config + ) + elif isinstance(child, FileBrowser) and child.search_type in self.new_config["content-type"]: if child.search_type.value == SearchType.Image: - self.new_config['content-type'][child.search_type.value]['input-directories'] = child.getPaths() if child.getPaths() != [] else None + self.new_config["content-type"][child.search_type.value]["input-directories"] = ( + child.getPaths() if child.getPaths() != [] else None + ) else: - self.new_config['content-type'][child.search_type.value]['input-files'] = child.getPaths() if child.getPaths() != [] else None + self.new_config["content-type"][child.search_type.value]["input-files"] = ( + child.getPaths() if child.getPaths() != [] else None + ) def update_processor_settings(self): "Update config with conversation settings from UI" @@ -201,16 +211,20 @@ class MainWindow(QtWidgets.QMainWindow): continue if isinstance(child, ProcessorCheckBox): # Processor Type Disabled - if not child.isChecked() and child.processor_type in self.new_config['processor']: - del self.new_config['processor'][child.processor_type] + if not child.isChecked() and child.processor_type in self.new_config["processor"]: + del self.new_config["processor"][child.processor_type] # Processor Type (re)-Enabled if child.isChecked(): - current_processor_config = self.current_config['processor'].get(child.processor_type, {}) - default_processor_config = self.get_default_config(processor_type = child.processor_type) - self.new_config['processor'][child.processor_type.value] = merge_dicts(current_processor_config, default_processor_config) - elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config['processor']: + current_processor_config = self.current_config["processor"].get(child.processor_type, {}) + default_processor_config = self.get_default_config(processor_type=child.processor_type) + self.new_config["processor"][child.processor_type.value] = merge_dicts( + current_processor_config, default_processor_config + ) + elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config["processor"]: if child.processor_type == ProcessorType.Conversation: - self.new_config['processor'][child.processor_type.value]['openai-api-key'] = child.input_field.toPlainText() if child.input_field.toPlainText() != '' else None + self.new_config["processor"][child.processor_type.value]["openai-api-key"] = ( + child.input_field.toPlainText() if child.input_field.toPlainText() != "" else None + ) def save_settings_to_file(self) -> bool: "Save validated settings to file" @@ -278,7 +292,7 @@ class MainWindow(QtWidgets.QMainWindow): self.show() self.setWindowState(Qt.WindowState.WindowActive) self.activateWindow() # For Bringing to Top on Windows - self.raise_() # For Bringing to Top from Minimized State on OSX + self.raise_() # For Bringing to Top from Minimized State on OSX class SettingsLoader(QObject): @@ -312,6 +326,7 @@ class ProcessorCheckBox(QtWidgets.QCheckBox): self.processor_type = processor_type super(ProcessorCheckBox, self).__init__(text, parent=parent) + class ErrorType(Enum): "Error Types" ConfigLoadingError = "Config Loading Error" diff --git a/src/khoj/interface/desktop/system_tray.py b/src/khoj/interface/desktop/system_tray.py index c8559527..65f4f5c4 100644 --- a/src/khoj/interface/desktop/system_tray.py +++ b/src/khoj/interface/desktop/system_tray.py @@ -17,17 +17,17 @@ def create_system_tray(gui: QtWidgets.QApplication, main_window: MainWindow): """ # Create the system tray with icon - icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png' - icon = QtGui.QIcon(f'{icon_path.absolute()}') + icon_path = constants.web_directory / "assets/icons/favicon-144x144.png" + icon = QtGui.QIcon(f"{icon_path.absolute()}") tray = QtWidgets.QSystemTrayIcon(icon) tray.setVisible(True) # Create the menu and menu actions menu = QtWidgets.QMenu() menu_actions = [ - ('Search', lambda: webbrowser.open(f'http://{state.host}:{state.port}/')), - ('Configure', main_window.show_on_top), - ('Quit', gui.quit), + ("Search", lambda: webbrowser.open(f"http://{state.host}:{state.port}/")), + ("Configure", main_window.show_on_top), + ("Quit", gui.quit), ] # Add the menu actions to the menu diff --git a/src/khoj/main.py b/src/khoj/main.py index 8c21aee3..b61363aa 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -8,8 +8,8 @@ import warnings from platform import system # Ignore non-actionable warnings -warnings.filterwarnings("ignore", message=r'snapshot_download.py has been made private', category=FutureWarning) -warnings.filterwarnings("ignore", message=r'legacy way to download files from the HF hub,', category=FutureWarning) +warnings.filterwarnings("ignore", message=r"snapshot_download.py has been made private", category=FutureWarning) +warnings.filterwarnings("ignore", message=r"legacy way to download files from the HF hub,", category=FutureWarning) # External Packages import uvicorn @@ -43,11 +43,12 @@ rich_handler = RichHandler(rich_tracebacks=True) rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]")) logging.basicConfig(handlers=[rich_handler]) -logger = logging.getLogger('khoj') +logger = logging.getLogger("khoj") + def run(): # Turn Tokenizers Parallelism Off. App does not support it. - os.environ["TOKENIZERS_PARALLELISM"] = 'false' + os.environ["TOKENIZERS_PARALLELISM"] = "false" # Load config from CLI state.cli_args = sys.argv[1:] @@ -66,7 +67,7 @@ def run(): logger.setLevel(logging.DEBUG) # Set Log File - fh = logging.FileHandler(state.config_file.parent / 'khoj.log') + fh = logging.FileHandler(state.config_file.parent / "khoj.log") fh.setLevel(logging.DEBUG) logger.addHandler(fh) @@ -87,7 +88,7 @@ def run(): # On Linux (Gnome) the System tray is not supported. # Since only the Main Window is available # Quitting it should quit the application - if system() in ['Windows', 'Darwin']: + if system() in ["Windows", "Darwin"]: gui.setQuitOnLastWindowClosed(False) tray = create_system_tray(gui, main_window) tray.show() @@ -97,7 +98,7 @@ def run(): server = ServerThread(app, args.host, args.port, args.socket) # Show Main Window on First Run Experience or if on Linux - if args.config is None or system() not in ['Windows', 'Darwin']: + if args.config is None or system() not in ["Windows", "Darwin"]: main_window.show() # Setup Signal Handlers @@ -112,9 +113,10 @@ def run(): gui.aboutToQuit.connect(server.terminate) # Close Splash Screen if still open - if system() != 'Darwin': + if system() != "Darwin": try: import pyi_splash + # Update the text on the splash screen pyi_splash.update_text("Khoj setup complete") # Close Splash Screen @@ -167,5 +169,5 @@ class ServerThread(QThread): start_server(self.app, self.host, self.port, self.socket) -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index a274074f..c942e08d 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -19,31 +19,27 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat # Setup Prompt based on Summary Type if summary_type == "chat": - prompt = f''' + prompt = f""" You are an AI. Summarize the conversation below from your perspective: {text} -Summarize the conversation from the AI's first-person perspective:''' +Summarize the conversation from the AI's first-person perspective:""" elif summary_type == "notes": - prompt = f''' + prompt = f""" Summarize the below notes about {user_query}: {text} -Summarize the notes in second person perspective:''' +Summarize the notes in second person perspective:""" # Get Response from GPT response = openai.Completion.create( - prompt=prompt, - model=model, - temperature=temperature, - max_tokens=max_tokens, - frequency_penalty=0.2, - stop="\"\"\"") + prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""' + ) # Extract, Clean Message from GPT's Response - story = response['choices'][0]['text'] + story = response["choices"][0]["text"] return str(story).replace("\n\n", "") @@ -53,7 +49,7 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1 """ # Initialize Variables openai.api_key = api_key or os.getenv("OPENAI_API_KEY") - understand_primer = ''' + understand_primer = """ Objective: Extract search type from user query and return information as JSON Allowed search types are listed below: @@ -73,7 +69,7 @@ A:{ "search-type": "notes" } Q: When did I buy Groceries last? A:{ "search-type": "ledger" } Q:When did I go surfing last? -A:{ "search-type": "notes" }''' +A:{ "search-type": "notes" }""" # Setup Prompt with Understand Primer prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:") @@ -82,15 +78,11 @@ A:{ "search-type": "notes" }''' # Get Response from GPT response = openai.Completion.create( - prompt=prompt, - model=model, - temperature=temperature, - max_tokens=max_tokens, - frequency_penalty=0.2, - stop=["\n"]) + prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"] + ) # Extract, Clean Message from GPT's Response - story = str(response['choices'][0]['text']) + story = str(response["choices"][0]["text"]) return json.loads(story.strip(empty_escape_sequences)) @@ -100,7 +92,7 @@ def understand(text, model, api_key=None, temperature=0.5, max_tokens=100, verbo """ # Initialize Variables openai.api_key = api_key or os.getenv("OPENAI_API_KEY") - understand_primer = ''' + understand_primer = """ Objective: Extract intent and trigger emotion information as JSON from each chat message Potential intent types and valid argument values are listed below: @@ -142,7 +134,7 @@ A: { "intent": {"type": "remember", "memory-type": "notes", "query": "recommend Q: When did I go surfing last? A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" } Q: Can you dance for me? -A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }''' +A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }""" # Setup Prompt with Understand Primer prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:") @@ -151,15 +143,11 @@ A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance # Get Response from GPT response = openai.Completion.create( - prompt=prompt, - model=model, - temperature=temperature, - max_tokens=max_tokens, - frequency_penalty=0.2, - stop=["\n"]) + prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"] + ) # Extract, Clean Message from GPT's Response - story = str(response['choices'][0]['text']) + story = str(response["choices"][0]["text"]) return json.loads(story.strip(empty_escape_sequences)) @@ -171,15 +159,15 @@ def converse(text, model, conversation_history=None, api_key=None, temperature=0 max_words = 500 openai.api_key = api_key or os.getenv("OPENAI_API_KEY") - conversation_primer = f''' + conversation_primer = f""" The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion. Human: Hello, who are you? -AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?''' +AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?""" # Setup Prompt with Primer or Conversation History prompt = message_to_prompt(text, conversation_history or conversation_primer) - prompt = ' '.join(prompt.split()[:max_words]) + prompt = " ".join(prompt.split()[:max_words]) # Get Response from GPT response = openai.Completion.create( @@ -188,14 +176,17 @@ AI: Hi, I am an AI conversational companion created by OpenAI. How can I help yo temperature=temperature, max_tokens=max_tokens, presence_penalty=0.6, - stop=["\n", "Human:", "AI:"]) + stop=["\n", "Human:", "AI:"], + ) # Extract, Clean Message from GPT's Response - story = str(response['choices'][0]['text']) + story = str(response["choices"][0]["text"]) return story.strip(empty_escape_sequences) -def message_to_prompt(user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"): +def message_to_prompt( + user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:" +): """Create prompt for GPT from messages and conversation history""" gpt_message = f" {gpt_message}" if gpt_message else "" @@ -205,12 +196,8 @@ def message_to_prompt(user_message, conversation_history="", gpt_message=None, s def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]): """Create json logs from messages, metadata for conversation log""" default_user_message_metadata = { - "intent": { - "type": "remember", - "memory-type": "notes", - "query": user_message - }, - "trigger-emotion": "calm" + "intent": {"type": "remember", "memory-type": "notes", "query": user_message}, + "trigger-emotion": "calm", } current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -229,5 +216,4 @@ def message_to_log(user_message, gpt_message, user_message_metadata={}, conversa def extract_summaries(metadata): """Extract summaries from metadata""" - return ''.join( - [f'\n{session["summary"]}' for session in metadata]) \ No newline at end of file + return "".join([f'\n{session["summary"]}' for session in metadata]) diff --git a/src/khoj/processor/ledger/beancount_to_jsonl.py b/src/khoj/processor/ledger/beancount_to_jsonl.py index 2703304f..006d0d70 100644 --- a/src/khoj/processor/ledger/beancount_to_jsonl.py +++ b/src/khoj/processor/ledger/beancount_to_jsonl.py @@ -19,7 +19,11 @@ class BeancountToJsonl(TextToJsonl): # Define Functions def process(self, previous_entries=None): # Extract required fields from config - beancount_files, beancount_file_filter, output_file = self.config.input_files, self.config.input_filter,self.config.compressed_jsonl + beancount_files, beancount_file_filter, output_file = ( + self.config.input_files, + self.config.input_filter, + self.config.compressed_jsonl, + ) # Input Validation if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): @@ -31,7 +35,9 @@ class BeancountToJsonl(TextToJsonl): # Extract Entries from specified Beancount files with timer("Parse transactions from Beancount files into dictionaries", logger): - current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files)) + current_entries = BeancountToJsonl.convert_transactions_to_maps( + *BeancountToJsonl.extract_beancount_transactions(beancount_files) + ) # Split entries by max tokens supported by model with timer("Split entries by max token size supported by model", logger): @@ -42,7 +48,9 @@ class BeancountToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) with timer("Write transactions to JSONL file", logger): # Process Each Entry from All Notes Files @@ -62,9 +70,7 @@ class BeancountToJsonl(TextToJsonl): "Get Beancount files to process" absolute_beancount_files, filtered_beancount_files = set(), set() if beancount_files: - absolute_beancount_files = {get_absolute_path(beancount_file) - for beancount_file - in beancount_files} + absolute_beancount_files = {get_absolute_path(beancount_file) for beancount_file in beancount_files} if beancount_file_filters: filtered_beancount_files = { filtered_file @@ -76,14 +82,13 @@ class BeancountToJsonl(TextToJsonl): files_with_non_beancount_extensions = { beancount_file - for beancount_file - in all_beancount_files + for beancount_file in all_beancount_files if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount") } if any(files_with_non_beancount_extensions): print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}") - logger.info(f'Processing files: {all_beancount_files}') + logger.info(f"Processing files: {all_beancount_files}") return all_beancount_files @@ -92,19 +97,20 @@ class BeancountToJsonl(TextToJsonl): "Extract entries from specified Beancount files" # Initialize Regex for extracting Beancount Entries - transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] ' - empty_newline = f'^[\n\r\t\ ]*$' + transaction_regex = r"^\n?\d{4}-\d{2}-\d{2} [\*|\!] " + empty_newline = f"^[\n\r\t\ ]*$" entries = [] transaction_to_file_map = [] for beancount_file in beancount_files: with open(beancount_file) as f: ledger_content = f.read() - transactions_per_file = [entry.strip(empty_escape_sequences) - for entry - in re.split(empty_newline, ledger_content, flags=re.MULTILINE) - if re.match(transaction_regex, entry)] - transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file)) + transactions_per_file = [ + entry.strip(empty_escape_sequences) + for entry in re.split(empty_newline, ledger_content, flags=re.MULTILINE) + if re.match(transaction_regex, entry) + ] + transaction_to_file_map += zip(transactions_per_file, [beancount_file] * len(transactions_per_file)) entries.extend(transactions_per_file) return entries, dict(transaction_to_file_map) @@ -113,7 +119,9 @@ class BeancountToJsonl(TextToJsonl): "Convert each parsed Beancount transaction into a Entry" entries = [] for parsed_entry in parsed_entries: - entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}')) + entries.append( + Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{transaction_to_file_map[parsed_entry]}") + ) logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries") @@ -122,4 +130,4 @@ class BeancountToJsonl(TextToJsonl): @staticmethod def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str: "Convert each Beancount transaction entry to JSON and collate as JSONL" - return ''.join([f'{entry.to_json()}\n' for entry in entries]) + return "".join([f"{entry.to_json()}\n" for entry in entries]) diff --git a/src/khoj/processor/markdown/markdown_to_jsonl.py b/src/khoj/processor/markdown/markdown_to_jsonl.py index 98e5d924..c661c1aa 100644 --- a/src/khoj/processor/markdown/markdown_to_jsonl.py +++ b/src/khoj/processor/markdown/markdown_to_jsonl.py @@ -20,7 +20,11 @@ class MarkdownToJsonl(TextToJsonl): # Define Functions def process(self, previous_entries=None): # Extract required fields from config - markdown_files, markdown_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl + markdown_files, markdown_file_filter, output_file = ( + self.config.input_files, + self.config.input_filter, + self.config.compressed_jsonl, + ) # Input Validation if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter): @@ -32,7 +36,9 @@ class MarkdownToJsonl(TextToJsonl): # Extract Entries from specified Markdown files with timer("Parse entries from Markdown files into dictionaries", logger): - current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files)) + current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps( + *MarkdownToJsonl.extract_markdown_entries(markdown_files) + ) # Split entries by max tokens supported by model with timer("Split entries by max token size supported by model", logger): @@ -43,7 +49,9 @@ class MarkdownToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) with timer("Write markdown entries to JSONL file", logger): # Process Each Entry from All Notes Files @@ -75,15 +83,16 @@ class MarkdownToJsonl(TextToJsonl): files_with_non_markdown_extensions = { md_file - for md_file - in all_markdown_files - if not md_file.endswith(".md") and not md_file.endswith('.markdown') + for md_file in all_markdown_files + if not md_file.endswith(".md") and not md_file.endswith(".markdown") } if any(files_with_non_markdown_extensions): - logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}") + logger.warn( + f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}" + ) - logger.info(f'Processing files: {all_markdown_files}') + logger.info(f"Processing files: {all_markdown_files}") return all_markdown_files @@ -92,20 +101,20 @@ class MarkdownToJsonl(TextToJsonl): "Extract entries by heading from specified Markdown files" # Regex to extract Markdown Entries by Heading - markdown_heading_regex = r'^#' + markdown_heading_regex = r"^#" entries = [] entry_to_file_map = [] for markdown_file in markdown_files: - with open(markdown_file, 'r', encoding='utf8') as f: + with open(markdown_file, "r", encoding="utf8") as f: markdown_content = f.read() markdown_entries_per_file = [] for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE): - prefix = '#' if entry.startswith('#') else '# ' - if entry.strip(empty_escape_sequences) != '': - markdown_entries_per_file.append(f'{prefix}{entry.strip(empty_escape_sequences)}') + prefix = "#" if entry.startswith("#") else "# " + if entry.strip(empty_escape_sequences) != "": + markdown_entries_per_file.append(f"{prefix}{entry.strip(empty_escape_sequences)}") - entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file)) + entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file)) entries.extend(markdown_entries_per_file) return entries, dict(entry_to_file_map) @@ -115,7 +124,7 @@ class MarkdownToJsonl(TextToJsonl): "Convert each Markdown entries into a dictionary" entries = [] for parsed_entry in parsed_entries: - entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}')) + entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{entry_to_file_map[parsed_entry]}")) logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries") @@ -124,4 +133,4 @@ class MarkdownToJsonl(TextToJsonl): @staticmethod def convert_markdown_maps_to_jsonl(entries: List[Entry]): "Convert each Markdown entry to JSON and collate as JSONL" - return ''.join([f'{entry.to_json()}\n' for entry in entries]) + return "".join([f"{entry.to_json()}\n" for entry in entries]) diff --git a/src/khoj/processor/org_mode/org_to_jsonl.py b/src/khoj/processor/org_mode/org_to_jsonl.py index 2938413f..05eff771 100644 --- a/src/khoj/processor/org_mode/org_to_jsonl.py +++ b/src/khoj/processor/org_mode/org_to_jsonl.py @@ -18,9 +18,13 @@ logger = logging.getLogger(__name__) class OrgToJsonl(TextToJsonl): # Define Functions - def process(self, previous_entries: List[Entry]=None): + def process(self, previous_entries: List[Entry] = None): # Extract required fields from config - org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl + org_files, org_file_filter, output_file = ( + self.config.input_files, + self.config.input_filter, + self.config.compressed_jsonl, + ) index_heading_entries = self.config.index_heading_entries # Input Validation @@ -46,7 +50,9 @@ class OrgToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) # Process Each Entry from All Notes Files with timer("Write org entries to JSONL file", logger): @@ -66,11 +72,7 @@ class OrgToJsonl(TextToJsonl): "Get Org files to process" absolute_org_files, filtered_org_files = set(), set() if org_files: - absolute_org_files = { - get_absolute_path(org_file) - for org_file - in org_files - } + absolute_org_files = {get_absolute_path(org_file) for org_file in org_files} if org_file_filters: filtered_org_files = { filtered_file @@ -84,7 +86,7 @@ class OrgToJsonl(TextToJsonl): if any(files_with_non_org_extensions): logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}") - logger.info(f'Processing files: {all_org_files}') + logger.info(f"Processing files: {all_org_files}") return all_org_files @@ -95,13 +97,15 @@ class OrgToJsonl(TextToJsonl): entry_to_file_map = [] for org_file in org_files: org_file_entries = orgnode.makelist(str(org_file)) - entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries)) + entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries)) entries.extend(org_file_entries) return entries, dict(entry_to_file_map) @staticmethod - def convert_org_nodes_to_entries(parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> List[Entry]: + def convert_org_nodes_to_entries( + parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False + ) -> List[Entry]: "Convert Org-Mode nodes into list of Entry objects" entries: List[Entry] = [] for parsed_entry in parsed_entries: @@ -109,13 +113,13 @@ class OrgToJsonl(TextToJsonl): # Ignore title notes i.e notes with just headings and empty body continue - compiled = f'{parsed_entry.heading}.' + compiled = f"{parsed_entry.heading}." if state.verbose > 2: logger.debug(f"Title: {parsed_entry.heading}") if parsed_entry.tags: tags_str = " ".join(parsed_entry.tags) - compiled += f'\t {tags_str}.' + compiled += f"\t {tags_str}." if state.verbose > 2: logger.debug(f"Tags: {tags_str}") @@ -130,19 +134,16 @@ class OrgToJsonl(TextToJsonl): logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}') if parsed_entry.hasBody: - compiled += f'\n {parsed_entry.body}' + compiled += f"\n {parsed_entry.body}" if state.verbose > 2: logger.debug(f"Body: {parsed_entry.body}") if compiled: - entries += [Entry( - compiled=compiled, - raw=f'{parsed_entry}', - file=f'{entry_to_file_map[parsed_entry]}')] + entries += [Entry(compiled=compiled, raw=f"{parsed_entry}", file=f"{entry_to_file_map[parsed_entry]}")] return entries @staticmethod def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str: "Convert each Org-Mode entry to JSON and collate as JSONL" - return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries]) + return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries]) diff --git a/src/khoj/processor/org_mode/orgnode.py b/src/khoj/processor/org_mode/orgnode.py index 7675650b..e352ecf2 100644 --- a/src/khoj/processor/org_mode/orgnode.py +++ b/src/khoj/processor/org_mode/orgnode.py @@ -39,182 +39,197 @@ from pathlib import Path from os.path import relpath from typing import List -indent_regex = re.compile(r'^ *') +indent_regex = re.compile(r"^ *") + def normalize_filename(filename): - "Normalize and escape filename for rendering" - if not Path(filename).is_absolute(): - # Normalize relative filename to be relative to current directory - normalized_filename = f'~/{relpath(filename, start=Path.home())}' - else: - normalized_filename = filename - escaped_filename = f'{normalized_filename}'.replace("[","\[").replace("]","\]") - return escaped_filename + "Normalize and escape filename for rendering" + if not Path(filename).is_absolute(): + # Normalize relative filename to be relative to current directory + normalized_filename = f"~/{relpath(filename, start=Path.home())}" + else: + normalized_filename = filename + escaped_filename = f"{normalized_filename}".replace("[", "\[").replace("]", "\]") + return escaped_filename + def makelist(filename): - """ - Read an org-mode file and return a list of Orgnode objects - created from this file. - """ - ctr = 0 + """ + Read an org-mode file and return a list of Orgnode objects + created from this file. + """ + ctr = 0 - f = open(filename, 'r') + f = open(filename, "r") - todos = { "TODO": "", "WAITING": "", "ACTIVE": "", - "DONE": "", "CANCELLED": "", "FAILED": ""} # populated from #+SEQ_TODO line - level = "" - heading = "" - bodytext = "" - tags = list() # set of all tags in headline - closed_date = '' - sched_date = '' - deadline_date = '' - logbook = list() - nodelist: List[Orgnode] = list() - property_map = dict() - in_properties_drawer = False - in_logbook_drawer = False - file_title = f'{filename}' + todos = { + "TODO": "", + "WAITING": "", + "ACTIVE": "", + "DONE": "", + "CANCELLED": "", + "FAILED": "", + } # populated from #+SEQ_TODO line + level = "" + heading = "" + bodytext = "" + tags = list() # set of all tags in headline + closed_date = "" + sched_date = "" + deadline_date = "" + logbook = list() + nodelist: List[Orgnode] = list() + property_map = dict() + in_properties_drawer = False + in_logbook_drawer = False + file_title = f"{filename}" - for line in f: - ctr += 1 - heading_search = re.search(r'^(\*+)\s(.*?)\s*$', line) - if heading_search: # we are processing a heading line - if heading: # if we have are on second heading, append first heading to headings list - thisNode = Orgnode(level, heading, bodytext, tags) - if closed_date: - thisNode.closed = closed_date - closed_date = '' - if sched_date: - thisNode.scheduled = sched_date - sched_date = "" - if deadline_date: - thisNode.deadline = deadline_date - deadline_date = '' - if logbook: - thisNode.logbook = logbook - logbook = list() - thisNode.properties = property_map - nodelist.append( thisNode ) - property_map = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'} - level = heading_search.group(1) - heading = heading_search.group(2) - bodytext = "" - tags = list() # set of all tags in headline - tag_search = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading) - if tag_search: - heading = tag_search.group(1) - parsedtags = tag_search.group(2) - if parsedtags: - for parsedtag in parsedtags.split(':'): - if parsedtag != '': tags.append(parsedtag) - else: # we are processing a non-heading line - if line[:10] == '#+SEQ_TODO': - kwlist = re.findall(r'([A-Z]+)\(', line) - for kw in kwlist: todos[kw] = "" + for line in f: + ctr += 1 + heading_search = re.search(r"^(\*+)\s(.*?)\s*$", line) + if heading_search: # we are processing a heading line + if heading: # if we have are on second heading, append first heading to headings list + thisNode = Orgnode(level, heading, bodytext, tags) + if closed_date: + thisNode.closed = closed_date + closed_date = "" + if sched_date: + thisNode.scheduled = sched_date + sched_date = "" + if deadline_date: + thisNode.deadline = deadline_date + deadline_date = "" + if logbook: + thisNode.logbook = logbook + logbook = list() + thisNode.properties = property_map + nodelist.append(thisNode) + property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"} + level = heading_search.group(1) + heading = heading_search.group(2) + bodytext = "" + tags = list() # set of all tags in headline + tag_search = re.search(r"(.*?)\s*:([a-zA-Z0-9].*?):$", heading) + if tag_search: + heading = tag_search.group(1) + parsedtags = tag_search.group(2) + if parsedtags: + for parsedtag in parsedtags.split(":"): + if parsedtag != "": + tags.append(parsedtag) + else: # we are processing a non-heading line + if line[:10] == "#+SEQ_TODO": + kwlist = re.findall(r"([A-Z]+)\(", line) + for kw in kwlist: + todos[kw] = "" - # Set file title to TITLE property, if it exists - title_search = re.search(r'^#\+TITLE:\s*(.*)$', line) - if title_search and title_search.group(1).strip() != '': - title_text = title_search.group(1).strip() - if file_title == f'{filename}': - file_title = title_text - else: - file_title += f' {title_text}' - continue + # Set file title to TITLE property, if it exists + title_search = re.search(r"^#\+TITLE:\s*(.*)$", line) + if title_search and title_search.group(1).strip() != "": + title_text = title_search.group(1).strip() + if file_title == f"{filename}": + file_title = title_text + else: + file_title += f" {title_text}" + continue - # Ignore Properties Drawers Completely - if re.search(':PROPERTIES:', line): - in_properties_drawer=True - continue - if in_properties_drawer and re.search(':END:', line): - in_properties_drawer=False - continue + # Ignore Properties Drawers Completely + if re.search(":PROPERTIES:", line): + in_properties_drawer = True + continue + if in_properties_drawer and re.search(":END:", line): + in_properties_drawer = False + continue - # Ignore Logbook Drawer Start, End Lines - if re.search(':LOGBOOK:', line): - in_logbook_drawer=True - continue - if in_logbook_drawer and re.search(':END:', line): - in_logbook_drawer=False - continue + # Ignore Logbook Drawer Start, End Lines + if re.search(":LOGBOOK:", line): + in_logbook_drawer = True + continue + if in_logbook_drawer and re.search(":END:", line): + in_logbook_drawer = False + continue - # Extract Clocking Lines - clocked_re = re.search(r'CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]', line) - if clocked_re: - # convert clock in, clock out strings to datetime objects - clocked_in = datetime.datetime.strptime(clocked_re.group(1), '%Y-%m-%d %a %H:%M') - clocked_out = datetime.datetime.strptime(clocked_re.group(2), '%Y-%m-%d %a %H:%M') - # add clocked time to the entries logbook list - logbook += [(clocked_in, clocked_out)] - line = "" + # Extract Clocking Lines + clocked_re = re.search( + r"CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]", + line, + ) + if clocked_re: + # convert clock in, clock out strings to datetime objects + clocked_in = datetime.datetime.strptime(clocked_re.group(1), "%Y-%m-%d %a %H:%M") + clocked_out = datetime.datetime.strptime(clocked_re.group(2), "%Y-%m-%d %a %H:%M") + # add clocked time to the entries logbook list + logbook += [(clocked_in, clocked_out)] + line = "" - property_search = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line) - if property_search: - # Set ID property to an id based org-mode link to the entry - if property_search.group(1) == 'ID': - property_map['ID'] = f'id:{property_search.group(2)}' - else: - property_map[property_search.group(1)] = property_search.group(2) - continue + property_search = re.search(r"^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$", line) + if property_search: + # Set ID property to an id based org-mode link to the entry + if property_search.group(1) == "ID": + property_map["ID"] = f"id:{property_search.group(2)}" + else: + property_map[property_search.group(1)] = property_search.group(2) + continue - cd_re = re.search(r'CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})', line) - if cd_re: - closed_date = datetime.date(int(cd_re.group(1)), - int(cd_re.group(2)), - int(cd_re.group(3)) ) - sd_re = re.search(r'SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)', line) - if sd_re: - sched_date = datetime.date(int(sd_re.group(1)), - int(sd_re.group(2)), - int(sd_re.group(3)) ) - dd_re = re.search(r'DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)', line) - if dd_re: - deadline_date = datetime.date(int(dd_re.group(1)), - int(dd_re.group(2)), - int(dd_re.group(3)) ) + cd_re = re.search(r"CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})", line) + if cd_re: + closed_date = datetime.date(int(cd_re.group(1)), int(cd_re.group(2)), int(cd_re.group(3))) + sd_re = re.search(r"SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)", line) + if sd_re: + sched_date = datetime.date(int(sd_re.group(1)), int(sd_re.group(2)), int(sd_re.group(3))) + dd_re = re.search(r"DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)", line) + if dd_re: + deadline_date = datetime.date(int(dd_re.group(1)), int(dd_re.group(2)), int(dd_re.group(3))) - # Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body - if not in_properties_drawer and not cd_re and not sd_re and not dd_re and not clocked_re and line[:1] != '#': - bodytext = bodytext + line + # Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body + if ( + not in_properties_drawer + and not cd_re + and not sd_re + and not dd_re + and not clocked_re + and line[:1] != "#" + ): + bodytext = bodytext + line - # write out last node - thisNode = Orgnode(level, heading or file_title, bodytext, tags) - thisNode.properties = property_map - if sched_date: - thisNode.scheduled = sched_date - if deadline_date: - thisNode.deadline = deadline_date - if closed_date: - thisNode.closed = closed_date - if logbook: - thisNode.logbook = logbook - nodelist.append( thisNode ) + # write out last node + thisNode = Orgnode(level, heading or file_title, bodytext, tags) + thisNode.properties = property_map + if sched_date: + thisNode.scheduled = sched_date + if deadline_date: + thisNode.deadline = deadline_date + if closed_date: + thisNode.closed = closed_date + if logbook: + thisNode.logbook = logbook + nodelist.append(thisNode) - # using the list of TODO keywords found in the file - # process the headings searching for TODO keywords - for n in nodelist: - todo_search = re.search(r'([A-Z]+)\s(.*?)$', n.heading) - if todo_search: - if todo_search.group(1) in todos: - n.heading = todo_search.group(2) - n.todo = todo_search.group(1) + # using the list of TODO keywords found in the file + # process the headings searching for TODO keywords + for n in nodelist: + todo_search = re.search(r"([A-Z]+)\s(.*?)$", n.heading) + if todo_search: + if todo_search.group(1) in todos: + n.heading = todo_search.group(2) + n.todo = todo_search.group(1) - # extract, set priority from heading, update heading if necessary - priority_search = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.heading) - if priority_search: - n.priority = priority_search.group(1) - n.heading = priority_search.group(2) + # extract, set priority from heading, update heading if necessary + priority_search = re.search(r"^\[\#(A|B|C)\] (.*?)$", n.heading) + if priority_search: + n.priority = priority_search.group(1) + n.heading = priority_search.group(2) - # Set SOURCE property to a file+heading based org-mode link to the entry - if n.level == 0: - n.properties['LINE'] = f'file:{normalize_filename(filename)}::0' - n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}]]' - else: - escaped_heading = n.heading.replace("[","\\[").replace("]","\\]") - n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]' + # Set SOURCE property to a file+heading based org-mode link to the entry + if n.level == 0: + n.properties["LINE"] = f"file:{normalize_filename(filename)}::0" + n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}]]" + else: + escaped_heading = n.heading.replace("[", "\\[").replace("]", "\\]") + n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}::*{escaped_heading}]]" + + return nodelist - return nodelist ###################### class Orgnode(object): @@ -222,6 +237,7 @@ class Orgnode(object): Orgnode class represents a headline, tags and text associated with the headline. """ + def __init__(self, level, headline, body, tags): """ Create an Orgnode object given the parameters of level (as the @@ -232,14 +248,14 @@ class Orgnode(object): self._level = len(level) self._heading = headline self._body = body - self._tags = tags # All tags in the headline + self._tags = tags # All tags in the headline self._todo = "" - self._priority = "" # empty of A, B or C - self._scheduled = "" # Scheduled date - self._deadline = "" # Deadline date - self._closed = "" # Closed date + self._priority = "" # empty of A, B or C + self._scheduled = "" # Scheduled date + self._deadline = "" # Deadline date + self._closed = "" # Closed date self._properties = dict() - self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries + self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries # Look for priority in headline and transfer to prty field @@ -270,7 +286,7 @@ class Orgnode(object): """ Returns True if node has non empty body, else False """ - return self._body and re.sub(r'\n|\t|\r| ', '', self._body) != '' + return self._body and re.sub(r"\n|\t|\r| ", "", self._body) != "" @property def level(self): @@ -417,20 +433,20 @@ class Orgnode(object): text as used to construct the node. """ # Output heading line - n = '' + n = "" for _ in range(0, self._level): - n = n + '*' - n = n + ' ' + n = n + "*" + n = n + " " if self._todo: - n = n + self._todo + ' ' + n = n + self._todo + " " if self._priority: - n = n + '[#' + self._priority + '] ' + n = n + "[#" + self._priority + "] " n = n + self._heading - n = "%-60s " % n # hack - tags will start in column 62 - closecolon = '' + n = "%-60s " % n # hack - tags will start in column 62 + closecolon = "" for t in self._tags: - n = n + ':' + t - closecolon = ':' + n = n + ":" + t + closecolon = ":" n = n + closecolon n = n + "\n" @@ -439,24 +455,24 @@ class Orgnode(object): # Output Closed Date, Scheduled Date, Deadline Date if self._closed or self._scheduled or self._deadline: - n = n + indent + n = n + indent if self._closed: - n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] ' + n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] ' if self._scheduled: - n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> ' + n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> ' if self._deadline: - n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> ' + n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> ' if self._closed or self._scheduled or self._deadline: - n = n + '\n' + n = n + "\n" # Ouput Property Drawer n = n + indent + ":PROPERTIES:\n" for key, value in self._properties.items(): - n = n + indent + f":{key}: {value}\n" + n = n + indent + f":{key}: {value}\n" n = n + indent + ":END:\n" # Output Body if self.hasBody: - n = n + self._body + n = n + self._body return n diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py index 80063128..570c22bb 100644 --- a/src/khoj/processor/text_to_jsonl.py +++ b/src/khoj/processor/text_to_jsonl.py @@ -17,14 +17,17 @@ class TextToJsonl(ABC): self.config = config @abstractmethod - def process(self, previous_entries: List[Entry]=None) -> List[Tuple[int, Entry]]: ... + def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]: + ... @staticmethod def hash_func(key: str) -> Callable: - return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest() + return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest() @staticmethod - def split_entries_by_max_tokens(entries: List[Entry], max_tokens: int=256, max_word_length: int=500) -> List[Entry]: + def split_entries_by_max_tokens( + entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500 + ) -> List[Entry]: "Split entries if compiled entry length exceeds the max tokens supported by the ML model." chunked_entries: List[Entry] = [] for entry in entries: @@ -32,13 +35,15 @@ class TextToJsonl(ABC): # Drop long words instead of having entry truncated to maintain quality of entry processed by models compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length] for chunk_index in range(0, len(compiled_entry_words), max_tokens): - compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens] - compiled_entry_chunk = ' '.join(compiled_entry_words_chunk) + compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens] + compiled_entry_chunk = " ".join(compiled_entry_words_chunk) entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file) chunked_entries.append(entry_chunk) return chunked_entries - def mark_entries_for_update(self, current_entries: List[Entry], previous_entries: List[Entry], key='compiled', logger=None) -> List[Tuple[int, Entry]]: + def mark_entries_for_update( + self, current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None + ) -> List[Tuple[int, Entry]]: # Hash all current and previous entries to identify new entries with timer("Hash previous, current entries", logger): current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries)) @@ -54,10 +59,7 @@ class TextToJsonl(ABC): existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) # Mark new entries with -1 id to flag for later embeddings generation - new_entries = [ - (-1, hash_to_current_entries[entry_hash]) - for entry_hash in new_entry_hashes - ] + new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes] # Set id of existing entries to their previous ids to reuse their existing encoded embeddings existing_entries = [ (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) @@ -67,4 +69,4 @@ class TextToJsonl(ABC): existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) entries_with_ids = existing_entries_sorted + new_entries - return entries_with_ids \ No newline at end of file + return entries_with_ids diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ea9b3334..3c946d9f 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -22,27 +22,30 @@ logger = logging.getLogger(__name__) # Create Routes -@api.get('/config/data/default') +@api.get("/config/data/default") def get_default_config_data(): return constants.default_config -@api.get('/config/data', response_model=FullConfig) + +@api.get("/config/data", response_model=FullConfig) def get_config_data(): return state.config -@api.post('/config/data') + +@api.post("/config/data") async def set_config_data(updated_config: FullConfig): state.config = updated_config - with open(state.config_file, 'w') as outfile: + with open(state.config_file, "w") as outfile: yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile) outfile.close() return state.config -@api.get('/search', response_model=List[SearchResponse]) + +@api.get("/search", response_model=List[SearchResponse]) def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): results: List[SearchResponse] = [] - if q is None or q == '': - logger.info(f'No query param (q) passed in API call to initiate search') + if q is None or q == "": + logger.info(f"No query param (q) passed in API call to initiate search") return results # initialize variables @@ -50,9 +53,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti results_count = n # return cached results, if available - query_cache_key = f'{user_query}-{n}-{t}-{r}' + query_cache_key = f"{user_query}-{n}-{t}-{r}" if query_cache_key in state.query_cache: - logger.info(f'Return response from query cache') + logger.info(f"Return response from query cache") return state.query_cache[query_cache_key] if (t == SearchType.Org or t == None) and state.model.orgmode_search: @@ -95,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti # query images with timer("Query took", logger): hits = image_search.query(user_query, results_count, state.model.image_search) - output_directory = constants.web_directory / 'images' + output_directory = constants.web_directory / "images" # collate and return results with timer("Collating results took", logger): @@ -103,8 +106,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti hits, image_names=state.model.image_search.image_names, output_directory=output_directory, - image_files_url='/static/images', - count=results_count) + image_files_url="/static/images", + count=results_count, + ) # Cache results state.query_cache[query_cache_key] = results @@ -112,7 +116,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti return results -@api.get('/update') +@api.get("/update") def update(t: Optional[SearchType] = None, force: Optional[bool] = False): try: state.search_index_lock.acquire() @@ -132,4 +136,4 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False): else: logger.info("Processor reconfigured via API call") - return {'status': 'ok', 'message': 'khoj reloaded'} + return {"status": "ok", "message": "khoj reloaded"} diff --git a/src/khoj/routers/api_beta.py b/src/khoj/routers/api_beta.py index 7e4d4c7a..e05b53dd 100644 --- a/src/khoj/routers/api_beta.py +++ b/src/khoj/routers/api_beta.py @@ -9,7 +9,14 @@ from fastapi import APIRouter # Internal Packages from khoj.routers.api import search -from khoj.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize +from khoj.processor.conversation.gpt import ( + converse, + extract_search_type, + message_to_log, + message_to_prompt, + understand, + summarize, +) from khoj.utils.config import SearchType from khoj.utils.helpers import get_from_dict, resolve_absolute_path from khoj.utils import state @@ -21,7 +28,7 @@ logger = logging.getLogger(__name__) # Create Routes -@api_beta.get('/search') +@api_beta.get("/search") def search_beta(q: str, n: Optional[int] = 1): # Initialize Variables model = state.processor_config.conversation.model @@ -32,16 +39,16 @@ def search_beta(q: str, n: Optional[int] = 1): metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose) search_type = get_from_dict(metadata, "search-type") except Exception as e: - return {'status': 'error', 'result': [str(e)], 'type': None} + return {"status": "error", "result": [str(e)], "type": None} # Search search_results = search(q, n=n, t=SearchType(search_type)) # Return response - return {'status': 'ok', 'result': search_results, 'type': search_type} + return {"status": "ok", "result": search_results, "type": search_type} -@api_beta.get('/summarize') +@api_beta.get("/summarize") def summarize_beta(q: str): # Initialize Variables model = state.processor_config.conversation.model @@ -54,23 +61,25 @@ def summarize_beta(q: str): # Converse with OpenAI GPT result_list = search(q, n=1, r=True) collated_result = "\n".join([item.entry for item in result_list]) - logger.debug(f'Semantically Similar Notes:\n{collated_result}') + logger.debug(f"Semantically Similar Notes:\n{collated_result}") try: gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key) - status = 'ok' + status = "ok" except Exception as e: gpt_response = str(e) - status = 'error' + status = "error" # Update Conversation History state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) - state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, conversation_log=meta_log.get('chat', [])) + state.processor_config.conversation.meta_log["chat"] = message_to_log( + q, gpt_response, conversation_log=meta_log.get("chat", []) + ) - return {'status': status, 'response': gpt_response} + return {"status": status, "response": gpt_response} -@api_beta.get('/chat') -def chat(q: Optional[str]=None): +@api_beta.get("/chat") +def chat(q: Optional[str] = None): # Initialize Variables model = state.processor_config.conversation.model api_key = state.processor_config.conversation.openai_api_key @@ -81,10 +90,10 @@ def chat(q: Optional[str]=None): # If user query is empty, return chat history if not q: - if meta_log.get('chat'): - return {'status': 'ok', 'response': meta_log["chat"]} + if meta_log.get("chat"): + return {"status": "ok", "response": meta_log["chat"]} else: - return {'status': 'ok', 'response': []} + return {"status": "ok", "response": []} # Converse with OpenAI GPT metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose) @@ -94,32 +103,39 @@ def chat(q: Optional[str]=None): query = get_from_dict(metadata, "intent", "query") result_list = search(query, n=1, t=SearchType.Org, r=True) collated_result = "\n".join([item.entry for item in result_list]) - logger.debug(f'Semantically Similar Notes:\n{collated_result}') + logger.debug(f"Semantically Similar Notes:\n{collated_result}") try: gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key) - status = 'ok' + status = "ok" except Exception as e: gpt_response = str(e) - status = 'error' + status = "error" else: try: gpt_response = converse(q, model, chat_session, api_key=api_key) - status = 'ok' + status = "ok" except Exception as e: gpt_response = str(e) - status = 'error' + status = "error" # Update Conversation History state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) - state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, metadata, meta_log.get('chat', [])) + state.processor_config.conversation.meta_log["chat"] = message_to_log( + q, gpt_response, metadata, meta_log.get("chat", []) + ) - return {'status': status, 'response': gpt_response} + return {"status": status, "response": gpt_response} @schedule.repeat(schedule.every(5).minutes) def save_chat_session(): # No need to create empty log file - if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log and state.processor_config.conversation.chat_session): + if not ( + state.processor_config + and state.processor_config.conversation + and state.processor_config.conversation.meta_log + and state.processor_config.conversation.chat_session + ): return # Summarize Conversation Logs for this Session @@ -130,19 +146,19 @@ def save_chat_session(): session = { "summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key), "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], - "session-end": len(conversation_log["chat"]) - } - if 'session' in conversation_log: - conversation_log['session'].append(session) + "session-end": len(conversation_log["chat"]), + } + if "session" in conversation_log: + conversation_log["session"].append(session) else: - conversation_log['session'] = [session] - logger.info('Added new chat session to conversation logs') + conversation_log["session"] = [session] + logger.info("Added new chat session to conversation logs") # Save Conversation Metadata Logs to Disk conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile) - conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist - with open(conversation_logfile, "w+", encoding='utf-8') as logfile: + conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist + with open(conversation_logfile, "w+", encoding="utf-8") as logfile: json.dump(conversation_log, logfile) state.processor_config.conversation.chat_session = None - logger.info('Saved updated conversation logs to disk.') + logger.info("Saved updated conversation logs to disk.") diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 79e47375..c452efdf 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -18,10 +18,12 @@ templates = Jinja2Templates(directory=constants.web_directory) def index(): return FileResponse(constants.web_directory / "index.html") -@web_client.get('/config', response_class=HTMLResponse) + +@web_client.get("/config", response_class=HTMLResponse) def config_page(request: Request): - return templates.TemplateResponse("config.html", context={'request': request}) + return templates.TemplateResponse("config.html", context={"request": request}) + @web_client.get("/chat", response_class=FileResponse) def chat_page(): - return FileResponse(constants.web_directory / "chat.html") \ No newline at end of file + return FileResponse(constants.web_directory / "chat.html") diff --git a/src/khoj/search_filter/base_filter.py b/src/khoj/search_filter/base_filter.py index c27a051c..c273f9b8 100644 --- a/src/khoj/search_filter/base_filter.py +++ b/src/khoj/search_filter/base_filter.py @@ -8,10 +8,13 @@ from khoj.utils.rawconfig import Entry class BaseFilter(ABC): @abstractmethod - def load(self, entries: List[Entry], *args, **kwargs): ... + def load(self, entries: List[Entry], *args, **kwargs): + ... @abstractmethod - def can_filter(self, raw_query:str) -> bool: ... + def can_filter(self, raw_query: str) -> bool: + ... @abstractmethod - def apply(self, query:str, entries: List[Entry]) -> Tuple[str, Set[int]]: ... + def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]: + ... diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index 13067562..22bfda3f 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -26,21 +26,19 @@ class DateFilter(BaseFilter): # - dt:"2 years ago" date_regex = r"dt([:><=]{1,2})\"(.*?)\"" - - def __init__(self, entry_key='raw'): + def __init__(self, entry_key="raw"): self.entry_key = entry_key self.date_to_entry_ids = defaultdict(set) self.cache = LRU() - def load(self, entries, *args, **kwargs): with timer("Created date filter index", logger): for id, entry in enumerate(entries): # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)): + for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)): # Convert date string in entry to unix timestamp try: - date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() + date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp() except ValueError: continue self.date_to_entry_ids[date_in_entry].add(id) @@ -49,7 +47,6 @@ class DateFilter(BaseFilter): "Check if query contains date filters" return self.extract_date_range(raw_query) is not None - def apply(self, query, entries): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query @@ -61,8 +58,8 @@ class DateFilter(BaseFilter): return query, set(range(len(entries))) # remove date range filter from query - query = re.sub(rf'\s+{self.date_regex}', ' ', query) - query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces + query = re.sub(rf"\s+{self.date_regex}", " ", query) + query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces # return results from cache if exists cache_key = tuple(query_daterange) @@ -87,7 +84,6 @@ class DateFilter(BaseFilter): return query, entries_to_include - def extract_date_range(self, query): # find date range filter in query date_range_matches = re.findall(self.date_regex, query) @@ -98,7 +94,7 @@ class DateFilter(BaseFilter): # extract, parse natural dates ranges from date range filter passed in query # e.g today maps to (start_of_day, start_of_tomorrow) date_ranges_from_filter = [] - for (cmp, date_str) in date_range_matches: + for cmp, date_str in date_range_matches: if self.parse(date_str): dt_start, dt_end = self.parse(date_str) date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]] @@ -111,15 +107,15 @@ class DateFilter(BaseFilter): effective_date_range = [0, inf] date_range_considering_comparator = [] for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter: - if cmp == '>': + if cmp == ">": date_range_considering_comparator += [[dtrange_end, inf]] - elif cmp == '>=': + elif cmp == ">=": date_range_considering_comparator += [[dtrange_start, inf]] - elif cmp == '<': + elif cmp == "<": date_range_considering_comparator += [[0, dtrange_start]] - elif cmp == '<=': + elif cmp == "<=": date_range_considering_comparator += [[0, dtrange_end]] - elif cmp == '=' or cmp == ':' or cmp == '==': + elif cmp == "=" or cmp == ":" or cmp == "==": date_range_considering_comparator += [[dtrange_start, dtrange_end]] # Combine above intervals (via AND/intersect) @@ -129,48 +125,48 @@ class DateFilter(BaseFilter): for date_range in date_range_considering_comparator: effective_date_range = [ max(effective_date_range[0], date_range[0]), - min(effective_date_range[1], date_range[1])] + min(effective_date_range[1], date_range[1]), + ] if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]: return None else: return effective_date_range - def parse(self, date_str, relative_base=None): "Parse date string passed in date filter of query to datetime object" # clean date string to handle future date parsing by date parser - future_strings = ['later', 'from now', 'from today'] - prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])] - clean_date_str = re.sub('|'.join(future_strings), '', date_str) + future_strings = ["later", "from now", "from today"] + prefer_dates_from = {True: "future", False: "past"}[any([True for fstr in future_strings if fstr in date_str])] + clean_date_str = re.sub("|".join(future_strings), "", date_str) # parse date passed in query date filter parsed_date = dtparse.parse( clean_date_str, - settings= { - 'RELATIVE_BASE': relative_base or datetime.now(), - 'PREFER_DAY_OF_MONTH': 'first', - 'PREFER_DATES_FROM': prefer_dates_from - }) + settings={ + "RELATIVE_BASE": relative_base or datetime.now(), + "PREFER_DAY_OF_MONTH": "first", + "PREFER_DATES_FROM": prefer_dates_from, + }, + ) if parsed_date is None: return None return self.date_to_daterange(parsed_date, date_str) - def date_to_daterange(self, parsed_date, date_str): "Convert parsed date to date ranges at natural granularity (day, week, month or year)" start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0) - if 'year' in date_str: - return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0)) - if 'month' in date_str: + if "year" in date_str: + return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year + 1, 1, 1, 0, 0, 0)) + if "month" in date_str: start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0) next_month = start_of_month + relativedelta(months=1) return (start_of_month, next_month) - if 'week' in date_str: + if "week" in date_str: # if week in date string, dateparser parses it to next week start # so today = end of this week start_of_week = start_of_day - timedelta(days=7) diff --git a/src/khoj/search_filter/file_filter.py b/src/khoj/search_filter/file_filter.py index 970da150..39bc10ba 100644 --- a/src/khoj/search_filter/file_filter.py +++ b/src/khoj/search_filter/file_filter.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class FileFilter(BaseFilter): file_filter_regex = r'file:"(.+?)" ?' - def __init__(self, entry_key='file'): + def __init__(self, entry_key="file"): self.entry_key = entry_key self.file_to_entry_map = defaultdict(set) self.cache = LRU() @@ -40,13 +40,13 @@ class FileFilter(BaseFilter): # e.g. "file:notes.org" -> "file:.*notes.org" files_to_search = [] for file in sorted(raw_files_to_search): - if '/' not in file and '\\' not in file and '*' not in file: - files_to_search += [f'*{file}'] + if "/" not in file and "\\" not in file and "*" not in file: + files_to_search += [f"*{file}"] else: files_to_search += [file] # Return item from cache if exists - query = re.sub(self.file_filter_regex, '', query).strip() + query = re.sub(self.file_filter_regex, "", query).strip() cache_key = tuple(files_to_search) if cache_key in self.cache: logger.info(f"Return file filter results from cache") @@ -58,10 +58,15 @@ class FileFilter(BaseFilter): # Mark entries that contain any blocked_words for exclusion with timer("Mark entries satisfying filter", logger): - included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] + included_entry_indices = set.union( + *[ + self.file_to_entry_map[entry_file] for entry_file in self.file_to_entry_map.keys() for search_file in files_to_search - if fnmatch.fnmatch(entry_file, search_file)], set()) + if fnmatch.fnmatch(entry_file, search_file) + ], + set(), + ) if not included_entry_indices: return query, {} diff --git a/src/khoj/search_filter/word_filter.py b/src/khoj/search_filter/word_filter.py index 16ed633d..9aab259a 100644 --- a/src/khoj/search_filter/word_filter.py +++ b/src/khoj/search_filter/word_filter.py @@ -17,26 +17,26 @@ class WordFilter(BaseFilter): required_regex = r'\+"([a-zA-Z0-9_-]+)" ?' blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?' - def __init__(self, entry_key='raw'): + def __init__(self, entry_key="raw"): self.entry_key = entry_key self.word_to_entry_index = defaultdict(set) self.cache = LRU() - def load(self, entries, *args, **kwargs): with timer("Created word filter index", logger): self.cache = {} # Clear cache on filter (re-)load - entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' + entry_splitter = ( + r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'" + ) # Create map of words to entries they exist in for entry_index, entry in enumerate(entries): for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()): - if word == '': + if word == "": continue self.word_to_entry_index[word].add(entry_index) return self.word_to_entry_index - def can_filter(self, raw_query): "Check if query contains word filters" required_words = re.findall(self.required_regex, raw_query) @@ -44,14 +44,13 @@ class WordFilter(BaseFilter): return len(required_words) != 0 or len(blocked_words) != 0 - def apply(self, query, entries): "Find entries containing required and not blocked words specified in query" # Separate natural query from required, blocked words filters with timer("Extract required, blocked filters from query", logger): required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) - query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip() + query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() if len(required_words) == 0 and len(blocked_words) == 0: return query, set(range(len(entries))) @@ -70,12 +69,16 @@ class WordFilter(BaseFilter): with timer("Mark entries satisfying filter", logger): entries_with_all_required_words = set(range(len(entries))) if len(required_words) > 0: - entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) + entries_with_all_required_words = set.intersection( + *[self.word_to_entry_index.get(word, set()) for word in required_words] + ) # mark entries that contain any blocked_words for exclusion entries_with_any_blocked_words = set() if len(blocked_words) > 0: - entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) + entries_with_any_blocked_words = set.union( + *[self.word_to_entry_index.get(word, set()) for word in blocked_words] + ) # get entries satisfying inclusion and exclusion filters included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py index d50b52ec..34264189 100644 --- a/src/khoj/search_type/image_search.py +++ b/src/khoj/search_type/image_search.py @@ -35,9 +35,10 @@ def initialize_model(search_config: ImageSearchConfig): # Load the CLIP model encoder = load_model( - model_dir = search_config.model_directory, - model_name = search_config.encoder, - model_type = search_config.encoder_type or SentenceTransformer) + model_dir=search_config.model_directory, + model_name=search_config.encoder, + model_type=search_config.encoder_type or SentenceTransformer, + ) return encoder @@ -46,12 +47,12 @@ def extract_entries(image_directories): image_names = [] for image_directory in image_directories: image_directory = resolve_absolute_path(image_directory, strict=True) - image_names.extend(list(image_directory.glob('*.jpg'))) - image_names.extend(list(image_directory.glob('*.jpeg'))) + image_names.extend(list(image_directory.glob("*.jpg"))) + image_names.extend(list(image_directory.glob("*.jpeg"))) if logger.level >= logging.INFO: - image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories]) - logger.info(f'Found {len(image_names)} images in {image_directory_names}') + image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories]) + logger.info(f"Found {len(image_names)} images in {image_directory_names}") return sorted(image_names) @@ -59,7 +60,9 @@ def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate) - image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate) + image_metadata_embeddings = compute_metadata_embeddings( + image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate + ) return image_embeddings, image_metadata_embeddings @@ -74,15 +77,12 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5 image_embeddings = [] for index in trange(0, len(image_names), batch_size): images = [] - for image_name in image_names[index:index+batch_size]: + for image_name in image_names[index : index + batch_size]: image = Image.open(image_name) # Resize images to max width of 640px for faster processing image.thumbnail((640, image.height)) images += [image] - image_embeddings += encoder.encode( - images, - convert_to_tensor=True, - batch_size=min(len(images), batch_size)) + image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size)) # Create directory for embeddings file, if it doesn't exist embeddings_file.parent.mkdir(parents=True, exist_ok=True) @@ -94,7 +94,9 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5 return image_embeddings -def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0): +def compute_metadata_embeddings( + image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0 +): image_metadata_embeddings = None # Load pre-computed image metadata embedding file if exists @@ -106,14 +108,17 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz if use_xmp_metadata and image_metadata_embeddings is None: image_metadata_embeddings = [] for index in trange(0, len(image_names), batch_size): - image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]] + image_metadata = [ + extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size] + ] try: image_metadata_embeddings += encoder.encode( - image_metadata, - convert_to_tensor=True, - batch_size=min(len(image_metadata), batch_size)) + image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size) + ) except RuntimeError as e: - logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}") + logger.error( + f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}" + ) continue torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata") logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata") @@ -123,8 +128,10 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz def extract_metadata(image_name): image_xmp_metadata = Image.open(image_name).getxmp() - image_description = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'description', 'Alt', 'li', 'text') - image_subjects = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'subject', 'Bag', 'li') + image_description = get_from_dict( + image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text" + ) + image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li") image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject]) image_processed_metadata = image_description @@ -141,7 +148,7 @@ def query(raw_query, count, model: ImageSearchModel): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) query = copy.deepcopy(Image.open(query_imagepath)) - query.thumbnail((640, query.height)) # scale down image for faster processing + query.thumbnail((640, query.height)) # scale down image for faster processing logger.info(f"Find Images by Image: {query_imagepath}") else: # Truncate words in query to stay below max_tokens supported by ML model @@ -155,36 +162,42 @@ def query(raw_query, count, model: ImageSearchModel): # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. with timer("Search Time", logger): - image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']} - for result - in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]} + image_hits = { + result["corpus_id"]: {"image_score": result["score"], "score": result["score"]} + for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0] + } # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. if model.image_metadata_embeddings: with timer("Metadata Search Time", logger): - metadata_hits = {result['corpus_id']: result['score'] - for result - in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]} + metadata_hits = { + result["corpus_id"]: result["score"] + for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0] + } # Sum metadata, image scores of the highest ranked images for corpus_id, score in metadata_hits.items(): scaling_factor = 0.33 - if 'corpus_id' in image_hits: - image_hits[corpus_id].update({ - 'metadata_score': score, - 'score': image_hits[corpus_id].get('score', 0) + scaling_factor*score, - }) + if "corpus_id" in image_hits: + image_hits[corpus_id].update( + { + "metadata_score": score, + "score": image_hits[corpus_id].get("score", 0) + scaling_factor * score, + } + ) else: - image_hits[corpus_id] = {'metadata_score': score, 'score': scaling_factor*score} + image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score} # Reformat results in original form from sentence transformer semantic_search() hits = [ { - 'corpus_id': corpus_id, - 'score': scores['score'], - 'image_score': scores.get('image_score', 0), - 'metadata_score': scores.get('metadata_score', 0), - } for corpus_id, scores in image_hits.items()] + "corpus_id": corpus_id, + "score": scores["score"], + "image_score": scores.get("image_score", 0), + "metadata_score": scores.get("metadata_score", 0), + } + for corpus_id, scores in image_hits.items() + ] # Sort the images based on their combined metadata, image scores return sorted(hits, key=lambda hit: hit["score"], reverse=True) @@ -194,7 +207,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= results: List[SearchResponse] = [] for index, hit in enumerate(hits[:count]): - source_path = image_names[hit['corpus_id']] + source_path = image_names[hit["corpus_id"]] target_image_name = f"{index}{source_path.suffix}" target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}") @@ -207,17 +220,18 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= shutil.copy(source_path, target_path) # Add the image metadata to the results - results += [SearchResponse.parse_obj( - { - "entry": f'{image_files_url}/{target_image_name}', - "score": f"{hit['score']:.9f}", - "additional": + results += [ + SearchResponse.parse_obj( { - "image_score": f"{hit['image_score']:.9f}", - "metadata_score": f"{hit['metadata_score']:.9f}", + "entry": f"{image_files_url}/{target_image_name}", + "score": f"{hit['score']:.9f}", + "additional": { + "image_score": f"{hit['image_score']:.9f}", + "metadata_score": f"{hit['metadata_score']:.9f}", + }, } - } - )] + ) + ] return results @@ -248,9 +262,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera embeddings_file, batch_size=config.batch_size, regenerate=regenerate, - use_xmp_metadata=config.use_xmp_metadata) + use_xmp_metadata=config.use_xmp_metadata, + ) - return ImageSearchModel(all_image_files, - image_embeddings, - image_metadata_embeddings, - encoder) + return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 59022e13..882bcb64 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -38,17 +38,19 @@ def initialize_model(search_config: TextSearchConfig): # The bi-encoder encodes all entries to use for semantic search bi_encoder = load_model( - model_dir = search_config.model_directory, - model_name = search_config.encoder, - model_type = search_config.encoder_type or SentenceTransformer, - device=f'{state.device}') + model_dir=search_config.model_directory, + model_name=search_config.encoder, + model_type=search_config.encoder_type or SentenceTransformer, + device=f"{state.device}", + ) # The cross-encoder re-ranks the results to improve quality cross_encoder = load_model( - model_dir = search_config.model_directory, - model_name = search_config.cross_encoder, - model_type = CrossEncoder, - device=f'{state.device}') + model_dir=search_config.model_directory, + model_name=search_config.cross_encoder, + model_type=CrossEncoder, + device=f"{state.device}", + ) return bi_encoder, cross_encoder, top_k @@ -58,7 +60,9 @@ def extract_entries(jsonl_file) -> List[Entry]: return list(map(Entry.from_dict, load_jsonl(jsonl_file))) -def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False): +def compute_embeddings( + entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False +): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" new_entries = [] # Load pre-computed embeddings from file if exists and update them if required @@ -69,17 +73,23 @@ def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: Ba # Encode any new entries in the corpus and update corpus embeddings new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1] if new_entries: - new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) + new_embeddings = bi_encoder.encode( + new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True + ) existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] if existing_entry_ids: - existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) + existing_embeddings = torch.index_select( + corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device) + ) else: existing_embeddings = torch.tensor([], device=state.device) corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) # Else compute the corpus embeddings from scratch else: new_entries = [entry.compiled for _, entry in entries_with_ids] - corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) + corpus_embeddings = bi_encoder.encode( + new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True + ) # Save regenerated or updated embeddings to file if new_entries: @@ -112,7 +122,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> # Find relevant entries for the query with timer("Search Time", logger, state.device): - hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] + hits = util.semantic_search( + question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score + )[0] # Score all retrieved entries using the cross-encoder if rank_results: @@ -128,26 +140,33 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]: - return [SearchResponse.parse_obj( - { - "entry": entries[hit['corpus_id']].raw, - "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}", - "additional": { - "file": entries[hit['corpus_id']].file, - "compiled": entries[hit['corpus_id']].compiled + return [ + SearchResponse.parse_obj( + { + "entry": entries[hit["corpus_id"]].raw, + "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}", + "additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled}, } - }) - for hit - in hits[0:count]] + ) + for hit in hits[0:count] + ] -def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = []) -> TextSearchModel: +def setup( + text_to_jsonl: Type[TextToJsonl], + config: TextContentConfig, + search_config: TextSearchConfig, + regenerate: bool, + filters: List[BaseFilter] = [], +) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) # Map notes in text files to (compressed) JSONL formatted file config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) - previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None + previous_entries = ( + extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None + ) entries_with_indices = text_to_jsonl(config).process(previous_entries) # Extract Updated Entries @@ -158,7 +177,9 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co # Compute or Load Embeddings config.embeddings_file = resolve_absolute_path(config.embeddings_file) - corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate) + corpus_embeddings = compute_embeddings( + entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate + ) for filter in filters: filter.load(entries, regenerate=regenerate) @@ -166,8 +187,10 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) -def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]) -> Tuple[str, List[Entry], torch.Tensor]: - '''Filter query, entries and embeddings before semantic search''' +def apply_filters( + query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter] +) -> Tuple[str, List[Entry], torch.Tensor]: + """Filter query, entries and embeddings before semantic search""" with timer("Total Filter Time", logger, state.device): included_entry_indices = set(range(len(entries))) @@ -178,45 +201,50 @@ def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Ten # Get entries (and associated embeddings) satisfying all filters if not included_entry_indices: - return '', [], torch.tensor([], device=state.device) + return "", [], torch.tensor([], device=state.device) else: entries = [entries[id] for id in included_entry_indices] - corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)) + corpus_embeddings = torch.index_select( + corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device) + ) return query, entries, corpus_embeddings def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]: - '''Score all retrieved entries using the cross-encoder''' + """Score all retrieved entries using the cross-encoder""" with timer("Cross-Encoder Predict Time", logger, state.device): - cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] + cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits] cross_scores = cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking for idx in range(len(cross_scores)): - hits[idx]['cross-score'] = cross_scores[idx] + hits[idx]["cross-score"] = cross_scores[idx] return hits def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]: - '''Order results by cross-encoder score followed by bi-encoder score''' + """Order results by cross-encoder score followed by bi-encoder score""" with timer("Rank Time", logger, state.device): - hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score + hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score if rank_results: - hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score + hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score return hits def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]: - '''Deduplicate entries by raw entry text before showing to users + """Deduplicate entries by raw entry text before showing to users Compiled entries are split by max tokens supported by ML models. - This can result in duplicate hits, entries shown to user.''' + This can result in duplicate hits, entries shown to user.""" with timer("Deduplication Time", logger, state.device): seen, original_hits_count = set(), len(hits) - hits = [hit for hit in hits - if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] # type: ignore[func-returns-value] + hits = [ + hit + for hit in hits + if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value] + ] duplicate_hits = original_hits_count - len(hits) logger.debug(f"Removed {duplicate_hits} duplicates") diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py index ab2e7af9..37d65458 100644 --- a/src/khoj/utils/cli.py +++ b/src/khoj/utils/cli.py @@ -10,21 +10,36 @@ from khoj.utils.yaml import parse_config_from_file def cli(args=None): # Setup Argument Parser for the Commandline Interface - parser = argparse.ArgumentParser(description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos") - parser.add_argument('--config-file', '-c', default='~/.khoj/khoj.yml', type=pathlib.Path, help="YAML file to configure Khoj") - parser.add_argument('--no-gui', action='store_true', default=False, help="Do not show native desktop GUI. Default: false") - parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false") - parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0") - parser.add_argument('--host', type=str, default='127.0.0.1', help="Host address of the server. Default: 127.0.0.1") - parser.add_argument('--port', '-p', type=int, default=8000, help="Port of the server. Default: 8000") - parser.add_argument('--socket', type=pathlib.Path, help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock") - parser.add_argument('--version', '-V', action='store_true', help="Print the installed Khoj version and exit") + parser = argparse.ArgumentParser( + description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos" + ) + parser.add_argument( + "--config-file", "-c", default="~/.khoj/khoj.yml", type=pathlib.Path, help="YAML file to configure Khoj" + ) + parser.add_argument( + "--no-gui", action="store_true", default=False, help="Do not show native desktop GUI. Default: false" + ) + parser.add_argument( + "--regenerate", + action="store_true", + default=False, + help="Regenerate model embeddings from source files. Default: false", + ) + parser.add_argument("--verbose", "-v", action="count", default=0, help="Show verbose conversion logs. Default: 0") + parser.add_argument("--host", type=str, default="127.0.0.1", help="Host address of the server. Default: 127.0.0.1") + parser.add_argument("--port", "-p", type=int, default=8000, help="Port of the server. Default: 8000") + parser.add_argument( + "--socket", + type=pathlib.Path, + help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock", + ) + parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit") args = parser.parse_args(args) if args.version: # Show version of khoj installed and exit - print(version('khoj-assistant')) + print(version("khoj-assistant")) exit(0) # Normalize config_file path to absolute path diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 22193f9e..1e23e077 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -28,8 +28,16 @@ class ProcessorType(str, Enum): Conversation = "conversation" -class TextSearchModel(): - def __init__(self, entries: List[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: List[BaseFilter], top_k): +class TextSearchModel: + def __init__( + self, + entries: List[Entry], + corpus_embeddings: torch.Tensor, + bi_encoder: BaseEncoder, + cross_encoder: CrossEncoder, + filters: List[BaseFilter], + top_k, + ): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder @@ -38,7 +46,7 @@ class TextSearchModel(): self.top_k = top_k -class ImageSearchModel(): +class ImageSearchModel: def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder): self.image_encoder = image_encoder self.image_names = image_names @@ -48,7 +56,7 @@ class ImageSearchModel(): @dataclass -class SearchModels(): +class SearchModels: orgmode_search: TextSearchModel = None ledger_search: TextSearchModel = None music_search: TextSearchModel = None @@ -56,15 +64,15 @@ class SearchModels(): image_search: ImageSearchModel = None -class ConversationProcessorConfigModel(): +class ConversationProcessorConfigModel: def __init__(self, processor_config: ConversationProcessorConfig): self.openai_api_key = processor_config.openai_api_key self.model = processor_config.model self.conversation_logfile = Path(processor_config.conversation_logfile) - self.chat_session = '' + self.chat_session = "" self.meta_log: dict = {} @dataclass -class ProcessorConfigModel(): +class ProcessorConfigModel: conversation: ConversationProcessorConfigModel = None diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 86686e93..40e01c5d 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -1,65 +1,62 @@ from pathlib import Path app_root_directory = Path(__file__).parent.parent.parent -web_directory = app_root_directory / 'khoj/interface/web/' -empty_escape_sequences = '\n|\r|\t| ' +web_directory = app_root_directory / "khoj/interface/web/" +empty_escape_sequences = "\n|\r|\t| " # default app config to use default_config = { - 'content-type': { - 'org': { - 'input-files': None, - 'input-filter': None, - 'compressed-jsonl': '~/.khoj/content/org/org.jsonl.gz', - 'embeddings-file': '~/.khoj/content/org/org_embeddings.pt', - 'index_heading_entries': False + "content-type": { + "org": { + "input-files": None, + "input-filter": None, + "compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz", + "embeddings-file": "~/.khoj/content/org/org_embeddings.pt", + "index_heading_entries": False, }, - 'markdown': { - 'input-files': None, - 'input-filter': None, - 'compressed-jsonl': '~/.khoj/content/markdown/markdown.jsonl.gz', - 'embeddings-file': '~/.khoj/content/markdown/markdown_embeddings.pt' + "markdown": { + "input-files": None, + "input-filter": None, + "compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz", + "embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt", }, - 'ledger': { - 'input-files': None, - 'input-filter': None, - 'compressed-jsonl': '~/.khoj/content/ledger/ledger.jsonl.gz', - 'embeddings-file': '~/.khoj/content/ledger/ledger_embeddings.pt' + "ledger": { + "input-files": None, + "input-filter": None, + "compressed-jsonl": "~/.khoj/content/ledger/ledger.jsonl.gz", + "embeddings-file": "~/.khoj/content/ledger/ledger_embeddings.pt", }, - 'image': { - 'input-directories': None, - 'input-filter': None, - 'embeddings-file': '~/.khoj/content/image/image_embeddings.pt', - 'batch-size': 50, - 'use-xmp-metadata': False + "image": { + "input-directories": None, + "input-filter": None, + "embeddings-file": "~/.khoj/content/image/image_embeddings.pt", + "batch-size": 50, + "use-xmp-metadata": False, }, - 'music': { - 'input-files': None, - 'input-filter': None, - 'compressed-jsonl': '~/.khoj/content/music/music.jsonl.gz', - 'embeddings-file': '~/.khoj/content/music/music_embeddings.pt' + "music": { + "input-files": None, + "input-filter": None, + "compressed-jsonl": "~/.khoj/content/music/music.jsonl.gz", + "embeddings-file": "~/.khoj/content/music/music_embeddings.pt", + }, + }, + "search-type": { + "symmetric": { + "encoder": "sentence-transformers/all-MiniLM-L6-v2", + "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "model_directory": "~/.khoj/search/symmetric/", + }, + "asymmetric": { + "encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", + "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "model_directory": "~/.khoj/search/asymmetric/", + }, + "image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"}, + }, + "processor": { + "conversation": { + "openai-api-key": None, + "conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json", } }, - 'search-type': { - 'symmetric': { - 'encoder': 'sentence-transformers/all-MiniLM-L6-v2', - 'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2', - 'model_directory': '~/.khoj/search/symmetric/' - }, - 'asymmetric': { - 'encoder': 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1', - 'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2', - 'model_directory': '~/.khoj/search/asymmetric/' - }, - 'image': { - 'encoder': 'sentence-transformers/clip-ViT-B-32', - 'model_directory': '~/.khoj/search/image/' - } - }, - 'processor': { - 'conversation': { - 'openai-api-key': None, - 'conversation-logfile': '~/.khoj/processor/conversation/conversation_logs.json' - } - } -} \ No newline at end of file +} diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 8e3e2bd5..d2f87126 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -13,16 +13,17 @@ from typing import Optional, Union, TYPE_CHECKING if TYPE_CHECKING: # External Packages from sentence_transformers import CrossEncoder + # Internal Packages from khoj.utils.models import BaseEncoder def is_none_or_empty(item): - return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == '' + return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == "" def to_snake_case_from_dash(item: str): - return item.replace('_', '-') + return item.replace("_", "-") def get_absolute_path(filepath: Union[str, Path]) -> str: @@ -34,11 +35,11 @@ def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) -> def get_from_dict(dictionary, *args): - '''null-aware get from a nested dictionary - Returns: dictionary[args[0]][args[1]]... or None if any keys missing''' + """null-aware get from a nested dictionary + Returns: dictionary[args[0]][args[1]]... or None if any keys missing""" current = dictionary for arg in args: - if not hasattr(current, '__iter__') or not arg in current: + if not hasattr(current, "__iter__") or not arg in current: return None current = current[arg] return current @@ -54,7 +55,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict -def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]: +def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]: "Load model from disk or huggingface" # Construct model path model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None @@ -74,17 +75,18 @@ def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> def is_pyinstaller_app(): "Returns true if the app is running from Native GUI created by PyInstaller" - return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') + return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS") def get_class_by_name(name: str) -> object: "Returns the class object from name string" - module_name, class_name = name.rsplit('.', 1) + module_name, class_name = name.rsplit(".", 1) return getattr(import_module(module_name), class_name) class timer: - '''Context manager to log time taken for a block of code to run''' + """Context manager to log time taken for a block of code to run""" + def __init__(self, message: str, logger: logging.Logger, device: torch.device = None): self.message = message self.logger = logger @@ -116,4 +118,4 @@ class LRU(OrderedDict): super().__setitem__(key, value) if len(self) > self.capacity: oldest = next(iter(self)) - del self[oldest] \ No newline at end of file + del self[oldest] diff --git a/src/khoj/utils/jsonl.py b/src/khoj/utils/jsonl.py index 41923186..a98910fb 100644 --- a/src/khoj/utils/jsonl.py +++ b/src/khoj/utils/jsonl.py @@ -19,9 +19,9 @@ def load_jsonl(input_path): # Open JSONL file if input_path.suffix == ".gz": - jsonl_file = gzip.open(get_absolute_path(input_path), 'rt', encoding='utf-8') + jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8") elif input_path.suffix == ".jsonl": - jsonl_file = open(get_absolute_path(input_path), 'r', encoding='utf-8') + jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8") # Read JSONL file for line in jsonl_file: @@ -31,7 +31,7 @@ def load_jsonl(input_path): jsonl_file.close() # Log JSONL entries loaded - logger.info(f'Loaded {len(data)} records from {input_path}') + logger.info(f"Loaded {len(data)} records from {input_path}") return data @@ -41,17 +41,17 @@ def dump_jsonl(jsonl_data, output_path): # Create output directory, if it doesn't exist output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, 'w', encoding='utf-8') as f: + with open(output_path, "w", encoding="utf-8") as f: f.write(jsonl_data) - logger.info(f'Wrote jsonl data to {output_path}') + logger.info(f"Wrote jsonl data to {output_path}") def compress_jsonl_data(jsonl_data, output_path): # Create output directory, if it doesn't exist output_path.parent.mkdir(parents=True, exist_ok=True) - with gzip.open(output_path, 'wt', encoding='utf-8') as gzip_file: + with gzip.open(output_path, "wt", encoding="utf-8") as gzip_file: gzip_file.write(jsonl_data) - logger.info(f'Wrote jsonl data to gzip compressed jsonl at {output_path}') \ No newline at end of file + logger.info(f"Wrote jsonl data to gzip compressed jsonl at {output_path}") diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index 77e620fd..12935818 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -13,17 +13,25 @@ from khoj.utils.state import processor_config, config_file class BaseEncoder(ABC): @abstractmethod - def __init__(self, model_name: str, device: torch.device=None, **kwargs): ... + def __init__(self, model_name: str, device: torch.device = None, **kwargs): + ... @abstractmethod - def encode(self, entries: List[str], device:torch.device=None, **kwargs) -> torch.Tensor: ... + def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor: + ... class OpenAI(BaseEncoder): def __init__(self, model_name, device=None): self.model_name = model_name - if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key: - raise Exception(f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}") + if ( + not processor_config + or not processor_config.conversation + or not processor_config.conversation.openai_api_key + ): + raise Exception( + f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}" + ) openai.api_key = processor_config.conversation.openai_api_key self.embedding_dimensions = None @@ -32,7 +40,7 @@ class OpenAI(BaseEncoder): for index in trange(0, len(entries)): # OpenAI models create better embeddings for entries without newlines - processed_entry = entries[index].replace('\n', ' ') + processed_entry = entries[index].replace("\n", " ") try: response = openai.Embedding.create(input=processed_entry, model=self.model_name) @@ -41,10 +49,12 @@ class OpenAI(BaseEncoder): # Else default to embedding dimensions of the text-embedding-ada-002 model self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536 except Exception as e: - print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}") + print( + f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}" + ) # Use zero embedding vector for entries with failed embeddings # This ensures entry embeddings match the order of the source entries # And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector) embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)] - return torch.stack(embedding_tensors) \ No newline at end of file + return torch.stack(embedding_tensors) diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 82715617..4fbe6543 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -9,11 +9,13 @@ from pydantic import BaseModel, validator # Internal Packages from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty + class ConfigBase(BaseModel): class Config: alias_generator = to_snake_case_from_dash allow_population_by_field_name = True + class TextContentConfig(ConfigBase): input_files: Optional[List[Path]] input_filter: Optional[List[str]] @@ -21,12 +23,15 @@ class TextContentConfig(ConfigBase): embeddings_file: Path index_heading_entries: Optional[bool] = False - @validator('input_filter') + @validator("input_filter") def input_filter_or_files_required(cls, input_filter, values, **kwargs): - if is_none_or_empty(input_filter) and ('input_files' not in values or values["input_files"] is None): - raise ValueError("Either input_filter or input_files required in all content-type. section of Khoj config file") + if is_none_or_empty(input_filter) and ("input_files" not in values or values["input_files"] is None): + raise ValueError( + "Either input_filter or input_files required in all content-type. section of Khoj config file" + ) return input_filter + class ImageContentConfig(ConfigBase): input_directories: Optional[List[Path]] input_filter: Optional[List[str]] @@ -34,12 +39,17 @@ class ImageContentConfig(ConfigBase): use_xmp_metadata: bool batch_size: int - @validator('input_filter') + @validator("input_filter") def input_filter_or_directories_required(cls, input_filter, values, **kwargs): - if is_none_or_empty(input_filter) and ('input_directories' not in values or values["input_directories"] is None): - raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file") + if is_none_or_empty(input_filter) and ( + "input_directories" not in values or values["input_directories"] is None + ): + raise ValueError( + "Either input_filter or input_directories required in all content-type.image section of Khoj config file" + ) return input_filter + class ContentConfig(ConfigBase): org: Optional[TextContentConfig] ledger: Optional[TextContentConfig] @@ -47,41 +57,49 @@ class ContentConfig(ConfigBase): music: Optional[TextContentConfig] markdown: Optional[TextContentConfig] + class TextSearchConfig(ConfigBase): encoder: str cross_encoder: str encoder_type: Optional[str] model_directory: Optional[Path] + class ImageSearchConfig(ConfigBase): encoder: str encoder_type: Optional[str] model_directory: Optional[Path] + class SearchConfig(ConfigBase): asymmetric: Optional[TextSearchConfig] symmetric: Optional[TextSearchConfig] image: Optional[ImageSearchConfig] + class ConversationProcessorConfig(ConfigBase): openai_api_key: str conversation_logfile: Path model: Optional[str] = "text-davinci-003" + class ProcessorConfig(ConfigBase): conversation: Optional[ConversationProcessorConfig] + class FullConfig(ConfigBase): content_type: Optional[ContentConfig] search_type: Optional[SearchConfig] processor: Optional[ProcessorConfig] + class SearchResponse(ConfigBase): entry: str score: str additional: Optional[dict] -class Entry(): + +class Entry: raw: str compiled: str file: Optional[str] @@ -99,8 +117,4 @@ class Entry(): @classmethod def from_dict(cls, dictionary: dict): - return cls( - raw=dictionary['raw'], - compiled=dictionary['compiled'], - file=dictionary.get('file', None) - ) \ No newline at end of file + return cls(raw=dictionary["raw"], compiled=dictionary["compiled"], file=dictionary.get("file", None)) diff --git a/src/khoj/utils/yaml.py b/src/khoj/utils/yaml.py index ca2a3aaf..c91f4fef 100644 --- a/src/khoj/utils/yaml.py +++ b/src/khoj/utils/yaml.py @@ -17,14 +17,14 @@ def save_config_to_file(yaml_config: dict, yaml_config_file: Path): # Create output directory, if it doesn't exist yaml_config_file.parent.mkdir(parents=True, exist_ok=True) - with open(yaml_config_file, 'w', encoding='utf-8') as config_file: + with open(yaml_config_file, "w", encoding="utf-8") as config_file: yaml.safe_dump(yaml_config, config_file, allow_unicode=True) def load_config_from_file(yaml_config_file: Path) -> dict: "Read config from YML file" config_from_file = None - with open(yaml_config_file, 'r', encoding='utf-8') as config_file: + with open(yaml_config_file, "r", encoding="utf-8") as config_file: config_from_file = yaml.safe_load(config_file) return config_from_file diff --git a/tests/conftest.py b/tests/conftest.py index cad2cb58..a5963eaf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,59 +6,67 @@ import pytest # Internal Packages from khoj.search_type import image_search, text_search from khoj.utils.helpers import resolve_absolute_path -from khoj.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig +from khoj.utils.rawconfig import ( + ContentConfig, + TextContentConfig, + ImageContentConfig, + SearchConfig, + TextSearchConfig, + ImageSearchConfig, +) from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.file_filter import FileFilter -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def search_config() -> SearchConfig: - model_dir = resolve_absolute_path('~/.khoj/search') + model_dir = resolve_absolute_path("~/.khoj/search") model_dir.mkdir(parents=True, exist_ok=True) search_config = SearchConfig() search_config.symmetric = TextSearchConfig( - encoder = "sentence-transformers/all-MiniLM-L6-v2", - cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", - model_directory = model_dir / 'symmetric/' + encoder="sentence-transformers/all-MiniLM-L6-v2", + cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", + model_directory=model_dir / "symmetric/", ) search_config.asymmetric = TextSearchConfig( - encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", - cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", - model_directory = model_dir / 'asymmetric/' + encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1", + cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", + model_directory=model_dir / "asymmetric/", ) search_config.image = ImageSearchConfig( - encoder = "sentence-transformers/clip-ViT-B-32", - model_directory = model_dir / 'image/' + encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/" ) return search_config -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def content_config(tmp_path_factory, search_config: SearchConfig): - content_dir = tmp_path_factory.mktemp('content') + content_dir = tmp_path_factory.mktemp("content") # Generate Image Embeddings from Test Images content_config = ContentConfig() content_config.image = ImageContentConfig( - input_directories = ['tests/data/images'], - embeddings_file = content_dir.joinpath('image_embeddings.pt'), - batch_size = 1, - use_xmp_metadata = False) + input_directories=["tests/data/images"], + embeddings_file=content_dir.joinpath("image_embeddings.pt"), + batch_size=1, + use_xmp_metadata=False, + ) image_search.setup(content_config.image, search_config.image, regenerate=False) # Generate Notes Embeddings from Test Notes content_config.org = TextContentConfig( - input_files = None, - input_filter = ['tests/data/org/*.org'], - compressed_jsonl = content_dir.joinpath('notes.jsonl.gz'), - embeddings_file = content_dir.joinpath('note_embeddings.pt')) + input_files=None, + input_filter=["tests/data/org/*.org"], + compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"), + embeddings_file=content_dir.joinpath("note_embeddings.pt"), + ) filters = [DateFilter(), WordFilter(), FileFilter()] text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) @@ -66,7 +74,7 @@ def content_config(tmp_path_factory, search_config: SearchConfig): return content_config -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def new_org_file(content_config: ContentConfig): # Setup new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org" @@ -79,9 +87,9 @@ def new_org_file(content_config: ContentConfig): new_org_file.unlink() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path): new_org_config = deepcopy(content_config.org) - new_org_config.input_files = [f'{new_org_file}'] + new_org_config.input_files = [f"{new_org_file}"] new_org_config.input_filter = None - return new_org_config \ No newline at end of file + return new_org_config diff --git a/tests/test_beancount_to_jsonl.py b/tests/test_beancount_to_jsonl.py index 7150ea35..923adb5a 100644 --- a/tests/test_beancount_to_jsonl.py +++ b/tests/test_beancount_to_jsonl.py @@ -8,10 +8,10 @@ from khoj.processor.ledger.beancount_to_jsonl import BeancountToJsonl def test_no_transactions_in_file(tmp_path): "Handle file with no transactions." # Arrange - entry = f''' + entry = f""" - Bullet point 1 - Bullet point 2 - ''' + """ beancount_file = create_file(tmp_path, entry) # Act @@ -20,7 +20,8 @@ def test_no_transactions_in_file(tmp_path): # Process Each Entry from All Beancount Files jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( - BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries)) + BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -30,11 +31,11 @@ def test_no_transactions_in_file(tmp_path): def test_single_beancount_transaction_to_jsonl(tmp_path): "Convert transaction from single file to jsonl." # Arrange - entry = f''' + entry = f""" 1984-04-01 * "Payee" "Narration" Expenses:Test:Test 1.00 KES Assets:Test:Test -1.00 KES - ''' + """ beancount_file = create_file(tmp_path, entry) # Act @@ -43,7 +44,8 @@ Assets:Test:Test -1.00 KES # Process Each Entry from All Beancount Files jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( - BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)) + BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -53,7 +55,7 @@ Assets:Test:Test -1.00 KES def test_multiple_transactions_to_jsonl(tmp_path): "Convert multiple transactions from single file to jsonl." # Arrange - entry = f''' + entry = f""" 1984-04-01 * "Payee" "Narration" Expenses:Test:Test 1.00 KES Assets:Test:Test -1.00 KES @@ -61,7 +63,7 @@ Assets:Test:Test -1.00 KES 1984-04-01 * "Payee" "Narration" Expenses:Test:Test 1.00 KES Assets:Test:Test -1.00 KES -''' +""" beancount_file = create_file(tmp_path, entry) @@ -71,7 +73,8 @@ Assets:Test:Test -1.00 KES # Process Each Entry from All Beancount Files jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( - BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)) + BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -95,8 +98,8 @@ def test_get_beancount_files(tmp_path): expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1])) # Setup input-files, input-filters - input_files = [tmp_path / 'ledger.bean'] - input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount'] + input_files = [tmp_path / "ledger.bean"] + input_filter = [tmp_path / "group1*.bean", tmp_path / "group2*.beancount"] # Act extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter) diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 4a135061..fdd99548 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -6,7 +6,7 @@ from khoj.processor.conversation.gpt import converse, understand, message_to_pro # Initialize variables for tests -model = 'text-davinci-003' +model = "text-davinci-003" api_key = None # Input your OpenAI API key to run the tests below @@ -14,19 +14,22 @@ api_key = None # Input your OpenAI API key to run the tests below # ---------------------------------------------------------------------------------------------------- def test_message_to_understand_prompt(): # Arrange - understand_primer = "Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=[\"companion\", \"notes\", \"ledger\", \"image\", \"music\"]\nsearch(search-type, data);\nsearch-type=[\"google\", \"youtube\"]\ngenerate(activity);\nactivity=[\"paint\",\"write\", \"chat\"]\ntrigger-emotion(emotion);\nemotion=[\"happy\",\"confidence\",\"fear\",\"surprise\",\"sadness\",\"disgust\",\"anger\", \"curiosity\", \"calm\"]\n\nQ: How are you doing?\nA: activity(\"chat\"); trigger-emotion(\"surprise\")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember(\"notes\", \"Brother Antoine when we were at the beach\"); trigger-emotion(\"curiosity\");\nQ: what did we talk about last time?\nA: remember(\"notes\", \"talk last time\"); trigger-emotion(\"curiosity\");\nQ: Let's make some drawings!\nA: generate(\"paint\"); trigger-emotion(\"happy\");\nQ: Do you know anything about Lebanon?\nA: search(\"google\", \"lebanon\"); trigger-emotion(\"confidence\");\nQ: Find a video about a panda rolling in the grass\nA: search(\"youtube\",\"panda rolling in the grass\"); trigger-emotion(\"happy\"); \nQ: Tell me a scary story\nA: generate(\"write\" \"A story about some adventure\"); trigger-emotion(\"fear\");\nQ: What fiction book was I reading last week about AI starship?\nA: remember(\"notes\", \"read fiction book about AI starship last week\"); trigger-emotion(\"curiosity\");\nQ: How much did I spend at Subway for dinner last time?\nA: remember(\"ledger\", \"last Subway dinner\"); trigger-emotion(\"curiosity\");\nQ: I'm feeling sleepy\nA: activity(\"chat\"); trigger-emotion(\"calm\")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember(\"music\", \"popular Sri lankan song that Alex showed recently\"); trigger-emotion(\"curiosity\"); \nQ: You're pretty funny!\nA: activity(\"chat\"); trigger-emotion(\"pride\")" - expected_response = "Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=[\"companion\", \"notes\", \"ledger\", \"image\", \"music\"]\nsearch(search-type, data);\nsearch-type=[\"google\", \"youtube\"]\ngenerate(activity);\nactivity=[\"paint\",\"write\", \"chat\"]\ntrigger-emotion(emotion);\nemotion=[\"happy\",\"confidence\",\"fear\",\"surprise\",\"sadness\",\"disgust\",\"anger\", \"curiosity\", \"calm\"]\n\nQ: How are you doing?\nA: activity(\"chat\"); trigger-emotion(\"surprise\")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember(\"notes\", \"Brother Antoine when we were at the beach\"); trigger-emotion(\"curiosity\");\nQ: what did we talk about last time?\nA: remember(\"notes\", \"talk last time\"); trigger-emotion(\"curiosity\");\nQ: Let's make some drawings!\nA: generate(\"paint\"); trigger-emotion(\"happy\");\nQ: Do you know anything about Lebanon?\nA: search(\"google\", \"lebanon\"); trigger-emotion(\"confidence\");\nQ: Find a video about a panda rolling in the grass\nA: search(\"youtube\",\"panda rolling in the grass\"); trigger-emotion(\"happy\"); \nQ: Tell me a scary story\nA: generate(\"write\" \"A story about some adventure\"); trigger-emotion(\"fear\");\nQ: What fiction book was I reading last week about AI starship?\nA: remember(\"notes\", \"read fiction book about AI starship last week\"); trigger-emotion(\"curiosity\");\nQ: How much did I spend at Subway for dinner last time?\nA: remember(\"ledger\", \"last Subway dinner\"); trigger-emotion(\"curiosity\");\nQ: I'm feeling sleepy\nA: activity(\"chat\"); trigger-emotion(\"calm\")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember(\"music\", \"popular Sri lankan song that Alex showed recently\"); trigger-emotion(\"curiosity\"); \nQ: You're pretty funny!\nA: activity(\"chat\"); trigger-emotion(\"pride\")\nQ: When did I last dine at Burger King?\nA:" + understand_primer = 'Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=["companion", "notes", "ledger", "image", "music"]\nsearch(search-type, data);\nsearch-type=["google", "youtube"]\ngenerate(activity);\nactivity=["paint","write", "chat"]\ntrigger-emotion(emotion);\nemotion=["happy","confidence","fear","surprise","sadness","disgust","anger", "curiosity", "calm"]\n\nQ: How are you doing?\nA: activity("chat"); trigger-emotion("surprise")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember("notes", "Brother Antoine when we were at the beach"); trigger-emotion("curiosity");\nQ: what did we talk about last time?\nA: remember("notes", "talk last time"); trigger-emotion("curiosity");\nQ: Let\'s make some drawings!\nA: generate("paint"); trigger-emotion("happy");\nQ: Do you know anything about Lebanon?\nA: search("google", "lebanon"); trigger-emotion("confidence");\nQ: Find a video about a panda rolling in the grass\nA: search("youtube","panda rolling in the grass"); trigger-emotion("happy"); \nQ: Tell me a scary story\nA: generate("write" "A story about some adventure"); trigger-emotion("fear");\nQ: What fiction book was I reading last week about AI starship?\nA: remember("notes", "read fiction book about AI starship last week"); trigger-emotion("curiosity");\nQ: How much did I spend at Subway for dinner last time?\nA: remember("ledger", "last Subway dinner"); trigger-emotion("curiosity");\nQ: I\'m feeling sleepy\nA: activity("chat"); trigger-emotion("calm")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember("music", "popular Sri lankan song that Alex showed recently"); trigger-emotion("curiosity"); \nQ: You\'re pretty funny!\nA: activity("chat"); trigger-emotion("pride")' + expected_response = 'Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=["companion", "notes", "ledger", "image", "music"]\nsearch(search-type, data);\nsearch-type=["google", "youtube"]\ngenerate(activity);\nactivity=["paint","write", "chat"]\ntrigger-emotion(emotion);\nemotion=["happy","confidence","fear","surprise","sadness","disgust","anger", "curiosity", "calm"]\n\nQ: How are you doing?\nA: activity("chat"); trigger-emotion("surprise")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember("notes", "Brother Antoine when we were at the beach"); trigger-emotion("curiosity");\nQ: what did we talk about last time?\nA: remember("notes", "talk last time"); trigger-emotion("curiosity");\nQ: Let\'s make some drawings!\nA: generate("paint"); trigger-emotion("happy");\nQ: Do you know anything about Lebanon?\nA: search("google", "lebanon"); trigger-emotion("confidence");\nQ: Find a video about a panda rolling in the grass\nA: search("youtube","panda rolling in the grass"); trigger-emotion("happy"); \nQ: Tell me a scary story\nA: generate("write" "A story about some adventure"); trigger-emotion("fear");\nQ: What fiction book was I reading last week about AI starship?\nA: remember("notes", "read fiction book about AI starship last week"); trigger-emotion("curiosity");\nQ: How much did I spend at Subway for dinner last time?\nA: remember("ledger", "last Subway dinner"); trigger-emotion("curiosity");\nQ: I\'m feeling sleepy\nA: activity("chat"); trigger-emotion("calm")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember("music", "popular Sri lankan song that Alex showed recently"); trigger-emotion("curiosity"); \nQ: You\'re pretty funny!\nA: activity("chat"); trigger-emotion("pride")\nQ: When did I last dine at Burger King?\nA:' # Act - actual_response = message_to_prompt("When did I last dine at Burger King?", understand_primer, start_sequence="\nA:", restart_sequence="\nQ:") + actual_response = message_to_prompt( + "When did I last dine at Burger King?", understand_primer, start_sequence="\nA:", restart_sequence="\nQ:" + ) # Assert assert actual_response == expected_response # ---------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(api_key is None, - reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys") +@pytest.mark.skipif( + api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys" +) def test_minimal_chat_with_gpt(): # Act response = converse("What will happen when the stars go out?", model=model, api_key=api_key) @@ -36,21 +39,29 @@ def test_minimal_chat_with_gpt(): # ---------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(api_key is None, - reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys") +@pytest.mark.skipif( + api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys" +) def test_chat_with_history(): # Arrange - ai_prompt="AI:" - human_prompt="Human:" + ai_prompt = "AI:" + human_prompt = "Human:" - conversation_primer = f''' + conversation_primer = f""" The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly companion. {human_prompt} Hello, I am Testatron. Who are you? -{ai_prompt} Hi, I am Khoj, an AI conversational companion created by OpenAI. How can I help you today?''' +{ai_prompt} Hi, I am Khoj, an AI conversational companion created by OpenAI. How can I help you today?""" # Act - response = converse("Hi Khoj, What is my name?", model=model, conversation_history=conversation_primer, api_key=api_key, temperature=0, max_tokens=50) + response = converse( + "Hi Khoj, What is my name?", + model=model, + conversation_history=conversation_primer, + api_key=api_key, + temperature=0, + max_tokens=50, + ) # Assert assert len(response) > 0 @@ -58,12 +69,13 @@ The following is a conversation with an AI assistant. The assistant is helpful, # ---------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(api_key is None, - reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys") +@pytest.mark.skipif( + api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys" +) def test_understand_message_using_gpt(): # Act response = understand("When did I last dine at Subway?", model=model, api_key=api_key) # Assert assert len(response) > 0 - assert response['intent']['memory-type'] == 'ledger' + assert response["intent"]["memory-type"] == "ledger" diff --git a/tests/test_cli.py b/tests/test_cli.py index f51f76cb..b7e18460 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,35 +14,37 @@ def test_cli_minimal_default(): actual_args = cli([]) # Assert - assert actual_args.config_file == resolve_absolute_path(Path('~/.khoj/khoj.yml')) + assert actual_args.config_file == resolve_absolute_path(Path("~/.khoj/khoj.yml")) assert actual_args.regenerate == False assert actual_args.no_gui == False assert actual_args.verbose == 0 + # ---------------------------------------------------------------------------------------------------- def test_cli_invalid_config_file_path(): # Arrange non_existent_config_file = f"non-existent-khoj-{random()}.yml" # Act - actual_args = cli([f'-c={non_existent_config_file}']) + actual_args = cli([f"-c={non_existent_config_file}"]) # Assert assert actual_args.config_file == resolve_absolute_path(non_existent_config_file) assert actual_args.config == None + # ---------------------------------------------------------------------------------------------------- def test_cli_config_from_file(): # Act - actual_args = cli(['-c=tests/data/config.yml', - '--regenerate', - '--no-gui', - '-vvv']) + actual_args = cli(["-c=tests/data/config.yml", "--regenerate", "--no-gui", "-vvv"]) # Assert - assert actual_args.config_file == resolve_absolute_path(Path('tests/data/config.yml')) + assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml")) assert actual_args.no_gui == True assert actual_args.regenerate == True assert actual_args.config is not None - assert actual_args.config.content_type.org.input_files == [Path('~/first_from_config.org'), Path('~/second_from_config.org')] + assert actual_args.config.content_type.org.input_files == [ + Path("~/first_from_config.org"), + Path("~/second_from_config.org"), + ] assert actual_args.verbose == 3 diff --git a/tests/test_client.py b/tests/test_client.py index ac94bc27..ed765fa0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -21,6 +21,7 @@ from khoj.search_filter.file_filter import FileFilter # ---------------------------------------------------------------------------------------------------- client = TestClient(app) + # Test # ---------------------------------------------------------------------------------------------------- def test_search_with_invalid_content_type(): @@ -98,9 +99,11 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig config.content_type = content_config config.search_type = search_config model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) - query_expected_image_pairs = [("kitten", "kitten_park.jpg"), - ("a horse and dog on a leash", "horse_dog.jpg"), - ("A guinea pig eating grass", "guineapig_grass.jpg")] + query_expected_image_pairs = [ + ("kitten", "kitten_park.jpg"), + ("a horse and dog on a leash", "horse_dog.jpg"), + ("A guinea pig eating grass", "guineapig_grass.jpg"), + ] for query, expected_image_name in query_expected_image_pairs: # Act @@ -135,7 +138,9 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter(), FileFilter()] - model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + model.orgmode_search = text_search.setup( + OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + ) user_query = quote('+"Emacs" file:"*.org"') # Act @@ -152,7 +157,9 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + model.orgmode_search = text_search.setup( + OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + ) user_query = quote('How to git install application? +"Emacs"') # Act @@ -169,7 +176,9 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + model.orgmode_search = text_search.setup( + OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + ) user_query = quote('How to git install application? -"clone"') # Act diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 480719fd..d764dca7 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -10,53 +10,59 @@ from khoj.utils.rawconfig import Entry def test_date_filter(): entries = [ - Entry(compiled='', raw='Entry with no date'), - Entry(compiled='', raw='April Fools entry: 1984-04-01'), - Entry(compiled='', raw='Entry with date:1984-04-02') + Entry(compiled="", raw="Entry with no date"), + Entry(compiled="", raw="April Fools entry: 1984-04-01"), + Entry(compiled="", raw="Entry with date:1984-04-02"), ] - q_with_no_date_filter = 'head tail' + q_with_no_date_filter = "head tail" ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {0, 1, 2} q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries) - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == set() query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {1} query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {1, 2} def test_extract_date_range(): - assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()] + assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [ + datetime(1984, 1, 5, 0, 0, 0).timestamp(), + datetime(1984, 1, 7, 0, 0, 0).timestamp(), + ] assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf] - assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()] + assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [ + datetime(1984, 1, 1, 0, 0, 0).timestamp(), + datetime(1984, 1, 2, 0, 0, 0).timestamp(), + ] # Unparseable date filter specified in query assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None # No date filter specified in query - assert DateFilter().extract_date_range('head tail') == None + assert DateFilter().extract_date_range("head tail") == None # Non intersecting date ranges assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None @@ -66,43 +72,79 @@ def test_parse(): test_now = datetime(1984, 4, 1, 21, 21, 21) # day variations - assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0)) - assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0)) - assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0)) - assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0)) + assert DateFilter().parse("today", relative_base=test_now) == ( + datetime(1984, 4, 1, 0, 0, 0), + datetime(1984, 4, 2, 0, 0, 0), + ) + assert DateFilter().parse("tomorrow", relative_base=test_now) == ( + datetime(1984, 4, 2, 0, 0, 0), + datetime(1984, 4, 3, 0, 0, 0), + ) + assert DateFilter().parse("yesterday", relative_base=test_now) == ( + datetime(1984, 3, 31, 0, 0, 0), + datetime(1984, 4, 1, 0, 0, 0), + ) + assert DateFilter().parse("5 days ago", relative_base=test_now) == ( + datetime(1984, 3, 27, 0, 0, 0), + datetime(1984, 3, 28, 0, 0, 0), + ) # week variations - assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0)) - assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0)) + assert DateFilter().parse("last week", relative_base=test_now) == ( + datetime(1984, 3, 18, 0, 0, 0), + datetime(1984, 3, 25, 0, 0, 0), + ) + assert DateFilter().parse("2 weeks ago", relative_base=test_now) == ( + datetime(1984, 3, 11, 0, 0, 0), + datetime(1984, 3, 18, 0, 0, 0), + ) # month variations - assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0)) - assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0)) + assert DateFilter().parse("next month", relative_base=test_now) == ( + datetime(1984, 5, 1, 0, 0, 0), + datetime(1984, 6, 1, 0, 0, 0), + ) + assert DateFilter().parse("2 months ago", relative_base=test_now) == ( + datetime(1984, 2, 1, 0, 0, 0), + datetime(1984, 3, 1, 0, 0, 0), + ) # year variations - assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0)) - assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0)) + assert DateFilter().parse("this year", relative_base=test_now) == ( + datetime(1984, 1, 1, 0, 0, 0), + datetime(1985, 1, 1, 0, 0, 0), + ) + assert DateFilter().parse("20 years later", relative_base=test_now) == ( + datetime(2004, 1, 1, 0, 0, 0), + datetime(2005, 1, 1, 0, 0, 0), + ) # specific month/date variation - assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) - assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) + assert DateFilter().parse("in august", relative_base=test_now) == ( + datetime(1983, 8, 1, 0, 0, 0), + datetime(1983, 8, 2, 0, 0, 0), + ) + assert DateFilter().parse("on 1983-08-01", relative_base=test_now) == ( + datetime(1983, 8, 1, 0, 0, 0), + datetime(1983, 8, 2, 0, 0, 0), + ) def test_date_filter_regex(): dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"') - assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] + assert dtrange_match == [(">", "today"), (":", "1984-01-01")] dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail') - assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] + assert dtrange_match == [(">", "today"), (":", "1984-01-01")] dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"') - assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')] + assert dtrange_match == [(">=", "today"), ("=", "1984-01-01")] dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail') - assert dtrange_match == [('<', 'multi word date')] + assert dtrange_match == [("<", "multi word date")] dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"') - assert dtrange_match == [('<=', 'multi word date')] + assert dtrange_match == [("<=", "multi word date")] - dtrange_match = re.findall(DateFilter().date_regex, 'head tail') - assert dtrange_match == [] \ No newline at end of file + dtrange_match = re.findall(DateFilter().date_regex, "head tail") + assert dtrange_match == [] diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index 7ac98f56..2ae82f66 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -7,7 +7,7 @@ def test_no_file_filter(): # Arrange file_filter = FileFilter() entries = arrange_content() - q_with_no_filter = 'head tail' + q_with_no_filter = "head tail" # Act can_filter = file_filter.can_filter(q_with_no_filter) @@ -15,7 +15,7 @@ def test_no_file_filter(): # Assert assert can_filter == False - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {0, 1, 2, 3} @@ -31,7 +31,7 @@ def test_file_filter_with_non_existent_file(): # Assert assert can_filter == True - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {} @@ -47,7 +47,7 @@ def test_single_file_filter(): # Assert assert can_filter == True - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {0, 2} @@ -63,7 +63,7 @@ def test_file_filter_with_partial_match(): # Assert assert can_filter == True - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {0, 2} @@ -79,7 +79,7 @@ def test_file_filter_with_regex_match(): # Assert assert can_filter == True - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {0, 1, 2, 3} @@ -95,16 +95,16 @@ def test_multiple_file_filter(): # Assert assert can_filter == True - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {0, 1, 2, 3} def arrange_content(): entries = [ - Entry(compiled='', raw='First Entry', file= 'file 1.org'), - Entry(compiled='', raw='Second Entry', file= 'file2.org'), - Entry(compiled='', raw='Third Entry', file= 'file 1.org'), - Entry(compiled='', raw='Fourth Entry', file= 'file2.org') + Entry(compiled="", raw="First Entry", file="file 1.org"), + Entry(compiled="", raw="Second Entry", file="file2.org"), + Entry(compiled="", raw="Third Entry", file="file 1.org"), + Entry(compiled="", raw="Fourth Entry", file="file2.org"), ] return entries diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 2ee1569e..622592b1 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,6 @@ from khoj.utils import helpers + def test_get_from_null_dict(): # null handling assert helpers.get_from_dict(dict()) == dict() @@ -7,39 +8,39 @@ def test_get_from_null_dict(): # key present in nested dictionary # 1-level dictionary - assert helpers.get_from_dict({'a': 1, 'b': 2}, 'a') == 1 - assert helpers.get_from_dict({'a': 1, 'b': 2}, 'c') == None + assert helpers.get_from_dict({"a": 1, "b": 2}, "a") == 1 + assert helpers.get_from_dict({"a": 1, "b": 2}, "c") == None # 2-level dictionary - assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'a') == {'a_a': 1} - assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'a', 'a_a') == 1 + assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a") == {"a_a": 1} + assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a", "a_a") == 1 # key not present in nested dictionary # 2-level_dictionary - assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'b', 'b_a') == None + assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "b", "b_a") == None def test_merge_dicts(): # basic merge of dicts with non-overlapping keys - assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'b': 2}) == {'a': 1, 'b': 2} + assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"b": 2}) == {"a": 1, "b": 2} # use default dict items when not present in priority dict - assert helpers.merge_dicts(priority_dict={}, default_dict={'b': 2}) == {'b': 2} + assert helpers.merge_dicts(priority_dict={}, default_dict={"b": 2}) == {"b": 2} # do not override existing key in priority_dict with default dict - assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1} + assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"a": 2}) == {"a": 1} def test_lru_cache(): # Test initializing cache - cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2) - assert cache == {'a': 1, 'b': 2} + cache = helpers.LRU({"a": 1, "b": 2}, capacity=2) + assert cache == {"a": 1, "b": 2} # Test capacity overflow - cache['c'] = 3 - assert cache == {'b': 2, 'c': 3} + cache["c"] = 3 + assert cache == {"b": 2, "c": 3} # Test delete least recently used item from LRU cache on capacity overflow - cache['b'] # accessing 'b' makes it the most recently used item - cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b' - assert cache == {'b': 2, 'd': 4} + cache["b"] # accessing 'b' makes it the most recently used item + cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b' + assert cache == {"b": 2, "d": 4} diff --git a/tests/test_image_search.py b/tests/test_image_search.py index cb19a55a..c29e93a1 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -30,7 +30,8 @@ def test_image_metadata(content_config: ContentConfig): expected_metadata_image_name_pairs = [ (["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"), (["Pasture.", "Horse", "Dog"], "horse_dog.jpg"), - (["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg")] + (["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg"), + ] test_image_paths = [ Path(content_config.image.input_directories[0] / image_name[1]) @@ -51,23 +52,23 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig # Arrange output_directory = resolve_absolute_path(web_directory) model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) - query_expected_image_pairs = [("kitten", "kitten_park.jpg"), - ("horse and dog in a farm", "horse_dog.jpg"), - ("A guinea pig eating grass", "guineapig_grass.jpg")] + query_expected_image_pairs = [ + ("kitten", "kitten_park.jpg"), + ("horse and dog in a farm", "horse_dog.jpg"), + ("A guinea pig eating grass", "guineapig_grass.jpg"), + ] # Act for query, expected_image_name in query_expected_image_pairs: - hits = image_search.query( - query, - count = 1, - model = model.image_search) + hits = image_search.query(query, count=1, model=model.image_search) results = image_search.collate_results( hits, model.image_search.image_names, output_directory=output_directory, - image_files_url='/static/images', - count=1) + image_files_url="/static/images", + count=1, + ) actual_image_path = output_directory.joinpath(Path(results[0].entry).name) actual_image = Image.open(actual_image_path) @@ -86,16 +87,13 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf # Arrange model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) max_words_supported = 10 - query = " ".join(["hello"]*100) - truncated_query = " ".join(["hello"]*max_words_supported) + query = " ".join(["hello"] * 100) + truncated_query = " ".join(["hello"] * max_words_supported) # Act try: with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): - image_search.query( - query, - count = 1, - model = model.image_search) + image_search.query(query, count=1, model=model.image_search) # Assert except RuntimeError as e: if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): @@ -115,17 +113,15 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config: # Act with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): - hits = image_search.query( - query, - count = 1, - model = model.image_search) + hits = image_search.query(query, count=1, model=model.image_search) results = image_search.collate_results( hits, model.image_search.image_names, output_directory=output_directory, - image_files_url='/static/images', - count=1) + image_files_url="/static/images", + count=1, + ) actual_image_path = output_directory.joinpath(Path(results[0].entry).name) actual_image = Image.open(actual_image_path) @@ -133,7 +129,9 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config: # Assert # Ensure file search triggered instead of query with file path as string - assert f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text, "File search not triggered" + assert ( + f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text + ), "File search not triggered" # Ensure the correct image is returned assert expected_image == actual_image, "Incorrect image returned by file search" diff --git a/tests/test_markdown_to_jsonl.py b/tests/test_markdown_to_jsonl.py index a019ef5e..16f19ab1 100644 --- a/tests/test_markdown_to_jsonl.py +++ b/tests/test_markdown_to_jsonl.py @@ -8,10 +8,10 @@ from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl def test_markdown_file_with_no_headings_to_jsonl(tmp_path): "Convert files with no heading to jsonl." # Arrange - entry = f''' + entry = f""" - Bullet point 1 - Bullet point 2 - ''' + """ markdownfile = create_file(tmp_path, entry) # Act @@ -20,7 +20,8 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path): # Process Each Entry from All Notes Files jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( - MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)) + MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -30,10 +31,10 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path): def test_single_markdown_entry_to_jsonl(tmp_path): "Convert markdown entry from single file to jsonl." # Arrange - entry = f'''### Heading + entry = f"""### Heading \t\r Body Line 1 - ''' + """ markdownfile = create_file(tmp_path, entry) # Act @@ -42,7 +43,8 @@ def test_single_markdown_entry_to_jsonl(tmp_path): # Process Each Entry from All Notes Files jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( - MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)) + MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -52,14 +54,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path): def test_multiple_markdown_entries_to_jsonl(tmp_path): "Convert multiple markdown entries from single file to jsonl." # Arrange - entry = f''' + entry = f""" ### Heading 1 \t\r Heading 1 Body Line 1 ### Heading 2 \t\r Heading 2 Body Line 2 - ''' + """ markdownfile = create_file(tmp_path, entry) # Act @@ -68,7 +70,8 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path): # Process Each Entry from All Notes Files jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( - MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)) + MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -92,8 +95,8 @@ def test_get_markdown_files(tmp_path): expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1])) # Setup input-files, input-filters - input_files = [tmp_path / 'notes.md'] - input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown'] + input_files = [tmp_path / "notes.md"] + input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"] # Act extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter) @@ -106,10 +109,10 @@ def test_get_markdown_files(tmp_path): def test_extract_entries_with_different_level_headings(tmp_path): "Extract markdown entries with different level headings." # Arrange - entry = f''' + entry = f""" # Heading 1 ## Heading 2 -''' +""" markdownfile = create_file(tmp_path, entry) # Act diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index 62cab75b..89a82a4d 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -9,23 +9,25 @@ from khoj.utils.rawconfig import Entry def test_configure_heading_entry_to_jsonl(tmp_path): - '''Ensure entries with empty body are ignored, unless explicitly configured to index heading entries. - Property drawers not considered Body. Ignore control characters for evaluating if Body empty.''' + """Ensure entries with empty body are ignored, unless explicitly configured to index heading entries. + Property drawers not considered Body. Ignore control characters for evaluating if Body empty.""" # Arrange - entry = f'''*** Heading + entry = f"""*** Heading :PROPERTIES: :ID: 42-42-42 :END: \t \r - ''' + """ orgfile = create_file(tmp_path, entry) for index_heading_entries in [True, False]: # Act # Extract entries into jsonl from specified Org files - jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries( - *OrgToJsonl.extract_org_entries(org_files=[orgfile]), - index_heading_entries=index_heading_entries)) + jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( + OrgToJsonl.convert_org_nodes_to_entries( + *OrgToJsonl.extract_org_entries(org_files=[orgfile]), index_heading_entries=index_heading_entries + ) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -40,10 +42,10 @@ def test_configure_heading_entry_to_jsonl(tmp_path): def test_entry_split_when_exceeds_max_words(tmp_path): "Ensure entries with compiled words exceeding max_words are split." # Arrange - entry = f'''*** Heading + entry = f"""*** Heading \t\r Body Line 1 - ''' + """ orgfile = create_file(tmp_path, entry) # Act @@ -53,9 +55,9 @@ def test_entry_split_when_exceeds_max_words(tmp_path): # Split each entry from specified Org files by max words jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( TextToJsonl.split_entries_by_max_tokens( - OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), - max_tokens = 2) + OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=2 ) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -65,15 +67,15 @@ def test_entry_split_when_exceeds_max_words(tmp_path): def test_entry_split_drops_large_words(tmp_path): "Ensure entries drops words larger than specified max word length from compiled version." # Arrange - entry_text = f'''*** Heading + entry_text = f"""*** Heading \t\r Body Line 1 - ''' + """ entry = Entry(raw=entry_text, compiled=entry_text) # Act # Split entry by max words and drop words larger than max word length - processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length = 5)[0] + processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0] # Assert # "Heading" dropped from compiled version because its over the set max word limit @@ -83,13 +85,13 @@ def test_entry_split_drops_large_words(tmp_path): def test_entry_with_body_to_jsonl(tmp_path): "Ensure entries with valid body text are loaded." # Arrange - entry = f'''*** Heading + entry = f"""*** Heading :PROPERTIES: :ID: 42-42-42 :END: \t\r Body Line 1 - ''' + """ orgfile = create_file(tmp_path, entry) # Act @@ -97,7 +99,9 @@ def test_entry_with_body_to_jsonl(tmp_path): entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map)) + jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( + OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map) + ) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -107,10 +111,10 @@ def test_entry_with_body_to_jsonl(tmp_path): def test_file_with_no_headings_to_jsonl(tmp_path): "Ensure files with no heading, only body text are loaded." # Arrange - entry = f''' + entry = f""" - Bullet point 1 - Bullet point 2 - ''' + """ orgfile = create_file(tmp_path, entry) # Act @@ -120,7 +124,7 @@ def test_file_with_no_headings_to_jsonl(tmp_path): # Process Each Entry from All Notes Files entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries) jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries) - jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] + jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert assert len(jsonl_data) == 1 @@ -143,8 +147,8 @@ def test_get_org_files(tmp_path): expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1])) # Setup input-files, input-filters - input_files = [tmp_path / 'orgfile1.org'] - input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org'] + input_files = [tmp_path / "orgfile1.org"] + input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"] # Act extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter) @@ -157,10 +161,10 @@ def test_get_org_files(tmp_path): def test_extract_entries_with_different_level_headings(tmp_path): "Extract org entries with different level headings." # Arrange - entry = f''' + entry = f""" * Heading 1 ** Heading 2 -''' +""" orgfile = create_file(tmp_path, entry) # Act @@ -169,8 +173,8 @@ def test_extract_entries_with_different_level_headings(tmp_path): # Assert assert len(entries) == 2 - assert f'{entries[0]}'.startswith("* Heading 1") - assert f'{entries[1]}'.startswith("** Heading 2") + assert f"{entries[0]}".startswith("* Heading 1") + assert f"{entries[1]}".startswith("** Heading 2") # Helper Functions diff --git a/tests/test_orgnode.py b/tests/test_orgnode.py index af67deed..53dee212 100644 --- a/tests/test_orgnode.py +++ b/tests/test_orgnode.py @@ -10,7 +10,7 @@ from khoj.processor.org_mode import orgnode def test_parse_entry_with_no_headings(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f'''Body Line 1''' + entry = f"""Body Line 1""" orgfile = create_file(tmp_path, entry) # Act @@ -18,7 +18,7 @@ def test_parse_entry_with_no_headings(tmp_path): # Assert assert len(entries) == 1 - assert entries[0].heading == f'{orgfile}' + assert entries[0].heading == f"{orgfile}" assert entries[0].tags == list() assert entries[0].body == "Body Line 1" assert entries[0].priority == "" @@ -32,9 +32,9 @@ def test_parse_entry_with_no_headings(tmp_path): def test_parse_minimal_entry(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f''' + entry = f""" * Heading -Body Line 1''' +Body Line 1""" orgfile = create_file(tmp_path, entry) # Act @@ -56,7 +56,7 @@ Body Line 1''' def test_parse_complete_entry(tmp_path): "Test parsing of entry with all important fields" # Arrange - entry = f''' + entry = f""" *** DONE [#A] Heading :Tag1:TAG2:tag3: CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun> :PROPERTIES: @@ -67,7 +67,7 @@ CLOCK: [1984-04-01 Sun 09:00]--[1984-04-01 Sun 12:00] => 3:00 - Clocked Log 1 :END: Body Line 1 -Body Line 2''' +Body Line 2""" orgfile = create_file(tmp_path, entry) # Act @@ -81,45 +81,45 @@ Body Line 2''' assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2" assert entries[0].priority == "A" assert entries[0].Property("ID") == "id:123-456-789-4234-1231" - assert entries[0].closed == datetime.date(1984,4,1) - assert entries[0].scheduled == datetime.date(1984,4,1) - assert entries[0].deadline == datetime.date(1984,4,1) - assert entries[0].logbook == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))] + assert entries[0].closed == datetime.date(1984, 4, 1) + assert entries[0].scheduled == datetime.date(1984, 4, 1) + assert entries[0].deadline == datetime.date(1984, 4, 1) + assert entries[0].logbook == [(datetime.datetime(1984, 4, 1, 9, 0, 0), datetime.datetime(1984, 4, 1, 12, 0, 0))] # ---------------------------------------------------------------------------------------------------- def test_render_entry_with_property_drawer_and_empty_body(tmp_path): "Render heading entry with property drawer" # Arrange - entry_to_render = f''' + entry_to_render = f""" *** [#A] Heading1 :tag1: :PROPERTIES: :ID: 111-111-111-1111-1111 :END: \t\r \n -''' +""" orgfile = create_file(tmp_path, entry_to_render) - expected_entry = f'''*** [#A] Heading1 :tag1: + expected_entry = f"""*** [#A] Heading1 :tag1: :PROPERTIES: :LINE: file:{orgfile}::2 :ID: id:111-111-111-1111-1111 :SOURCE: [[file:{orgfile}::*Heading1]] :END: -''' +""" # Act parsed_entries = orgnode.makelist(orgfile) # Assert - assert f'{parsed_entries[0]}' == expected_entry + assert f"{parsed_entries[0]}" == expected_entry # ---------------------------------------------------------------------------------------------------- def test_all_links_to_entry_rendered(tmp_path): "Ensure all links to entry rendered in property drawer from entry" # Arrange - entry = f''' + entry = f""" *** [#A] Heading :tag1: :PROPERTIES: :ID: 123-456-789-4234-1231 @@ -127,7 +127,7 @@ def test_all_links_to_entry_rendered(tmp_path): Body Line 1 *** Heading2 Body Line 2 -''' +""" orgfile = create_file(tmp_path, entry) # Act @@ -135,23 +135,23 @@ Body Line 2 # Assert # SOURCE link rendered with Heading - assert f':SOURCE: [[file:{orgfile}::*{entries[0].heading}]]' in f'{entries[0]}' + assert f":SOURCE: [[file:{orgfile}::*{entries[0].heading}]]" in f"{entries[0]}" # ID link rendered with ID - assert f':ID: id:123-456-789-4234-1231' in f'{entries[0]}' + assert f":ID: id:123-456-789-4234-1231" in f"{entries[0]}" # LINE link rendered with line number - assert f':LINE: file:{orgfile}::2' in f'{entries[0]}' + assert f":LINE: file:{orgfile}::2" in f"{entries[0]}" # ---------------------------------------------------------------------------------------------------- def test_source_link_to_entry_escaped_for_rendering(tmp_path): "Test SOURCE link renders with square brackets in filename, heading escaped for org-mode rendering" # Arrange - entry = f''' + entry = f""" *** [#A] Heading[1] :tag1: :PROPERTIES: :ID: 123-456-789-4234-1231 :END: -Body Line 1''' +Body Line 1""" orgfile = create_file(tmp_path, entry, filename="test[1].org") # Act @@ -162,15 +162,15 @@ Body Line 1''' # parsed heading from entry assert entries[0].heading == "Heading[1]" # ensure SOURCE link has square brackets in filename, heading escaped in rendered entries - escaped_orgfile = f'{orgfile}'.replace("[1]", "\\[1\\]") - assert f':SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]' in f'{entries[0]}' + escaped_orgfile = f"{orgfile}".replace("[1]", "\\[1\\]") + assert f":SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]" in f"{entries[0]}" # ---------------------------------------------------------------------------------------------------- def test_parse_multiple_entries(tmp_path): "Test parsing of multiple entries" # Arrange - content = f''' + content = f""" *** FAILED [#A] Heading1 :tag1: CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun> :PROPERTIES: @@ -193,7 +193,7 @@ CLOCK: [1984-04-02 Mon 09:00]--[1984-04-02 Mon 12:00] => 3:00 :END: Body 2 -''' +""" orgfile = create_file(tmp_path, content) # Act @@ -208,18 +208,20 @@ Body 2 assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n" assert entry.priority == "A" assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}" - assert entry.closed == datetime.date(1984,4,index+1) - assert entry.scheduled == datetime.date(1984,4,index+1) - assert entry.deadline == datetime.date(1984,4,index+1) - assert entry.logbook == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))] + assert entry.closed == datetime.date(1984, 4, index + 1) + assert entry.scheduled == datetime.date(1984, 4, index + 1) + assert entry.deadline == datetime.date(1984, 4, index + 1) + assert entry.logbook == [ + (datetime.datetime(1984, 4, index + 1, 9, 0, 0), datetime.datetime(1984, 4, index + 1, 12, 0, 0)) + ] # ---------------------------------------------------------------------------------------------------- def test_parse_entry_with_empty_title(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f'''#+TITLE: -Body Line 1''' + entry = f"""#+TITLE: +Body Line 1""" orgfile = create_file(tmp_path, entry) # Act @@ -227,7 +229,7 @@ Body Line 1''' # Assert assert len(entries) == 1 - assert entries[0].heading == f'{orgfile}' + assert entries[0].heading == f"{orgfile}" assert entries[0].tags == list() assert entries[0].body == "Body Line 1" assert entries[0].priority == "" @@ -241,8 +243,8 @@ Body Line 1''' def test_parse_entry_with_title_and_no_headings(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f'''#+TITLE: test -Body Line 1''' + entry = f"""#+TITLE: test +Body Line 1""" orgfile = create_file(tmp_path, entry) # Act @@ -250,7 +252,7 @@ Body Line 1''' # Assert assert len(entries) == 1 - assert entries[0].heading == 'test' + assert entries[0].heading == "test" assert entries[0].tags == list() assert entries[0].body == "Body Line 1" assert entries[0].priority == "" @@ -264,9 +266,9 @@ Body Line 1''' def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f'''#+TITLE: title1 + entry = f"""#+TITLE: title1 Body Line 1 -#+TITLE: title2 ''' +#+TITLE: title2 """ orgfile = create_file(tmp_path, entry) # Act @@ -274,7 +276,7 @@ Body Line 1 # Assert assert len(entries) == 1 - assert entries[0].heading == 'title1 title2' + assert entries[0].heading == "title1 title2" assert entries[0].tags == list() assert entries[0].body == "Body Line 1\n" assert entries[0].priority == "" diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 412bceec..871af227 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -14,7 +14,9 @@ from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl # Test # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig): +def test_asymmetric_setup_with_missing_file_raises_error( + org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig +): # Arrange # Ensure file mentioned in org.input-files is missing single_new_file = Path(org_config_with_only_new_file.input_files[0]) @@ -27,10 +29,12 @@ def test_asymmetric_setup_with_missing_file_raises_error(org_config_with_only_ne # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup_with_empty_file_raises_error(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig): +def test_asymmetric_setup_with_empty_file_raises_error( + org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig +): # Act # Generate notes embeddings during asymmetric setup - with pytest.raises(ValueError, match=r'^No valid entries found*'): + with pytest.raises(ValueError, match=r"^No valid entries found*"): text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True) @@ -52,15 +56,9 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC query = "How to git install application?" # Act - hits, entries = text_search.query( - query, - model = model.notes_search, - rank_results=True) + hits, entries = text_search.query(query, model=model.notes_search, rank_results=True) - results = text_search.collate_results( - hits, - entries, - count=1) + results = text_search.collate_results(hits, entries, count=1) # Assert # Actual_data should contain "Khoj via Emacs" entry @@ -76,12 +74,14 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) with open(new_file_to_index, "w") as f: f.write(f"* Entry more than {max_tokens} words\n") - for index in range(max_tokens+1): + for index in range(max_tokens + 1): f.write(f"{index} ") # Act # reload embeddings, entries, notes model after adding new org-mode file - initial_notes_model = text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup( + OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False + ) # Assert # verify newly added org-mode entry is split by max tokens @@ -92,18 +92,20 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent # ---------------------------------------------------------------------------------------------------- def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): # Arrange - initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 # append org-mode entry to first org input file in config - content_config.org.input_files = [f'{new_org_file}'] + content_config.org.input_files = [f"{new_org_file}"] with open(new_org_file, "w") as f: f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") # regenerate notes jsonl, model embeddings and model to include entry from new file - regenerated_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + regenerated_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True + ) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files @@ -137,7 +139,7 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search # Act # update embeddings, entries with the newly added note - content_config.org.input_files = [f'{new_org_file}'] + content_config.org.input_files = [f"{new_org_file}"] initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) # Assert diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 4a7c894b..82d0dce8 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -7,7 +7,7 @@ def test_no_word_filter(): # Arrange word_filter = WordFilter() entries = arrange_content() - q_with_no_filter = 'head tail' + q_with_no_filter = "head tail" # Act can_filter = word_filter.can_filter(q_with_no_filter) @@ -15,7 +15,7 @@ def test_no_word_filter(): # Assert assert can_filter == False - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {0, 1, 2, 3} @@ -31,7 +31,7 @@ def test_word_exclude_filter(): # Assert assert can_filter == True - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {0, 2} @@ -47,7 +47,7 @@ def test_word_include_filter(): # Assert assert can_filter == True - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {2, 3} @@ -63,16 +63,16 @@ def test_word_include_and_exclude_filter(): # Assert assert can_filter == True - assert ret_query == 'head tail' + assert ret_query == "head tail" assert entry_indices == {2} def arrange_content(): entries = [ - Entry(compiled='', raw='Minimal Entry'), - Entry(compiled='', raw='Entry with exclude_word'), - Entry(compiled='', raw='Entry with include_word'), - Entry(compiled='', raw='Entry with include_word and exclude_word') + Entry(compiled="", raw="Minimal Entry"), + Entry(compiled="", raw="Entry with exclude_word"), + Entry(compiled="", raw="Entry with include_word"), + Entry(compiled="", raw="Entry with include_word and exclude_word"), ] return entries