Use Black to format Khoj server code and tests

This commit is contained in:
Debanjum Singh Solanky 2023-02-17 10:04:26 -06:00
parent 6130fddf45
commit 5e83baab21
44 changed files with 1167 additions and 915 deletions

View file

@ -64,7 +64,8 @@ khoj = "khoj.main:run"
[project.optional-dependencies] [project.optional-dependencies]
test = [ test = [
"pytest == 7.1.2", "pytest >= 7.1.2",
"black >= 23.1.0",
] ]
dev = ["khoj-assistant[test]"] dev = ["khoj-assistant[test]"]
@ -88,3 +89,6 @@ exclude = [
"src/khoj/interface/desktop/file_browser.py", "src/khoj/interface/desktop/file_browser.py",
"src/khoj/interface/desktop/system_tray.py", "src/khoj/interface/desktop/system_tray.py",
] ]
[tool.black]
line-length = 120

View file

@ -26,10 +26,12 @@ logger = logging.getLogger(__name__)
def configure_server(args, required=False): def configure_server(args, required=False):
if args.config is None: if args.config is None:
if required: if required:
logger.error(f'Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.') logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.")
sys.exit(1) sys.exit(1)
else: else:
logger.warn(f'Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}.') logger.warn(
f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}."
)
return return
else: else:
state.config = args.config state.config = args.config
@ -60,7 +62,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.org, config.content_type.org,
search_config=config.search_type.asymmetric, search_config=config.search_type.asymmetric,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()]) filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Org Music Search # Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music: if (t == SearchType.Music or t == None) and config.content_type.music:
@ -70,7 +73,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.music, config.content_type.music,
search_config=config.search_type.asymmetric, search_config=config.search_type.asymmetric,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter()]) filters=[DateFilter(), WordFilter()],
)
# Initialize Markdown Search # Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown: if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
@ -80,7 +84,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.markdown, config.content_type.markdown,
search_config=config.search_type.asymmetric, search_config=config.search_type.asymmetric,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()]) filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Ledger Search # Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger: if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
@ -90,15 +95,15 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.ledger, config.content_type.ledger,
search_config=config.search_type.symmetric, search_config=config.search_type.symmetric,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()]) filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Image Search # Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image: if (t == SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup( model.image_search = image_search.setup(
config.content_type.image, config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
search_config=config.search_type.image, )
regenerate=regenerate)
# Invalidate Query Cache # Invalidate Query Cache
state.query_cache = LRU() state.query_cache = LRU()
@ -125,9 +130,9 @@ def configure_conversation_processor(conversation_processor_config):
if conversation_logfile.is_file(): if conversation_logfile.is_file():
# Load Metadata Logs from Conversation Logfile # Load Metadata Logs from Conversation Logfile
with conversation_logfile.open('r') as f: with conversation_logfile.open("r") as f:
conversation_processor.meta_log = json.load(f) conversation_processor.meta_log = json.load(f)
logger.info('Conversation logs loaded from disk.') logger.info("Conversation logs loaded from disk.")
else: else:
# Initialize Conversation Logs # Initialize Conversation Logs
conversation_processor.meta_log = {} conversation_processor.meta_log = {}

View file

@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty
class FileBrowser(QtWidgets.QWidget): class FileBrowser(QtWidgets.QWidget):
def __init__(self, title, search_type: SearchType=None, default_files:list=[]): def __init__(self, title, search_type: SearchType = None, default_files: list = []):
QtWidgets.QWidget.__init__(self) QtWidgets.QWidget.__init__(self)
layout = QtWidgets.QHBoxLayout() layout = QtWidgets.QHBoxLayout()
self.setLayout(layout) self.setLayout(layout)
@ -22,51 +22,54 @@ class FileBrowser(QtWidgets.QWidget):
self.label.setFixedWidth(95) self.label.setFixedWidth(95)
self.label.setWordWrap(True) self.label.setWordWrap(True)
layout.addWidget(self.label) layout.addWidget(self.label)
self.lineEdit = QtWidgets.QPlainTextEdit(self) self.lineEdit = QtWidgets.QPlainTextEdit(self)
self.lineEdit.setFixedWidth(330) self.lineEdit.setFixedWidth(330)
self.setFiles(default_files) self.setFiles(default_files)
self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90)) self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90))
self.lineEdit.textChanged.connect(self.updateFieldHeight) self.lineEdit.textChanged.connect(self.updateFieldHeight)
layout.addWidget(self.lineEdit) layout.addWidget(self.lineEdit)
self.button = QtWidgets.QPushButton('Add') self.button = QtWidgets.QPushButton("Add")
self.button.clicked.connect(self.storeFilesSelectedInFileDialog) self.button.clicked.connect(self.storeFilesSelectedInFileDialog)
layout.addWidget(self.button) layout.addWidget(self.button)
layout.addStretch() layout.addStretch()
def getFileFilter(self, search_type): def getFileFilter(self, search_type):
if search_type == SearchType.Org: if search_type == SearchType.Org:
return 'Org-Mode Files (*.org)' return "Org-Mode Files (*.org)"
elif search_type == SearchType.Ledger: elif search_type == SearchType.Ledger:
return 'Beancount Files (*.bean *.beancount)' return "Beancount Files (*.bean *.beancount)"
elif search_type == SearchType.Markdown: elif search_type == SearchType.Markdown:
return 'Markdown Files (*.md *.markdown)' return "Markdown Files (*.md *.markdown)"
elif search_type == SearchType.Music: elif search_type == SearchType.Music:
return 'Org-Music Files (*.org)' return "Org-Music Files (*.org)"
elif search_type == SearchType.Image: elif search_type == SearchType.Image:
return 'Images (*.jp[e]g)' return "Images (*.jp[e]g)"
def storeFilesSelectedInFileDialog(self): def storeFilesSelectedInFileDialog(self):
filepaths = self.getPaths() filepaths = self.getPaths()
if self.search_type == SearchType.Image: if self.search_type == SearchType.Image:
filepaths.append(QtWidgets.QFileDialog.getExistingDirectory(self, caption='Choose Folder', filepaths.append(
directory=self.dirpath)) QtWidgets.QFileDialog.getExistingDirectory(self, caption="Choose Folder", directory=self.dirpath)
)
else: else:
filepaths.extend(QtWidgets.QFileDialog.getOpenFileNames(self, caption='Choose Files', filepaths.extend(
directory=self.dirpath, QtWidgets.QFileDialog.getOpenFileNames(
filter=self.filter_name)[0]) self, caption="Choose Files", directory=self.dirpath, filter=self.filter_name
)[0]
)
self.setFiles(filepaths) self.setFiles(filepaths)
def setFiles(self, paths:list): def setFiles(self, paths: list):
self.filepaths = [path for path in paths if not is_none_or_empty(path)] self.filepaths = [path for path in paths if not is_none_or_empty(path)]
self.lineEdit.setPlainText("\n".join(self.filepaths)) self.lineEdit.setPlainText("\n".join(self.filepaths))
def getPaths(self) -> list: def getPaths(self) -> list:
if self.lineEdit.toPlainText() == '': if self.lineEdit.toPlainText() == "":
return [] return []
else: else:
return self.lineEdit.toPlainText().split('\n') return self.lineEdit.toPlainText().split("\n")
def updateFieldHeight(self): def updateFieldHeight(self):
self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90)) self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90))

View file

@ -6,7 +6,7 @@ from khoj.utils.config import ProcessorType
class LabelledTextField(QtWidgets.QWidget): class LabelledTextField(QtWidgets.QWidget):
def __init__(self, title, processor_type: ProcessorType=None, default_value: str=None): def __init__(self, title, processor_type: ProcessorType = None, default_value: str = None):
QtWidgets.QWidget.__init__(self) QtWidgets.QWidget.__init__(self)
layout = QtWidgets.QHBoxLayout() layout = QtWidgets.QHBoxLayout()
self.setLayout(layout) self.setLayout(layout)

View file

@ -31,9 +31,9 @@ class MainWindow(QtWidgets.QMainWindow):
self.config_file = config_file self.config_file = config_file
# Set regenerate flag to regenerate embeddings everytime user clicks configure # Set regenerate flag to regenerate embeddings everytime user clicks configure
if state.cli_args: if state.cli_args:
state.cli_args += ['--regenerate'] state.cli_args += ["--regenerate"]
else: else:
state.cli_args = ['--regenerate'] state.cli_args = ["--regenerate"]
# Load config from existing config, if exists, else load from default config # Load config from existing config, if exists, else load from default config
if resolve_absolute_path(self.config_file).exists(): if resolve_absolute_path(self.config_file).exists():
@ -49,8 +49,8 @@ class MainWindow(QtWidgets.QMainWindow):
self.setFixedWidth(600) self.setFixedWidth(600)
# Set Window Icon # Set Window Icon
icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png' icon_path = constants.web_directory / "assets/icons/favicon-144x144.png"
self.setWindowIcon(QtGui.QIcon(f'{icon_path.absolute()}')) self.setWindowIcon(QtGui.QIcon(f"{icon_path.absolute()}"))
# Initialize Configure Window Layout # Initialize Configure Window Layout
self.layout = QtWidgets.QVBoxLayout() self.layout = QtWidgets.QVBoxLayout()
@ -58,13 +58,13 @@ class MainWindow(QtWidgets.QMainWindow):
# Add Settings Panels for each Search Type to Configure Window Layout # Add Settings Panels for each Search Type to Configure Window Layout
self.search_settings_panels = [] self.search_settings_panels = []
for search_type in SearchType: for search_type in SearchType:
current_content_config = self.current_config['content-type'].get(search_type, {}) current_content_config = self.current_config["content-type"].get(search_type, {})
self.search_settings_panels += [self.add_settings_panel(current_content_config, search_type)] self.search_settings_panels += [self.add_settings_panel(current_content_config, search_type)]
# Add Conversation Processor Panel to Configure Screen # Add Conversation Processor Panel to Configure Screen
self.processor_settings_panels = [] self.processor_settings_panels = []
conversation_type = ProcessorType.Conversation conversation_type = ProcessorType.Conversation
current_conversation_config = self.current_config['processor'].get(conversation_type, {}) current_conversation_config = self.current_config["processor"].get(conversation_type, {})
self.processor_settings_panels += [self.add_processor_panel(current_conversation_config, conversation_type)] self.processor_settings_panels += [self.add_processor_panel(current_conversation_config, conversation_type)]
# Add Action Buttons Panel # Add Action Buttons Panel
@ -81,11 +81,11 @@ class MainWindow(QtWidgets.QMainWindow):
"Add Settings Panel for specified Search Type. Toggle Editable Search Types" "Add Settings Panel for specified Search Type. Toggle Editable Search Types"
# Get current files from config for given search type # Get current files from config for given search type
if search_type == SearchType.Image: if search_type == SearchType.Image:
current_content_files = current_content_config.get('input-directories', []) current_content_files = current_content_config.get("input-directories", [])
file_input_text = f'{search_type.name} Folders' file_input_text = f"{search_type.name} Folders"
else: else:
current_content_files = current_content_config.get('input-files', []) current_content_files = current_content_config.get("input-files", [])
file_input_text = f'{search_type.name} Files' file_input_text = f"{search_type.name} Files"
# Create widgets to display settings for given search type # Create widgets to display settings for given search type
search_type_settings = QtWidgets.QWidget() search_type_settings = QtWidgets.QWidget()
@ -109,7 +109,7 @@ class MainWindow(QtWidgets.QMainWindow):
def add_processor_panel(self, current_conversation_config: dict, processor_type: ProcessorType): def add_processor_panel(self, current_conversation_config: dict, processor_type: ProcessorType):
"Add Conversation Processor Panel" "Add Conversation Processor Panel"
# Get current settings from config for given processor type # Get current settings from config for given processor type
current_openai_api_key = current_conversation_config.get('openai-api-key', None) current_openai_api_key = current_conversation_config.get("openai-api-key", None)
# Create widgets to display settings for given processor type # Create widgets to display settings for given processor type
processor_type_settings = QtWidgets.QWidget() processor_type_settings = QtWidgets.QWidget()
@ -137,20 +137,22 @@ class MainWindow(QtWidgets.QMainWindow):
action_bar_layout = QtWidgets.QHBoxLayout(action_bar) action_bar_layout = QtWidgets.QHBoxLayout(action_bar)
self.configure_button = QtWidgets.QPushButton("Configure", clicked=self.configure_app) self.configure_button = QtWidgets.QPushButton("Configure", clicked=self.configure_app)
self.search_button = QtWidgets.QPushButton("Search", clicked=lambda: webbrowser.open(f'http://{state.host}:{state.port}/')) self.search_button = QtWidgets.QPushButton(
"Search", clicked=lambda: webbrowser.open(f"http://{state.host}:{state.port}/")
)
self.search_button.setEnabled(not self.first_run) self.search_button.setEnabled(not self.first_run)
action_bar_layout.addWidget(self.configure_button) action_bar_layout.addWidget(self.configure_button)
action_bar_layout.addWidget(self.search_button) action_bar_layout.addWidget(self.search_button)
self.layout.addWidget(action_bar) self.layout.addWidget(action_bar)
def get_default_config(self, search_type:SearchType=None, processor_type:ProcessorType=None): def get_default_config(self, search_type: SearchType = None, processor_type: ProcessorType = None):
"Get default config" "Get default config"
config = constants.default_config config = constants.default_config
if search_type: if search_type:
return config['content-type'][search_type] return config["content-type"][search_type]
elif processor_type: elif processor_type:
return config['processor'][processor_type] return config["processor"][processor_type]
else: else:
return config return config
@ -160,7 +162,9 @@ class MainWindow(QtWidgets.QMainWindow):
for message_prefix in ErrorType: for message_prefix in ErrorType:
for i in reversed(range(self.layout.count())): for i in reversed(range(self.layout.count())):
current_widget = self.layout.itemAt(i).widget() current_widget = self.layout.itemAt(i).widget()
if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(message_prefix.value): if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(
message_prefix.value
):
self.layout.removeWidget(current_widget) self.layout.removeWidget(current_widget)
current_widget.deleteLater() current_widget.deleteLater()
@ -180,18 +184,24 @@ class MainWindow(QtWidgets.QMainWindow):
continue continue
if isinstance(child, SearchCheckBox): if isinstance(child, SearchCheckBox):
# Search Type Disabled # Search Type Disabled
if not child.isChecked() and child.search_type in self.new_config['content-type']: if not child.isChecked() and child.search_type in self.new_config["content-type"]:
del self.new_config['content-type'][child.search_type] del self.new_config["content-type"][child.search_type]
# Search Type (re)-Enabled # Search Type (re)-Enabled
if child.isChecked(): if child.isChecked():
current_search_config = self.current_config['content-type'].get(child.search_type, {}) current_search_config = self.current_config["content-type"].get(child.search_type, {})
default_search_config = self.get_default_config(search_type = child.search_type) default_search_config = self.get_default_config(search_type=child.search_type)
self.new_config['content-type'][child.search_type.value] = merge_dicts(current_search_config, default_search_config) self.new_config["content-type"][child.search_type.value] = merge_dicts(
elif isinstance(child, FileBrowser) and child.search_type in self.new_config['content-type']: current_search_config, default_search_config
)
elif isinstance(child, FileBrowser) and child.search_type in self.new_config["content-type"]:
if child.search_type.value == SearchType.Image: if child.search_type.value == SearchType.Image:
self.new_config['content-type'][child.search_type.value]['input-directories'] = child.getPaths() if child.getPaths() != [] else None self.new_config["content-type"][child.search_type.value]["input-directories"] = (
child.getPaths() if child.getPaths() != [] else None
)
else: else:
self.new_config['content-type'][child.search_type.value]['input-files'] = child.getPaths() if child.getPaths() != [] else None self.new_config["content-type"][child.search_type.value]["input-files"] = (
child.getPaths() if child.getPaths() != [] else None
)
def update_processor_settings(self): def update_processor_settings(self):
"Update config with conversation settings from UI" "Update config with conversation settings from UI"
@ -201,16 +211,20 @@ class MainWindow(QtWidgets.QMainWindow):
continue continue
if isinstance(child, ProcessorCheckBox): if isinstance(child, ProcessorCheckBox):
# Processor Type Disabled # Processor Type Disabled
if not child.isChecked() and child.processor_type in self.new_config['processor']: if not child.isChecked() and child.processor_type in self.new_config["processor"]:
del self.new_config['processor'][child.processor_type] del self.new_config["processor"][child.processor_type]
# Processor Type (re)-Enabled # Processor Type (re)-Enabled
if child.isChecked(): if child.isChecked():
current_processor_config = self.current_config['processor'].get(child.processor_type, {}) current_processor_config = self.current_config["processor"].get(child.processor_type, {})
default_processor_config = self.get_default_config(processor_type = child.processor_type) default_processor_config = self.get_default_config(processor_type=child.processor_type)
self.new_config['processor'][child.processor_type.value] = merge_dicts(current_processor_config, default_processor_config) self.new_config["processor"][child.processor_type.value] = merge_dicts(
elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config['processor']: current_processor_config, default_processor_config
)
elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config["processor"]:
if child.processor_type == ProcessorType.Conversation: if child.processor_type == ProcessorType.Conversation:
self.new_config['processor'][child.processor_type.value]['openai-api-key'] = child.input_field.toPlainText() if child.input_field.toPlainText() != '' else None self.new_config["processor"][child.processor_type.value]["openai-api-key"] = (
child.input_field.toPlainText() if child.input_field.toPlainText() != "" else None
)
def save_settings_to_file(self) -> bool: def save_settings_to_file(self) -> bool:
"Save validated settings to file" "Save validated settings to file"
@ -278,7 +292,7 @@ class MainWindow(QtWidgets.QMainWindow):
self.show() self.show()
self.setWindowState(Qt.WindowState.WindowActive) self.setWindowState(Qt.WindowState.WindowActive)
self.activateWindow() # For Bringing to Top on Windows self.activateWindow() # For Bringing to Top on Windows
self.raise_() # For Bringing to Top from Minimized State on OSX self.raise_() # For Bringing to Top from Minimized State on OSX
class SettingsLoader(QObject): class SettingsLoader(QObject):
@ -312,6 +326,7 @@ class ProcessorCheckBox(QtWidgets.QCheckBox):
self.processor_type = processor_type self.processor_type = processor_type
super(ProcessorCheckBox, self).__init__(text, parent=parent) super(ProcessorCheckBox, self).__init__(text, parent=parent)
class ErrorType(Enum): class ErrorType(Enum):
"Error Types" "Error Types"
ConfigLoadingError = "Config Loading Error" ConfigLoadingError = "Config Loading Error"

View file

@ -17,17 +17,17 @@ def create_system_tray(gui: QtWidgets.QApplication, main_window: MainWindow):
""" """
# Create the system tray with icon # Create the system tray with icon
icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png' icon_path = constants.web_directory / "assets/icons/favicon-144x144.png"
icon = QtGui.QIcon(f'{icon_path.absolute()}') icon = QtGui.QIcon(f"{icon_path.absolute()}")
tray = QtWidgets.QSystemTrayIcon(icon) tray = QtWidgets.QSystemTrayIcon(icon)
tray.setVisible(True) tray.setVisible(True)
# Create the menu and menu actions # Create the menu and menu actions
menu = QtWidgets.QMenu() menu = QtWidgets.QMenu()
menu_actions = [ menu_actions = [
('Search', lambda: webbrowser.open(f'http://{state.host}:{state.port}/')), ("Search", lambda: webbrowser.open(f"http://{state.host}:{state.port}/")),
('Configure', main_window.show_on_top), ("Configure", main_window.show_on_top),
('Quit', gui.quit), ("Quit", gui.quit),
] ]
# Add the menu actions to the menu # Add the menu actions to the menu

View file

@ -8,8 +8,8 @@ import warnings
from platform import system from platform import system
# Ignore non-actionable warnings # Ignore non-actionable warnings
warnings.filterwarnings("ignore", message=r'snapshot_download.py has been made private', category=FutureWarning) warnings.filterwarnings("ignore", message=r"snapshot_download.py has been made private", category=FutureWarning)
warnings.filterwarnings("ignore", message=r'legacy way to download files from the HF hub,', category=FutureWarning) warnings.filterwarnings("ignore", message=r"legacy way to download files from the HF hub,", category=FutureWarning)
# External Packages # External Packages
import uvicorn import uvicorn
@ -43,11 +43,12 @@ rich_handler = RichHandler(rich_tracebacks=True)
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]")) rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
logging.basicConfig(handlers=[rich_handler]) logging.basicConfig(handlers=[rich_handler])
logger = logging.getLogger('khoj') logger = logging.getLogger("khoj")
def run(): def run():
# Turn Tokenizers Parallelism Off. App does not support it. # Turn Tokenizers Parallelism Off. App does not support it.
os.environ["TOKENIZERS_PARALLELISM"] = 'false' os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Load config from CLI # Load config from CLI
state.cli_args = sys.argv[1:] state.cli_args = sys.argv[1:]
@ -66,7 +67,7 @@ def run():
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
# Set Log File # Set Log File
fh = logging.FileHandler(state.config_file.parent / 'khoj.log') fh = logging.FileHandler(state.config_file.parent / "khoj.log")
fh.setLevel(logging.DEBUG) fh.setLevel(logging.DEBUG)
logger.addHandler(fh) logger.addHandler(fh)
@ -87,7 +88,7 @@ def run():
# On Linux (Gnome) the System tray is not supported. # On Linux (Gnome) the System tray is not supported.
# Since only the Main Window is available # Since only the Main Window is available
# Quitting it should quit the application # Quitting it should quit the application
if system() in ['Windows', 'Darwin']: if system() in ["Windows", "Darwin"]:
gui.setQuitOnLastWindowClosed(False) gui.setQuitOnLastWindowClosed(False)
tray = create_system_tray(gui, main_window) tray = create_system_tray(gui, main_window)
tray.show() tray.show()
@ -97,7 +98,7 @@ def run():
server = ServerThread(app, args.host, args.port, args.socket) server = ServerThread(app, args.host, args.port, args.socket)
# Show Main Window on First Run Experience or if on Linux # Show Main Window on First Run Experience or if on Linux
if args.config is None or system() not in ['Windows', 'Darwin']: if args.config is None or system() not in ["Windows", "Darwin"]:
main_window.show() main_window.show()
# Setup Signal Handlers # Setup Signal Handlers
@ -112,9 +113,10 @@ def run():
gui.aboutToQuit.connect(server.terminate) gui.aboutToQuit.connect(server.terminate)
# Close Splash Screen if still open # Close Splash Screen if still open
if system() != 'Darwin': if system() != "Darwin":
try: try:
import pyi_splash import pyi_splash
# Update the text on the splash screen # Update the text on the splash screen
pyi_splash.update_text("Khoj setup complete") pyi_splash.update_text("Khoj setup complete")
# Close Splash Screen # Close Splash Screen
@ -167,5 +169,5 @@ class ServerThread(QThread):
start_server(self.app, self.host, self.port, self.socket) start_server(self.app, self.host, self.port, self.socket)
if __name__ == '__main__': if __name__ == "__main__":
run() run()

View file

@ -19,31 +19,27 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
# Setup Prompt based on Summary Type # Setup Prompt based on Summary Type
if summary_type == "chat": if summary_type == "chat":
prompt = f''' prompt = f"""
You are an AI. Summarize the conversation below from your perspective: You are an AI. Summarize the conversation below from your perspective:
{text} {text}
Summarize the conversation from the AI's first-person perspective:''' Summarize the conversation from the AI's first-person perspective:"""
elif summary_type == "notes": elif summary_type == "notes":
prompt = f''' prompt = f"""
Summarize the below notes about {user_query}: Summarize the below notes about {user_query}:
{text} {text}
Summarize the notes in second person perspective:''' Summarize the notes in second person perspective:"""
# Get Response from GPT # Get Response from GPT
response = openai.Completion.create( response = openai.Completion.create(
prompt=prompt, prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
model=model, )
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop="\"\"\"")
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
story = response['choices'][0]['text'] story = response["choices"][0]["text"]
return str(story).replace("\n\n", "") return str(story).replace("\n\n", "")
@ -53,7 +49,7 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
""" """
# Initialize Variables # Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY") openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
understand_primer = ''' understand_primer = """
Objective: Extract search type from user query and return information as JSON Objective: Extract search type from user query and return information as JSON
Allowed search types are listed below: Allowed search types are listed below:
@ -73,7 +69,7 @@ A:{ "search-type": "notes" }
Q: When did I buy Groceries last? Q: When did I buy Groceries last?
A:{ "search-type": "ledger" } A:{ "search-type": "ledger" }
Q:When did I go surfing last? Q:When did I go surfing last?
A:{ "search-type": "notes" }''' A:{ "search-type": "notes" }"""
# Setup Prompt with Understand Primer # Setup Prompt with Understand Primer
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:") prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
@ -82,15 +78,11 @@ A:{ "search-type": "notes" }'''
# Get Response from GPT # Get Response from GPT
response = openai.Completion.create( response = openai.Completion.create(
prompt=prompt, prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
model=model, )
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop=["\n"])
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text']) story = str(response["choices"][0]["text"])
return json.loads(story.strip(empty_escape_sequences)) return json.loads(story.strip(empty_escape_sequences))
@ -100,7 +92,7 @@ def understand(text, model, api_key=None, temperature=0.5, max_tokens=100, verbo
""" """
# Initialize Variables # Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY") openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
understand_primer = ''' understand_primer = """
Objective: Extract intent and trigger emotion information as JSON from each chat message Objective: Extract intent and trigger emotion information as JSON from each chat message
Potential intent types and valid argument values are listed below: Potential intent types and valid argument values are listed below:
@ -142,7 +134,7 @@ A: { "intent": {"type": "remember", "memory-type": "notes", "query": "recommend
Q: When did I go surfing last? Q: When did I go surfing last?
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" } A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" }
Q: Can you dance for me? Q: Can you dance for me?
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }''' A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }"""
# Setup Prompt with Understand Primer # Setup Prompt with Understand Primer
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:") prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
@ -151,15 +143,11 @@ A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance
# Get Response from GPT # Get Response from GPT
response = openai.Completion.create( response = openai.Completion.create(
prompt=prompt, prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
model=model, )
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop=["\n"])
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text']) story = str(response["choices"][0]["text"])
return json.loads(story.strip(empty_escape_sequences)) return json.loads(story.strip(empty_escape_sequences))
@ -171,15 +159,15 @@ def converse(text, model, conversation_history=None, api_key=None, temperature=0
max_words = 500 max_words = 500
openai.api_key = api_key or os.getenv("OPENAI_API_KEY") openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
conversation_primer = f''' conversation_primer = f"""
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion. The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion.
Human: Hello, who are you? Human: Hello, who are you?
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?''' AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?"""
# Setup Prompt with Primer or Conversation History # Setup Prompt with Primer or Conversation History
prompt = message_to_prompt(text, conversation_history or conversation_primer) prompt = message_to_prompt(text, conversation_history or conversation_primer)
prompt = ' '.join(prompt.split()[:max_words]) prompt = " ".join(prompt.split()[:max_words])
# Get Response from GPT # Get Response from GPT
response = openai.Completion.create( response = openai.Completion.create(
@ -188,14 +176,17 @@ AI: Hi, I am an AI conversational companion created by OpenAI. How can I help yo
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
presence_penalty=0.6, presence_penalty=0.6,
stop=["\n", "Human:", "AI:"]) stop=["\n", "Human:", "AI:"],
)
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text']) story = str(response["choices"][0]["text"])
return story.strip(empty_escape_sequences) return story.strip(empty_escape_sequences)
def message_to_prompt(user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"): def message_to_prompt(
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
):
"""Create prompt for GPT from messages and conversation history""" """Create prompt for GPT from messages and conversation history"""
gpt_message = f" {gpt_message}" if gpt_message else "" gpt_message = f" {gpt_message}" if gpt_message else ""
@ -205,12 +196,8 @@ def message_to_prompt(user_message, conversation_history="", gpt_message=None, s
def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]): def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]):
"""Create json logs from messages, metadata for conversation log""" """Create json logs from messages, metadata for conversation log"""
default_user_message_metadata = { default_user_message_metadata = {
"intent": { "intent": {"type": "remember", "memory-type": "notes", "query": user_message},
"type": "remember", "trigger-emotion": "calm",
"memory-type": "notes",
"query": user_message
},
"trigger-emotion": "calm"
} }
current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -229,5 +216,4 @@ def message_to_log(user_message, gpt_message, user_message_metadata={}, conversa
def extract_summaries(metadata): def extract_summaries(metadata):
"""Extract summaries from metadata""" """Extract summaries from metadata"""
return ''.join( return "".join([f'\n{session["summary"]}' for session in metadata])
[f'\n{session["summary"]}' for session in metadata])

View file

@ -19,7 +19,11 @@ class BeancountToJsonl(TextToJsonl):
# Define Functions # Define Functions
def process(self, previous_entries=None): def process(self, previous_entries=None):
# Extract required fields from config # Extract required fields from config
beancount_files, beancount_file_filter, output_file = self.config.input_files, self.config.input_filter,self.config.compressed_jsonl beancount_files, beancount_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Input Validation # Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
@ -31,7 +35,9 @@ class BeancountToJsonl(TextToJsonl):
# Extract Entries from specified Beancount files # Extract Entries from specified Beancount files
with timer("Parse transactions from Beancount files into dictionaries", logger): with timer("Parse transactions from Beancount files into dictionaries", logger):
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files)) current_entries = BeancountToJsonl.convert_transactions_to_maps(
*BeancountToJsonl.extract_beancount_transactions(beancount_files)
)
# Split entries by max tokens supported by model # Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger): with timer("Split entries by max token size supported by model", logger):
@ -42,7 +48,9 @@ class BeancountToJsonl(TextToJsonl):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write transactions to JSONL file", logger): with timer("Write transactions to JSONL file", logger):
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
@ -62,9 +70,7 @@ class BeancountToJsonl(TextToJsonl):
"Get Beancount files to process" "Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set() absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files: if beancount_files:
absolute_beancount_files = {get_absolute_path(beancount_file) absolute_beancount_files = {get_absolute_path(beancount_file) for beancount_file in beancount_files}
for beancount_file
in beancount_files}
if beancount_file_filters: if beancount_file_filters:
filtered_beancount_files = { filtered_beancount_files = {
filtered_file filtered_file
@ -76,14 +82,13 @@ class BeancountToJsonl(TextToJsonl):
files_with_non_beancount_extensions = { files_with_non_beancount_extensions = {
beancount_file beancount_file
for beancount_file for beancount_file in all_beancount_files
in all_beancount_files
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount") if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
} }
if any(files_with_non_beancount_extensions): if any(files_with_non_beancount_extensions):
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}") print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
logger.info(f'Processing files: {all_beancount_files}') logger.info(f"Processing files: {all_beancount_files}")
return all_beancount_files return all_beancount_files
@ -92,19 +97,20 @@ class BeancountToJsonl(TextToJsonl):
"Extract entries from specified Beancount files" "Extract entries from specified Beancount files"
# Initialize Regex for extracting Beancount Entries # Initialize Regex for extracting Beancount Entries
transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] ' transaction_regex = r"^\n?\d{4}-\d{2}-\d{2} [\*|\!] "
empty_newline = f'^[\n\r\t\ ]*$' empty_newline = f"^[\n\r\t\ ]*$"
entries = [] entries = []
transaction_to_file_map = [] transaction_to_file_map = []
for beancount_file in beancount_files: for beancount_file in beancount_files:
with open(beancount_file) as f: with open(beancount_file) as f:
ledger_content = f.read() ledger_content = f.read()
transactions_per_file = [entry.strip(empty_escape_sequences) transactions_per_file = [
for entry entry.strip(empty_escape_sequences)
in re.split(empty_newline, ledger_content, flags=re.MULTILINE) for entry in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
if re.match(transaction_regex, entry)] if re.match(transaction_regex, entry)
transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file)) ]
transaction_to_file_map += zip(transactions_per_file, [beancount_file] * len(transactions_per_file))
entries.extend(transactions_per_file) entries.extend(transactions_per_file)
return entries, dict(transaction_to_file_map) return entries, dict(transaction_to_file_map)
@ -113,7 +119,9 @@ class BeancountToJsonl(TextToJsonl):
"Convert each parsed Beancount transaction into a Entry" "Convert each parsed Beancount transaction into a Entry"
entries = [] entries = []
for parsed_entry in parsed_entries: for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}')) entries.append(
Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{transaction_to_file_map[parsed_entry]}")
)
logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries") logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
@ -122,4 +130,4 @@ class BeancountToJsonl(TextToJsonl):
@staticmethod @staticmethod
def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str: def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str:
"Convert each Beancount transaction entry to JSON and collate as JSONL" "Convert each Beancount transaction entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries]) return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -20,7 +20,11 @@ class MarkdownToJsonl(TextToJsonl):
# Define Functions # Define Functions
def process(self, previous_entries=None): def process(self, previous_entries=None):
# Extract required fields from config # Extract required fields from config
markdown_files, markdown_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl markdown_files, markdown_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Input Validation # Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter): if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
@ -32,7 +36,9 @@ class MarkdownToJsonl(TextToJsonl):
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files
with timer("Parse entries from Markdown files into dictionaries", logger): with timer("Parse entries from Markdown files into dictionaries", logger):
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files)) current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(
*MarkdownToJsonl.extract_markdown_entries(markdown_files)
)
# Split entries by max tokens supported by model # Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger): with timer("Split entries by max token size supported by model", logger):
@ -43,7 +49,9 @@ class MarkdownToJsonl(TextToJsonl):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write markdown entries to JSONL file", logger): with timer("Write markdown entries to JSONL file", logger):
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
@ -75,15 +83,16 @@ class MarkdownToJsonl(TextToJsonl):
files_with_non_markdown_extensions = { files_with_non_markdown_extensions = {
md_file md_file
for md_file for md_file in all_markdown_files
in all_markdown_files if not md_file.endswith(".md") and not md_file.endswith(".markdown")
if not md_file.endswith(".md") and not md_file.endswith('.markdown')
} }
if any(files_with_non_markdown_extensions): if any(files_with_non_markdown_extensions):
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}") logger.warn(
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
)
logger.info(f'Processing files: {all_markdown_files}') logger.info(f"Processing files: {all_markdown_files}")
return all_markdown_files return all_markdown_files
@ -92,20 +101,20 @@ class MarkdownToJsonl(TextToJsonl):
"Extract entries by heading from specified Markdown files" "Extract entries by heading from specified Markdown files"
# Regex to extract Markdown Entries by Heading # Regex to extract Markdown Entries by Heading
markdown_heading_regex = r'^#' markdown_heading_regex = r"^#"
entries = [] entries = []
entry_to_file_map = [] entry_to_file_map = []
for markdown_file in markdown_files: for markdown_file in markdown_files:
with open(markdown_file, 'r', encoding='utf8') as f: with open(markdown_file, "r", encoding="utf8") as f:
markdown_content = f.read() markdown_content = f.read()
markdown_entries_per_file = [] markdown_entries_per_file = []
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE): for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
prefix = '#' if entry.startswith('#') else '# ' prefix = "#" if entry.startswith("#") else "# "
if entry.strip(empty_escape_sequences) != '': if entry.strip(empty_escape_sequences) != "":
markdown_entries_per_file.append(f'{prefix}{entry.strip(empty_escape_sequences)}') markdown_entries_per_file.append(f"{prefix}{entry.strip(empty_escape_sequences)}")
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file)) entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
entries.extend(markdown_entries_per_file) entries.extend(markdown_entries_per_file)
return entries, dict(entry_to_file_map) return entries, dict(entry_to_file_map)
@ -115,7 +124,7 @@ class MarkdownToJsonl(TextToJsonl):
"Convert each Markdown entries into a dictionary" "Convert each Markdown entries into a dictionary"
entries = [] entries = []
for parsed_entry in parsed_entries: for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}')) entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{entry_to_file_map[parsed_entry]}"))
logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries") logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
@ -124,4 +133,4 @@ class MarkdownToJsonl(TextToJsonl):
@staticmethod @staticmethod
def convert_markdown_maps_to_jsonl(entries: List[Entry]): def convert_markdown_maps_to_jsonl(entries: List[Entry]):
"Convert each Markdown entry to JSON and collate as JSONL" "Convert each Markdown entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries]) return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -18,9 +18,13 @@ logger = logging.getLogger(__name__)
class OrgToJsonl(TextToJsonl): class OrgToJsonl(TextToJsonl):
# Define Functions # Define Functions
def process(self, previous_entries: List[Entry]=None): def process(self, previous_entries: List[Entry] = None):
# Extract required fields from config # Extract required fields from config
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl org_files, org_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
index_heading_entries = self.config.index_heading_entries index_heading_entries = self.config.index_heading_entries
# Input Validation # Input Validation
@ -46,7 +50,9 @@ class OrgToJsonl(TextToJsonl):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
with timer("Write org entries to JSONL file", logger): with timer("Write org entries to JSONL file", logger):
@ -66,11 +72,7 @@ class OrgToJsonl(TextToJsonl):
"Get Org files to process" "Get Org files to process"
absolute_org_files, filtered_org_files = set(), set() absolute_org_files, filtered_org_files = set(), set()
if org_files: if org_files:
absolute_org_files = { absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
get_absolute_path(org_file)
for org_file
in org_files
}
if org_file_filters: if org_file_filters:
filtered_org_files = { filtered_org_files = {
filtered_file filtered_file
@ -84,7 +86,7 @@ class OrgToJsonl(TextToJsonl):
if any(files_with_non_org_extensions): if any(files_with_non_org_extensions):
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}") logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.info(f'Processing files: {all_org_files}') logger.info(f"Processing files: {all_org_files}")
return all_org_files return all_org_files
@ -95,13 +97,15 @@ class OrgToJsonl(TextToJsonl):
entry_to_file_map = [] entry_to_file_map = []
for org_file in org_files: for org_file in org_files:
org_file_entries = orgnode.makelist(str(org_file)) org_file_entries = orgnode.makelist(str(org_file))
entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries)) entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries) entries.extend(org_file_entries)
return entries, dict(entry_to_file_map) return entries, dict(entry_to_file_map)
@staticmethod @staticmethod
def convert_org_nodes_to_entries(parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> List[Entry]: def convert_org_nodes_to_entries(
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
) -> List[Entry]:
"Convert Org-Mode nodes into list of Entry objects" "Convert Org-Mode nodes into list of Entry objects"
entries: List[Entry] = [] entries: List[Entry] = []
for parsed_entry in parsed_entries: for parsed_entry in parsed_entries:
@ -109,13 +113,13 @@ class OrgToJsonl(TextToJsonl):
# Ignore title notes i.e notes with just headings and empty body # Ignore title notes i.e notes with just headings and empty body
continue continue
compiled = f'{parsed_entry.heading}.' compiled = f"{parsed_entry.heading}."
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Title: {parsed_entry.heading}") logger.debug(f"Title: {parsed_entry.heading}")
if parsed_entry.tags: if parsed_entry.tags:
tags_str = " ".join(parsed_entry.tags) tags_str = " ".join(parsed_entry.tags)
compiled += f'\t {tags_str}.' compiled += f"\t {tags_str}."
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Tags: {tags_str}") logger.debug(f"Tags: {tags_str}")
@ -130,19 +134,16 @@ class OrgToJsonl(TextToJsonl):
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}') logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
if parsed_entry.hasBody: if parsed_entry.hasBody:
compiled += f'\n {parsed_entry.body}' compiled += f"\n {parsed_entry.body}"
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Body: {parsed_entry.body}") logger.debug(f"Body: {parsed_entry.body}")
if compiled: if compiled:
entries += [Entry( entries += [Entry(compiled=compiled, raw=f"{parsed_entry}", file=f"{entry_to_file_map[parsed_entry]}")]
compiled=compiled,
raw=f'{parsed_entry}',
file=f'{entry_to_file_map[parsed_entry]}')]
return entries return entries
@staticmethod @staticmethod
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str: def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL" "Convert each Org-Mode entry to JSON and collate as JSONL"
return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries]) return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries])

View file

@ -39,182 +39,197 @@ from pathlib import Path
from os.path import relpath from os.path import relpath
from typing import List from typing import List
indent_regex = re.compile(r'^ *') indent_regex = re.compile(r"^ *")
def normalize_filename(filename): def normalize_filename(filename):
"Normalize and escape filename for rendering" "Normalize and escape filename for rendering"
if not Path(filename).is_absolute(): if not Path(filename).is_absolute():
# Normalize relative filename to be relative to current directory # Normalize relative filename to be relative to current directory
normalized_filename = f'~/{relpath(filename, start=Path.home())}' normalized_filename = f"~/{relpath(filename, start=Path.home())}"
else: else:
normalized_filename = filename normalized_filename = filename
escaped_filename = f'{normalized_filename}'.replace("[","\[").replace("]","\]") escaped_filename = f"{normalized_filename}".replace("[", "\[").replace("]", "\]")
return escaped_filename return escaped_filename
def makelist(filename): def makelist(filename):
""" """
Read an org-mode file and return a list of Orgnode objects Read an org-mode file and return a list of Orgnode objects
created from this file. created from this file.
""" """
ctr = 0 ctr = 0
f = open(filename, 'r') f = open(filename, "r")
todos = { "TODO": "", "WAITING": "", "ACTIVE": "", todos = {
"DONE": "", "CANCELLED": "", "FAILED": ""} # populated from #+SEQ_TODO line "TODO": "",
level = "" "WAITING": "",
heading = "" "ACTIVE": "",
bodytext = "" "DONE": "",
tags = list() # set of all tags in headline "CANCELLED": "",
closed_date = '' "FAILED": "",
sched_date = '' } # populated from #+SEQ_TODO line
deadline_date = '' level = ""
logbook = list() heading = ""
nodelist: List[Orgnode] = list() bodytext = ""
property_map = dict() tags = list() # set of all tags in headline
in_properties_drawer = False closed_date = ""
in_logbook_drawer = False sched_date = ""
file_title = f'{filename}' deadline_date = ""
logbook = list()
nodelist: List[Orgnode] = list()
property_map = dict()
in_properties_drawer = False
in_logbook_drawer = False
file_title = f"{filename}"
for line in f: for line in f:
ctr += 1 ctr += 1
heading_search = re.search(r'^(\*+)\s(.*?)\s*$', line) heading_search = re.search(r"^(\*+)\s(.*?)\s*$", line)
if heading_search: # we are processing a heading line if heading_search: # we are processing a heading line
if heading: # if we have are on second heading, append first heading to headings list if heading: # if we have are on second heading, append first heading to headings list
thisNode = Orgnode(level, heading, bodytext, tags) thisNode = Orgnode(level, heading, bodytext, tags)
if closed_date: if closed_date:
thisNode.closed = closed_date thisNode.closed = closed_date
closed_date = '' closed_date = ""
if sched_date: if sched_date:
thisNode.scheduled = sched_date thisNode.scheduled = sched_date
sched_date = "" sched_date = ""
if deadline_date: if deadline_date:
thisNode.deadline = deadline_date thisNode.deadline = deadline_date
deadline_date = '' deadline_date = ""
if logbook: if logbook:
thisNode.logbook = logbook thisNode.logbook = logbook
logbook = list() logbook = list()
thisNode.properties = property_map thisNode.properties = property_map
nodelist.append( thisNode ) nodelist.append(thisNode)
property_map = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'} property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
level = heading_search.group(1) level = heading_search.group(1)
heading = heading_search.group(2) heading = heading_search.group(2)
bodytext = "" bodytext = ""
tags = list() # set of all tags in headline tags = list() # set of all tags in headline
tag_search = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading) tag_search = re.search(r"(.*?)\s*:([a-zA-Z0-9].*?):$", heading)
if tag_search: if tag_search:
heading = tag_search.group(1) heading = tag_search.group(1)
parsedtags = tag_search.group(2) parsedtags = tag_search.group(2)
if parsedtags: if parsedtags:
for parsedtag in parsedtags.split(':'): for parsedtag in parsedtags.split(":"):
if parsedtag != '': tags.append(parsedtag) if parsedtag != "":
else: # we are processing a non-heading line tags.append(parsedtag)
if line[:10] == '#+SEQ_TODO': else: # we are processing a non-heading line
kwlist = re.findall(r'([A-Z]+)\(', line) if line[:10] == "#+SEQ_TODO":
for kw in kwlist: todos[kw] = "" kwlist = re.findall(r"([A-Z]+)\(", line)
for kw in kwlist:
todos[kw] = ""
# Set file title to TITLE property, if it exists # Set file title to TITLE property, if it exists
title_search = re.search(r'^#\+TITLE:\s*(.*)$', line) title_search = re.search(r"^#\+TITLE:\s*(.*)$", line)
if title_search and title_search.group(1).strip() != '': if title_search and title_search.group(1).strip() != "":
title_text = title_search.group(1).strip() title_text = title_search.group(1).strip()
if file_title == f'{filename}': if file_title == f"{filename}":
file_title = title_text file_title = title_text
else: else:
file_title += f' {title_text}' file_title += f" {title_text}"
continue continue
# Ignore Properties Drawers Completely # Ignore Properties Drawers Completely
if re.search(':PROPERTIES:', line): if re.search(":PROPERTIES:", line):
in_properties_drawer=True in_properties_drawer = True
continue continue
if in_properties_drawer and re.search(':END:', line): if in_properties_drawer and re.search(":END:", line):
in_properties_drawer=False in_properties_drawer = False
continue continue
# Ignore Logbook Drawer Start, End Lines # Ignore Logbook Drawer Start, End Lines
if re.search(':LOGBOOK:', line): if re.search(":LOGBOOK:", line):
in_logbook_drawer=True in_logbook_drawer = True
continue continue
if in_logbook_drawer and re.search(':END:', line): if in_logbook_drawer and re.search(":END:", line):
in_logbook_drawer=False in_logbook_drawer = False
continue continue
# Extract Clocking Lines # Extract Clocking Lines
clocked_re = re.search(r'CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]', line) clocked_re = re.search(
if clocked_re: r"CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]",
# convert clock in, clock out strings to datetime objects line,
clocked_in = datetime.datetime.strptime(clocked_re.group(1), '%Y-%m-%d %a %H:%M') )
clocked_out = datetime.datetime.strptime(clocked_re.group(2), '%Y-%m-%d %a %H:%M') if clocked_re:
# add clocked time to the entries logbook list # convert clock in, clock out strings to datetime objects
logbook += [(clocked_in, clocked_out)] clocked_in = datetime.datetime.strptime(clocked_re.group(1), "%Y-%m-%d %a %H:%M")
line = "" clocked_out = datetime.datetime.strptime(clocked_re.group(2), "%Y-%m-%d %a %H:%M")
# add clocked time to the entries logbook list
logbook += [(clocked_in, clocked_out)]
line = ""
property_search = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line) property_search = re.search(r"^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$", line)
if property_search: if property_search:
# Set ID property to an id based org-mode link to the entry # Set ID property to an id based org-mode link to the entry
if property_search.group(1) == 'ID': if property_search.group(1) == "ID":
property_map['ID'] = f'id:{property_search.group(2)}' property_map["ID"] = f"id:{property_search.group(2)}"
else: else:
property_map[property_search.group(1)] = property_search.group(2) property_map[property_search.group(1)] = property_search.group(2)
continue continue
cd_re = re.search(r'CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})', line) cd_re = re.search(r"CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})", line)
if cd_re: if cd_re:
closed_date = datetime.date(int(cd_re.group(1)), closed_date = datetime.date(int(cd_re.group(1)), int(cd_re.group(2)), int(cd_re.group(3)))
int(cd_re.group(2)), sd_re = re.search(r"SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)", line)
int(cd_re.group(3)) ) if sd_re:
sd_re = re.search(r'SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)', line) sched_date = datetime.date(int(sd_re.group(1)), int(sd_re.group(2)), int(sd_re.group(3)))
if sd_re: dd_re = re.search(r"DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)", line)
sched_date = datetime.date(int(sd_re.group(1)), if dd_re:
int(sd_re.group(2)), deadline_date = datetime.date(int(dd_re.group(1)), int(dd_re.group(2)), int(dd_re.group(3)))
int(sd_re.group(3)) )
dd_re = re.search(r'DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)', line)
if dd_re:
deadline_date = datetime.date(int(dd_re.group(1)),
int(dd_re.group(2)),
int(dd_re.group(3)) )
# Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body # Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body
if not in_properties_drawer and not cd_re and not sd_re and not dd_re and not clocked_re and line[:1] != '#': if (
bodytext = bodytext + line not in_properties_drawer
and not cd_re
and not sd_re
and not dd_re
and not clocked_re
and line[:1] != "#"
):
bodytext = bodytext + line
# write out last node # write out last node
thisNode = Orgnode(level, heading or file_title, bodytext, tags) thisNode = Orgnode(level, heading or file_title, bodytext, tags)
thisNode.properties = property_map thisNode.properties = property_map
if sched_date: if sched_date:
thisNode.scheduled = sched_date thisNode.scheduled = sched_date
if deadline_date: if deadline_date:
thisNode.deadline = deadline_date thisNode.deadline = deadline_date
if closed_date: if closed_date:
thisNode.closed = closed_date thisNode.closed = closed_date
if logbook: if logbook:
thisNode.logbook = logbook thisNode.logbook = logbook
nodelist.append( thisNode ) nodelist.append(thisNode)
# using the list of TODO keywords found in the file # using the list of TODO keywords found in the file
# process the headings searching for TODO keywords # process the headings searching for TODO keywords
for n in nodelist: for n in nodelist:
todo_search = re.search(r'([A-Z]+)\s(.*?)$', n.heading) todo_search = re.search(r"([A-Z]+)\s(.*?)$", n.heading)
if todo_search: if todo_search:
if todo_search.group(1) in todos: if todo_search.group(1) in todos:
n.heading = todo_search.group(2) n.heading = todo_search.group(2)
n.todo = todo_search.group(1) n.todo = todo_search.group(1)
# extract, set priority from heading, update heading if necessary # extract, set priority from heading, update heading if necessary
priority_search = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.heading) priority_search = re.search(r"^\[\#(A|B|C)\] (.*?)$", n.heading)
if priority_search: if priority_search:
n.priority = priority_search.group(1) n.priority = priority_search.group(1)
n.heading = priority_search.group(2) n.heading = priority_search.group(2)
# Set SOURCE property to a file+heading based org-mode link to the entry # Set SOURCE property to a file+heading based org-mode link to the entry
if n.level == 0: if n.level == 0:
n.properties['LINE'] = f'file:{normalize_filename(filename)}::0' n.properties["LINE"] = f"file:{normalize_filename(filename)}::0"
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}]]' n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}]]"
else: else:
escaped_heading = n.heading.replace("[","\\[").replace("]","\\]") escaped_heading = n.heading.replace("[", "\\[").replace("]", "\\]")
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]' n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}::*{escaped_heading}]]"
return nodelist
return nodelist
###################### ######################
class Orgnode(object): class Orgnode(object):
@ -222,6 +237,7 @@ class Orgnode(object):
Orgnode class represents a headline, tags and text associated Orgnode class represents a headline, tags and text associated
with the headline. with the headline.
""" """
def __init__(self, level, headline, body, tags): def __init__(self, level, headline, body, tags):
""" """
Create an Orgnode object given the parameters of level (as the Create an Orgnode object given the parameters of level (as the
@ -232,14 +248,14 @@ class Orgnode(object):
self._level = len(level) self._level = len(level)
self._heading = headline self._heading = headline
self._body = body self._body = body
self._tags = tags # All tags in the headline self._tags = tags # All tags in the headline
self._todo = "" self._todo = ""
self._priority = "" # empty of A, B or C self._priority = "" # empty of A, B or C
self._scheduled = "" # Scheduled date self._scheduled = "" # Scheduled date
self._deadline = "" # Deadline date self._deadline = "" # Deadline date
self._closed = "" # Closed date self._closed = "" # Closed date
self._properties = dict() self._properties = dict()
self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries
# Look for priority in headline and transfer to prty field # Look for priority in headline and transfer to prty field
@ -270,7 +286,7 @@ class Orgnode(object):
""" """
Returns True if node has non empty body, else False Returns True if node has non empty body, else False
""" """
return self._body and re.sub(r'\n|\t|\r| ', '', self._body) != '' return self._body and re.sub(r"\n|\t|\r| ", "", self._body) != ""
@property @property
def level(self): def level(self):
@ -417,20 +433,20 @@ class Orgnode(object):
text as used to construct the node. text as used to construct the node.
""" """
# Output heading line # Output heading line
n = '' n = ""
for _ in range(0, self._level): for _ in range(0, self._level):
n = n + '*' n = n + "*"
n = n + ' ' n = n + " "
if self._todo: if self._todo:
n = n + self._todo + ' ' n = n + self._todo + " "
if self._priority: if self._priority:
n = n + '[#' + self._priority + '] ' n = n + "[#" + self._priority + "] "
n = n + self._heading n = n + self._heading
n = "%-60s " % n # hack - tags will start in column 62 n = "%-60s " % n # hack - tags will start in column 62
closecolon = '' closecolon = ""
for t in self._tags: for t in self._tags:
n = n + ':' + t n = n + ":" + t
closecolon = ':' closecolon = ":"
n = n + closecolon n = n + closecolon
n = n + "\n" n = n + "\n"
@ -439,24 +455,24 @@ class Orgnode(object):
# Output Closed Date, Scheduled Date, Deadline Date # Output Closed Date, Scheduled Date, Deadline Date
if self._closed or self._scheduled or self._deadline: if self._closed or self._scheduled or self._deadline:
n = n + indent n = n + indent
if self._closed: if self._closed:
n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] ' n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] '
if self._scheduled: if self._scheduled:
n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> ' n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> '
if self._deadline: if self._deadline:
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> ' n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
if self._closed or self._scheduled or self._deadline: if self._closed or self._scheduled or self._deadline:
n = n + '\n' n = n + "\n"
# Ouput Property Drawer # Ouput Property Drawer
n = n + indent + ":PROPERTIES:\n" n = n + indent + ":PROPERTIES:\n"
for key, value in self._properties.items(): for key, value in self._properties.items():
n = n + indent + f":{key}: {value}\n" n = n + indent + f":{key}: {value}\n"
n = n + indent + ":END:\n" n = n + indent + ":END:\n"
# Output Body # Output Body
if self.hasBody: if self.hasBody:
n = n + self._body n = n + self._body
return n return n

View file

@ -17,14 +17,17 @@ class TextToJsonl(ABC):
self.config = config self.config = config
@abstractmethod @abstractmethod
def process(self, previous_entries: List[Entry]=None) -> List[Tuple[int, Entry]]: ... def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]:
...
@staticmethod @staticmethod
def hash_func(key: str) -> Callable: def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest() return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
@staticmethod @staticmethod
def split_entries_by_max_tokens(entries: List[Entry], max_tokens: int=256, max_word_length: int=500) -> List[Entry]: def split_entries_by_max_tokens(
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
) -> List[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model." "Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: List[Entry] = [] chunked_entries: List[Entry] = []
for entry in entries: for entry in entries:
@ -32,13 +35,15 @@ class TextToJsonl(ABC):
# Drop long words instead of having entry truncated to maintain quality of entry processed by models # Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length] compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
for chunk_index in range(0, len(compiled_entry_words), max_tokens): for chunk_index in range(0, len(compiled_entry_words), max_tokens):
compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens] compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
compiled_entry_chunk = ' '.join(compiled_entry_words_chunk) compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file) entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file)
chunked_entries.append(entry_chunk) chunked_entries.append(entry_chunk)
return chunked_entries return chunked_entries
def mark_entries_for_update(self, current_entries: List[Entry], previous_entries: List[Entry], key='compiled', logger=None) -> List[Tuple[int, Entry]]: def mark_entries_for_update(
self, current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None
) -> List[Tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries # Hash all current and previous entries to identify new entries
with timer("Hash previous, current entries", logger): with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries)) current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
@ -54,10 +59,7 @@ class TextToJsonl(ABC):
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with -1 id to flag for later embeddings generation # Mark new entries with -1 id to flag for later embeddings generation
new_entries = [ new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes]
(-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings # Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [ existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
@ -67,4 +69,4 @@ class TextToJsonl(ABC):
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries entries_with_ids = existing_entries_sorted + new_entries
return entries_with_ids return entries_with_ids

View file

@ -22,27 +22,30 @@ logger = logging.getLogger(__name__)
# Create Routes # Create Routes
@api.get('/config/data/default') @api.get("/config/data/default")
def get_default_config_data(): def get_default_config_data():
return constants.default_config return constants.default_config
@api.get('/config/data', response_model=FullConfig)
@api.get("/config/data", response_model=FullConfig)
def get_config_data(): def get_config_data():
return state.config return state.config
@api.post('/config/data')
@api.post("/config/data")
async def set_config_data(updated_config: FullConfig): async def set_config_data(updated_config: FullConfig):
state.config = updated_config state.config = updated_config
with open(state.config_file, 'w') as outfile: with open(state.config_file, "w") as outfile:
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile) yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close() outfile.close()
return state.config return state.config
@api.get('/search', response_model=List[SearchResponse])
@api.get("/search", response_model=List[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
results: List[SearchResponse] = [] results: List[SearchResponse] = []
if q is None or q == '': if q is None or q == "":
logger.info(f'No query param (q) passed in API call to initiate search') logger.info(f"No query param (q) passed in API call to initiate search")
return results return results
# initialize variables # initialize variables
@ -50,9 +53,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
results_count = n results_count = n
# return cached results, if available # return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}' query_cache_key = f"{user_query}-{n}-{t}-{r}"
if query_cache_key in state.query_cache: if query_cache_key in state.query_cache:
logger.info(f'Return response from query cache') logger.info(f"Return response from query cache")
return state.query_cache[query_cache_key] return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.orgmode_search: if (t == SearchType.Org or t == None) and state.model.orgmode_search:
@ -95,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# query images # query images
with timer("Query took", logger): with timer("Query took", logger):
hits = image_search.query(user_query, results_count, state.model.image_search) hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images' output_directory = constants.web_directory / "images"
# collate and return results # collate and return results
with timer("Collating results took", logger): with timer("Collating results took", logger):
@ -103,8 +106,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
hits, hits,
image_names=state.model.image_search.image_names, image_names=state.model.image_search.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url='/static/images', image_files_url="/static/images",
count=results_count) count=results_count,
)
# Cache results # Cache results
state.query_cache[query_cache_key] = results state.query_cache[query_cache_key] = results
@ -112,7 +116,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
return results return results
@api.get('/update') @api.get("/update")
def update(t: Optional[SearchType] = None, force: Optional[bool] = False): def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
try: try:
state.search_index_lock.acquire() state.search_index_lock.acquire()
@ -132,4 +136,4 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
else: else:
logger.info("Processor reconfigured via API call") logger.info("Processor reconfigured via API call")
return {'status': 'ok', 'message': 'khoj reloaded'} return {"status": "ok", "message": "khoj reloaded"}

View file

@ -9,7 +9,14 @@ from fastapi import APIRouter
# Internal Packages # Internal Packages
from khoj.routers.api import search from khoj.routers.api import search
from khoj.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from khoj.processor.conversation.gpt import (
converse,
extract_search_type,
message_to_log,
message_to_prompt,
understand,
summarize,
)
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
from khoj.utils.helpers import get_from_dict, resolve_absolute_path from khoj.utils.helpers import get_from_dict, resolve_absolute_path
from khoj.utils import state from khoj.utils import state
@ -21,7 +28,7 @@ logger = logging.getLogger(__name__)
# Create Routes # Create Routes
@api_beta.get('/search') @api_beta.get("/search")
def search_beta(q: str, n: Optional[int] = 1): def search_beta(q: str, n: Optional[int] = 1):
# Initialize Variables # Initialize Variables
model = state.processor_config.conversation.model model = state.processor_config.conversation.model
@ -32,16 +39,16 @@ def search_beta(q: str, n: Optional[int] = 1):
metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose) metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type") search_type = get_from_dict(metadata, "search-type")
except Exception as e: except Exception as e:
return {'status': 'error', 'result': [str(e)], 'type': None} return {"status": "error", "result": [str(e)], "type": None}
# Search # Search
search_results = search(q, n=n, t=SearchType(search_type)) search_results = search(q, n=n, t=SearchType(search_type))
# Return response # Return response
return {'status': 'ok', 'result': search_results, 'type': search_type} return {"status": "ok", "result": search_results, "type": search_type}
@api_beta.get('/summarize') @api_beta.get("/summarize")
def summarize_beta(q: str): def summarize_beta(q: str):
# Initialize Variables # Initialize Variables
model = state.processor_config.conversation.model model = state.processor_config.conversation.model
@ -54,23 +61,25 @@ def summarize_beta(q: str):
# Converse with OpenAI GPT # Converse with OpenAI GPT
result_list = search(q, n=1, r=True) result_list = search(q, n=1, r=True)
collated_result = "\n".join([item.entry for item in result_list]) collated_result = "\n".join([item.entry for item in result_list])
logger.debug(f'Semantically Similar Notes:\n{collated_result}') logger.debug(f"Semantically Similar Notes:\n{collated_result}")
try: try:
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key) gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
status = 'ok' status = "ok"
except Exception as e: except Exception as e:
gpt_response = str(e) gpt_response = str(e)
status = 'error' status = "error"
# Update Conversation History # Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, conversation_log=meta_log.get('chat', [])) state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, conversation_log=meta_log.get("chat", [])
)
return {'status': status, 'response': gpt_response} return {"status": status, "response": gpt_response}
@api_beta.get('/chat') @api_beta.get("/chat")
def chat(q: Optional[str]=None): def chat(q: Optional[str] = None):
# Initialize Variables # Initialize Variables
model = state.processor_config.conversation.model model = state.processor_config.conversation.model
api_key = state.processor_config.conversation.openai_api_key api_key = state.processor_config.conversation.openai_api_key
@ -81,10 +90,10 @@ def chat(q: Optional[str]=None):
# If user query is empty, return chat history # If user query is empty, return chat history
if not q: if not q:
if meta_log.get('chat'): if meta_log.get("chat"):
return {'status': 'ok', 'response': meta_log["chat"]} return {"status": "ok", "response": meta_log["chat"]}
else: else:
return {'status': 'ok', 'response': []} return {"status": "ok", "response": []}
# Converse with OpenAI GPT # Converse with OpenAI GPT
metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose) metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose)
@ -94,32 +103,39 @@ def chat(q: Optional[str]=None):
query = get_from_dict(metadata, "intent", "query") query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org, r=True) result_list = search(query, n=1, t=SearchType.Org, r=True)
collated_result = "\n".join([item.entry for item in result_list]) collated_result = "\n".join([item.entry for item in result_list])
logger.debug(f'Semantically Similar Notes:\n{collated_result}') logger.debug(f"Semantically Similar Notes:\n{collated_result}")
try: try:
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key) gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
status = 'ok' status = "ok"
except Exception as e: except Exception as e:
gpt_response = str(e) gpt_response = str(e)
status = 'error' status = "error"
else: else:
try: try:
gpt_response = converse(q, model, chat_session, api_key=api_key) gpt_response = converse(q, model, chat_session, api_key=api_key)
status = 'ok' status = "ok"
except Exception as e: except Exception as e:
gpt_response = str(e) gpt_response = str(e)
status = 'error' status = "error"
# Update Conversation History # Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, metadata, meta_log.get('chat', [])) state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, metadata, meta_log.get("chat", [])
)
return {'status': status, 'response': gpt_response} return {"status": status, "response": gpt_response}
@schedule.repeat(schedule.every(5).minutes) @schedule.repeat(schedule.every(5).minutes)
def save_chat_session(): def save_chat_session():
# No need to create empty log file # No need to create empty log file
if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log and state.processor_config.conversation.chat_session): if not (
state.processor_config
and state.processor_config.conversation
and state.processor_config.conversation.meta_log
and state.processor_config.conversation.chat_session
):
return return
# Summarize Conversation Logs for this Session # Summarize Conversation Logs for this Session
@ -130,19 +146,19 @@ def save_chat_session():
session = { session = {
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key), "summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"]) "session-end": len(conversation_log["chat"]),
} }
if 'session' in conversation_log: if "session" in conversation_log:
conversation_log['session'].append(session) conversation_log["session"].append(session)
else: else:
conversation_log['session'] = [session] conversation_log["session"] = [session]
logger.info('Added new chat session to conversation logs') logger.info("Added new chat session to conversation logs")
# Save Conversation Metadata Logs to Disk # Save Conversation Metadata Logs to Disk
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile) conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
with open(conversation_logfile, "w+", encoding='utf-8') as logfile: with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile) json.dump(conversation_log, logfile)
state.processor_config.conversation.chat_session = None state.processor_config.conversation.chat_session = None
logger.info('Saved updated conversation logs to disk.') logger.info("Saved updated conversation logs to disk.")

View file

@ -18,10 +18,12 @@ templates = Jinja2Templates(directory=constants.web_directory)
def index(): def index():
return FileResponse(constants.web_directory / "index.html") return FileResponse(constants.web_directory / "index.html")
@web_client.get('/config', response_class=HTMLResponse)
@web_client.get("/config", response_class=HTMLResponse)
def config_page(request: Request): def config_page(request: Request):
return templates.TemplateResponse("config.html", context={'request': request}) return templates.TemplateResponse("config.html", context={"request": request})
@web_client.get("/chat", response_class=FileResponse) @web_client.get("/chat", response_class=FileResponse)
def chat_page(): def chat_page():
return FileResponse(constants.web_directory / "chat.html") return FileResponse(constants.web_directory / "chat.html")

View file

@ -8,10 +8,13 @@ from khoj.utils.rawconfig import Entry
class BaseFilter(ABC): class BaseFilter(ABC):
@abstractmethod @abstractmethod
def load(self, entries: List[Entry], *args, **kwargs): ... def load(self, entries: List[Entry], *args, **kwargs):
...
@abstractmethod @abstractmethod
def can_filter(self, raw_query:str) -> bool: ... def can_filter(self, raw_query: str) -> bool:
...
@abstractmethod @abstractmethod
def apply(self, query:str, entries: List[Entry]) -> Tuple[str, Set[int]]: ... def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
...

View file

@ -26,21 +26,19 @@ class DateFilter(BaseFilter):
# - dt:"2 years ago" # - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\"" date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def __init__(self, entry_key="raw"):
def __init__(self, entry_key='raw'):
self.entry_key = entry_key self.entry_key = entry_key
self.date_to_entry_ids = defaultdict(set) self.date_to_entry_ids = defaultdict(set)
self.cache = LRU() self.cache = LRU()
def load(self, entries, *args, **kwargs): def load(self, entries, *args, **kwargs):
with timer("Created date filter index", logger): with timer("Created date filter index", logger):
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
# Extract dates from entry # Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)): for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp # Convert date string in entry to unix timestamp
try: try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp()
except ValueError: except ValueError:
continue continue
self.date_to_entry_ids[date_in_entry].add(id) self.date_to_entry_ids[date_in_entry].add(id)
@ -49,7 +47,6 @@ class DateFilter(BaseFilter):
"Check if query contains date filters" "Check if query contains date filters"
return self.extract_date_range(raw_query) is not None return self.extract_date_range(raw_query) is not None
def apply(self, query, entries): def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
@ -61,8 +58,8 @@ class DateFilter(BaseFilter):
return query, set(range(len(entries))) return query, set(range(len(entries)))
# remove date range filter from query # remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
# return results from cache if exists # return results from cache if exists
cache_key = tuple(query_daterange) cache_key = tuple(query_daterange)
@ -87,7 +84,6 @@ class DateFilter(BaseFilter):
return query, entries_to_include return query, entries_to_include
def extract_date_range(self, query): def extract_date_range(self, query):
# find date range filter in query # find date range filter in query
date_range_matches = re.findall(self.date_regex, query) date_range_matches = re.findall(self.date_regex, query)
@ -98,7 +94,7 @@ class DateFilter(BaseFilter):
# extract, parse natural dates ranges from date range filter passed in query # extract, parse natural dates ranges from date range filter passed in query
# e.g today maps to (start_of_day, start_of_tomorrow) # e.g today maps to (start_of_day, start_of_tomorrow)
date_ranges_from_filter = [] date_ranges_from_filter = []
for (cmp, date_str) in date_range_matches: for cmp, date_str in date_range_matches:
if self.parse(date_str): if self.parse(date_str):
dt_start, dt_end = self.parse(date_str) dt_start, dt_end = self.parse(date_str)
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]] date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
@ -111,15 +107,15 @@ class DateFilter(BaseFilter):
effective_date_range = [0, inf] effective_date_range = [0, inf]
date_range_considering_comparator = [] date_range_considering_comparator = []
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter: for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
if cmp == '>': if cmp == ">":
date_range_considering_comparator += [[dtrange_end, inf]] date_range_considering_comparator += [[dtrange_end, inf]]
elif cmp == '>=': elif cmp == ">=":
date_range_considering_comparator += [[dtrange_start, inf]] date_range_considering_comparator += [[dtrange_start, inf]]
elif cmp == '<': elif cmp == "<":
date_range_considering_comparator += [[0, dtrange_start]] date_range_considering_comparator += [[0, dtrange_start]]
elif cmp == '<=': elif cmp == "<=":
date_range_considering_comparator += [[0, dtrange_end]] date_range_considering_comparator += [[0, dtrange_end]]
elif cmp == '=' or cmp == ':' or cmp == '==': elif cmp == "=" or cmp == ":" or cmp == "==":
date_range_considering_comparator += [[dtrange_start, dtrange_end]] date_range_considering_comparator += [[dtrange_start, dtrange_end]]
# Combine above intervals (via AND/intersect) # Combine above intervals (via AND/intersect)
@ -129,48 +125,48 @@ class DateFilter(BaseFilter):
for date_range in date_range_considering_comparator: for date_range in date_range_considering_comparator:
effective_date_range = [ effective_date_range = [
max(effective_date_range[0], date_range[0]), max(effective_date_range[0], date_range[0]),
min(effective_date_range[1], date_range[1])] min(effective_date_range[1], date_range[1]),
]
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]: if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return None return None
else: else:
return effective_date_range return effective_date_range
def parse(self, date_str, relative_base=None): def parse(self, date_str, relative_base=None):
"Parse date string passed in date filter of query to datetime object" "Parse date string passed in date filter of query to datetime object"
# clean date string to handle future date parsing by date parser # clean date string to handle future date parsing by date parser
future_strings = ['later', 'from now', 'from today'] future_strings = ["later", "from now", "from today"]
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])] prefer_dates_from = {True: "future", False: "past"}[any([True for fstr in future_strings if fstr in date_str])]
clean_date_str = re.sub('|'.join(future_strings), '', date_str) clean_date_str = re.sub("|".join(future_strings), "", date_str)
# parse date passed in query date filter # parse date passed in query date filter
parsed_date = dtparse.parse( parsed_date = dtparse.parse(
clean_date_str, clean_date_str,
settings= { settings={
'RELATIVE_BASE': relative_base or datetime.now(), "RELATIVE_BASE": relative_base or datetime.now(),
'PREFER_DAY_OF_MONTH': 'first', "PREFER_DAY_OF_MONTH": "first",
'PREFER_DATES_FROM': prefer_dates_from "PREFER_DATES_FROM": prefer_dates_from,
}) },
)
if parsed_date is None: if parsed_date is None:
return None return None
return self.date_to_daterange(parsed_date, date_str) return self.date_to_daterange(parsed_date, date_str)
def date_to_daterange(self, parsed_date, date_str): def date_to_daterange(self, parsed_date, date_str):
"Convert parsed date to date ranges at natural granularity (day, week, month or year)" "Convert parsed date to date ranges at natural granularity (day, week, month or year)"
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0) start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
if 'year' in date_str: if "year" in date_str:
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0)) return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year + 1, 1, 1, 0, 0, 0))
if 'month' in date_str: if "month" in date_str:
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0) start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
next_month = start_of_month + relativedelta(months=1) next_month = start_of_month + relativedelta(months=1)
return (start_of_month, next_month) return (start_of_month, next_month)
if 'week' in date_str: if "week" in date_str:
# if week in date string, dateparser parses it to next week start # if week in date string, dateparser parses it to next week start
# so today = end of this week # so today = end of this week
start_of_week = start_of_day - timedelta(days=7) start_of_week = start_of_day - timedelta(days=7)

View file

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
class FileFilter(BaseFilter): class FileFilter(BaseFilter):
file_filter_regex = r'file:"(.+?)" ?' file_filter_regex = r'file:"(.+?)" ?'
def __init__(self, entry_key='file'): def __init__(self, entry_key="file"):
self.entry_key = entry_key self.entry_key = entry_key
self.file_to_entry_map = defaultdict(set) self.file_to_entry_map = defaultdict(set)
self.cache = LRU() self.cache = LRU()
@ -40,13 +40,13 @@ class FileFilter(BaseFilter):
# e.g. "file:notes.org" -> "file:.*notes.org" # e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = [] files_to_search = []
for file in sorted(raw_files_to_search): for file in sorted(raw_files_to_search):
if '/' not in file and '\\' not in file and '*' not in file: if "/" not in file and "\\" not in file and "*" not in file:
files_to_search += [f'*{file}'] files_to_search += [f"*{file}"]
else: else:
files_to_search += [file] files_to_search += [file]
# Return item from cache if exists # Return item from cache if exists
query = re.sub(self.file_filter_regex, '', query).strip() query = re.sub(self.file_filter_regex, "", query).strip()
cache_key = tuple(files_to_search) cache_key = tuple(files_to_search)
if cache_key in self.cache: if cache_key in self.cache:
logger.info(f"Return file filter results from cache") logger.info(f"Return file filter results from cache")
@ -58,10 +58,15 @@ class FileFilter(BaseFilter):
# Mark entries that contain any blocked_words for exclusion # Mark entries that contain any blocked_words for exclusion
with timer("Mark entries satisfying filter", logger): with timer("Mark entries satisfying filter", logger):
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] included_entry_indices = set.union(
*[
self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys() for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)], set()) if fnmatch.fnmatch(entry_file, search_file)
],
set(),
)
if not included_entry_indices: if not included_entry_indices:
return query, {} return query, {}

View file

@ -17,26 +17,26 @@ class WordFilter(BaseFilter):
required_regex = r'\+"([a-zA-Z0-9_-]+)" ?' required_regex = r'\+"([a-zA-Z0-9_-]+)" ?'
blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?' blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?'
def __init__(self, entry_key='raw'): def __init__(self, entry_key="raw"):
self.entry_key = entry_key self.entry_key = entry_key
self.word_to_entry_index = defaultdict(set) self.word_to_entry_index = defaultdict(set)
self.cache = LRU() self.cache = LRU()
def load(self, entries, *args, **kwargs): def load(self, entries, *args, **kwargs):
with timer("Created word filter index", logger): with timer("Created word filter index", logger):
self.cache = {} # Clear cache on filter (re-)load self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' entry_splitter = (
r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
)
# Create map of words to entries they exist in # Create map of words to entries they exist in
for entry_index, entry in enumerate(entries): for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()): for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '': if word == "":
continue continue
self.word_to_entry_index[word].add(entry_index) self.word_to_entry_index[word].add(entry_index)
return self.word_to_entry_index return self.word_to_entry_index
def can_filter(self, raw_query): def can_filter(self, raw_query):
"Check if query contains word filters" "Check if query contains word filters"
required_words = re.findall(self.required_regex, raw_query) required_words = re.findall(self.required_regex, raw_query)
@ -44,14 +44,13 @@ class WordFilter(BaseFilter):
return len(required_words) != 0 or len(blocked_words) != 0 return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, query, entries): def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query" "Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters # Separate natural query from required, blocked words filters
with timer("Extract required, blocked filters from query", logger): with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip() query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
if len(required_words) == 0 and len(blocked_words) == 0: if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries))) return query, set(range(len(entries)))
@ -70,12 +69,16 @@ class WordFilter(BaseFilter):
with timer("Mark entries satisfying filter", logger): with timer("Mark entries satisfying filter", logger):
entries_with_all_required_words = set(range(len(entries))) entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0: if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) entries_with_all_required_words = set.intersection(
*[self.word_to_entry_index.get(word, set()) for word in required_words]
)
# mark entries that contain any blocked_words for exclusion # mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set() entries_with_any_blocked_words = set()
if len(blocked_words) > 0: if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) entries_with_any_blocked_words = set.union(
*[self.word_to_entry_index.get(word, set()) for word in blocked_words]
)
# get entries satisfying inclusion and exclusion filters # get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words

View file

@ -35,9 +35,10 @@ def initialize_model(search_config: ImageSearchConfig):
# Load the CLIP model # Load the CLIP model
encoder = load_model( encoder = load_model(
model_dir = search_config.model_directory, model_dir=search_config.model_directory,
model_name = search_config.encoder, model_name=search_config.encoder,
model_type = search_config.encoder_type or SentenceTransformer) model_type=search_config.encoder_type or SentenceTransformer,
)
return encoder return encoder
@ -46,12 +47,12 @@ def extract_entries(image_directories):
image_names = [] image_names = []
for image_directory in image_directories: for image_directory in image_directories:
image_directory = resolve_absolute_path(image_directory, strict=True) image_directory = resolve_absolute_path(image_directory, strict=True)
image_names.extend(list(image_directory.glob('*.jpg'))) image_names.extend(list(image_directory.glob("*.jpg")))
image_names.extend(list(image_directory.glob('*.jpeg'))) image_names.extend(list(image_directory.glob("*.jpeg")))
if logger.level >= logging.INFO: if logger.level >= logging.INFO:
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories]) image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories])
logger.info(f'Found {len(image_names)} images in {image_directory_names}') logger.info(f"Found {len(image_names)} images in {image_directory_names}")
return sorted(image_names) return sorted(image_names)
@ -59,7 +60,9 @@ def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate) image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate) image_metadata_embeddings = compute_metadata_embeddings(
image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate
)
return image_embeddings, image_metadata_embeddings return image_embeddings, image_metadata_embeddings
@ -74,15 +77,12 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
image_embeddings = [] image_embeddings = []
for index in trange(0, len(image_names), batch_size): for index in trange(0, len(image_names), batch_size):
images = [] images = []
for image_name in image_names[index:index+batch_size]: for image_name in image_names[index : index + batch_size]:
image = Image.open(image_name) image = Image.open(image_name)
# Resize images to max width of 640px for faster processing # Resize images to max width of 640px for faster processing
image.thumbnail((640, image.height)) image.thumbnail((640, image.height))
images += [image] images += [image]
image_embeddings += encoder.encode( image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size))
images,
convert_to_tensor=True,
batch_size=min(len(images), batch_size))
# Create directory for embeddings file, if it doesn't exist # Create directory for embeddings file, if it doesn't exist
embeddings_file.parent.mkdir(parents=True, exist_ok=True) embeddings_file.parent.mkdir(parents=True, exist_ok=True)
@ -94,7 +94,9 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
return image_embeddings return image_embeddings
def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0): def compute_metadata_embeddings(
image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0
):
image_metadata_embeddings = None image_metadata_embeddings = None
# Load pre-computed image metadata embedding file if exists # Load pre-computed image metadata embedding file if exists
@ -106,14 +108,17 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
if use_xmp_metadata and image_metadata_embeddings is None: if use_xmp_metadata and image_metadata_embeddings is None:
image_metadata_embeddings = [] image_metadata_embeddings = []
for index in trange(0, len(image_names), batch_size): for index in trange(0, len(image_names), batch_size):
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]] image_metadata = [
extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size]
]
try: try:
image_metadata_embeddings += encoder.encode( image_metadata_embeddings += encoder.encode(
image_metadata, image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size)
convert_to_tensor=True, )
batch_size=min(len(image_metadata), batch_size))
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}") logger.error(
f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}"
)
continue continue
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata") torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata") logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
@ -123,8 +128,10 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
def extract_metadata(image_name): def extract_metadata(image_name):
image_xmp_metadata = Image.open(image_name).getxmp() image_xmp_metadata = Image.open(image_name).getxmp()
image_description = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'description', 'Alt', 'li', 'text') image_description = get_from_dict(
image_subjects = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'subject', 'Bag', 'li') image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text"
)
image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li")
image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject]) image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject])
image_processed_metadata = image_description image_processed_metadata = image_description
@ -141,7 +148,7 @@ def query(raw_query, count, model: ImageSearchModel):
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
query = copy.deepcopy(Image.open(query_imagepath)) query = copy.deepcopy(Image.open(query_imagepath))
query.thumbnail((640, query.height)) # scale down image for faster processing query.thumbnail((640, query.height)) # scale down image for faster processing
logger.info(f"Find Images by Image: {query_imagepath}") logger.info(f"Find Images by Image: {query_imagepath}")
else: else:
# Truncate words in query to stay below max_tokens supported by ML model # Truncate words in query to stay below max_tokens supported by ML model
@ -155,36 +162,42 @@ def query(raw_query, count, model: ImageSearchModel):
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger): with timer("Search Time", logger):
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']} image_hits = {
for result result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]} for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings: if model.image_metadata_embeddings:
with timer("Metadata Search Time", logger): with timer("Metadata Search Time", logger):
metadata_hits = {result['corpus_id']: result['score'] metadata_hits = {
for result result["corpus_id"]: result["score"]
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]} for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
}
# Sum metadata, image scores of the highest ranked images # Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items(): for corpus_id, score in metadata_hits.items():
scaling_factor = 0.33 scaling_factor = 0.33
if 'corpus_id' in image_hits: if "corpus_id" in image_hits:
image_hits[corpus_id].update({ image_hits[corpus_id].update(
'metadata_score': score, {
'score': image_hits[corpus_id].get('score', 0) + scaling_factor*score, "metadata_score": score,
}) "score": image_hits[corpus_id].get("score", 0) + scaling_factor * score,
}
)
else: else:
image_hits[corpus_id] = {'metadata_score': score, 'score': scaling_factor*score} image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score}
# Reformat results in original form from sentence transformer semantic_search() # Reformat results in original form from sentence transformer semantic_search()
hits = [ hits = [
{ {
'corpus_id': corpus_id, "corpus_id": corpus_id,
'score': scores['score'], "score": scores["score"],
'image_score': scores.get('image_score', 0), "image_score": scores.get("image_score", 0),
'metadata_score': scores.get('metadata_score', 0), "metadata_score": scores.get("metadata_score", 0),
} for corpus_id, scores in image_hits.items()] }
for corpus_id, scores in image_hits.items()
]
# Sort the images based on their combined metadata, image scores # Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True) return sorted(hits, key=lambda hit: hit["score"], reverse=True)
@ -194,7 +207,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
results: List[SearchResponse] = [] results: List[SearchResponse] = []
for index, hit in enumerate(hits[:count]): for index, hit in enumerate(hits[:count]):
source_path = image_names[hit['corpus_id']] source_path = image_names[hit["corpus_id"]]
target_image_name = f"{index}{source_path.suffix}" target_image_name = f"{index}{source_path.suffix}"
target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}") target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}")
@ -207,17 +220,18 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
shutil.copy(source_path, target_path) shutil.copy(source_path, target_path)
# Add the image metadata to the results # Add the image metadata to the results
results += [SearchResponse.parse_obj( results += [
{ SearchResponse.parse_obj(
"entry": f'{image_files_url}/{target_image_name}',
"score": f"{hit['score']:.9f}",
"additional":
{ {
"image_score": f"{hit['image_score']:.9f}", "entry": f"{image_files_url}/{target_image_name}",
"metadata_score": f"{hit['metadata_score']:.9f}", "score": f"{hit['score']:.9f}",
"additional": {
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
},
} }
} )
)] ]
return results return results
@ -248,9 +262,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
embeddings_file, embeddings_file,
batch_size=config.batch_size, batch_size=config.batch_size,
regenerate=regenerate, regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata) use_xmp_metadata=config.use_xmp_metadata,
)
return ImageSearchModel(all_image_files, return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
image_embeddings,
image_metadata_embeddings,
encoder)

View file

@ -38,17 +38,19 @@ def initialize_model(search_config: TextSearchConfig):
# The bi-encoder encodes all entries to use for semantic search # The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model( bi_encoder = load_model(
model_dir = search_config.model_directory, model_dir=search_config.model_directory,
model_name = search_config.encoder, model_name=search_config.encoder,
model_type = search_config.encoder_type or SentenceTransformer, model_type=search_config.encoder_type or SentenceTransformer,
device=f'{state.device}') device=f"{state.device}",
)
# The cross-encoder re-ranks the results to improve quality # The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model( cross_encoder = load_model(
model_dir = search_config.model_directory, model_dir=search_config.model_directory,
model_name = search_config.cross_encoder, model_name=search_config.cross_encoder,
model_type = CrossEncoder, model_type=CrossEncoder,
device=f'{state.device}') device=f"{state.device}",
)
return bi_encoder, cross_encoder, top_k return bi_encoder, cross_encoder, top_k
@ -58,7 +60,9 @@ def extract_entries(jsonl_file) -> List[Entry]:
return list(map(Entry.from_dict, load_jsonl(jsonl_file))) return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False): def compute_embeddings(
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = [] new_entries = []
# Load pre-computed embeddings from file if exists and update them if required # Load pre-computed embeddings from file if exists and update them if required
@ -69,17 +73,23 @@ def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: Ba
# Encode any new entries in the corpus and update corpus embeddings # Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1] new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries: if new_entries:
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) new_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
if existing_entry_ids: if existing_entry_ids:
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) existing_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
)
else: else:
existing_embeddings = torch.tensor([], device=state.device) existing_embeddings = torch.tensor([], device=state.device)
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch # Else compute the corpus embeddings from scratch
else: else:
new_entries = [entry.compiled for _, entry in entries_with_ids] new_entries = [entry.compiled for _, entry in entries_with_ids]
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) corpus_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
# Save regenerated or updated embeddings to file # Save regenerated or updated embeddings to file
if new_entries: if new_entries:
@ -112,7 +122,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
# Find relevant entries for the query # Find relevant entries for the query
with timer("Search Time", logger, state.device): with timer("Search Time", logger, state.device):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] hits = util.semantic_search(
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
)[0]
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results: if rank_results:
@ -128,26 +140,33 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]: def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
return [SearchResponse.parse_obj( return [
{ SearchResponse.parse_obj(
"entry": entries[hit['corpus_id']].raw, {
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}", "entry": entries[hit["corpus_id"]].raw,
"additional": { "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
"file": entries[hit['corpus_id']].file, "additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
"compiled": entries[hit['corpus_id']].compiled
} }
}) )
for hit for hit in hits[0:count]
in hits[0:count]] ]
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = []) -> TextSearchModel: def setup(
text_to_jsonl: Type[TextToJsonl],
config: TextContentConfig,
search_config: TextSearchConfig,
regenerate: bool,
filters: List[BaseFilter] = [],
) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config) bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file # Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None previous_entries = (
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
)
entries_with_indices = text_to_jsonl(config).process(previous_entries) entries_with_indices = text_to_jsonl(config).process(previous_entries)
# Extract Updated Entries # Extract Updated Entries
@ -158,7 +177,9 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
# Compute or Load Embeddings # Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file) config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate) corpus_embeddings = compute_embeddings(
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
)
for filter in filters: for filter in filters:
filter.load(entries, regenerate=regenerate) filter.load(entries, regenerate=regenerate)
@ -166,8 +187,10 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]) -> Tuple[str, List[Entry], torch.Tensor]: def apply_filters(
'''Filter query, entries and embeddings before semantic search''' query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
) -> Tuple[str, List[Entry], torch.Tensor]:
"""Filter query, entries and embeddings before semantic search"""
with timer("Total Filter Time", logger, state.device): with timer("Total Filter Time", logger, state.device):
included_entry_indices = set(range(len(entries))) included_entry_indices = set(range(len(entries)))
@ -178,45 +201,50 @@ def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Ten
# Get entries (and associated embeddings) satisfying all filters # Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices: if not included_entry_indices:
return '', [], torch.tensor([], device=state.device) return "", [], torch.tensor([], device=state.device)
else: else:
entries = [entries[id] for id in included_entry_indices] entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)) corpus_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
)
return query, entries, corpus_embeddings return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]: def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
'''Score all retrieved entries using the cross-encoder''' """Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device): with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp) cross_scores = cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking # Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx] hits[idx]["cross-score"] = cross_scores[idx]
return hits return hits
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]: def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
'''Order results by cross-encoder score followed by bi-encoder score''' """Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device): with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
if rank_results: if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score
return hits return hits
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]: def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
'''Deduplicate entries by raw entry text before showing to users """Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models. Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user.''' This can result in duplicate hits, entries shown to user."""
with timer("Deduplication Time", logger, state.device): with timer("Deduplication Time", logger, state.device):
seen, original_hits_count = set(), len(hits) seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits hits = [
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] # type: ignore[func-returns-value] hit
for hit in hits
if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
]
duplicate_hits = original_hits_count - len(hits) duplicate_hits = original_hits_count - len(hits)
logger.debug(f"Removed {duplicate_hits} duplicates") logger.debug(f"Removed {duplicate_hits} duplicates")

View file

@ -10,21 +10,36 @@ from khoj.utils.yaml import parse_config_from_file
def cli(args=None): def cli(args=None):
# Setup Argument Parser for the Commandline Interface # Setup Argument Parser for the Commandline Interface
parser = argparse.ArgumentParser(description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos") parser = argparse.ArgumentParser(
parser.add_argument('--config-file', '-c', default='~/.khoj/khoj.yml', type=pathlib.Path, help="YAML file to configure Khoj") description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos"
parser.add_argument('--no-gui', action='store_true', default=False, help="Do not show native desktop GUI. Default: false") )
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false") parser.add_argument(
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0") "--config-file", "-c", default="~/.khoj/khoj.yml", type=pathlib.Path, help="YAML file to configure Khoj"
parser.add_argument('--host', type=str, default='127.0.0.1', help="Host address of the server. Default: 127.0.0.1") )
parser.add_argument('--port', '-p', type=int, default=8000, help="Port of the server. Default: 8000") parser.add_argument(
parser.add_argument('--socket', type=pathlib.Path, help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock") "--no-gui", action="store_true", default=False, help="Do not show native desktop GUI. Default: false"
parser.add_argument('--version', '-V', action='store_true', help="Print the installed Khoj version and exit") )
parser.add_argument(
"--regenerate",
action="store_true",
default=False,
help="Regenerate model embeddings from source files. Default: false",
)
parser.add_argument("--verbose", "-v", action="count", default=0, help="Show verbose conversion logs. Default: 0")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host address of the server. Default: 127.0.0.1")
parser.add_argument("--port", "-p", type=int, default=8000, help="Port of the server. Default: 8000")
parser.add_argument(
"--socket",
type=pathlib.Path,
help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock",
)
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
args = parser.parse_args(args) args = parser.parse_args(args)
if args.version: if args.version:
# Show version of khoj installed and exit # Show version of khoj installed and exit
print(version('khoj-assistant')) print(version("khoj-assistant"))
exit(0) exit(0)
# Normalize config_file path to absolute path # Normalize config_file path to absolute path

View file

@ -28,8 +28,16 @@ class ProcessorType(str, Enum):
Conversation = "conversation" Conversation = "conversation"
class TextSearchModel(): class TextSearchModel:
def __init__(self, entries: List[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: List[BaseFilter], top_k): def __init__(
self,
entries: List[Entry],
corpus_embeddings: torch.Tensor,
bi_encoder: BaseEncoder,
cross_encoder: CrossEncoder,
filters: List[BaseFilter],
top_k,
):
self.entries = entries self.entries = entries
self.corpus_embeddings = corpus_embeddings self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder self.bi_encoder = bi_encoder
@ -38,7 +46,7 @@ class TextSearchModel():
self.top_k = top_k self.top_k = top_k
class ImageSearchModel(): class ImageSearchModel:
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder): def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder self.image_encoder = image_encoder
self.image_names = image_names self.image_names = image_names
@ -48,7 +56,7 @@ class ImageSearchModel():
@dataclass @dataclass
class SearchModels(): class SearchModels:
orgmode_search: TextSearchModel = None orgmode_search: TextSearchModel = None
ledger_search: TextSearchModel = None ledger_search: TextSearchModel = None
music_search: TextSearchModel = None music_search: TextSearchModel = None
@ -56,15 +64,15 @@ class SearchModels():
image_search: ImageSearchModel = None image_search: ImageSearchModel = None
class ConversationProcessorConfigModel(): class ConversationProcessorConfigModel:
def __init__(self, processor_config: ConversationProcessorConfig): def __init__(self, processor_config: ConversationProcessorConfig):
self.openai_api_key = processor_config.openai_api_key self.openai_api_key = processor_config.openai_api_key
self.model = processor_config.model self.model = processor_config.model
self.conversation_logfile = Path(processor_config.conversation_logfile) self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_session = '' self.chat_session = ""
self.meta_log: dict = {} self.meta_log: dict = {}
@dataclass @dataclass
class ProcessorConfigModel(): class ProcessorConfigModel:
conversation: ConversationProcessorConfigModel = None conversation: ConversationProcessorConfigModel = None

View file

@ -1,65 +1,62 @@
from pathlib import Path from pathlib import Path
app_root_directory = Path(__file__).parent.parent.parent app_root_directory = Path(__file__).parent.parent.parent
web_directory = app_root_directory / 'khoj/interface/web/' web_directory = app_root_directory / "khoj/interface/web/"
empty_escape_sequences = '\n|\r|\t| ' empty_escape_sequences = "\n|\r|\t| "
# default app config to use # default app config to use
default_config = { default_config = {
'content-type': { "content-type": {
'org': { "org": {
'input-files': None, "input-files": None,
'input-filter': None, "input-filter": None,
'compressed-jsonl': '~/.khoj/content/org/org.jsonl.gz', "compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
'embeddings-file': '~/.khoj/content/org/org_embeddings.pt', "embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
'index_heading_entries': False "index_heading_entries": False,
}, },
'markdown': { "markdown": {
'input-files': None, "input-files": None,
'input-filter': None, "input-filter": None,
'compressed-jsonl': '~/.khoj/content/markdown/markdown.jsonl.gz', "compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
'embeddings-file': '~/.khoj/content/markdown/markdown_embeddings.pt' "embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
}, },
'ledger': { "ledger": {
'input-files': None, "input-files": None,
'input-filter': None, "input-filter": None,
'compressed-jsonl': '~/.khoj/content/ledger/ledger.jsonl.gz', "compressed-jsonl": "~/.khoj/content/ledger/ledger.jsonl.gz",
'embeddings-file': '~/.khoj/content/ledger/ledger_embeddings.pt' "embeddings-file": "~/.khoj/content/ledger/ledger_embeddings.pt",
}, },
'image': { "image": {
'input-directories': None, "input-directories": None,
'input-filter': None, "input-filter": None,
'embeddings-file': '~/.khoj/content/image/image_embeddings.pt', "embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
'batch-size': 50, "batch-size": 50,
'use-xmp-metadata': False "use-xmp-metadata": False,
}, },
'music': { "music": {
'input-files': None, "input-files": None,
'input-filter': None, "input-filter": None,
'compressed-jsonl': '~/.khoj/content/music/music.jsonl.gz', "compressed-jsonl": "~/.khoj/content/music/music.jsonl.gz",
'embeddings-file': '~/.khoj/content/music/music_embeddings.pt' "embeddings-file": "~/.khoj/content/music/music_embeddings.pt",
},
},
"search-type": {
"symmetric": {
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/symmetric/",
},
"asymmetric": {
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/asymmetric/",
},
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
"processor": {
"conversation": {
"openai-api-key": None,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
} }
}, },
'search-type': { }
'symmetric': {
'encoder': 'sentence-transformers/all-MiniLM-L6-v2',
'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
'model_directory': '~/.khoj/search/symmetric/'
},
'asymmetric': {
'encoder': 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1',
'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
'model_directory': '~/.khoj/search/asymmetric/'
},
'image': {
'encoder': 'sentence-transformers/clip-ViT-B-32',
'model_directory': '~/.khoj/search/image/'
}
},
'processor': {
'conversation': {
'openai-api-key': None,
'conversation-logfile': '~/.khoj/processor/conversation/conversation_logs.json'
}
}
}

View file

@ -13,16 +13,17 @@ from typing import Optional, Union, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
# External Packages # External Packages
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
# Internal Packages # Internal Packages
from khoj.utils.models import BaseEncoder from khoj.utils.models import BaseEncoder
def is_none_or_empty(item): def is_none_or_empty(item):
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == '' return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
def to_snake_case_from_dash(item: str): def to_snake_case_from_dash(item: str):
return item.replace('_', '-') return item.replace("_", "-")
def get_absolute_path(filepath: Union[str, Path]) -> str: def get_absolute_path(filepath: Union[str, Path]) -> str:
@ -34,11 +35,11 @@ def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) ->
def get_from_dict(dictionary, *args): def get_from_dict(dictionary, *args):
'''null-aware get from a nested dictionary """null-aware get from a nested dictionary
Returns: dictionary[args[0]][args[1]]... or None if any keys missing''' Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
current = dictionary current = dictionary
for arg in args: for arg in args:
if not hasattr(current, '__iter__') or not arg in current: if not hasattr(current, "__iter__") or not arg in current:
return None return None
current = current[arg] current = current[arg]
return current return current
@ -54,7 +55,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]: def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
"Load model from disk or huggingface" "Load model from disk or huggingface"
# Construct model path # Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
@ -74,17 +75,18 @@ def load_model(model_name: str, model_type, model_dir=None, device:str=None) ->
def is_pyinstaller_app(): def is_pyinstaller_app():
"Returns true if the app is running from Native GUI created by PyInstaller" "Returns true if the app is running from Native GUI created by PyInstaller"
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
def get_class_by_name(name: str) -> object: def get_class_by_name(name: str) -> object:
"Returns the class object from name string" "Returns the class object from name string"
module_name, class_name = name.rsplit('.', 1) module_name, class_name = name.rsplit(".", 1)
return getattr(import_module(module_name), class_name) return getattr(import_module(module_name), class_name)
class timer: class timer:
'''Context manager to log time taken for a block of code to run''' """Context manager to log time taken for a block of code to run"""
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None): def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
self.message = message self.message = message
self.logger = logger self.logger = logger
@ -116,4 +118,4 @@ class LRU(OrderedDict):
super().__setitem__(key, value) super().__setitem__(key, value)
if len(self) > self.capacity: if len(self) > self.capacity:
oldest = next(iter(self)) oldest = next(iter(self))
del self[oldest] del self[oldest]

View file

@ -19,9 +19,9 @@ def load_jsonl(input_path):
# Open JSONL file # Open JSONL file
if input_path.suffix == ".gz": if input_path.suffix == ".gz":
jsonl_file = gzip.open(get_absolute_path(input_path), 'rt', encoding='utf-8') jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
elif input_path.suffix == ".jsonl": elif input_path.suffix == ".jsonl":
jsonl_file = open(get_absolute_path(input_path), 'r', encoding='utf-8') jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
# Read JSONL file # Read JSONL file
for line in jsonl_file: for line in jsonl_file:
@ -31,7 +31,7 @@ def load_jsonl(input_path):
jsonl_file.close() jsonl_file.close()
# Log JSONL entries loaded # Log JSONL entries loaded
logger.info(f'Loaded {len(data)} records from {input_path}') logger.info(f"Loaded {len(data)} records from {input_path}")
return data return data
@ -41,17 +41,17 @@ def dump_jsonl(jsonl_data, output_path):
# Create output directory, if it doesn't exist # Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, "w", encoding="utf-8") as f:
f.write(jsonl_data) f.write(jsonl_data)
logger.info(f'Wrote jsonl data to {output_path}') logger.info(f"Wrote jsonl data to {output_path}")
def compress_jsonl_data(jsonl_data, output_path): def compress_jsonl_data(jsonl_data, output_path):
# Create output directory, if it doesn't exist # Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
with gzip.open(output_path, 'wt', encoding='utf-8') as gzip_file: with gzip.open(output_path, "wt", encoding="utf-8") as gzip_file:
gzip_file.write(jsonl_data) gzip_file.write(jsonl_data)
logger.info(f'Wrote jsonl data to gzip compressed jsonl at {output_path}') logger.info(f"Wrote jsonl data to gzip compressed jsonl at {output_path}")

View file

@ -13,17 +13,25 @@ from khoj.utils.state import processor_config, config_file
class BaseEncoder(ABC): class BaseEncoder(ABC):
@abstractmethod @abstractmethod
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ... def __init__(self, model_name: str, device: torch.device = None, **kwargs):
...
@abstractmethod @abstractmethod
def encode(self, entries: List[str], device:torch.device=None, **kwargs) -> torch.Tensor: ... def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor:
...
class OpenAI(BaseEncoder): class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None): def __init__(self, model_name, device=None):
self.model_name = model_name self.model_name = model_name
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key: if (
raise Exception(f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}") not processor_config
or not processor_config.conversation
or not processor_config.conversation.openai_api_key
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}"
)
openai.api_key = processor_config.conversation.openai_api_key openai.api_key = processor_config.conversation.openai_api_key
self.embedding_dimensions = None self.embedding_dimensions = None
@ -32,7 +40,7 @@ class OpenAI(BaseEncoder):
for index in trange(0, len(entries)): for index in trange(0, len(entries)):
# OpenAI models create better embeddings for entries without newlines # OpenAI models create better embeddings for entries without newlines
processed_entry = entries[index].replace('\n', ' ') processed_entry = entries[index].replace("\n", " ")
try: try:
response = openai.Embedding.create(input=processed_entry, model=self.model_name) response = openai.Embedding.create(input=processed_entry, model=self.model_name)
@ -41,10 +49,12 @@ class OpenAI(BaseEncoder):
# Else default to embedding dimensions of the text-embedding-ada-002 model # Else default to embedding dimensions of the text-embedding-ada-002 model
self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536 self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536
except Exception as e: except Exception as e:
print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}") print(
f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}"
)
# Use zero embedding vector for entries with failed embeddings # Use zero embedding vector for entries with failed embeddings
# This ensures entry embeddings match the order of the source entries # This ensures entry embeddings match the order of the source entries
# And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector) # And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector)
embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)] embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)]
return torch.stack(embedding_tensors) return torch.stack(embedding_tensors)

View file

@ -9,11 +9,13 @@ from pydantic import BaseModel, validator
# Internal Packages # Internal Packages
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
class ConfigBase(BaseModel): class ConfigBase(BaseModel):
class Config: class Config:
alias_generator = to_snake_case_from_dash alias_generator = to_snake_case_from_dash
allow_population_by_field_name = True allow_population_by_field_name = True
class TextContentConfig(ConfigBase): class TextContentConfig(ConfigBase):
input_files: Optional[List[Path]] input_files: Optional[List[Path]]
input_filter: Optional[List[str]] input_filter: Optional[List[str]]
@ -21,12 +23,15 @@ class TextContentConfig(ConfigBase):
embeddings_file: Path embeddings_file: Path
index_heading_entries: Optional[bool] = False index_heading_entries: Optional[bool] = False
@validator('input_filter') @validator("input_filter")
def input_filter_or_files_required(cls, input_filter, values, **kwargs): def input_filter_or_files_required(cls, input_filter, values, **kwargs):
if is_none_or_empty(input_filter) and ('input_files' not in values or values["input_files"] is None): if is_none_or_empty(input_filter) and ("input_files" not in values or values["input_files"] is None):
raise ValueError("Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file") raise ValueError(
"Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file"
)
return input_filter return input_filter
class ImageContentConfig(ConfigBase): class ImageContentConfig(ConfigBase):
input_directories: Optional[List[Path]] input_directories: Optional[List[Path]]
input_filter: Optional[List[str]] input_filter: Optional[List[str]]
@ -34,12 +39,17 @@ class ImageContentConfig(ConfigBase):
use_xmp_metadata: bool use_xmp_metadata: bool
batch_size: int batch_size: int
@validator('input_filter') @validator("input_filter")
def input_filter_or_directories_required(cls, input_filter, values, **kwargs): def input_filter_or_directories_required(cls, input_filter, values, **kwargs):
if is_none_or_empty(input_filter) and ('input_directories' not in values or values["input_directories"] is None): if is_none_or_empty(input_filter) and (
raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file") "input_directories" not in values or values["input_directories"] is None
):
raise ValueError(
"Either input_filter or input_directories required in all content-type.image section of Khoj config file"
)
return input_filter return input_filter
class ContentConfig(ConfigBase): class ContentConfig(ConfigBase):
org: Optional[TextContentConfig] org: Optional[TextContentConfig]
ledger: Optional[TextContentConfig] ledger: Optional[TextContentConfig]
@ -47,41 +57,49 @@ class ContentConfig(ConfigBase):
music: Optional[TextContentConfig] music: Optional[TextContentConfig]
markdown: Optional[TextContentConfig] markdown: Optional[TextContentConfig]
class TextSearchConfig(ConfigBase): class TextSearchConfig(ConfigBase):
encoder: str encoder: str
cross_encoder: str cross_encoder: str
encoder_type: Optional[str] encoder_type: Optional[str]
model_directory: Optional[Path] model_directory: Optional[Path]
class ImageSearchConfig(ConfigBase): class ImageSearchConfig(ConfigBase):
encoder: str encoder: str
encoder_type: Optional[str] encoder_type: Optional[str]
model_directory: Optional[Path] model_directory: Optional[Path]
class SearchConfig(ConfigBase): class SearchConfig(ConfigBase):
asymmetric: Optional[TextSearchConfig] asymmetric: Optional[TextSearchConfig]
symmetric: Optional[TextSearchConfig] symmetric: Optional[TextSearchConfig]
image: Optional[ImageSearchConfig] image: Optional[ImageSearchConfig]
class ConversationProcessorConfig(ConfigBase): class ConversationProcessorConfig(ConfigBase):
openai_api_key: str openai_api_key: str
conversation_logfile: Path conversation_logfile: Path
model: Optional[str] = "text-davinci-003" model: Optional[str] = "text-davinci-003"
class ProcessorConfig(ConfigBase): class ProcessorConfig(ConfigBase):
conversation: Optional[ConversationProcessorConfig] conversation: Optional[ConversationProcessorConfig]
class FullConfig(ConfigBase): class FullConfig(ConfigBase):
content_type: Optional[ContentConfig] content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig] search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig] processor: Optional[ProcessorConfig]
class SearchResponse(ConfigBase): class SearchResponse(ConfigBase):
entry: str entry: str
score: str score: str
additional: Optional[dict] additional: Optional[dict]
class Entry():
class Entry:
raw: str raw: str
compiled: str compiled: str
file: Optional[str] file: Optional[str]
@ -99,8 +117,4 @@ class Entry():
@classmethod @classmethod
def from_dict(cls, dictionary: dict): def from_dict(cls, dictionary: dict):
return cls( return cls(raw=dictionary["raw"], compiled=dictionary["compiled"], file=dictionary.get("file", None))
raw=dictionary['raw'],
compiled=dictionary['compiled'],
file=dictionary.get('file', None)
)

View file

@ -17,14 +17,14 @@ def save_config_to_file(yaml_config: dict, yaml_config_file: Path):
# Create output directory, if it doesn't exist # Create output directory, if it doesn't exist
yaml_config_file.parent.mkdir(parents=True, exist_ok=True) yaml_config_file.parent.mkdir(parents=True, exist_ok=True)
with open(yaml_config_file, 'w', encoding='utf-8') as config_file: with open(yaml_config_file, "w", encoding="utf-8") as config_file:
yaml.safe_dump(yaml_config, config_file, allow_unicode=True) yaml.safe_dump(yaml_config, config_file, allow_unicode=True)
def load_config_from_file(yaml_config_file: Path) -> dict: def load_config_from_file(yaml_config_file: Path) -> dict:
"Read config from YML file" "Read config from YML file"
config_from_file = None config_from_file = None
with open(yaml_config_file, 'r', encoding='utf-8') as config_file: with open(yaml_config_file, "r", encoding="utf-8") as config_file:
config_from_file = yaml.safe_load(config_file) config_from_file = yaml.safe_load(config_file)
return config_from_file return config_from_file

View file

@ -6,59 +6,67 @@ import pytest
# Internal Packages # Internal Packages
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils.helpers import resolve_absolute_path from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from khoj.utils.rawconfig import (
ContentConfig,
TextContentConfig,
ImageContentConfig,
SearchConfig,
TextSearchConfig,
ImageSearchConfig,
)
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def search_config() -> SearchConfig: def search_config() -> SearchConfig:
model_dir = resolve_absolute_path('~/.khoj/search') model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig() search_config = SearchConfig()
search_config.symmetric = TextSearchConfig( search_config.symmetric = TextSearchConfig(
encoder = "sentence-transformers/all-MiniLM-L6-v2", encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir / 'symmetric/' model_directory=model_dir / "symmetric/",
) )
search_config.asymmetric = TextSearchConfig( search_config.asymmetric = TextSearchConfig(
encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir / 'asymmetric/' model_directory=model_dir / "asymmetric/",
) )
search_config.image = ImageSearchConfig( search_config.image = ImageSearchConfig(
encoder = "sentence-transformers/clip-ViT-B-32", encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/"
model_directory = model_dir / 'image/'
) )
return search_config return search_config
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_config: SearchConfig): def content_config(tmp_path_factory, search_config: SearchConfig):
content_dir = tmp_path_factory.mktemp('content') content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images # Generate Image Embeddings from Test Images
content_config = ContentConfig() content_config = ContentConfig()
content_config.image = ImageContentConfig( content_config.image = ImageContentConfig(
input_directories = ['tests/data/images'], input_directories=["tests/data/images"],
embeddings_file = content_dir.joinpath('image_embeddings.pt'), embeddings_file=content_dir.joinpath("image_embeddings.pt"),
batch_size = 1, batch_size=1,
use_xmp_metadata = False) use_xmp_metadata=False,
)
image_search.setup(content_config.image, search_config.image, regenerate=False) image_search.setup(content_config.image, search_config.image, regenerate=False)
# Generate Notes Embeddings from Test Notes # Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig( content_config.org = TextContentConfig(
input_files = None, input_files=None,
input_filter = ['tests/data/org/*.org'], input_filter=["tests/data/org/*.org"],
compressed_jsonl = content_dir.joinpath('notes.jsonl.gz'), compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
embeddings_file = content_dir.joinpath('note_embeddings.pt')) embeddings_file=content_dir.joinpath("note_embeddings.pt"),
)
filters = [DateFilter(), WordFilter(), FileFilter()] filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
@ -66,7 +74,7 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
return content_config return content_config
@pytest.fixture(scope='function') @pytest.fixture(scope="function")
def new_org_file(content_config: ContentConfig): def new_org_file(content_config: ContentConfig):
# Setup # Setup
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org" new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
@ -79,9 +87,9 @@ def new_org_file(content_config: ContentConfig):
new_org_file.unlink() new_org_file.unlink()
@pytest.fixture(scope='function') @pytest.fixture(scope="function")
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path): def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
new_org_config = deepcopy(content_config.org) new_org_config = deepcopy(content_config.org)
new_org_config.input_files = [f'{new_org_file}'] new_org_config.input_files = [f"{new_org_file}"]
new_org_config.input_filter = None new_org_config.input_filter = None
return new_org_config return new_org_config

View file

@ -8,10 +8,10 @@ from khoj.processor.ledger.beancount_to_jsonl import BeancountToJsonl
def test_no_transactions_in_file(tmp_path): def test_no_transactions_in_file(tmp_path):
"Handle file with no transactions." "Handle file with no transactions."
# Arrange # Arrange
entry = f''' entry = f"""
- Bullet point 1 - Bullet point 1
- Bullet point 2 - Bullet point 2
''' """
beancount_file = create_file(tmp_path, entry) beancount_file = create_file(tmp_path, entry)
# Act # Act
@ -20,7 +20,8 @@ def test_no_transactions_in_file(tmp_path):
# Process Each Entry from All Beancount Files # Process Each Entry from All Beancount Files
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries)) BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -30,11 +31,11 @@ def test_no_transactions_in_file(tmp_path):
def test_single_beancount_transaction_to_jsonl(tmp_path): def test_single_beancount_transaction_to_jsonl(tmp_path):
"Convert transaction from single file to jsonl." "Convert transaction from single file to jsonl."
# Arrange # Arrange
entry = f''' entry = f"""
1984-04-01 * "Payee" "Narration" 1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES Assets:Test:Test -1.00 KES
''' """
beancount_file = create_file(tmp_path, entry) beancount_file = create_file(tmp_path, entry)
# Act # Act
@ -43,7 +44,8 @@ Assets:Test:Test -1.00 KES
# Process Each Entry from All Beancount Files # Process Each Entry from All Beancount Files
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)) BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -53,7 +55,7 @@ Assets:Test:Test -1.00 KES
def test_multiple_transactions_to_jsonl(tmp_path): def test_multiple_transactions_to_jsonl(tmp_path):
"Convert multiple transactions from single file to jsonl." "Convert multiple transactions from single file to jsonl."
# Arrange # Arrange
entry = f''' entry = f"""
1984-04-01 * "Payee" "Narration" 1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES Assets:Test:Test -1.00 KES
@ -61,7 +63,7 @@ Assets:Test:Test -1.00 KES
1984-04-01 * "Payee" "Narration" 1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES Assets:Test:Test -1.00 KES
''' """
beancount_file = create_file(tmp_path, entry) beancount_file = create_file(tmp_path, entry)
@ -71,7 +73,8 @@ Assets:Test:Test -1.00 KES
# Process Each Entry from All Beancount Files # Process Each Entry from All Beancount Files
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)) BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -95,8 +98,8 @@ def test_get_beancount_files(tmp_path):
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1])) expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters # Setup input-files, input-filters
input_files = [tmp_path / 'ledger.bean'] input_files = [tmp_path / "ledger.bean"]
input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount'] input_filter = [tmp_path / "group1*.bean", tmp_path / "group2*.beancount"]
# Act # Act
extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter) extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter)

View file

@ -6,7 +6,7 @@ from khoj.processor.conversation.gpt import converse, understand, message_to_pro
# Initialize variables for tests # Initialize variables for tests
model = 'text-davinci-003' model = "text-davinci-003"
api_key = None # Input your OpenAI API key to run the tests below api_key = None # Input your OpenAI API key to run the tests below
@ -14,19 +14,22 @@ api_key = None # Input your OpenAI API key to run the tests below
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_message_to_understand_prompt(): def test_message_to_understand_prompt():
# Arrange # Arrange
understand_primer = "Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=[\"companion\", \"notes\", \"ledger\", \"image\", \"music\"]\nsearch(search-type, data);\nsearch-type=[\"google\", \"youtube\"]\ngenerate(activity);\nactivity=[\"paint\",\"write\", \"chat\"]\ntrigger-emotion(emotion);\nemotion=[\"happy\",\"confidence\",\"fear\",\"surprise\",\"sadness\",\"disgust\",\"anger\", \"curiosity\", \"calm\"]\n\nQ: How are you doing?\nA: activity(\"chat\"); trigger-emotion(\"surprise\")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember(\"notes\", \"Brother Antoine when we were at the beach\"); trigger-emotion(\"curiosity\");\nQ: what did we talk about last time?\nA: remember(\"notes\", \"talk last time\"); trigger-emotion(\"curiosity\");\nQ: Let's make some drawings!\nA: generate(\"paint\"); trigger-emotion(\"happy\");\nQ: Do you know anything about Lebanon?\nA: search(\"google\", \"lebanon\"); trigger-emotion(\"confidence\");\nQ: Find a video about a panda rolling in the grass\nA: search(\"youtube\",\"panda rolling in the grass\"); trigger-emotion(\"happy\"); \nQ: Tell me a scary story\nA: generate(\"write\" \"A story about some adventure\"); trigger-emotion(\"fear\");\nQ: What fiction book was I reading last week about AI starship?\nA: remember(\"notes\", \"read fiction book about AI starship last week\"); trigger-emotion(\"curiosity\");\nQ: How much did I spend at Subway for dinner last time?\nA: remember(\"ledger\", \"last Subway dinner\"); trigger-emotion(\"curiosity\");\nQ: I'm feeling sleepy\nA: activity(\"chat\"); trigger-emotion(\"calm\")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember(\"music\", \"popular Sri lankan song that Alex showed recently\"); trigger-emotion(\"curiosity\"); \nQ: You're pretty funny!\nA: activity(\"chat\"); trigger-emotion(\"pride\")" understand_primer = 'Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=["companion", "notes", "ledger", "image", "music"]\nsearch(search-type, data);\nsearch-type=["google", "youtube"]\ngenerate(activity);\nactivity=["paint","write", "chat"]\ntrigger-emotion(emotion);\nemotion=["happy","confidence","fear","surprise","sadness","disgust","anger", "curiosity", "calm"]\n\nQ: How are you doing?\nA: activity("chat"); trigger-emotion("surprise")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember("notes", "Brother Antoine when we were at the beach"); trigger-emotion("curiosity");\nQ: what did we talk about last time?\nA: remember("notes", "talk last time"); trigger-emotion("curiosity");\nQ: Let\'s make some drawings!\nA: generate("paint"); trigger-emotion("happy");\nQ: Do you know anything about Lebanon?\nA: search("google", "lebanon"); trigger-emotion("confidence");\nQ: Find a video about a panda rolling in the grass\nA: search("youtube","panda rolling in the grass"); trigger-emotion("happy"); \nQ: Tell me a scary story\nA: generate("write" "A story about some adventure"); trigger-emotion("fear");\nQ: What fiction book was I reading last week about AI starship?\nA: remember("notes", "read fiction book about AI starship last week"); trigger-emotion("curiosity");\nQ: How much did I spend at Subway for dinner last time?\nA: remember("ledger", "last Subway dinner"); trigger-emotion("curiosity");\nQ: I\'m feeling sleepy\nA: activity("chat"); trigger-emotion("calm")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember("music", "popular Sri lankan song that Alex showed recently"); trigger-emotion("curiosity"); \nQ: You\'re pretty funny!\nA: activity("chat"); trigger-emotion("pride")'
expected_response = "Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=[\"companion\", \"notes\", \"ledger\", \"image\", \"music\"]\nsearch(search-type, data);\nsearch-type=[\"google\", \"youtube\"]\ngenerate(activity);\nactivity=[\"paint\",\"write\", \"chat\"]\ntrigger-emotion(emotion);\nemotion=[\"happy\",\"confidence\",\"fear\",\"surprise\",\"sadness\",\"disgust\",\"anger\", \"curiosity\", \"calm\"]\n\nQ: How are you doing?\nA: activity(\"chat\"); trigger-emotion(\"surprise\")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember(\"notes\", \"Brother Antoine when we were at the beach\"); trigger-emotion(\"curiosity\");\nQ: what did we talk about last time?\nA: remember(\"notes\", \"talk last time\"); trigger-emotion(\"curiosity\");\nQ: Let's make some drawings!\nA: generate(\"paint\"); trigger-emotion(\"happy\");\nQ: Do you know anything about Lebanon?\nA: search(\"google\", \"lebanon\"); trigger-emotion(\"confidence\");\nQ: Find a video about a panda rolling in the grass\nA: search(\"youtube\",\"panda rolling in the grass\"); trigger-emotion(\"happy\"); \nQ: Tell me a scary story\nA: generate(\"write\" \"A story about some adventure\"); trigger-emotion(\"fear\");\nQ: What fiction book was I reading last week about AI starship?\nA: remember(\"notes\", \"read fiction book about AI starship last week\"); trigger-emotion(\"curiosity\");\nQ: How much did I spend at Subway for dinner last time?\nA: remember(\"ledger\", \"last Subway dinner\"); trigger-emotion(\"curiosity\");\nQ: I'm feeling sleepy\nA: activity(\"chat\"); trigger-emotion(\"calm\")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember(\"music\", \"popular Sri lankan song that Alex showed recently\"); trigger-emotion(\"curiosity\"); \nQ: You're pretty funny!\nA: activity(\"chat\"); trigger-emotion(\"pride\")\nQ: When did I last dine at Burger King?\nA:" expected_response = 'Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=["companion", "notes", "ledger", "image", "music"]\nsearch(search-type, data);\nsearch-type=["google", "youtube"]\ngenerate(activity);\nactivity=["paint","write", "chat"]\ntrigger-emotion(emotion);\nemotion=["happy","confidence","fear","surprise","sadness","disgust","anger", "curiosity", "calm"]\n\nQ: How are you doing?\nA: activity("chat"); trigger-emotion("surprise")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember("notes", "Brother Antoine when we were at the beach"); trigger-emotion("curiosity");\nQ: what did we talk about last time?\nA: remember("notes", "talk last time"); trigger-emotion("curiosity");\nQ: Let\'s make some drawings!\nA: generate("paint"); trigger-emotion("happy");\nQ: Do you know anything about Lebanon?\nA: search("google", "lebanon"); trigger-emotion("confidence");\nQ: Find a video about a panda rolling in the grass\nA: search("youtube","panda rolling in the grass"); trigger-emotion("happy"); \nQ: Tell me a scary story\nA: generate("write" "A story about some adventure"); trigger-emotion("fear");\nQ: What fiction book was I reading last week about AI starship?\nA: remember("notes", "read fiction book about AI starship last week"); trigger-emotion("curiosity");\nQ: How much did I spend at Subway for dinner last time?\nA: remember("ledger", "last Subway dinner"); trigger-emotion("curiosity");\nQ: I\'m feeling sleepy\nA: activity("chat"); trigger-emotion("calm")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember("music", "popular Sri lankan song that Alex showed recently"); trigger-emotion("curiosity"); \nQ: You\'re pretty funny!\nA: activity("chat"); trigger-emotion("pride")\nQ: When did I last dine at Burger King?\nA:'
# Act # Act
actual_response = message_to_prompt("When did I last dine at Burger King?", understand_primer, start_sequence="\nA:", restart_sequence="\nQ:") actual_response = message_to_prompt(
"When did I last dine at Burger King?", understand_primer, start_sequence="\nA:", restart_sequence="\nQ:"
)
# Assert # Assert
assert actual_response == expected_response assert actual_response == expected_response
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(api_key is None, @pytest.mark.skipif(
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys") api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
)
def test_minimal_chat_with_gpt(): def test_minimal_chat_with_gpt():
# Act # Act
response = converse("What will happen when the stars go out?", model=model, api_key=api_key) response = converse("What will happen when the stars go out?", model=model, api_key=api_key)
@ -36,21 +39,29 @@ def test_minimal_chat_with_gpt():
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(api_key is None, @pytest.mark.skipif(
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys") api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
)
def test_chat_with_history(): def test_chat_with_history():
# Arrange # Arrange
ai_prompt="AI:" ai_prompt = "AI:"
human_prompt="Human:" human_prompt = "Human:"
conversation_primer = f''' conversation_primer = f"""
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly companion. The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly companion.
{human_prompt} Hello, I am Testatron. Who are you? {human_prompt} Hello, I am Testatron. Who are you?
{ai_prompt} Hi, I am Khoj, an AI conversational companion created by OpenAI. How can I help you today?''' {ai_prompt} Hi, I am Khoj, an AI conversational companion created by OpenAI. How can I help you today?"""
# Act # Act
response = converse("Hi Khoj, What is my name?", model=model, conversation_history=conversation_primer, api_key=api_key, temperature=0, max_tokens=50) response = converse(
"Hi Khoj, What is my name?",
model=model,
conversation_history=conversation_primer,
api_key=api_key,
temperature=0,
max_tokens=50,
)
# Assert # Assert
assert len(response) > 0 assert len(response) > 0
@ -58,12 +69,13 @@ The following is a conversation with an AI assistant. The assistant is helpful,
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(api_key is None, @pytest.mark.skipif(
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys") api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
)
def test_understand_message_using_gpt(): def test_understand_message_using_gpt():
# Act # Act
response = understand("When did I last dine at Subway?", model=model, api_key=api_key) response = understand("When did I last dine at Subway?", model=model, api_key=api_key)
# Assert # Assert
assert len(response) > 0 assert len(response) > 0
assert response['intent']['memory-type'] == 'ledger' assert response["intent"]["memory-type"] == "ledger"

View file

@ -14,35 +14,37 @@ def test_cli_minimal_default():
actual_args = cli([]) actual_args = cli([])
# Assert # Assert
assert actual_args.config_file == resolve_absolute_path(Path('~/.khoj/khoj.yml')) assert actual_args.config_file == resolve_absolute_path(Path("~/.khoj/khoj.yml"))
assert actual_args.regenerate == False assert actual_args.regenerate == False
assert actual_args.no_gui == False assert actual_args.no_gui == False
assert actual_args.verbose == 0 assert actual_args.verbose == 0
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_cli_invalid_config_file_path(): def test_cli_invalid_config_file_path():
# Arrange # Arrange
non_existent_config_file = f"non-existent-khoj-{random()}.yml" non_existent_config_file = f"non-existent-khoj-{random()}.yml"
# Act # Act
actual_args = cli([f'-c={non_existent_config_file}']) actual_args = cli([f"-c={non_existent_config_file}"])
# Assert # Assert
assert actual_args.config_file == resolve_absolute_path(non_existent_config_file) assert actual_args.config_file == resolve_absolute_path(non_existent_config_file)
assert actual_args.config == None assert actual_args.config == None
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_cli_config_from_file(): def test_cli_config_from_file():
# Act # Act
actual_args = cli(['-c=tests/data/config.yml', actual_args = cli(["-c=tests/data/config.yml", "--regenerate", "--no-gui", "-vvv"])
'--regenerate',
'--no-gui',
'-vvv'])
# Assert # Assert
assert actual_args.config_file == resolve_absolute_path(Path('tests/data/config.yml')) assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml"))
assert actual_args.no_gui == True assert actual_args.no_gui == True
assert actual_args.regenerate == True assert actual_args.regenerate == True
assert actual_args.config is not None assert actual_args.config is not None
assert actual_args.config.content_type.org.input_files == [Path('~/first_from_config.org'), Path('~/second_from_config.org')] assert actual_args.config.content_type.org.input_files == [
Path("~/first_from_config.org"),
Path("~/second_from_config.org"),
]
assert actual_args.verbose == 3 assert actual_args.verbose == 3

View file

@ -21,6 +21,7 @@ from khoj.search_filter.file_filter import FileFilter
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
client = TestClient(app) client = TestClient(app)
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_search_with_invalid_content_type(): def test_search_with_invalid_content_type():
@ -98,9 +99,11 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
config.content_type = content_config config.content_type = content_config
config.search_type = search_config config.search_type = search_config
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [("kitten", "kitten_park.jpg"), query_expected_image_pairs = [
("a horse and dog on a leash", "horse_dog.jpg"), ("kitten", "kitten_park.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg")] ("a horse and dog on a leash", "horse_dog.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg"),
]
for query, expected_image_name in query_expected_image_pairs: for query, expected_image_name in query_expected_image_pairs:
# Act # Act
@ -135,7 +138,9 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter(), FileFilter()] filters = [WordFilter(), FileFilter()]
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) model.orgmode_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
)
user_query = quote('+"Emacs" file:"*.org"') user_query = quote('+"Emacs" file:"*.org"')
# Act # Act
@ -152,7 +157,9 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter()] filters = [WordFilter()]
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) model.orgmode_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
)
user_query = quote('How to git install application? +"Emacs"') user_query = quote('How to git install application? +"Emacs"')
# Act # Act
@ -169,7 +176,9 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter()] filters = [WordFilter()]
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) model.orgmode_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
)
user_query = quote('How to git install application? -"clone"') user_query = quote('How to git install application? -"clone"')
# Act # Act

View file

@ -10,53 +10,59 @@ from khoj.utils.rawconfig import Entry
def test_date_filter(): def test_date_filter():
entries = [ entries = [
Entry(compiled='', raw='Entry with no date'), Entry(compiled="", raw="Entry with no date"),
Entry(compiled='', raw='April Fools entry: 1984-04-01'), Entry(compiled="", raw="April Fools entry: 1984-04-01"),
Entry(compiled='', raw='Entry with date:1984-04-02') Entry(compiled="", raw="Entry with date:1984-04-02"),
] ]
q_with_no_date_filter = 'head tail' q_with_no_date_filter = "head tail"
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {0, 1, 2} assert entry_indices == {0, 1, 2}
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries) ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == set() assert entry_indices == set()
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {2} assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {1} assert entry_indices == {1}
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {2} assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {1, 2} assert entry_indices == {1, 2}
def test_extract_date_range(): def test_extract_date_range():
assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf] assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
]
# Unparseable date filter specified in query # Unparseable date filter specified in query
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
# No date filter specified in query # No date filter specified in query
assert DateFilter().extract_date_range('head tail') == None assert DateFilter().extract_date_range("head tail") == None
# Non intersecting date ranges # Non intersecting date ranges
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
@ -66,43 +72,79 @@ def test_parse():
test_now = datetime(1984, 4, 1, 21, 21, 21) test_now = datetime(1984, 4, 1, 21, 21, 21)
# day variations # day variations
assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0)) assert DateFilter().parse("today", relative_base=test_now) == (
assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0)) datetime(1984, 4, 1, 0, 0, 0),
assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0)) datetime(1984, 4, 2, 0, 0, 0),
assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0)) )
assert DateFilter().parse("tomorrow", relative_base=test_now) == (
datetime(1984, 4, 2, 0, 0, 0),
datetime(1984, 4, 3, 0, 0, 0),
)
assert DateFilter().parse("yesterday", relative_base=test_now) == (
datetime(1984, 3, 31, 0, 0, 0),
datetime(1984, 4, 1, 0, 0, 0),
)
assert DateFilter().parse("5 days ago", relative_base=test_now) == (
datetime(1984, 3, 27, 0, 0, 0),
datetime(1984, 3, 28, 0, 0, 0),
)
# week variations # week variations
assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0)) assert DateFilter().parse("last week", relative_base=test_now) == (
assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0)) datetime(1984, 3, 18, 0, 0, 0),
datetime(1984, 3, 25, 0, 0, 0),
)
assert DateFilter().parse("2 weeks ago", relative_base=test_now) == (
datetime(1984, 3, 11, 0, 0, 0),
datetime(1984, 3, 18, 0, 0, 0),
)
# month variations # month variations
assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0)) assert DateFilter().parse("next month", relative_base=test_now) == (
assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0)) datetime(1984, 5, 1, 0, 0, 0),
datetime(1984, 6, 1, 0, 0, 0),
)
assert DateFilter().parse("2 months ago", relative_base=test_now) == (
datetime(1984, 2, 1, 0, 0, 0),
datetime(1984, 3, 1, 0, 0, 0),
)
# year variations # year variations
assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0)) assert DateFilter().parse("this year", relative_base=test_now) == (
assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0)) datetime(1984, 1, 1, 0, 0, 0),
datetime(1985, 1, 1, 0, 0, 0),
)
assert DateFilter().parse("20 years later", relative_base=test_now) == (
datetime(2004, 1, 1, 0, 0, 0),
datetime(2005, 1, 1, 0, 0, 0),
)
# specific month/date variation # specific month/date variation
assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) assert DateFilter().parse("in august", relative_base=test_now) == (
assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) datetime(1983, 8, 1, 0, 0, 0),
datetime(1983, 8, 2, 0, 0, 0),
)
assert DateFilter().parse("on 1983-08-01", relative_base=test_now) == (
datetime(1983, 8, 1, 0, 0, 0),
datetime(1983, 8, 2, 0, 0, 0),
)
def test_date_filter_regex(): def test_date_filter_regex():
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"') dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] assert dtrange_match == [(">", "today"), (":", "1984-01-01")]
dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail') dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] assert dtrange_match == [(">", "today"), (":", "1984-01-01")]
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"') dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')] assert dtrange_match == [(">=", "today"), ("=", "1984-01-01")]
dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail') dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
assert dtrange_match == [('<', 'multi word date')] assert dtrange_match == [("<", "multi word date")]
dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"') dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
assert dtrange_match == [('<=', 'multi word date')] assert dtrange_match == [("<=", "multi word date")]
dtrange_match = re.findall(DateFilter().date_regex, 'head tail') dtrange_match = re.findall(DateFilter().date_regex, "head tail")
assert dtrange_match == [] assert dtrange_match == []

View file

@ -7,7 +7,7 @@ def test_no_file_filter():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
entries = arrange_content() entries = arrange_content()
q_with_no_filter = 'head tail' q_with_no_filter = "head tail"
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
@ -15,7 +15,7 @@ def test_no_file_filter():
# Assert # Assert
assert can_filter == False assert can_filter == False
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3} assert entry_indices == {0, 1, 2, 3}
@ -31,7 +31,7 @@ def test_file_filter_with_non_existent_file():
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {} assert entry_indices == {}
@ -47,7 +47,7 @@ def test_single_file_filter():
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {0, 2} assert entry_indices == {0, 2}
@ -63,7 +63,7 @@ def test_file_filter_with_partial_match():
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {0, 2} assert entry_indices == {0, 2}
@ -79,7 +79,7 @@ def test_file_filter_with_regex_match():
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3} assert entry_indices == {0, 1, 2, 3}
@ -95,16 +95,16 @@ def test_multiple_file_filter():
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3} assert entry_indices == {0, 1, 2, 3}
def arrange_content(): def arrange_content():
entries = [ entries = [
Entry(compiled='', raw='First Entry', file= 'file 1.org'), Entry(compiled="", raw="First Entry", file="file 1.org"),
Entry(compiled='', raw='Second Entry', file= 'file2.org'), Entry(compiled="", raw="Second Entry", file="file2.org"),
Entry(compiled='', raw='Third Entry', file= 'file 1.org'), Entry(compiled="", raw="Third Entry", file="file 1.org"),
Entry(compiled='', raw='Fourth Entry', file= 'file2.org') Entry(compiled="", raw="Fourth Entry", file="file2.org"),
] ]
return entries return entries

View file

@ -1,5 +1,6 @@
from khoj.utils import helpers from khoj.utils import helpers
def test_get_from_null_dict(): def test_get_from_null_dict():
# null handling # null handling
assert helpers.get_from_dict(dict()) == dict() assert helpers.get_from_dict(dict()) == dict()
@ -7,39 +8,39 @@ def test_get_from_null_dict():
# key present in nested dictionary # key present in nested dictionary
# 1-level dictionary # 1-level dictionary
assert helpers.get_from_dict({'a': 1, 'b': 2}, 'a') == 1 assert helpers.get_from_dict({"a": 1, "b": 2}, "a") == 1
assert helpers.get_from_dict({'a': 1, 'b': 2}, 'c') == None assert helpers.get_from_dict({"a": 1, "b": 2}, "c") == None
# 2-level dictionary # 2-level dictionary
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'a') == {'a_a': 1} assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a") == {"a_a": 1}
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'a', 'a_a') == 1 assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a", "a_a") == 1
# key not present in nested dictionary # key not present in nested dictionary
# 2-level_dictionary # 2-level_dictionary
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'b', 'b_a') == None assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "b", "b_a") == None
def test_merge_dicts(): def test_merge_dicts():
# basic merge of dicts with non-overlapping keys # basic merge of dicts with non-overlapping keys
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'b': 2}) == {'a': 1, 'b': 2} assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"b": 2}) == {"a": 1, "b": 2}
# use default dict items when not present in priority dict # use default dict items when not present in priority dict
assert helpers.merge_dicts(priority_dict={}, default_dict={'b': 2}) == {'b': 2} assert helpers.merge_dicts(priority_dict={}, default_dict={"b": 2}) == {"b": 2}
# do not override existing key in priority_dict with default dict # do not override existing key in priority_dict with default dict
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1} assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"a": 2}) == {"a": 1}
def test_lru_cache(): def test_lru_cache():
# Test initializing cache # Test initializing cache
cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2) cache = helpers.LRU({"a": 1, "b": 2}, capacity=2)
assert cache == {'a': 1, 'b': 2} assert cache == {"a": 1, "b": 2}
# Test capacity overflow # Test capacity overflow
cache['c'] = 3 cache["c"] = 3
assert cache == {'b': 2, 'c': 3} assert cache == {"b": 2, "c": 3}
# Test delete least recently used item from LRU cache on capacity overflow # Test delete least recently used item from LRU cache on capacity overflow
cache['b'] # accessing 'b' makes it the most recently used item cache["b"] # accessing 'b' makes it the most recently used item
cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b' cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {'b': 2, 'd': 4} assert cache == {"b": 2, "d": 4}

View file

@ -30,7 +30,8 @@ def test_image_metadata(content_config: ContentConfig):
expected_metadata_image_name_pairs = [ expected_metadata_image_name_pairs = [
(["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"), (["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"),
(["Pasture.", "Horse", "Dog"], "horse_dog.jpg"), (["Pasture.", "Horse", "Dog"], "horse_dog.jpg"),
(["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg")] (["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg"),
]
test_image_paths = [ test_image_paths = [
Path(content_config.image.input_directories[0] / image_name[1]) Path(content_config.image.input_directories[0] / image_name[1])
@ -51,23 +52,23 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# Arrange # Arrange
output_directory = resolve_absolute_path(web_directory) output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [("kitten", "kitten_park.jpg"), query_expected_image_pairs = [
("horse and dog in a farm", "horse_dog.jpg"), ("kitten", "kitten_park.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg")] ("horse and dog in a farm", "horse_dog.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg"),
]
# Act # Act
for query, expected_image_name in query_expected_image_pairs: for query, expected_image_name in query_expected_image_pairs:
hits = image_search.query( hits = image_search.query(query, count=1, model=model.image_search)
query,
count = 1,
model = model.image_search)
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
model.image_search.image_names, model.image_search.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url='/static/images', image_files_url="/static/images",
count=1) count=1,
)
actual_image_path = output_directory.joinpath(Path(results[0].entry).name) actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path) actual_image = Image.open(actual_image_path)
@ -86,16 +87,13 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
# Arrange # Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
max_words_supported = 10 max_words_supported = 10
query = " ".join(["hello"]*100) query = " ".join(["hello"] * 100)
truncated_query = " ".join(["hello"]*max_words_supported) truncated_query = " ".join(["hello"] * max_words_supported)
# Act # Act
try: try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
image_search.query( image_search.query(query, count=1, model=model.image_search)
query,
count = 1,
model = model.image_search)
# Assert # Assert
except RuntimeError as e: except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
@ -115,17 +113,15 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
# Act # Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = image_search.query( hits = image_search.query(query, count=1, model=model.image_search)
query,
count = 1,
model = model.image_search)
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
model.image_search.image_names, model.image_search.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url='/static/images', image_files_url="/static/images",
count=1) count=1,
)
actual_image_path = output_directory.joinpath(Path(results[0].entry).name) actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path) actual_image = Image.open(actual_image_path)
@ -133,7 +129,9 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
# Assert # Assert
# Ensure file search triggered instead of query with file path as string # Ensure file search triggered instead of query with file path as string
assert f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text, "File search not triggered" assert (
f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text
), "File search not triggered"
# Ensure the correct image is returned # Ensure the correct image is returned
assert expected_image == actual_image, "Incorrect image returned by file search" assert expected_image == actual_image, "Incorrect image returned by file search"

View file

@ -8,10 +8,10 @@ from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
def test_markdown_file_with_no_headings_to_jsonl(tmp_path): def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
"Convert files with no heading to jsonl." "Convert files with no heading to jsonl."
# Arrange # Arrange
entry = f''' entry = f"""
- Bullet point 1 - Bullet point 1
- Bullet point 2 - Bullet point 2
''' """
markdownfile = create_file(tmp_path, entry) markdownfile = create_file(tmp_path, entry)
# Act # Act
@ -20,7 +20,8 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)) MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -30,10 +31,10 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
def test_single_markdown_entry_to_jsonl(tmp_path): def test_single_markdown_entry_to_jsonl(tmp_path):
"Convert markdown entry from single file to jsonl." "Convert markdown entry from single file to jsonl."
# Arrange # Arrange
entry = f'''### Heading entry = f"""### Heading
\t\r \t\r
Body Line 1 Body Line 1
''' """
markdownfile = create_file(tmp_path, entry) markdownfile = create_file(tmp_path, entry)
# Act # Act
@ -42,7 +43,8 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)) MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -52,14 +54,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
def test_multiple_markdown_entries_to_jsonl(tmp_path): def test_multiple_markdown_entries_to_jsonl(tmp_path):
"Convert multiple markdown entries from single file to jsonl." "Convert multiple markdown entries from single file to jsonl."
# Arrange # Arrange
entry = f''' entry = f"""
### Heading 1 ### Heading 1
\t\r \t\r
Heading 1 Body Line 1 Heading 1 Body Line 1
### Heading 2 ### Heading 2
\t\r \t\r
Heading 2 Body Line 2 Heading 2 Body Line 2
''' """
markdownfile = create_file(tmp_path, entry) markdownfile = create_file(tmp_path, entry)
# Act # Act
@ -68,7 +70,8 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)) MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -92,8 +95,8 @@ def test_get_markdown_files(tmp_path):
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1])) expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters # Setup input-files, input-filters
input_files = [tmp_path / 'notes.md'] input_files = [tmp_path / "notes.md"]
input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown'] input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"]
# Act # Act
extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter) extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter)
@ -106,10 +109,10 @@ def test_get_markdown_files(tmp_path):
def test_extract_entries_with_different_level_headings(tmp_path): def test_extract_entries_with_different_level_headings(tmp_path):
"Extract markdown entries with different level headings." "Extract markdown entries with different level headings."
# Arrange # Arrange
entry = f''' entry = f"""
# Heading 1 # Heading 1
## Heading 2 ## Heading 2
''' """
markdownfile = create_file(tmp_path, entry) markdownfile = create_file(tmp_path, entry)
# Act # Act

View file

@ -9,23 +9,25 @@ from khoj.utils.rawconfig import Entry
def test_configure_heading_entry_to_jsonl(tmp_path): def test_configure_heading_entry_to_jsonl(tmp_path):
'''Ensure entries with empty body are ignored, unless explicitly configured to index heading entries. """Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
Property drawers not considered Body. Ignore control characters for evaluating if Body empty.''' Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
# Arrange # Arrange
entry = f'''*** Heading entry = f"""*** Heading
:PROPERTIES: :PROPERTIES:
:ID: 42-42-42 :ID: 42-42-42
:END: :END:
\t \r \t \r
''' """
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
for index_heading_entries in [True, False]: for index_heading_entries in [True, False]:
# Act # Act
# Extract entries into jsonl from specified Org files # Extract entries into jsonl from specified Org files
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries( jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
*OrgToJsonl.extract_org_entries(org_files=[orgfile]), OrgToJsonl.convert_org_nodes_to_entries(
index_heading_entries=index_heading_entries)) *OrgToJsonl.extract_org_entries(org_files=[orgfile]), index_heading_entries=index_heading_entries
)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -40,10 +42,10 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
def test_entry_split_when_exceeds_max_words(tmp_path): def test_entry_split_when_exceeds_max_words(tmp_path):
"Ensure entries with compiled words exceeding max_words are split." "Ensure entries with compiled words exceeding max_words are split."
# Arrange # Arrange
entry = f'''*** Heading entry = f"""*** Heading
\t\r \t\r
Body Line 1 Body Line 1
''' """
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -53,9 +55,9 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
# Split each entry from specified Org files by max words # Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
TextToJsonl.split_entries_by_max_tokens( TextToJsonl.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=2
max_tokens = 2)
) )
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -65,15 +67,15 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
def test_entry_split_drops_large_words(tmp_path): def test_entry_split_drops_large_words(tmp_path):
"Ensure entries drops words larger than specified max word length from compiled version." "Ensure entries drops words larger than specified max word length from compiled version."
# Arrange # Arrange
entry_text = f'''*** Heading entry_text = f"""*** Heading
\t\r \t\r
Body Line 1 Body Line 1
''' """
entry = Entry(raw=entry_text, compiled=entry_text) entry = Entry(raw=entry_text, compiled=entry_text)
# Act # Act
# Split entry by max words and drop words larger than max word length # Split entry by max words and drop words larger than max word length
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length = 5)[0] processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert # Assert
# "Heading" dropped from compiled version because its over the set max word limit # "Heading" dropped from compiled version because its over the set max word limit
@ -83,13 +85,13 @@ def test_entry_split_drops_large_words(tmp_path):
def test_entry_with_body_to_jsonl(tmp_path): def test_entry_with_body_to_jsonl(tmp_path):
"Ensure entries with valid body text are loaded." "Ensure entries with valid body text are loaded."
# Arrange # Arrange
entry = f'''*** Heading entry = f"""*** Heading
:PROPERTIES: :PROPERTIES:
:ID: 42-42-42 :ID: 42-42-42
:END: :END:
\t\r \t\r
Body Line 1 Body Line 1
''' """
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -97,7 +99,9 @@ def test_entry_with_body_to_jsonl(tmp_path):
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile]) entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map)) jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@ -107,10 +111,10 @@ def test_entry_with_body_to_jsonl(tmp_path):
def test_file_with_no_headings_to_jsonl(tmp_path): def test_file_with_no_headings_to_jsonl(tmp_path):
"Ensure files with no heading, only body text are loaded." "Ensure files with no heading, only body text are loaded."
# Arrange # Arrange
entry = f''' entry = f"""
- Bullet point 1 - Bullet point 1
- Bullet point 2 - Bullet point 2
''' """
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -120,7 +124,7 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries) entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries) jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
assert len(jsonl_data) == 1 assert len(jsonl_data) == 1
@ -143,8 +147,8 @@ def test_get_org_files(tmp_path):
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1])) expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]))
# Setup input-files, input-filters # Setup input-files, input-filters
input_files = [tmp_path / 'orgfile1.org'] input_files = [tmp_path / "orgfile1.org"]
input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org'] input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"]
# Act # Act
extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter) extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter)
@ -157,10 +161,10 @@ def test_get_org_files(tmp_path):
def test_extract_entries_with_different_level_headings(tmp_path): def test_extract_entries_with_different_level_headings(tmp_path):
"Extract org entries with different level headings." "Extract org entries with different level headings."
# Arrange # Arrange
entry = f''' entry = f"""
* Heading 1 * Heading 1
** Heading 2 ** Heading 2
''' """
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -169,8 +173,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Assert # Assert
assert len(entries) == 2 assert len(entries) == 2
assert f'{entries[0]}'.startswith("* Heading 1") assert f"{entries[0]}".startswith("* Heading 1")
assert f'{entries[1]}'.startswith("** Heading 2") assert f"{entries[1]}".startswith("** Heading 2")
# Helper Functions # Helper Functions

View file

@ -10,7 +10,7 @@ from khoj.processor.org_mode import orgnode
def test_parse_entry_with_no_headings(tmp_path): def test_parse_entry_with_no_headings(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f'''Body Line 1''' entry = f"""Body Line 1"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -18,7 +18,7 @@ def test_parse_entry_with_no_headings(tmp_path):
# Assert # Assert
assert len(entries) == 1 assert len(entries) == 1
assert entries[0].heading == f'{orgfile}' assert entries[0].heading == f"{orgfile}"
assert entries[0].tags == list() assert entries[0].tags == list()
assert entries[0].body == "Body Line 1" assert entries[0].body == "Body Line 1"
assert entries[0].priority == "" assert entries[0].priority == ""
@ -32,9 +32,9 @@ def test_parse_entry_with_no_headings(tmp_path):
def test_parse_minimal_entry(tmp_path): def test_parse_minimal_entry(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f''' entry = f"""
* Heading * Heading
Body Line 1''' Body Line 1"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -56,7 +56,7 @@ Body Line 1'''
def test_parse_complete_entry(tmp_path): def test_parse_complete_entry(tmp_path):
"Test parsing of entry with all important fields" "Test parsing of entry with all important fields"
# Arrange # Arrange
entry = f''' entry = f"""
*** DONE [#A] Heading :Tag1:TAG2:tag3: *** DONE [#A] Heading :Tag1:TAG2:tag3:
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun> CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
:PROPERTIES: :PROPERTIES:
@ -67,7 +67,7 @@ CLOCK: [1984-04-01 Sun 09:00]--[1984-04-01 Sun 12:00] => 3:00
- Clocked Log 1 - Clocked Log 1
:END: :END:
Body Line 1 Body Line 1
Body Line 2''' Body Line 2"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -81,45 +81,45 @@ Body Line 2'''
assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2" assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2"
assert entries[0].priority == "A" assert entries[0].priority == "A"
assert entries[0].Property("ID") == "id:123-456-789-4234-1231" assert entries[0].Property("ID") == "id:123-456-789-4234-1231"
assert entries[0].closed == datetime.date(1984,4,1) assert entries[0].closed == datetime.date(1984, 4, 1)
assert entries[0].scheduled == datetime.date(1984,4,1) assert entries[0].scheduled == datetime.date(1984, 4, 1)
assert entries[0].deadline == datetime.date(1984,4,1) assert entries[0].deadline == datetime.date(1984, 4, 1)
assert entries[0].logbook == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))] assert entries[0].logbook == [(datetime.datetime(1984, 4, 1, 9, 0, 0), datetime.datetime(1984, 4, 1, 12, 0, 0))]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_render_entry_with_property_drawer_and_empty_body(tmp_path): def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
"Render heading entry with property drawer" "Render heading entry with property drawer"
# Arrange # Arrange
entry_to_render = f''' entry_to_render = f"""
*** [#A] Heading1 :tag1: *** [#A] Heading1 :tag1:
:PROPERTIES: :PROPERTIES:
:ID: 111-111-111-1111-1111 :ID: 111-111-111-1111-1111
:END: :END:
\t\r \n \t\r \n
''' """
orgfile = create_file(tmp_path, entry_to_render) orgfile = create_file(tmp_path, entry_to_render)
expected_entry = f'''*** [#A] Heading1 :tag1: expected_entry = f"""*** [#A] Heading1 :tag1:
:PROPERTIES: :PROPERTIES:
:LINE: file:{orgfile}::2 :LINE: file:{orgfile}::2
:ID: id:111-111-111-1111-1111 :ID: id:111-111-111-1111-1111
:SOURCE: [[file:{orgfile}::*Heading1]] :SOURCE: [[file:{orgfile}::*Heading1]]
:END: :END:
''' """
# Act # Act
parsed_entries = orgnode.makelist(orgfile) parsed_entries = orgnode.makelist(orgfile)
# Assert # Assert
assert f'{parsed_entries[0]}' == expected_entry assert f"{parsed_entries[0]}" == expected_entry
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_all_links_to_entry_rendered(tmp_path): def test_all_links_to_entry_rendered(tmp_path):
"Ensure all links to entry rendered in property drawer from entry" "Ensure all links to entry rendered in property drawer from entry"
# Arrange # Arrange
entry = f''' entry = f"""
*** [#A] Heading :tag1: *** [#A] Heading :tag1:
:PROPERTIES: :PROPERTIES:
:ID: 123-456-789-4234-1231 :ID: 123-456-789-4234-1231
@ -127,7 +127,7 @@ def test_all_links_to_entry_rendered(tmp_path):
Body Line 1 Body Line 1
*** Heading2 *** Heading2
Body Line 2 Body Line 2
''' """
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -135,23 +135,23 @@ Body Line 2
# Assert # Assert
# SOURCE link rendered with Heading # SOURCE link rendered with Heading
assert f':SOURCE: [[file:{orgfile}::*{entries[0].heading}]]' in f'{entries[0]}' assert f":SOURCE: [[file:{orgfile}::*{entries[0].heading}]]" in f"{entries[0]}"
# ID link rendered with ID # ID link rendered with ID
assert f':ID: id:123-456-789-4234-1231' in f'{entries[0]}' assert f":ID: id:123-456-789-4234-1231" in f"{entries[0]}"
# LINE link rendered with line number # LINE link rendered with line number
assert f':LINE: file:{orgfile}::2' in f'{entries[0]}' assert f":LINE: file:{orgfile}::2" in f"{entries[0]}"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_source_link_to_entry_escaped_for_rendering(tmp_path): def test_source_link_to_entry_escaped_for_rendering(tmp_path):
"Test SOURCE link renders with square brackets in filename, heading escaped for org-mode rendering" "Test SOURCE link renders with square brackets in filename, heading escaped for org-mode rendering"
# Arrange # Arrange
entry = f''' entry = f"""
*** [#A] Heading[1] :tag1: *** [#A] Heading[1] :tag1:
:PROPERTIES: :PROPERTIES:
:ID: 123-456-789-4234-1231 :ID: 123-456-789-4234-1231
:END: :END:
Body Line 1''' Body Line 1"""
orgfile = create_file(tmp_path, entry, filename="test[1].org") orgfile = create_file(tmp_path, entry, filename="test[1].org")
# Act # Act
@ -162,15 +162,15 @@ Body Line 1'''
# parsed heading from entry # parsed heading from entry
assert entries[0].heading == "Heading[1]" assert entries[0].heading == "Heading[1]"
# ensure SOURCE link has square brackets in filename, heading escaped in rendered entries # ensure SOURCE link has square brackets in filename, heading escaped in rendered entries
escaped_orgfile = f'{orgfile}'.replace("[1]", "\\[1\\]") escaped_orgfile = f"{orgfile}".replace("[1]", "\\[1\\]")
assert f':SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]' in f'{entries[0]}' assert f":SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]" in f"{entries[0]}"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_parse_multiple_entries(tmp_path): def test_parse_multiple_entries(tmp_path):
"Test parsing of multiple entries" "Test parsing of multiple entries"
# Arrange # Arrange
content = f''' content = f"""
*** FAILED [#A] Heading1 :tag1: *** FAILED [#A] Heading1 :tag1:
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun> CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
:PROPERTIES: :PROPERTIES:
@ -193,7 +193,7 @@ CLOCK: [1984-04-02 Mon 09:00]--[1984-04-02 Mon 12:00] => 3:00
:END: :END:
Body 2 Body 2
''' """
orgfile = create_file(tmp_path, content) orgfile = create_file(tmp_path, content)
# Act # Act
@ -208,18 +208,20 @@ Body 2
assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n" assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n"
assert entry.priority == "A" assert entry.priority == "A"
assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}" assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}"
assert entry.closed == datetime.date(1984,4,index+1) assert entry.closed == datetime.date(1984, 4, index + 1)
assert entry.scheduled == datetime.date(1984,4,index+1) assert entry.scheduled == datetime.date(1984, 4, index + 1)
assert entry.deadline == datetime.date(1984,4,index+1) assert entry.deadline == datetime.date(1984, 4, index + 1)
assert entry.logbook == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))] assert entry.logbook == [
(datetime.datetime(1984, 4, index + 1, 9, 0, 0), datetime.datetime(1984, 4, index + 1, 12, 0, 0))
]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_parse_entry_with_empty_title(tmp_path): def test_parse_entry_with_empty_title(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f'''#+TITLE: entry = f"""#+TITLE:
Body Line 1''' Body Line 1"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -227,7 +229,7 @@ Body Line 1'''
# Assert # Assert
assert len(entries) == 1 assert len(entries) == 1
assert entries[0].heading == f'{orgfile}' assert entries[0].heading == f"{orgfile}"
assert entries[0].tags == list() assert entries[0].tags == list()
assert entries[0].body == "Body Line 1" assert entries[0].body == "Body Line 1"
assert entries[0].priority == "" assert entries[0].priority == ""
@ -241,8 +243,8 @@ Body Line 1'''
def test_parse_entry_with_title_and_no_headings(tmp_path): def test_parse_entry_with_title_and_no_headings(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f'''#+TITLE: test entry = f"""#+TITLE: test
Body Line 1''' Body Line 1"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -250,7 +252,7 @@ Body Line 1'''
# Assert # Assert
assert len(entries) == 1 assert len(entries) == 1
assert entries[0].heading == 'test' assert entries[0].heading == "test"
assert entries[0].tags == list() assert entries[0].tags == list()
assert entries[0].body == "Body Line 1" assert entries[0].body == "Body Line 1"
assert entries[0].priority == "" assert entries[0].priority == ""
@ -264,9 +266,9 @@ Body Line 1'''
def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path): def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f'''#+TITLE: title1 entry = f"""#+TITLE: title1
Body Line 1 Body Line 1
#+TITLE: title2 ''' #+TITLE: title2 """
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@ -274,7 +276,7 @@ Body Line 1
# Assert # Assert
assert len(entries) == 1 assert len(entries) == 1
assert entries[0].heading == 'title1 title2' assert entries[0].heading == "title1 title2"
assert entries[0].tags == list() assert entries[0].tags == list()
assert entries[0].body == "Body Line 1\n" assert entries[0].body == "Body Line 1\n"
assert entries[0].priority == "" assert entries[0].priority == ""

View file

@ -14,7 +14,9 @@ from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig): def test_asymmetric_setup_with_missing_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
):
# Arrange # Arrange
# Ensure file mentioned in org.input-files is missing # Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0]) single_new_file = Path(org_config_with_only_new_file.input_files[0])
@ -27,10 +29,12 @@ def test_asymmetric_setup_with_missing_file_raises_error(org_config_with_only_ne
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup_with_empty_file_raises_error(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig): def test_asymmetric_setup_with_empty_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
):
# Act # Act
# Generate notes embeddings during asymmetric setup # Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r'^No valid entries found*'): with pytest.raises(ValueError, match=r"^No valid entries found*"):
text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True) text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
@ -52,15 +56,9 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
query = "How to git install application?" query = "How to git install application?"
# Act # Act
hits, entries = text_search.query( hits, entries = text_search.query(query, model=model.notes_search, rank_results=True)
query,
model = model.notes_search,
rank_results=True)
results = text_search.collate_results( results = text_search.collate_results(hits, entries, count=1)
hits,
entries,
count=1)
# Assert # Assert
# Actual_data should contain "Khoj via Emacs" entry # Actual_data should contain "Khoj via Emacs" entry
@ -76,12 +74,14 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
with open(new_file_to_index, "w") as f: with open(new_file_to_index, "w") as f:
f.write(f"* Entry more than {max_tokens} words\n") f.write(f"* Entry more than {max_tokens} words\n")
for index in range(max_tokens+1): for index in range(max_tokens + 1):
f.write(f"{index} ") f.write(f"{index} ")
# Act # Act
# reload embeddings, entries, notes model after adding new org-mode file # reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False
)
# Assert # Assert
# verify newly added org-mode entry is split by max tokens # verify newly added org-mode entry is split by max tokens
@ -92,18 +92,20 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
# Arrange # Arrange
initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10
# append org-mode entry to first org input file in config # append org-mode entry to first org input file in config
content_config.org.input_files = [f'{new_org_file}'] content_config.org.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f: with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# regenerate notes jsonl, model embeddings and model to include entry from new file # regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True
)
# Act # Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
@ -137,7 +139,7 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act # Act
# update embeddings, entries with the newly added note # update embeddings, entries with the newly added note
content_config.org.input_files = [f'{new_org_file}'] content_config.org.input_files = [f"{new_org_file}"]
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
# Assert # Assert

View file

@ -7,7 +7,7 @@ def test_no_word_filter():
# Arrange # Arrange
word_filter = WordFilter() word_filter = WordFilter()
entries = arrange_content() entries = arrange_content()
q_with_no_filter = 'head tail' q_with_no_filter = "head tail"
# Act # Act
can_filter = word_filter.can_filter(q_with_no_filter) can_filter = word_filter.can_filter(q_with_no_filter)
@ -15,7 +15,7 @@ def test_no_word_filter():
# Assert # Assert
assert can_filter == False assert can_filter == False
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3} assert entry_indices == {0, 1, 2, 3}
@ -31,7 +31,7 @@ def test_word_exclude_filter():
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {0, 2} assert entry_indices == {0, 2}
@ -47,7 +47,7 @@ def test_word_include_filter():
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {2, 3} assert entry_indices == {2, 3}
@ -63,16 +63,16 @@ def test_word_include_and_exclude_filter():
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == "head tail"
assert entry_indices == {2} assert entry_indices == {2}
def arrange_content(): def arrange_content():
entries = [ entries = [
Entry(compiled='', raw='Minimal Entry'), Entry(compiled="", raw="Minimal Entry"),
Entry(compiled='', raw='Entry with exclude_word'), Entry(compiled="", raw="Entry with exclude_word"),
Entry(compiled='', raw='Entry with include_word'), Entry(compiled="", raw="Entry with include_word"),
Entry(compiled='', raw='Entry with include_word and exclude_word') Entry(compiled="", raw="Entry with include_word and exclude_word"),
] ]
return entries return entries