mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Use Black to format Khoj server code and tests
This commit is contained in:
parent
6130fddf45
commit
5e83baab21
44 changed files with 1167 additions and 915 deletions
|
@ -64,7 +64,8 @@ khoj = "khoj.main:run"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = [
|
test = [
|
||||||
"pytest == 7.1.2",
|
"pytest >= 7.1.2",
|
||||||
|
"black >= 23.1.0",
|
||||||
]
|
]
|
||||||
dev = ["khoj-assistant[test]"]
|
dev = ["khoj-assistant[test]"]
|
||||||
|
|
||||||
|
@ -88,3 +89,6 @@ exclude = [
|
||||||
"src/khoj/interface/desktop/file_browser.py",
|
"src/khoj/interface/desktop/file_browser.py",
|
||||||
"src/khoj/interface/desktop/system_tray.py",
|
"src/khoj/interface/desktop/system_tray.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
|
@ -26,10 +26,12 @@ logger = logging.getLogger(__name__)
|
||||||
def configure_server(args, required=False):
|
def configure_server(args, required=False):
|
||||||
if args.config is None:
|
if args.config is None:
|
||||||
if required:
|
if required:
|
||||||
logger.error(f'Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.')
|
logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
logger.warn(f'Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}.')
|
logger.warn(
|
||||||
|
f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
state.config = args.config
|
state.config = args.config
|
||||||
|
@ -60,7 +62,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
config.content_type.org,
|
config.content_type.org,
|
||||||
search_config=config.search_type.asymmetric,
|
search_config=config.search_type.asymmetric,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()])
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Org Music Search
|
# Initialize Org Music Search
|
||||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||||
|
@ -70,7 +73,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
config.content_type.music,
|
config.content_type.music,
|
||||||
search_config=config.search_type.asymmetric,
|
search_config=config.search_type.asymmetric,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter()])
|
filters=[DateFilter(), WordFilter()],
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||||
|
@ -80,7 +84,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
config.content_type.markdown,
|
config.content_type.markdown,
|
||||||
search_config=config.search_type.asymmetric,
|
search_config=config.search_type.asymmetric,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()])
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Ledger Search
|
# Initialize Ledger Search
|
||||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||||
|
@ -90,15 +95,15 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||||
config.content_type.ledger,
|
config.content_type.ledger,
|
||||||
search_config=config.search_type.symmetric,
|
search_config=config.search_type.symmetric,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()])
|
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||||
# Extract Entries, Generate Image Embeddings
|
# Extract Entries, Generate Image Embeddings
|
||||||
model.image_search = image_search.setup(
|
model.image_search = image_search.setup(
|
||||||
config.content_type.image,
|
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
||||||
search_config=config.search_type.image,
|
)
|
||||||
regenerate=regenerate)
|
|
||||||
|
|
||||||
# Invalidate Query Cache
|
# Invalidate Query Cache
|
||||||
state.query_cache = LRU()
|
state.query_cache = LRU()
|
||||||
|
@ -125,9 +130,9 @@ def configure_conversation_processor(conversation_processor_config):
|
||||||
|
|
||||||
if conversation_logfile.is_file():
|
if conversation_logfile.is_file():
|
||||||
# Load Metadata Logs from Conversation Logfile
|
# Load Metadata Logs from Conversation Logfile
|
||||||
with conversation_logfile.open('r') as f:
|
with conversation_logfile.open("r") as f:
|
||||||
conversation_processor.meta_log = json.load(f)
|
conversation_processor.meta_log = json.load(f)
|
||||||
logger.info('Conversation logs loaded from disk.')
|
logger.info("Conversation logs loaded from disk.")
|
||||||
else:
|
else:
|
||||||
# Initialize Conversation Logs
|
# Initialize Conversation Logs
|
||||||
conversation_processor.meta_log = {}
|
conversation_processor.meta_log = {}
|
||||||
|
|
|
@ -26,36 +26,39 @@ class FileBrowser(QtWidgets.QWidget):
|
||||||
self.lineEdit = QtWidgets.QPlainTextEdit(self)
|
self.lineEdit = QtWidgets.QPlainTextEdit(self)
|
||||||
self.lineEdit.setFixedWidth(330)
|
self.lineEdit.setFixedWidth(330)
|
||||||
self.setFiles(default_files)
|
self.setFiles(default_files)
|
||||||
self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90))
|
self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90))
|
||||||
self.lineEdit.textChanged.connect(self.updateFieldHeight)
|
self.lineEdit.textChanged.connect(self.updateFieldHeight)
|
||||||
layout.addWidget(self.lineEdit)
|
layout.addWidget(self.lineEdit)
|
||||||
|
|
||||||
self.button = QtWidgets.QPushButton('Add')
|
self.button = QtWidgets.QPushButton("Add")
|
||||||
self.button.clicked.connect(self.storeFilesSelectedInFileDialog)
|
self.button.clicked.connect(self.storeFilesSelectedInFileDialog)
|
||||||
layout.addWidget(self.button)
|
layout.addWidget(self.button)
|
||||||
layout.addStretch()
|
layout.addStretch()
|
||||||
|
|
||||||
def getFileFilter(self, search_type):
|
def getFileFilter(self, search_type):
|
||||||
if search_type == SearchType.Org:
|
if search_type == SearchType.Org:
|
||||||
return 'Org-Mode Files (*.org)'
|
return "Org-Mode Files (*.org)"
|
||||||
elif search_type == SearchType.Ledger:
|
elif search_type == SearchType.Ledger:
|
||||||
return 'Beancount Files (*.bean *.beancount)'
|
return "Beancount Files (*.bean *.beancount)"
|
||||||
elif search_type == SearchType.Markdown:
|
elif search_type == SearchType.Markdown:
|
||||||
return 'Markdown Files (*.md *.markdown)'
|
return "Markdown Files (*.md *.markdown)"
|
||||||
elif search_type == SearchType.Music:
|
elif search_type == SearchType.Music:
|
||||||
return 'Org-Music Files (*.org)'
|
return "Org-Music Files (*.org)"
|
||||||
elif search_type == SearchType.Image:
|
elif search_type == SearchType.Image:
|
||||||
return 'Images (*.jp[e]g)'
|
return "Images (*.jp[e]g)"
|
||||||
|
|
||||||
def storeFilesSelectedInFileDialog(self):
|
def storeFilesSelectedInFileDialog(self):
|
||||||
filepaths = self.getPaths()
|
filepaths = self.getPaths()
|
||||||
if self.search_type == SearchType.Image:
|
if self.search_type == SearchType.Image:
|
||||||
filepaths.append(QtWidgets.QFileDialog.getExistingDirectory(self, caption='Choose Folder',
|
filepaths.append(
|
||||||
directory=self.dirpath))
|
QtWidgets.QFileDialog.getExistingDirectory(self, caption="Choose Folder", directory=self.dirpath)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filepaths.extend(QtWidgets.QFileDialog.getOpenFileNames(self, caption='Choose Files',
|
filepaths.extend(
|
||||||
directory=self.dirpath,
|
QtWidgets.QFileDialog.getOpenFileNames(
|
||||||
filter=self.filter_name)[0])
|
self, caption="Choose Files", directory=self.dirpath, filter=self.filter_name
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
self.setFiles(filepaths)
|
self.setFiles(filepaths)
|
||||||
|
|
||||||
def setFiles(self, paths: list):
|
def setFiles(self, paths: list):
|
||||||
|
@ -63,10 +66,10 @@ class FileBrowser(QtWidgets.QWidget):
|
||||||
self.lineEdit.setPlainText("\n".join(self.filepaths))
|
self.lineEdit.setPlainText("\n".join(self.filepaths))
|
||||||
|
|
||||||
def getPaths(self) -> list:
|
def getPaths(self) -> list:
|
||||||
if self.lineEdit.toPlainText() == '':
|
if self.lineEdit.toPlainText() == "":
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
return self.lineEdit.toPlainText().split('\n')
|
return self.lineEdit.toPlainText().split("\n")
|
||||||
|
|
||||||
def updateFieldHeight(self):
|
def updateFieldHeight(self):
|
||||||
self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90))
|
self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90))
|
||||||
|
|
|
@ -31,9 +31,9 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
self.config_file = config_file
|
self.config_file = config_file
|
||||||
# Set regenerate flag to regenerate embeddings everytime user clicks configure
|
# Set regenerate flag to regenerate embeddings everytime user clicks configure
|
||||||
if state.cli_args:
|
if state.cli_args:
|
||||||
state.cli_args += ['--regenerate']
|
state.cli_args += ["--regenerate"]
|
||||||
else:
|
else:
|
||||||
state.cli_args = ['--regenerate']
|
state.cli_args = ["--regenerate"]
|
||||||
|
|
||||||
# Load config from existing config, if exists, else load from default config
|
# Load config from existing config, if exists, else load from default config
|
||||||
if resolve_absolute_path(self.config_file).exists():
|
if resolve_absolute_path(self.config_file).exists():
|
||||||
|
@ -49,8 +49,8 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
self.setFixedWidth(600)
|
self.setFixedWidth(600)
|
||||||
|
|
||||||
# Set Window Icon
|
# Set Window Icon
|
||||||
icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png'
|
icon_path = constants.web_directory / "assets/icons/favicon-144x144.png"
|
||||||
self.setWindowIcon(QtGui.QIcon(f'{icon_path.absolute()}'))
|
self.setWindowIcon(QtGui.QIcon(f"{icon_path.absolute()}"))
|
||||||
|
|
||||||
# Initialize Configure Window Layout
|
# Initialize Configure Window Layout
|
||||||
self.layout = QtWidgets.QVBoxLayout()
|
self.layout = QtWidgets.QVBoxLayout()
|
||||||
|
@ -58,13 +58,13 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
# Add Settings Panels for each Search Type to Configure Window Layout
|
# Add Settings Panels for each Search Type to Configure Window Layout
|
||||||
self.search_settings_panels = []
|
self.search_settings_panels = []
|
||||||
for search_type in SearchType:
|
for search_type in SearchType:
|
||||||
current_content_config = self.current_config['content-type'].get(search_type, {})
|
current_content_config = self.current_config["content-type"].get(search_type, {})
|
||||||
self.search_settings_panels += [self.add_settings_panel(current_content_config, search_type)]
|
self.search_settings_panels += [self.add_settings_panel(current_content_config, search_type)]
|
||||||
|
|
||||||
# Add Conversation Processor Panel to Configure Screen
|
# Add Conversation Processor Panel to Configure Screen
|
||||||
self.processor_settings_panels = []
|
self.processor_settings_panels = []
|
||||||
conversation_type = ProcessorType.Conversation
|
conversation_type = ProcessorType.Conversation
|
||||||
current_conversation_config = self.current_config['processor'].get(conversation_type, {})
|
current_conversation_config = self.current_config["processor"].get(conversation_type, {})
|
||||||
self.processor_settings_panels += [self.add_processor_panel(current_conversation_config, conversation_type)]
|
self.processor_settings_panels += [self.add_processor_panel(current_conversation_config, conversation_type)]
|
||||||
|
|
||||||
# Add Action Buttons Panel
|
# Add Action Buttons Panel
|
||||||
|
@ -81,11 +81,11 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
"Add Settings Panel for specified Search Type. Toggle Editable Search Types"
|
"Add Settings Panel for specified Search Type. Toggle Editable Search Types"
|
||||||
# Get current files from config for given search type
|
# Get current files from config for given search type
|
||||||
if search_type == SearchType.Image:
|
if search_type == SearchType.Image:
|
||||||
current_content_files = current_content_config.get('input-directories', [])
|
current_content_files = current_content_config.get("input-directories", [])
|
||||||
file_input_text = f'{search_type.name} Folders'
|
file_input_text = f"{search_type.name} Folders"
|
||||||
else:
|
else:
|
||||||
current_content_files = current_content_config.get('input-files', [])
|
current_content_files = current_content_config.get("input-files", [])
|
||||||
file_input_text = f'{search_type.name} Files'
|
file_input_text = f"{search_type.name} Files"
|
||||||
|
|
||||||
# Create widgets to display settings for given search type
|
# Create widgets to display settings for given search type
|
||||||
search_type_settings = QtWidgets.QWidget()
|
search_type_settings = QtWidgets.QWidget()
|
||||||
|
@ -109,7 +109,7 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
def add_processor_panel(self, current_conversation_config: dict, processor_type: ProcessorType):
|
def add_processor_panel(self, current_conversation_config: dict, processor_type: ProcessorType):
|
||||||
"Add Conversation Processor Panel"
|
"Add Conversation Processor Panel"
|
||||||
# Get current settings from config for given processor type
|
# Get current settings from config for given processor type
|
||||||
current_openai_api_key = current_conversation_config.get('openai-api-key', None)
|
current_openai_api_key = current_conversation_config.get("openai-api-key", None)
|
||||||
|
|
||||||
# Create widgets to display settings for given processor type
|
# Create widgets to display settings for given processor type
|
||||||
processor_type_settings = QtWidgets.QWidget()
|
processor_type_settings = QtWidgets.QWidget()
|
||||||
|
@ -137,7 +137,9 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
action_bar_layout = QtWidgets.QHBoxLayout(action_bar)
|
action_bar_layout = QtWidgets.QHBoxLayout(action_bar)
|
||||||
|
|
||||||
self.configure_button = QtWidgets.QPushButton("Configure", clicked=self.configure_app)
|
self.configure_button = QtWidgets.QPushButton("Configure", clicked=self.configure_app)
|
||||||
self.search_button = QtWidgets.QPushButton("Search", clicked=lambda: webbrowser.open(f'http://{state.host}:{state.port}/'))
|
self.search_button = QtWidgets.QPushButton(
|
||||||
|
"Search", clicked=lambda: webbrowser.open(f"http://{state.host}:{state.port}/")
|
||||||
|
)
|
||||||
self.search_button.setEnabled(not self.first_run)
|
self.search_button.setEnabled(not self.first_run)
|
||||||
|
|
||||||
action_bar_layout.addWidget(self.configure_button)
|
action_bar_layout.addWidget(self.configure_button)
|
||||||
|
@ -148,9 +150,9 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
"Get default config"
|
"Get default config"
|
||||||
config = constants.default_config
|
config = constants.default_config
|
||||||
if search_type:
|
if search_type:
|
||||||
return config['content-type'][search_type]
|
return config["content-type"][search_type]
|
||||||
elif processor_type:
|
elif processor_type:
|
||||||
return config['processor'][processor_type]
|
return config["processor"][processor_type]
|
||||||
else:
|
else:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
@ -160,7 +162,9 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
for message_prefix in ErrorType:
|
for message_prefix in ErrorType:
|
||||||
for i in reversed(range(self.layout.count())):
|
for i in reversed(range(self.layout.count())):
|
||||||
current_widget = self.layout.itemAt(i).widget()
|
current_widget = self.layout.itemAt(i).widget()
|
||||||
if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(message_prefix.value):
|
if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(
|
||||||
|
message_prefix.value
|
||||||
|
):
|
||||||
self.layout.removeWidget(current_widget)
|
self.layout.removeWidget(current_widget)
|
||||||
current_widget.deleteLater()
|
current_widget.deleteLater()
|
||||||
|
|
||||||
|
@ -180,18 +184,24 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
continue
|
continue
|
||||||
if isinstance(child, SearchCheckBox):
|
if isinstance(child, SearchCheckBox):
|
||||||
# Search Type Disabled
|
# Search Type Disabled
|
||||||
if not child.isChecked() and child.search_type in self.new_config['content-type']:
|
if not child.isChecked() and child.search_type in self.new_config["content-type"]:
|
||||||
del self.new_config['content-type'][child.search_type]
|
del self.new_config["content-type"][child.search_type]
|
||||||
# Search Type (re)-Enabled
|
# Search Type (re)-Enabled
|
||||||
if child.isChecked():
|
if child.isChecked():
|
||||||
current_search_config = self.current_config['content-type'].get(child.search_type, {})
|
current_search_config = self.current_config["content-type"].get(child.search_type, {})
|
||||||
default_search_config = self.get_default_config(search_type=child.search_type)
|
default_search_config = self.get_default_config(search_type=child.search_type)
|
||||||
self.new_config['content-type'][child.search_type.value] = merge_dicts(current_search_config, default_search_config)
|
self.new_config["content-type"][child.search_type.value] = merge_dicts(
|
||||||
elif isinstance(child, FileBrowser) and child.search_type in self.new_config['content-type']:
|
current_search_config, default_search_config
|
||||||
|
)
|
||||||
|
elif isinstance(child, FileBrowser) and child.search_type in self.new_config["content-type"]:
|
||||||
if child.search_type.value == SearchType.Image:
|
if child.search_type.value == SearchType.Image:
|
||||||
self.new_config['content-type'][child.search_type.value]['input-directories'] = child.getPaths() if child.getPaths() != [] else None
|
self.new_config["content-type"][child.search_type.value]["input-directories"] = (
|
||||||
|
child.getPaths() if child.getPaths() != [] else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.new_config['content-type'][child.search_type.value]['input-files'] = child.getPaths() if child.getPaths() != [] else None
|
self.new_config["content-type"][child.search_type.value]["input-files"] = (
|
||||||
|
child.getPaths() if child.getPaths() != [] else None
|
||||||
|
)
|
||||||
|
|
||||||
def update_processor_settings(self):
|
def update_processor_settings(self):
|
||||||
"Update config with conversation settings from UI"
|
"Update config with conversation settings from UI"
|
||||||
|
@ -201,16 +211,20 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||||
continue
|
continue
|
||||||
if isinstance(child, ProcessorCheckBox):
|
if isinstance(child, ProcessorCheckBox):
|
||||||
# Processor Type Disabled
|
# Processor Type Disabled
|
||||||
if not child.isChecked() and child.processor_type in self.new_config['processor']:
|
if not child.isChecked() and child.processor_type in self.new_config["processor"]:
|
||||||
del self.new_config['processor'][child.processor_type]
|
del self.new_config["processor"][child.processor_type]
|
||||||
# Processor Type (re)-Enabled
|
# Processor Type (re)-Enabled
|
||||||
if child.isChecked():
|
if child.isChecked():
|
||||||
current_processor_config = self.current_config['processor'].get(child.processor_type, {})
|
current_processor_config = self.current_config["processor"].get(child.processor_type, {})
|
||||||
default_processor_config = self.get_default_config(processor_type=child.processor_type)
|
default_processor_config = self.get_default_config(processor_type=child.processor_type)
|
||||||
self.new_config['processor'][child.processor_type.value] = merge_dicts(current_processor_config, default_processor_config)
|
self.new_config["processor"][child.processor_type.value] = merge_dicts(
|
||||||
elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config['processor']:
|
current_processor_config, default_processor_config
|
||||||
|
)
|
||||||
|
elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config["processor"]:
|
||||||
if child.processor_type == ProcessorType.Conversation:
|
if child.processor_type == ProcessorType.Conversation:
|
||||||
self.new_config['processor'][child.processor_type.value]['openai-api-key'] = child.input_field.toPlainText() if child.input_field.toPlainText() != '' else None
|
self.new_config["processor"][child.processor_type.value]["openai-api-key"] = (
|
||||||
|
child.input_field.toPlainText() if child.input_field.toPlainText() != "" else None
|
||||||
|
)
|
||||||
|
|
||||||
def save_settings_to_file(self) -> bool:
|
def save_settings_to_file(self) -> bool:
|
||||||
"Save validated settings to file"
|
"Save validated settings to file"
|
||||||
|
@ -312,6 +326,7 @@ class ProcessorCheckBox(QtWidgets.QCheckBox):
|
||||||
self.processor_type = processor_type
|
self.processor_type = processor_type
|
||||||
super(ProcessorCheckBox, self).__init__(text, parent=parent)
|
super(ProcessorCheckBox, self).__init__(text, parent=parent)
|
||||||
|
|
||||||
|
|
||||||
class ErrorType(Enum):
|
class ErrorType(Enum):
|
||||||
"Error Types"
|
"Error Types"
|
||||||
ConfigLoadingError = "Config Loading Error"
|
ConfigLoadingError = "Config Loading Error"
|
||||||
|
|
|
@ -17,17 +17,17 @@ def create_system_tray(gui: QtWidgets.QApplication, main_window: MainWindow):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create the system tray with icon
|
# Create the system tray with icon
|
||||||
icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png'
|
icon_path = constants.web_directory / "assets/icons/favicon-144x144.png"
|
||||||
icon = QtGui.QIcon(f'{icon_path.absolute()}')
|
icon = QtGui.QIcon(f"{icon_path.absolute()}")
|
||||||
tray = QtWidgets.QSystemTrayIcon(icon)
|
tray = QtWidgets.QSystemTrayIcon(icon)
|
||||||
tray.setVisible(True)
|
tray.setVisible(True)
|
||||||
|
|
||||||
# Create the menu and menu actions
|
# Create the menu and menu actions
|
||||||
menu = QtWidgets.QMenu()
|
menu = QtWidgets.QMenu()
|
||||||
menu_actions = [
|
menu_actions = [
|
||||||
('Search', lambda: webbrowser.open(f'http://{state.host}:{state.port}/')),
|
("Search", lambda: webbrowser.open(f"http://{state.host}:{state.port}/")),
|
||||||
('Configure', main_window.show_on_top),
|
("Configure", main_window.show_on_top),
|
||||||
('Quit', gui.quit),
|
("Quit", gui.quit),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add the menu actions to the menu
|
# Add the menu actions to the menu
|
||||||
|
|
|
@ -8,8 +8,8 @@ import warnings
|
||||||
from platform import system
|
from platform import system
|
||||||
|
|
||||||
# Ignore non-actionable warnings
|
# Ignore non-actionable warnings
|
||||||
warnings.filterwarnings("ignore", message=r'snapshot_download.py has been made private', category=FutureWarning)
|
warnings.filterwarnings("ignore", message=r"snapshot_download.py has been made private", category=FutureWarning)
|
||||||
warnings.filterwarnings("ignore", message=r'legacy way to download files from the HF hub,', category=FutureWarning)
|
warnings.filterwarnings("ignore", message=r"legacy way to download files from the HF hub,", category=FutureWarning)
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
@ -43,11 +43,12 @@ rich_handler = RichHandler(rich_tracebacks=True)
|
||||||
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
|
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
|
||||||
logging.basicConfig(handlers=[rich_handler])
|
logging.basicConfig(handlers=[rich_handler])
|
||||||
|
|
||||||
logger = logging.getLogger('khoj')
|
logger = logging.getLogger("khoj")
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
# Turn Tokenizers Parallelism Off. App does not support it.
|
# Turn Tokenizers Parallelism Off. App does not support it.
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = 'false'
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
# Load config from CLI
|
# Load config from CLI
|
||||||
state.cli_args = sys.argv[1:]
|
state.cli_args = sys.argv[1:]
|
||||||
|
@ -66,7 +67,7 @@ def run():
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
# Set Log File
|
# Set Log File
|
||||||
fh = logging.FileHandler(state.config_file.parent / 'khoj.log')
|
fh = logging.FileHandler(state.config_file.parent / "khoj.log")
|
||||||
fh.setLevel(logging.DEBUG)
|
fh.setLevel(logging.DEBUG)
|
||||||
logger.addHandler(fh)
|
logger.addHandler(fh)
|
||||||
|
|
||||||
|
@ -87,7 +88,7 @@ def run():
|
||||||
# On Linux (Gnome) the System tray is not supported.
|
# On Linux (Gnome) the System tray is not supported.
|
||||||
# Since only the Main Window is available
|
# Since only the Main Window is available
|
||||||
# Quitting it should quit the application
|
# Quitting it should quit the application
|
||||||
if system() in ['Windows', 'Darwin']:
|
if system() in ["Windows", "Darwin"]:
|
||||||
gui.setQuitOnLastWindowClosed(False)
|
gui.setQuitOnLastWindowClosed(False)
|
||||||
tray = create_system_tray(gui, main_window)
|
tray = create_system_tray(gui, main_window)
|
||||||
tray.show()
|
tray.show()
|
||||||
|
@ -97,7 +98,7 @@ def run():
|
||||||
server = ServerThread(app, args.host, args.port, args.socket)
|
server = ServerThread(app, args.host, args.port, args.socket)
|
||||||
|
|
||||||
# Show Main Window on First Run Experience or if on Linux
|
# Show Main Window on First Run Experience or if on Linux
|
||||||
if args.config is None or system() not in ['Windows', 'Darwin']:
|
if args.config is None or system() not in ["Windows", "Darwin"]:
|
||||||
main_window.show()
|
main_window.show()
|
||||||
|
|
||||||
# Setup Signal Handlers
|
# Setup Signal Handlers
|
||||||
|
@ -112,9 +113,10 @@ def run():
|
||||||
gui.aboutToQuit.connect(server.terminate)
|
gui.aboutToQuit.connect(server.terminate)
|
||||||
|
|
||||||
# Close Splash Screen if still open
|
# Close Splash Screen if still open
|
||||||
if system() != 'Darwin':
|
if system() != "Darwin":
|
||||||
try:
|
try:
|
||||||
import pyi_splash
|
import pyi_splash
|
||||||
|
|
||||||
# Update the text on the splash screen
|
# Update the text on the splash screen
|
||||||
pyi_splash.update_text("Khoj setup complete")
|
pyi_splash.update_text("Khoj setup complete")
|
||||||
# Close Splash Screen
|
# Close Splash Screen
|
||||||
|
@ -167,5 +169,5 @@ class ServerThread(QThread):
|
||||||
start_server(self.app, self.host, self.port, self.socket)
|
start_server(self.app, self.host, self.port, self.socket)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
run()
|
run()
|
||||||
|
|
|
@ -19,31 +19,27 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
|
||||||
|
|
||||||
# Setup Prompt based on Summary Type
|
# Setup Prompt based on Summary Type
|
||||||
if summary_type == "chat":
|
if summary_type == "chat":
|
||||||
prompt = f'''
|
prompt = f"""
|
||||||
You are an AI. Summarize the conversation below from your perspective:
|
You are an AI. Summarize the conversation below from your perspective:
|
||||||
|
|
||||||
{text}
|
{text}
|
||||||
|
|
||||||
Summarize the conversation from the AI's first-person perspective:'''
|
Summarize the conversation from the AI's first-person perspective:"""
|
||||||
elif summary_type == "notes":
|
elif summary_type == "notes":
|
||||||
prompt = f'''
|
prompt = f"""
|
||||||
Summarize the below notes about {user_query}:
|
Summarize the below notes about {user_query}:
|
||||||
|
|
||||||
{text}
|
{text}
|
||||||
|
|
||||||
Summarize the notes in second person perspective:'''
|
Summarize the notes in second person perspective:"""
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
response = openai.Completion.create(
|
response = openai.Completion.create(
|
||||||
prompt=prompt,
|
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
|
||||||
model=model,
|
)
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
frequency_penalty=0.2,
|
|
||||||
stop="\"\"\"")
|
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
story = response['choices'][0]['text']
|
story = response["choices"][0]["text"]
|
||||||
return str(story).replace("\n\n", "")
|
return str(story).replace("\n\n", "")
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +49,7 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
understand_primer = '''
|
understand_primer = """
|
||||||
Objective: Extract search type from user query and return information as JSON
|
Objective: Extract search type from user query and return information as JSON
|
||||||
|
|
||||||
Allowed search types are listed below:
|
Allowed search types are listed below:
|
||||||
|
@ -73,7 +69,7 @@ A:{ "search-type": "notes" }
|
||||||
Q: When did I buy Groceries last?
|
Q: When did I buy Groceries last?
|
||||||
A:{ "search-type": "ledger" }
|
A:{ "search-type": "ledger" }
|
||||||
Q:When did I go surfing last?
|
Q:When did I go surfing last?
|
||||||
A:{ "search-type": "notes" }'''
|
A:{ "search-type": "notes" }"""
|
||||||
|
|
||||||
# Setup Prompt with Understand Primer
|
# Setup Prompt with Understand Primer
|
||||||
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
|
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
|
||||||
|
@ -82,15 +78,11 @@ A:{ "search-type": "notes" }'''
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
response = openai.Completion.create(
|
response = openai.Completion.create(
|
||||||
prompt=prompt,
|
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
|
||||||
model=model,
|
)
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
frequency_penalty=0.2,
|
|
||||||
stop=["\n"])
|
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
story = str(response['choices'][0]['text'])
|
story = str(response["choices"][0]["text"])
|
||||||
return json.loads(story.strip(empty_escape_sequences))
|
return json.loads(story.strip(empty_escape_sequences))
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,7 +92,7 @@ def understand(text, model, api_key=None, temperature=0.5, max_tokens=100, verbo
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
understand_primer = '''
|
understand_primer = """
|
||||||
Objective: Extract intent and trigger emotion information as JSON from each chat message
|
Objective: Extract intent and trigger emotion information as JSON from each chat message
|
||||||
|
|
||||||
Potential intent types and valid argument values are listed below:
|
Potential intent types and valid argument values are listed below:
|
||||||
|
@ -142,7 +134,7 @@ A: { "intent": {"type": "remember", "memory-type": "notes", "query": "recommend
|
||||||
Q: When did I go surfing last?
|
Q: When did I go surfing last?
|
||||||
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" }
|
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" }
|
||||||
Q: Can you dance for me?
|
Q: Can you dance for me?
|
||||||
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }'''
|
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }"""
|
||||||
|
|
||||||
# Setup Prompt with Understand Primer
|
# Setup Prompt with Understand Primer
|
||||||
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
|
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
|
||||||
|
@ -151,15 +143,11 @@ A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
response = openai.Completion.create(
|
response = openai.Completion.create(
|
||||||
prompt=prompt,
|
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
|
||||||
model=model,
|
)
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
frequency_penalty=0.2,
|
|
||||||
stop=["\n"])
|
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
story = str(response['choices'][0]['text'])
|
story = str(response["choices"][0]["text"])
|
||||||
return json.loads(story.strip(empty_escape_sequences))
|
return json.loads(story.strip(empty_escape_sequences))
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,15 +159,15 @@ def converse(text, model, conversation_history=None, api_key=None, temperature=0
|
||||||
max_words = 500
|
max_words = 500
|
||||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
conversation_primer = f'''
|
conversation_primer = f"""
|
||||||
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion.
|
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion.
|
||||||
|
|
||||||
Human: Hello, who are you?
|
Human: Hello, who are you?
|
||||||
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?'''
|
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?"""
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
prompt = message_to_prompt(text, conversation_history or conversation_primer)
|
prompt = message_to_prompt(text, conversation_history or conversation_primer)
|
||||||
prompt = ' '.join(prompt.split()[:max_words])
|
prompt = " ".join(prompt.split()[:max_words])
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
response = openai.Completion.create(
|
response = openai.Completion.create(
|
||||||
|
@ -188,14 +176,17 @@ AI: Hi, I am an AI conversational companion created by OpenAI. How can I help yo
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=0.6,
|
presence_penalty=0.6,
|
||||||
stop=["\n", "Human:", "AI:"])
|
stop=["\n", "Human:", "AI:"],
|
||||||
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
story = str(response['choices'][0]['text'])
|
story = str(response["choices"][0]["text"])
|
||||||
return story.strip(empty_escape_sequences)
|
return story.strip(empty_escape_sequences)
|
||||||
|
|
||||||
|
|
||||||
def message_to_prompt(user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"):
|
def message_to_prompt(
|
||||||
|
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
|
||||||
|
):
|
||||||
"""Create prompt for GPT from messages and conversation history"""
|
"""Create prompt for GPT from messages and conversation history"""
|
||||||
gpt_message = f" {gpt_message}" if gpt_message else ""
|
gpt_message = f" {gpt_message}" if gpt_message else ""
|
||||||
|
|
||||||
|
@ -205,12 +196,8 @@ def message_to_prompt(user_message, conversation_history="", gpt_message=None, s
|
||||||
def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]):
|
def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]):
|
||||||
"""Create json logs from messages, metadata for conversation log"""
|
"""Create json logs from messages, metadata for conversation log"""
|
||||||
default_user_message_metadata = {
|
default_user_message_metadata = {
|
||||||
"intent": {
|
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
|
||||||
"type": "remember",
|
"trigger-emotion": "calm",
|
||||||
"memory-type": "notes",
|
|
||||||
"query": user_message
|
|
||||||
},
|
|
||||||
"trigger-emotion": "calm"
|
|
||||||
}
|
}
|
||||||
current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
@ -229,5 +216,4 @@ def message_to_log(user_message, gpt_message, user_message_metadata={}, conversa
|
||||||
|
|
||||||
def extract_summaries(metadata):
|
def extract_summaries(metadata):
|
||||||
"""Extract summaries from metadata"""
|
"""Extract summaries from metadata"""
|
||||||
return ''.join(
|
return "".join([f'\n{session["summary"]}' for session in metadata])
|
||||||
[f'\n{session["summary"]}' for session in metadata])
|
|
||||||
|
|
|
@ -19,7 +19,11 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=None):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
beancount_files, beancount_file_filter, output_file = self.config.input_files, self.config.input_filter,self.config.compressed_jsonl
|
beancount_files, beancount_file_filter, output_file = (
|
||||||
|
self.config.input_files,
|
||||||
|
self.config.input_filter,
|
||||||
|
self.config.compressed_jsonl,
|
||||||
|
)
|
||||||
|
|
||||||
# Input Validation
|
# Input Validation
|
||||||
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
|
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
|
||||||
|
@ -31,7 +35,9 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Extract Entries from specified Beancount files
|
# Extract Entries from specified Beancount files
|
||||||
with timer("Parse transactions from Beancount files into dictionaries", logger):
|
with timer("Parse transactions from Beancount files into dictionaries", logger):
|
||||||
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
|
current_entries = BeancountToJsonl.convert_transactions_to_maps(
|
||||||
|
*BeancountToJsonl.extract_beancount_transactions(beancount_files)
|
||||||
|
)
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
|
@ -42,7 +48,9 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
entries_with_ids = list(enumerate(current_entries))
|
||||||
else:
|
else:
|
||||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
entries_with_ids = self.mark_entries_for_update(
|
||||||
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
with timer("Write transactions to JSONL file", logger):
|
with timer("Write transactions to JSONL file", logger):
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
|
@ -62,9 +70,7 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
"Get Beancount files to process"
|
"Get Beancount files to process"
|
||||||
absolute_beancount_files, filtered_beancount_files = set(), set()
|
absolute_beancount_files, filtered_beancount_files = set(), set()
|
||||||
if beancount_files:
|
if beancount_files:
|
||||||
absolute_beancount_files = {get_absolute_path(beancount_file)
|
absolute_beancount_files = {get_absolute_path(beancount_file) for beancount_file in beancount_files}
|
||||||
for beancount_file
|
|
||||||
in beancount_files}
|
|
||||||
if beancount_file_filters:
|
if beancount_file_filters:
|
||||||
filtered_beancount_files = {
|
filtered_beancount_files = {
|
||||||
filtered_file
|
filtered_file
|
||||||
|
@ -76,14 +82,13 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
|
|
||||||
files_with_non_beancount_extensions = {
|
files_with_non_beancount_extensions = {
|
||||||
beancount_file
|
beancount_file
|
||||||
for beancount_file
|
for beancount_file in all_beancount_files
|
||||||
in all_beancount_files
|
|
||||||
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
|
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
|
||||||
}
|
}
|
||||||
if any(files_with_non_beancount_extensions):
|
if any(files_with_non_beancount_extensions):
|
||||||
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
|
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
|
||||||
|
|
||||||
logger.info(f'Processing files: {all_beancount_files}')
|
logger.info(f"Processing files: {all_beancount_files}")
|
||||||
|
|
||||||
return all_beancount_files
|
return all_beancount_files
|
||||||
|
|
||||||
|
@ -92,18 +97,19 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
"Extract entries from specified Beancount files"
|
"Extract entries from specified Beancount files"
|
||||||
|
|
||||||
# Initialize Regex for extracting Beancount Entries
|
# Initialize Regex for extracting Beancount Entries
|
||||||
transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] '
|
transaction_regex = r"^\n?\d{4}-\d{2}-\d{2} [\*|\!] "
|
||||||
empty_newline = f'^[\n\r\t\ ]*$'
|
empty_newline = f"^[\n\r\t\ ]*$"
|
||||||
|
|
||||||
entries = []
|
entries = []
|
||||||
transaction_to_file_map = []
|
transaction_to_file_map = []
|
||||||
for beancount_file in beancount_files:
|
for beancount_file in beancount_files:
|
||||||
with open(beancount_file) as f:
|
with open(beancount_file) as f:
|
||||||
ledger_content = f.read()
|
ledger_content = f.read()
|
||||||
transactions_per_file = [entry.strip(empty_escape_sequences)
|
transactions_per_file = [
|
||||||
for entry
|
entry.strip(empty_escape_sequences)
|
||||||
in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
|
for entry in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
|
||||||
if re.match(transaction_regex, entry)]
|
if re.match(transaction_regex, entry)
|
||||||
|
]
|
||||||
transaction_to_file_map += zip(transactions_per_file, [beancount_file] * len(transactions_per_file))
|
transaction_to_file_map += zip(transactions_per_file, [beancount_file] * len(transactions_per_file))
|
||||||
entries.extend(transactions_per_file)
|
entries.extend(transactions_per_file)
|
||||||
return entries, dict(transaction_to_file_map)
|
return entries, dict(transaction_to_file_map)
|
||||||
|
@ -113,7 +119,9 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
"Convert each parsed Beancount transaction into a Entry"
|
"Convert each parsed Beancount transaction into a Entry"
|
||||||
entries = []
|
entries = []
|
||||||
for parsed_entry in parsed_entries:
|
for parsed_entry in parsed_entries:
|
||||||
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}'))
|
entries.append(
|
||||||
|
Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{transaction_to_file_map[parsed_entry]}")
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
|
logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
|
||||||
|
|
||||||
|
@ -122,4 +130,4 @@ class BeancountToJsonl(TextToJsonl):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str:
|
def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str:
|
||||||
"Convert each Beancount transaction entry to JSON and collate as JSONL"
|
"Convert each Beancount transaction entry to JSON and collate as JSONL"
|
||||||
return ''.join([f'{entry.to_json()}\n' for entry in entries])
|
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||||
|
|
|
@ -20,7 +20,11 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=None):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
markdown_files, markdown_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
|
markdown_files, markdown_file_filter, output_file = (
|
||||||
|
self.config.input_files,
|
||||||
|
self.config.input_filter,
|
||||||
|
self.config.compressed_jsonl,
|
||||||
|
)
|
||||||
|
|
||||||
# Input Validation
|
# Input Validation
|
||||||
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
|
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
|
||||||
|
@ -32,7 +36,9 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
with timer("Parse entries from Markdown files into dictionaries", logger):
|
with timer("Parse entries from Markdown files into dictionaries", logger):
|
||||||
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
|
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(
|
||||||
|
*MarkdownToJsonl.extract_markdown_entries(markdown_files)
|
||||||
|
)
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
|
@ -43,7 +49,9 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
entries_with_ids = list(enumerate(current_entries))
|
||||||
else:
|
else:
|
||||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
entries_with_ids = self.mark_entries_for_update(
|
||||||
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
with timer("Write markdown entries to JSONL file", logger):
|
with timer("Write markdown entries to JSONL file", logger):
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
|
@ -75,15 +83,16 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
|
|
||||||
files_with_non_markdown_extensions = {
|
files_with_non_markdown_extensions = {
|
||||||
md_file
|
md_file
|
||||||
for md_file
|
for md_file in all_markdown_files
|
||||||
in all_markdown_files
|
if not md_file.endswith(".md") and not md_file.endswith(".markdown")
|
||||||
if not md_file.endswith(".md") and not md_file.endswith('.markdown')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if any(files_with_non_markdown_extensions):
|
if any(files_with_non_markdown_extensions):
|
||||||
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
|
logger.warn(
|
||||||
|
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f'Processing files: {all_markdown_files}')
|
logger.info(f"Processing files: {all_markdown_files}")
|
||||||
|
|
||||||
return all_markdown_files
|
return all_markdown_files
|
||||||
|
|
||||||
|
@ -92,18 +101,18 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
"Extract entries by heading from specified Markdown files"
|
"Extract entries by heading from specified Markdown files"
|
||||||
|
|
||||||
# Regex to extract Markdown Entries by Heading
|
# Regex to extract Markdown Entries by Heading
|
||||||
markdown_heading_regex = r'^#'
|
markdown_heading_regex = r"^#"
|
||||||
|
|
||||||
entries = []
|
entries = []
|
||||||
entry_to_file_map = []
|
entry_to_file_map = []
|
||||||
for markdown_file in markdown_files:
|
for markdown_file in markdown_files:
|
||||||
with open(markdown_file, 'r', encoding='utf8') as f:
|
with open(markdown_file, "r", encoding="utf8") as f:
|
||||||
markdown_content = f.read()
|
markdown_content = f.read()
|
||||||
markdown_entries_per_file = []
|
markdown_entries_per_file = []
|
||||||
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
|
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
|
||||||
prefix = '#' if entry.startswith('#') else '# '
|
prefix = "#" if entry.startswith("#") else "# "
|
||||||
if entry.strip(empty_escape_sequences) != '':
|
if entry.strip(empty_escape_sequences) != "":
|
||||||
markdown_entries_per_file.append(f'{prefix}{entry.strip(empty_escape_sequences)}')
|
markdown_entries_per_file.append(f"{prefix}{entry.strip(empty_escape_sequences)}")
|
||||||
|
|
||||||
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
|
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
|
||||||
entries.extend(markdown_entries_per_file)
|
entries.extend(markdown_entries_per_file)
|
||||||
|
@ -115,7 +124,7 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
"Convert each Markdown entries into a dictionary"
|
"Convert each Markdown entries into a dictionary"
|
||||||
entries = []
|
entries = []
|
||||||
for parsed_entry in parsed_entries:
|
for parsed_entry in parsed_entries:
|
||||||
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}'))
|
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{entry_to_file_map[parsed_entry]}"))
|
||||||
|
|
||||||
logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
|
logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
|
||||||
|
|
||||||
|
@ -124,4 +133,4 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
|
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
|
||||||
"Convert each Markdown entry to JSON and collate as JSONL"
|
"Convert each Markdown entry to JSON and collate as JSONL"
|
||||||
return ''.join([f'{entry.to_json()}\n' for entry in entries])
|
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||||
|
|
|
@ -20,7 +20,11 @@ class OrgToJsonl(TextToJsonl):
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries: List[Entry] = None):
|
def process(self, previous_entries: List[Entry] = None):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
|
org_files, org_file_filter, output_file = (
|
||||||
|
self.config.input_files,
|
||||||
|
self.config.input_filter,
|
||||||
|
self.config.compressed_jsonl,
|
||||||
|
)
|
||||||
index_heading_entries = self.config.index_heading_entries
|
index_heading_entries = self.config.index_heading_entries
|
||||||
|
|
||||||
# Input Validation
|
# Input Validation
|
||||||
|
@ -46,7 +50,9 @@ class OrgToJsonl(TextToJsonl):
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
entries_with_ids = list(enumerate(current_entries))
|
||||||
else:
|
else:
|
||||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
entries_with_ids = self.mark_entries_for_update(
|
||||||
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
with timer("Write org entries to JSONL file", logger):
|
with timer("Write org entries to JSONL file", logger):
|
||||||
|
@ -66,11 +72,7 @@ class OrgToJsonl(TextToJsonl):
|
||||||
"Get Org files to process"
|
"Get Org files to process"
|
||||||
absolute_org_files, filtered_org_files = set(), set()
|
absolute_org_files, filtered_org_files = set(), set()
|
||||||
if org_files:
|
if org_files:
|
||||||
absolute_org_files = {
|
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
||||||
get_absolute_path(org_file)
|
|
||||||
for org_file
|
|
||||||
in org_files
|
|
||||||
}
|
|
||||||
if org_file_filters:
|
if org_file_filters:
|
||||||
filtered_org_files = {
|
filtered_org_files = {
|
||||||
filtered_file
|
filtered_file
|
||||||
|
@ -84,7 +86,7 @@ class OrgToJsonl(TextToJsonl):
|
||||||
if any(files_with_non_org_extensions):
|
if any(files_with_non_org_extensions):
|
||||||
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||||
|
|
||||||
logger.info(f'Processing files: {all_org_files}')
|
logger.info(f"Processing files: {all_org_files}")
|
||||||
|
|
||||||
return all_org_files
|
return all_org_files
|
||||||
|
|
||||||
|
@ -101,7 +103,9 @@ class OrgToJsonl(TextToJsonl):
|
||||||
return entries, dict(entry_to_file_map)
|
return entries, dict(entry_to_file_map)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_org_nodes_to_entries(parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> List[Entry]:
|
def convert_org_nodes_to_entries(
|
||||||
|
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
|
||||||
|
) -> List[Entry]:
|
||||||
"Convert Org-Mode nodes into list of Entry objects"
|
"Convert Org-Mode nodes into list of Entry objects"
|
||||||
entries: List[Entry] = []
|
entries: List[Entry] = []
|
||||||
for parsed_entry in parsed_entries:
|
for parsed_entry in parsed_entries:
|
||||||
|
@ -109,13 +113,13 @@ class OrgToJsonl(TextToJsonl):
|
||||||
# Ignore title notes i.e notes with just headings and empty body
|
# Ignore title notes i.e notes with just headings and empty body
|
||||||
continue
|
continue
|
||||||
|
|
||||||
compiled = f'{parsed_entry.heading}.'
|
compiled = f"{parsed_entry.heading}."
|
||||||
if state.verbose > 2:
|
if state.verbose > 2:
|
||||||
logger.debug(f"Title: {parsed_entry.heading}")
|
logger.debug(f"Title: {parsed_entry.heading}")
|
||||||
|
|
||||||
if parsed_entry.tags:
|
if parsed_entry.tags:
|
||||||
tags_str = " ".join(parsed_entry.tags)
|
tags_str = " ".join(parsed_entry.tags)
|
||||||
compiled += f'\t {tags_str}.'
|
compiled += f"\t {tags_str}."
|
||||||
if state.verbose > 2:
|
if state.verbose > 2:
|
||||||
logger.debug(f"Tags: {tags_str}")
|
logger.debug(f"Tags: {tags_str}")
|
||||||
|
|
||||||
|
@ -130,19 +134,16 @@ class OrgToJsonl(TextToJsonl):
|
||||||
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
|
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
|
||||||
|
|
||||||
if parsed_entry.hasBody:
|
if parsed_entry.hasBody:
|
||||||
compiled += f'\n {parsed_entry.body}'
|
compiled += f"\n {parsed_entry.body}"
|
||||||
if state.verbose > 2:
|
if state.verbose > 2:
|
||||||
logger.debug(f"Body: {parsed_entry.body}")
|
logger.debug(f"Body: {parsed_entry.body}")
|
||||||
|
|
||||||
if compiled:
|
if compiled:
|
||||||
entries += [Entry(
|
entries += [Entry(compiled=compiled, raw=f"{parsed_entry}", file=f"{entry_to_file_map[parsed_entry]}")]
|
||||||
compiled=compiled,
|
|
||||||
raw=f'{parsed_entry}',
|
|
||||||
file=f'{entry_to_file_map[parsed_entry]}')]
|
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
|
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
|
||||||
"Convert each Org-Mode entry to JSON and collate as JSONL"
|
"Convert each Org-Mode entry to JSON and collate as JSONL"
|
||||||
return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries])
|
return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries])
|
||||||
|
|
|
@ -39,18 +39,20 @@ from pathlib import Path
|
||||||
from os.path import relpath
|
from os.path import relpath
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
indent_regex = re.compile(r'^ *')
|
indent_regex = re.compile(r"^ *")
|
||||||
|
|
||||||
|
|
||||||
def normalize_filename(filename):
|
def normalize_filename(filename):
|
||||||
"Normalize and escape filename for rendering"
|
"Normalize and escape filename for rendering"
|
||||||
if not Path(filename).is_absolute():
|
if not Path(filename).is_absolute():
|
||||||
# Normalize relative filename to be relative to current directory
|
# Normalize relative filename to be relative to current directory
|
||||||
normalized_filename = f'~/{relpath(filename, start=Path.home())}'
|
normalized_filename = f"~/{relpath(filename, start=Path.home())}"
|
||||||
else:
|
else:
|
||||||
normalized_filename = filename
|
normalized_filename = filename
|
||||||
escaped_filename = f'{normalized_filename}'.replace("[","\[").replace("]","\]")
|
escaped_filename = f"{normalized_filename}".replace("[", "\[").replace("]", "\]")
|
||||||
return escaped_filename
|
return escaped_filename
|
||||||
|
|
||||||
|
|
||||||
def makelist(filename):
|
def makelist(filename):
|
||||||
"""
|
"""
|
||||||
Read an org-mode file and return a list of Orgnode objects
|
Read an org-mode file and return a list of Orgnode objects
|
||||||
|
@ -58,124 +60,136 @@ def makelist(filename):
|
||||||
"""
|
"""
|
||||||
ctr = 0
|
ctr = 0
|
||||||
|
|
||||||
f = open(filename, 'r')
|
f = open(filename, "r")
|
||||||
|
|
||||||
todos = { "TODO": "", "WAITING": "", "ACTIVE": "",
|
todos = {
|
||||||
"DONE": "", "CANCELLED": "", "FAILED": ""} # populated from #+SEQ_TODO line
|
"TODO": "",
|
||||||
|
"WAITING": "",
|
||||||
|
"ACTIVE": "",
|
||||||
|
"DONE": "",
|
||||||
|
"CANCELLED": "",
|
||||||
|
"FAILED": "",
|
||||||
|
} # populated from #+SEQ_TODO line
|
||||||
level = ""
|
level = ""
|
||||||
heading = ""
|
heading = ""
|
||||||
bodytext = ""
|
bodytext = ""
|
||||||
tags = list() # set of all tags in headline
|
tags = list() # set of all tags in headline
|
||||||
closed_date = ''
|
closed_date = ""
|
||||||
sched_date = ''
|
sched_date = ""
|
||||||
deadline_date = ''
|
deadline_date = ""
|
||||||
logbook = list()
|
logbook = list()
|
||||||
nodelist: List[Orgnode] = list()
|
nodelist: List[Orgnode] = list()
|
||||||
property_map = dict()
|
property_map = dict()
|
||||||
in_properties_drawer = False
|
in_properties_drawer = False
|
||||||
in_logbook_drawer = False
|
in_logbook_drawer = False
|
||||||
file_title = f'{filename}'
|
file_title = f"{filename}"
|
||||||
|
|
||||||
for line in f:
|
for line in f:
|
||||||
ctr += 1
|
ctr += 1
|
||||||
heading_search = re.search(r'^(\*+)\s(.*?)\s*$', line)
|
heading_search = re.search(r"^(\*+)\s(.*?)\s*$", line)
|
||||||
if heading_search: # we are processing a heading line
|
if heading_search: # we are processing a heading line
|
||||||
if heading: # if we have are on second heading, append first heading to headings list
|
if heading: # if we have are on second heading, append first heading to headings list
|
||||||
thisNode = Orgnode(level, heading, bodytext, tags)
|
thisNode = Orgnode(level, heading, bodytext, tags)
|
||||||
if closed_date:
|
if closed_date:
|
||||||
thisNode.closed = closed_date
|
thisNode.closed = closed_date
|
||||||
closed_date = ''
|
closed_date = ""
|
||||||
if sched_date:
|
if sched_date:
|
||||||
thisNode.scheduled = sched_date
|
thisNode.scheduled = sched_date
|
||||||
sched_date = ""
|
sched_date = ""
|
||||||
if deadline_date:
|
if deadline_date:
|
||||||
thisNode.deadline = deadline_date
|
thisNode.deadline = deadline_date
|
||||||
deadline_date = ''
|
deadline_date = ""
|
||||||
if logbook:
|
if logbook:
|
||||||
thisNode.logbook = logbook
|
thisNode.logbook = logbook
|
||||||
logbook = list()
|
logbook = list()
|
||||||
thisNode.properties = property_map
|
thisNode.properties = property_map
|
||||||
nodelist.append(thisNode)
|
nodelist.append(thisNode)
|
||||||
property_map = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'}
|
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
|
||||||
level = heading_search.group(1)
|
level = heading_search.group(1)
|
||||||
heading = heading_search.group(2)
|
heading = heading_search.group(2)
|
||||||
bodytext = ""
|
bodytext = ""
|
||||||
tags = list() # set of all tags in headline
|
tags = list() # set of all tags in headline
|
||||||
tag_search = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading)
|
tag_search = re.search(r"(.*?)\s*:([a-zA-Z0-9].*?):$", heading)
|
||||||
if tag_search:
|
if tag_search:
|
||||||
heading = tag_search.group(1)
|
heading = tag_search.group(1)
|
||||||
parsedtags = tag_search.group(2)
|
parsedtags = tag_search.group(2)
|
||||||
if parsedtags:
|
if parsedtags:
|
||||||
for parsedtag in parsedtags.split(':'):
|
for parsedtag in parsedtags.split(":"):
|
||||||
if parsedtag != '': tags.append(parsedtag)
|
if parsedtag != "":
|
||||||
|
tags.append(parsedtag)
|
||||||
else: # we are processing a non-heading line
|
else: # we are processing a non-heading line
|
||||||
if line[:10] == '#+SEQ_TODO':
|
if line[:10] == "#+SEQ_TODO":
|
||||||
kwlist = re.findall(r'([A-Z]+)\(', line)
|
kwlist = re.findall(r"([A-Z]+)\(", line)
|
||||||
for kw in kwlist: todos[kw] = ""
|
for kw in kwlist:
|
||||||
|
todos[kw] = ""
|
||||||
|
|
||||||
# Set file title to TITLE property, if it exists
|
# Set file title to TITLE property, if it exists
|
||||||
title_search = re.search(r'^#\+TITLE:\s*(.*)$', line)
|
title_search = re.search(r"^#\+TITLE:\s*(.*)$", line)
|
||||||
if title_search and title_search.group(1).strip() != '':
|
if title_search and title_search.group(1).strip() != "":
|
||||||
title_text = title_search.group(1).strip()
|
title_text = title_search.group(1).strip()
|
||||||
if file_title == f'{filename}':
|
if file_title == f"{filename}":
|
||||||
file_title = title_text
|
file_title = title_text
|
||||||
else:
|
else:
|
||||||
file_title += f' {title_text}'
|
file_title += f" {title_text}"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Ignore Properties Drawers Completely
|
# Ignore Properties Drawers Completely
|
||||||
if re.search(':PROPERTIES:', line):
|
if re.search(":PROPERTIES:", line):
|
||||||
in_properties_drawer = True
|
in_properties_drawer = True
|
||||||
continue
|
continue
|
||||||
if in_properties_drawer and re.search(':END:', line):
|
if in_properties_drawer and re.search(":END:", line):
|
||||||
in_properties_drawer = False
|
in_properties_drawer = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Ignore Logbook Drawer Start, End Lines
|
# Ignore Logbook Drawer Start, End Lines
|
||||||
if re.search(':LOGBOOK:', line):
|
if re.search(":LOGBOOK:", line):
|
||||||
in_logbook_drawer = True
|
in_logbook_drawer = True
|
||||||
continue
|
continue
|
||||||
if in_logbook_drawer and re.search(':END:', line):
|
if in_logbook_drawer and re.search(":END:", line):
|
||||||
in_logbook_drawer = False
|
in_logbook_drawer = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract Clocking Lines
|
# Extract Clocking Lines
|
||||||
clocked_re = re.search(r'CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]', line)
|
clocked_re = re.search(
|
||||||
|
r"CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]",
|
||||||
|
line,
|
||||||
|
)
|
||||||
if clocked_re:
|
if clocked_re:
|
||||||
# convert clock in, clock out strings to datetime objects
|
# convert clock in, clock out strings to datetime objects
|
||||||
clocked_in = datetime.datetime.strptime(clocked_re.group(1), '%Y-%m-%d %a %H:%M')
|
clocked_in = datetime.datetime.strptime(clocked_re.group(1), "%Y-%m-%d %a %H:%M")
|
||||||
clocked_out = datetime.datetime.strptime(clocked_re.group(2), '%Y-%m-%d %a %H:%M')
|
clocked_out = datetime.datetime.strptime(clocked_re.group(2), "%Y-%m-%d %a %H:%M")
|
||||||
# add clocked time to the entries logbook list
|
# add clocked time to the entries logbook list
|
||||||
logbook += [(clocked_in, clocked_out)]
|
logbook += [(clocked_in, clocked_out)]
|
||||||
line = ""
|
line = ""
|
||||||
|
|
||||||
property_search = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line)
|
property_search = re.search(r"^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$", line)
|
||||||
if property_search:
|
if property_search:
|
||||||
# Set ID property to an id based org-mode link to the entry
|
# Set ID property to an id based org-mode link to the entry
|
||||||
if property_search.group(1) == 'ID':
|
if property_search.group(1) == "ID":
|
||||||
property_map['ID'] = f'id:{property_search.group(2)}'
|
property_map["ID"] = f"id:{property_search.group(2)}"
|
||||||
else:
|
else:
|
||||||
property_map[property_search.group(1)] = property_search.group(2)
|
property_map[property_search.group(1)] = property_search.group(2)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cd_re = re.search(r'CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})', line)
|
cd_re = re.search(r"CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})", line)
|
||||||
if cd_re:
|
if cd_re:
|
||||||
closed_date = datetime.date(int(cd_re.group(1)),
|
closed_date = datetime.date(int(cd_re.group(1)), int(cd_re.group(2)), int(cd_re.group(3)))
|
||||||
int(cd_re.group(2)),
|
sd_re = re.search(r"SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)", line)
|
||||||
int(cd_re.group(3)) )
|
|
||||||
sd_re = re.search(r'SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)', line)
|
|
||||||
if sd_re:
|
if sd_re:
|
||||||
sched_date = datetime.date(int(sd_re.group(1)),
|
sched_date = datetime.date(int(sd_re.group(1)), int(sd_re.group(2)), int(sd_re.group(3)))
|
||||||
int(sd_re.group(2)),
|
dd_re = re.search(r"DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)", line)
|
||||||
int(sd_re.group(3)) )
|
|
||||||
dd_re = re.search(r'DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)', line)
|
|
||||||
if dd_re:
|
if dd_re:
|
||||||
deadline_date = datetime.date(int(dd_re.group(1)),
|
deadline_date = datetime.date(int(dd_re.group(1)), int(dd_re.group(2)), int(dd_re.group(3)))
|
||||||
int(dd_re.group(2)),
|
|
||||||
int(dd_re.group(3)) )
|
|
||||||
|
|
||||||
# Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body
|
# Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body
|
||||||
if not in_properties_drawer and not cd_re and not sd_re and not dd_re and not clocked_re and line[:1] != '#':
|
if (
|
||||||
|
not in_properties_drawer
|
||||||
|
and not cd_re
|
||||||
|
and not sd_re
|
||||||
|
and not dd_re
|
||||||
|
and not clocked_re
|
||||||
|
and line[:1] != "#"
|
||||||
|
):
|
||||||
bodytext = bodytext + line
|
bodytext = bodytext + line
|
||||||
|
|
||||||
# write out last node
|
# write out last node
|
||||||
|
@ -194,34 +208,36 @@ def makelist(filename):
|
||||||
# using the list of TODO keywords found in the file
|
# using the list of TODO keywords found in the file
|
||||||
# process the headings searching for TODO keywords
|
# process the headings searching for TODO keywords
|
||||||
for n in nodelist:
|
for n in nodelist:
|
||||||
todo_search = re.search(r'([A-Z]+)\s(.*?)$', n.heading)
|
todo_search = re.search(r"([A-Z]+)\s(.*?)$", n.heading)
|
||||||
if todo_search:
|
if todo_search:
|
||||||
if todo_search.group(1) in todos:
|
if todo_search.group(1) in todos:
|
||||||
n.heading = todo_search.group(2)
|
n.heading = todo_search.group(2)
|
||||||
n.todo = todo_search.group(1)
|
n.todo = todo_search.group(1)
|
||||||
|
|
||||||
# extract, set priority from heading, update heading if necessary
|
# extract, set priority from heading, update heading if necessary
|
||||||
priority_search = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.heading)
|
priority_search = re.search(r"^\[\#(A|B|C)\] (.*?)$", n.heading)
|
||||||
if priority_search:
|
if priority_search:
|
||||||
n.priority = priority_search.group(1)
|
n.priority = priority_search.group(1)
|
||||||
n.heading = priority_search.group(2)
|
n.heading = priority_search.group(2)
|
||||||
|
|
||||||
# Set SOURCE property to a file+heading based org-mode link to the entry
|
# Set SOURCE property to a file+heading based org-mode link to the entry
|
||||||
if n.level == 0:
|
if n.level == 0:
|
||||||
n.properties['LINE'] = f'file:{normalize_filename(filename)}::0'
|
n.properties["LINE"] = f"file:{normalize_filename(filename)}::0"
|
||||||
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}]]'
|
n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}]]"
|
||||||
else:
|
else:
|
||||||
escaped_heading = n.heading.replace("[", "\\[").replace("]", "\\]")
|
escaped_heading = n.heading.replace("[", "\\[").replace("]", "\\]")
|
||||||
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]'
|
n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}::*{escaped_heading}]]"
|
||||||
|
|
||||||
return nodelist
|
return nodelist
|
||||||
|
|
||||||
|
|
||||||
######################
|
######################
|
||||||
class Orgnode(object):
|
class Orgnode(object):
|
||||||
"""
|
"""
|
||||||
Orgnode class represents a headline, tags and text associated
|
Orgnode class represents a headline, tags and text associated
|
||||||
with the headline.
|
with the headline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, level, headline, body, tags):
|
def __init__(self, level, headline, body, tags):
|
||||||
"""
|
"""
|
||||||
Create an Orgnode object given the parameters of level (as the
|
Create an Orgnode object given the parameters of level (as the
|
||||||
|
@ -270,7 +286,7 @@ class Orgnode(object):
|
||||||
"""
|
"""
|
||||||
Returns True if node has non empty body, else False
|
Returns True if node has non empty body, else False
|
||||||
"""
|
"""
|
||||||
return self._body and re.sub(r'\n|\t|\r| ', '', self._body) != ''
|
return self._body and re.sub(r"\n|\t|\r| ", "", self._body) != ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def level(self):
|
def level(self):
|
||||||
|
@ -417,20 +433,20 @@ class Orgnode(object):
|
||||||
text as used to construct the node.
|
text as used to construct the node.
|
||||||
"""
|
"""
|
||||||
# Output heading line
|
# Output heading line
|
||||||
n = ''
|
n = ""
|
||||||
for _ in range(0, self._level):
|
for _ in range(0, self._level):
|
||||||
n = n + '*'
|
n = n + "*"
|
||||||
n = n + ' '
|
n = n + " "
|
||||||
if self._todo:
|
if self._todo:
|
||||||
n = n + self._todo + ' '
|
n = n + self._todo + " "
|
||||||
if self._priority:
|
if self._priority:
|
||||||
n = n + '[#' + self._priority + '] '
|
n = n + "[#" + self._priority + "] "
|
||||||
n = n + self._heading
|
n = n + self._heading
|
||||||
n = "%-60s " % n # hack - tags will start in column 62
|
n = "%-60s " % n # hack - tags will start in column 62
|
||||||
closecolon = ''
|
closecolon = ""
|
||||||
for t in self._tags:
|
for t in self._tags:
|
||||||
n = n + ':' + t
|
n = n + ":" + t
|
||||||
closecolon = ':'
|
closecolon = ":"
|
||||||
n = n + closecolon
|
n = n + closecolon
|
||||||
n = n + "\n"
|
n = n + "\n"
|
||||||
|
|
||||||
|
@ -447,7 +463,7 @@ class Orgnode(object):
|
||||||
if self._deadline:
|
if self._deadline:
|
||||||
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
|
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
|
||||||
if self._closed or self._scheduled or self._deadline:
|
if self._closed or self._scheduled or self._deadline:
|
||||||
n = n + '\n'
|
n = n + "\n"
|
||||||
|
|
||||||
# Ouput Property Drawer
|
# Ouput Property Drawer
|
||||||
n = n + indent + ":PROPERTIES:\n"
|
n = n + indent + ":PROPERTIES:\n"
|
||||||
|
|
|
@ -17,14 +17,17 @@ class TextToJsonl(ABC):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(self, previous_entries: List[Entry]=None) -> List[Tuple[int, Entry]]: ...
|
def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]:
|
||||||
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hash_func(key: str) -> Callable:
|
def hash_func(key: str) -> Callable:
|
||||||
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
|
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def split_entries_by_max_tokens(entries: List[Entry], max_tokens: int=256, max_word_length: int=500) -> List[Entry]:
|
def split_entries_by_max_tokens(
|
||||||
|
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
|
||||||
|
) -> List[Entry]:
|
||||||
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
|
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
|
||||||
chunked_entries: List[Entry] = []
|
chunked_entries: List[Entry] = []
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
|
@ -33,12 +36,14 @@ class TextToJsonl(ABC):
|
||||||
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
||||||
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
||||||
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
|
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
|
||||||
compiled_entry_chunk = ' '.join(compiled_entry_words_chunk)
|
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
|
||||||
entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file)
|
entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file)
|
||||||
chunked_entries.append(entry_chunk)
|
chunked_entries.append(entry_chunk)
|
||||||
return chunked_entries
|
return chunked_entries
|
||||||
|
|
||||||
def mark_entries_for_update(self, current_entries: List[Entry], previous_entries: List[Entry], key='compiled', logger=None) -> List[Tuple[int, Entry]]:
|
def mark_entries_for_update(
|
||||||
|
self, current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None
|
||||||
|
) -> List[Tuple[int, Entry]]:
|
||||||
# Hash all current and previous entries to identify new entries
|
# Hash all current and previous entries to identify new entries
|
||||||
with timer("Hash previous, current entries", logger):
|
with timer("Hash previous, current entries", logger):
|
||||||
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
||||||
|
@ -54,10 +59,7 @@ class TextToJsonl(ABC):
|
||||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||||
|
|
||||||
# Mark new entries with -1 id to flag for later embeddings generation
|
# Mark new entries with -1 id to flag for later embeddings generation
|
||||||
new_entries = [
|
new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes]
|
||||||
(-1, hash_to_current_entries[entry_hash])
|
|
||||||
for entry_hash in new_entry_hashes
|
|
||||||
]
|
|
||||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||||
existing_entries = [
|
existing_entries = [
|
||||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||||
|
|
|
@ -22,27 +22,30 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Create Routes
|
# Create Routes
|
||||||
@api.get('/config/data/default')
|
@api.get("/config/data/default")
|
||||||
def get_default_config_data():
|
def get_default_config_data():
|
||||||
return constants.default_config
|
return constants.default_config
|
||||||
|
|
||||||
@api.get('/config/data', response_model=FullConfig)
|
|
||||||
|
@api.get("/config/data", response_model=FullConfig)
|
||||||
def get_config_data():
|
def get_config_data():
|
||||||
return state.config
|
return state.config
|
||||||
|
|
||||||
@api.post('/config/data')
|
|
||||||
|
@api.post("/config/data")
|
||||||
async def set_config_data(updated_config: FullConfig):
|
async def set_config_data(updated_config: FullConfig):
|
||||||
state.config = updated_config
|
state.config = updated_config
|
||||||
with open(state.config_file, 'w') as outfile:
|
with open(state.config_file, "w") as outfile:
|
||||||
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
|
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
|
||||||
outfile.close()
|
outfile.close()
|
||||||
return state.config
|
return state.config
|
||||||
|
|
||||||
@api.get('/search', response_model=List[SearchResponse])
|
|
||||||
|
@api.get("/search", response_model=List[SearchResponse])
|
||||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||||
results: List[SearchResponse] = []
|
results: List[SearchResponse] = []
|
||||||
if q is None or q == '':
|
if q is None or q == "":
|
||||||
logger.info(f'No query param (q) passed in API call to initiate search')
|
logger.info(f"No query param (q) passed in API call to initiate search")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# initialize variables
|
# initialize variables
|
||||||
|
@ -50,9 +53,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
results_count = n
|
results_count = n
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
query_cache_key = f'{user_query}-{n}-{t}-{r}'
|
query_cache_key = f"{user_query}-{n}-{t}-{r}"
|
||||||
if query_cache_key in state.query_cache:
|
if query_cache_key in state.query_cache:
|
||||||
logger.info(f'Return response from query cache')
|
logger.info(f"Return response from query cache")
|
||||||
return state.query_cache[query_cache_key]
|
return state.query_cache[query_cache_key]
|
||||||
|
|
||||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||||
|
@ -95,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
# query images
|
# query images
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits = image_search.query(user_query, results_count, state.model.image_search)
|
hits = image_search.query(user_query, results_count, state.model.image_search)
|
||||||
output_directory = constants.web_directory / 'images'
|
output_directory = constants.web_directory / "images"
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
with timer("Collating results took", logger):
|
with timer("Collating results took", logger):
|
||||||
|
@ -103,8 +106,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
hits,
|
hits,
|
||||||
image_names=state.model.image_search.image_names,
|
image_names=state.model.image_search.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
image_files_url='/static/images',
|
image_files_url="/static/images",
|
||||||
count=results_count)
|
count=results_count,
|
||||||
|
)
|
||||||
|
|
||||||
# Cache results
|
# Cache results
|
||||||
state.query_cache[query_cache_key] = results
|
state.query_cache[query_cache_key] = results
|
||||||
|
@ -112,7 +116,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@api.get('/update')
|
@api.get("/update")
|
||||||
def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
|
def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
|
||||||
try:
|
try:
|
||||||
state.search_index_lock.acquire()
|
state.search_index_lock.acquire()
|
||||||
|
@ -132,4 +136,4 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
|
||||||
else:
|
else:
|
||||||
logger.info("Processor reconfigured via API call")
|
logger.info("Processor reconfigured via API call")
|
||||||
|
|
||||||
return {'status': 'ok', 'message': 'khoj reloaded'}
|
return {"status": "ok", "message": "khoj reloaded"}
|
||||||
|
|
|
@ -9,7 +9,14 @@ from fastapi import APIRouter
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.routers.api import search
|
from khoj.routers.api import search
|
||||||
from khoj.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
from khoj.processor.conversation.gpt import (
|
||||||
|
converse,
|
||||||
|
extract_search_type,
|
||||||
|
message_to_log,
|
||||||
|
message_to_prompt,
|
||||||
|
understand,
|
||||||
|
summarize,
|
||||||
|
)
|
||||||
from khoj.utils.config import SearchType
|
from khoj.utils.config import SearchType
|
||||||
from khoj.utils.helpers import get_from_dict, resolve_absolute_path
|
from khoj.utils.helpers import get_from_dict, resolve_absolute_path
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
@ -21,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Create Routes
|
# Create Routes
|
||||||
@api_beta.get('/search')
|
@api_beta.get("/search")
|
||||||
def search_beta(q: str, n: Optional[int] = 1):
|
def search_beta(q: str, n: Optional[int] = 1):
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
model = state.processor_config.conversation.model
|
model = state.processor_config.conversation.model
|
||||||
|
@ -32,16 +39,16 @@ def search_beta(q: str, n: Optional[int] = 1):
|
||||||
metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose)
|
metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose)
|
||||||
search_type = get_from_dict(metadata, "search-type")
|
search_type = get_from_dict(metadata, "search-type")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {'status': 'error', 'result': [str(e)], 'type': None}
|
return {"status": "error", "result": [str(e)], "type": None}
|
||||||
|
|
||||||
# Search
|
# Search
|
||||||
search_results = search(q, n=n, t=SearchType(search_type))
|
search_results = search(q, n=n, t=SearchType(search_type))
|
||||||
|
|
||||||
# Return response
|
# Return response
|
||||||
return {'status': 'ok', 'result': search_results, 'type': search_type}
|
return {"status": "ok", "result": search_results, "type": search_type}
|
||||||
|
|
||||||
|
|
||||||
@api_beta.get('/summarize')
|
@api_beta.get("/summarize")
|
||||||
def summarize_beta(q: str):
|
def summarize_beta(q: str):
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
model = state.processor_config.conversation.model
|
model = state.processor_config.conversation.model
|
||||||
|
@ -54,22 +61,24 @@ def summarize_beta(q: str):
|
||||||
# Converse with OpenAI GPT
|
# Converse with OpenAI GPT
|
||||||
result_list = search(q, n=1, r=True)
|
result_list = search(q, n=1, r=True)
|
||||||
collated_result = "\n".join([item.entry for item in result_list])
|
collated_result = "\n".join([item.entry for item in result_list])
|
||||||
logger.debug(f'Semantically Similar Notes:\n{collated_result}')
|
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
|
||||||
try:
|
try:
|
||||||
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
|
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
|
||||||
status = 'ok'
|
status = "ok"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
gpt_response = str(e)
|
gpt_response = str(e)
|
||||||
status = 'error'
|
status = "error"
|
||||||
|
|
||||||
# Update Conversation History
|
# Update Conversation History
|
||||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||||
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, conversation_log=meta_log.get('chat', []))
|
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||||
|
q, gpt_response, conversation_log=meta_log.get("chat", [])
|
||||||
|
)
|
||||||
|
|
||||||
return {'status': status, 'response': gpt_response}
|
return {"status": status, "response": gpt_response}
|
||||||
|
|
||||||
|
|
||||||
@api_beta.get('/chat')
|
@api_beta.get("/chat")
|
||||||
def chat(q: Optional[str] = None):
|
def chat(q: Optional[str] = None):
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
model = state.processor_config.conversation.model
|
model = state.processor_config.conversation.model
|
||||||
|
@ -81,10 +90,10 @@ def chat(q: Optional[str]=None):
|
||||||
|
|
||||||
# If user query is empty, return chat history
|
# If user query is empty, return chat history
|
||||||
if not q:
|
if not q:
|
||||||
if meta_log.get('chat'):
|
if meta_log.get("chat"):
|
||||||
return {'status': 'ok', 'response': meta_log["chat"]}
|
return {"status": "ok", "response": meta_log["chat"]}
|
||||||
else:
|
else:
|
||||||
return {'status': 'ok', 'response': []}
|
return {"status": "ok", "response": []}
|
||||||
|
|
||||||
# Converse with OpenAI GPT
|
# Converse with OpenAI GPT
|
||||||
metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose)
|
metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose)
|
||||||
|
@ -94,32 +103,39 @@ def chat(q: Optional[str]=None):
|
||||||
query = get_from_dict(metadata, "intent", "query")
|
query = get_from_dict(metadata, "intent", "query")
|
||||||
result_list = search(query, n=1, t=SearchType.Org, r=True)
|
result_list = search(query, n=1, t=SearchType.Org, r=True)
|
||||||
collated_result = "\n".join([item.entry for item in result_list])
|
collated_result = "\n".join([item.entry for item in result_list])
|
||||||
logger.debug(f'Semantically Similar Notes:\n{collated_result}')
|
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
|
||||||
try:
|
try:
|
||||||
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
|
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
|
||||||
status = 'ok'
|
status = "ok"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
gpt_response = str(e)
|
gpt_response = str(e)
|
||||||
status = 'error'
|
status = "error"
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
gpt_response = converse(q, model, chat_session, api_key=api_key)
|
gpt_response = converse(q, model, chat_session, api_key=api_key)
|
||||||
status = 'ok'
|
status = "ok"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
gpt_response = str(e)
|
gpt_response = str(e)
|
||||||
status = 'error'
|
status = "error"
|
||||||
|
|
||||||
# Update Conversation History
|
# Update Conversation History
|
||||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||||
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, metadata, meta_log.get('chat', []))
|
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||||
|
q, gpt_response, metadata, meta_log.get("chat", [])
|
||||||
|
)
|
||||||
|
|
||||||
return {'status': status, 'response': gpt_response}
|
return {"status": status, "response": gpt_response}
|
||||||
|
|
||||||
|
|
||||||
@schedule.repeat(schedule.every(5).minutes)
|
@schedule.repeat(schedule.every(5).minutes)
|
||||||
def save_chat_session():
|
def save_chat_session():
|
||||||
# No need to create empty log file
|
# No need to create empty log file
|
||||||
if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log and state.processor_config.conversation.chat_session):
|
if not (
|
||||||
|
state.processor_config
|
||||||
|
and state.processor_config.conversation
|
||||||
|
and state.processor_config.conversation.meta_log
|
||||||
|
and state.processor_config.conversation.chat_session
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Summarize Conversation Logs for this Session
|
# Summarize Conversation Logs for this Session
|
||||||
|
@ -130,19 +146,19 @@ def save_chat_session():
|
||||||
session = {
|
session = {
|
||||||
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
|
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
|
||||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||||
"session-end": len(conversation_log["chat"])
|
"session-end": len(conversation_log["chat"]),
|
||||||
}
|
}
|
||||||
if 'session' in conversation_log:
|
if "session" in conversation_log:
|
||||||
conversation_log['session'].append(session)
|
conversation_log["session"].append(session)
|
||||||
else:
|
else:
|
||||||
conversation_log['session'] = [session]
|
conversation_log["session"] = [session]
|
||||||
logger.info('Added new chat session to conversation logs')
|
logger.info("Added new chat session to conversation logs")
|
||||||
|
|
||||||
# Save Conversation Metadata Logs to Disk
|
# Save Conversation Metadata Logs to Disk
|
||||||
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
|
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
|
||||||
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
|
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
|
||||||
with open(conversation_logfile, "w+", encoding='utf-8') as logfile:
|
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
||||||
json.dump(conversation_log, logfile)
|
json.dump(conversation_log, logfile)
|
||||||
|
|
||||||
state.processor_config.conversation.chat_session = None
|
state.processor_config.conversation.chat_session = None
|
||||||
logger.info('Saved updated conversation logs to disk.')
|
logger.info("Saved updated conversation logs to disk.")
|
||||||
|
|
|
@ -18,9 +18,11 @@ templates = Jinja2Templates(directory=constants.web_directory)
|
||||||
def index():
|
def index():
|
||||||
return FileResponse(constants.web_directory / "index.html")
|
return FileResponse(constants.web_directory / "index.html")
|
||||||
|
|
||||||
@web_client.get('/config', response_class=HTMLResponse)
|
|
||||||
|
@web_client.get("/config", response_class=HTMLResponse)
|
||||||
def config_page(request: Request):
|
def config_page(request: Request):
|
||||||
return templates.TemplateResponse("config.html", context={'request': request})
|
return templates.TemplateResponse("config.html", context={"request": request})
|
||||||
|
|
||||||
|
|
||||||
@web_client.get("/chat", response_class=FileResponse)
|
@web_client.get("/chat", response_class=FileResponse)
|
||||||
def chat_page():
|
def chat_page():
|
||||||
|
|
|
@ -8,10 +8,13 @@ from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
class BaseFilter(ABC):
|
class BaseFilter(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, entries: List[Entry], *args, **kwargs): ...
|
def load(self, entries: List[Entry], *args, **kwargs):
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def can_filter(self, raw_query:str) -> bool: ...
|
def can_filter(self, raw_query: str) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, query:str, entries: List[Entry]) -> Tuple[str, Set[int]]: ...
|
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
||||||
|
...
|
||||||
|
|
|
@ -26,21 +26,19 @@ class DateFilter(BaseFilter):
|
||||||
# - dt:"2 years ago"
|
# - dt:"2 years ago"
|
||||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||||
|
|
||||||
|
def __init__(self, entry_key="raw"):
|
||||||
def __init__(self, entry_key='raw'):
|
|
||||||
self.entry_key = entry_key
|
self.entry_key = entry_key
|
||||||
self.date_to_entry_ids = defaultdict(set)
|
self.date_to_entry_ids = defaultdict(set)
|
||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
|
||||||
|
|
||||||
def load(self, entries, *args, **kwargs):
|
def load(self, entries, *args, **kwargs):
|
||||||
with timer("Created date filter index", logger):
|
with timer("Created date filter index", logger):
|
||||||
for id, entry in enumerate(entries):
|
for id, entry in enumerate(entries):
|
||||||
# Extract dates from entry
|
# Extract dates from entry
|
||||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
|
for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)):
|
||||||
# Convert date string in entry to unix timestamp
|
# Convert date string in entry to unix timestamp
|
||||||
try:
|
try:
|
||||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
self.date_to_entry_ids[date_in_entry].add(id)
|
self.date_to_entry_ids[date_in_entry].add(id)
|
||||||
|
@ -49,7 +47,6 @@ class DateFilter(BaseFilter):
|
||||||
"Check if query contains date filters"
|
"Check if query contains date filters"
|
||||||
return self.extract_date_range(raw_query) is not None
|
return self.extract_date_range(raw_query) is not None
|
||||||
|
|
||||||
|
|
||||||
def apply(self, query, entries):
|
def apply(self, query, entries):
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
"Find entries containing any dates that fall within date range specified in query"
|
||||||
# extract date range specified in date filter of query
|
# extract date range specified in date filter of query
|
||||||
|
@ -61,8 +58,8 @@ class DateFilter(BaseFilter):
|
||||||
return query, set(range(len(entries)))
|
return query, set(range(len(entries)))
|
||||||
|
|
||||||
# remove date range filter from query
|
# remove date range filter from query
|
||||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
||||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
||||||
|
|
||||||
# return results from cache if exists
|
# return results from cache if exists
|
||||||
cache_key = tuple(query_daterange)
|
cache_key = tuple(query_daterange)
|
||||||
|
@ -87,7 +84,6 @@ class DateFilter(BaseFilter):
|
||||||
|
|
||||||
return query, entries_to_include
|
return query, entries_to_include
|
||||||
|
|
||||||
|
|
||||||
def extract_date_range(self, query):
|
def extract_date_range(self, query):
|
||||||
# find date range filter in query
|
# find date range filter in query
|
||||||
date_range_matches = re.findall(self.date_regex, query)
|
date_range_matches = re.findall(self.date_regex, query)
|
||||||
|
@ -98,7 +94,7 @@ class DateFilter(BaseFilter):
|
||||||
# extract, parse natural dates ranges from date range filter passed in query
|
# extract, parse natural dates ranges from date range filter passed in query
|
||||||
# e.g today maps to (start_of_day, start_of_tomorrow)
|
# e.g today maps to (start_of_day, start_of_tomorrow)
|
||||||
date_ranges_from_filter = []
|
date_ranges_from_filter = []
|
||||||
for (cmp, date_str) in date_range_matches:
|
for cmp, date_str in date_range_matches:
|
||||||
if self.parse(date_str):
|
if self.parse(date_str):
|
||||||
dt_start, dt_end = self.parse(date_str)
|
dt_start, dt_end = self.parse(date_str)
|
||||||
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
|
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
|
||||||
|
@ -111,15 +107,15 @@ class DateFilter(BaseFilter):
|
||||||
effective_date_range = [0, inf]
|
effective_date_range = [0, inf]
|
||||||
date_range_considering_comparator = []
|
date_range_considering_comparator = []
|
||||||
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
|
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
|
||||||
if cmp == '>':
|
if cmp == ">":
|
||||||
date_range_considering_comparator += [[dtrange_end, inf]]
|
date_range_considering_comparator += [[dtrange_end, inf]]
|
||||||
elif cmp == '>=':
|
elif cmp == ">=":
|
||||||
date_range_considering_comparator += [[dtrange_start, inf]]
|
date_range_considering_comparator += [[dtrange_start, inf]]
|
||||||
elif cmp == '<':
|
elif cmp == "<":
|
||||||
date_range_considering_comparator += [[0, dtrange_start]]
|
date_range_considering_comparator += [[0, dtrange_start]]
|
||||||
elif cmp == '<=':
|
elif cmp == "<=":
|
||||||
date_range_considering_comparator += [[0, dtrange_end]]
|
date_range_considering_comparator += [[0, dtrange_end]]
|
||||||
elif cmp == '=' or cmp == ':' or cmp == '==':
|
elif cmp == "=" or cmp == ":" or cmp == "==":
|
||||||
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
|
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
|
||||||
|
|
||||||
# Combine above intervals (via AND/intersect)
|
# Combine above intervals (via AND/intersect)
|
||||||
|
@ -129,48 +125,48 @@ class DateFilter(BaseFilter):
|
||||||
for date_range in date_range_considering_comparator:
|
for date_range in date_range_considering_comparator:
|
||||||
effective_date_range = [
|
effective_date_range = [
|
||||||
max(effective_date_range[0], date_range[0]),
|
max(effective_date_range[0], date_range[0]),
|
||||||
min(effective_date_range[1], date_range[1])]
|
min(effective_date_range[1], date_range[1]),
|
||||||
|
]
|
||||||
|
|
||||||
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return effective_date_range
|
return effective_date_range
|
||||||
|
|
||||||
|
|
||||||
def parse(self, date_str, relative_base=None):
|
def parse(self, date_str, relative_base=None):
|
||||||
"Parse date string passed in date filter of query to datetime object"
|
"Parse date string passed in date filter of query to datetime object"
|
||||||
# clean date string to handle future date parsing by date parser
|
# clean date string to handle future date parsing by date parser
|
||||||
future_strings = ['later', 'from now', 'from today']
|
future_strings = ["later", "from now", "from today"]
|
||||||
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
|
prefer_dates_from = {True: "future", False: "past"}[any([True for fstr in future_strings if fstr in date_str])]
|
||||||
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
|
clean_date_str = re.sub("|".join(future_strings), "", date_str)
|
||||||
|
|
||||||
# parse date passed in query date filter
|
# parse date passed in query date filter
|
||||||
parsed_date = dtparse.parse(
|
parsed_date = dtparse.parse(
|
||||||
clean_date_str,
|
clean_date_str,
|
||||||
settings={
|
settings={
|
||||||
'RELATIVE_BASE': relative_base or datetime.now(),
|
"RELATIVE_BASE": relative_base or datetime.now(),
|
||||||
'PREFER_DAY_OF_MONTH': 'first',
|
"PREFER_DAY_OF_MONTH": "first",
|
||||||
'PREFER_DATES_FROM': prefer_dates_from
|
"PREFER_DATES_FROM": prefer_dates_from,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if parsed_date is None:
|
if parsed_date is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.date_to_daterange(parsed_date, date_str)
|
return self.date_to_daterange(parsed_date, date_str)
|
||||||
|
|
||||||
|
|
||||||
def date_to_daterange(self, parsed_date, date_str):
|
def date_to_daterange(self, parsed_date, date_str):
|
||||||
"Convert parsed date to date ranges at natural granularity (day, week, month or year)"
|
"Convert parsed date to date ranges at natural granularity (day, week, month or year)"
|
||||||
|
|
||||||
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
if 'year' in date_str:
|
if "year" in date_str:
|
||||||
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year + 1, 1, 1, 0, 0, 0))
|
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year + 1, 1, 1, 0, 0, 0))
|
||||||
if 'month' in date_str:
|
if "month" in date_str:
|
||||||
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
|
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
|
||||||
next_month = start_of_month + relativedelta(months=1)
|
next_month = start_of_month + relativedelta(months=1)
|
||||||
return (start_of_month, next_month)
|
return (start_of_month, next_month)
|
||||||
if 'week' in date_str:
|
if "week" in date_str:
|
||||||
# if week in date string, dateparser parses it to next week start
|
# if week in date string, dateparser parses it to next week start
|
||||||
# so today = end of this week
|
# so today = end of this week
|
||||||
start_of_week = start_of_day - timedelta(days=7)
|
start_of_week = start_of_day - timedelta(days=7)
|
||||||
|
|
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
class FileFilter(BaseFilter):
|
class FileFilter(BaseFilter):
|
||||||
file_filter_regex = r'file:"(.+?)" ?'
|
file_filter_regex = r'file:"(.+?)" ?'
|
||||||
|
|
||||||
def __init__(self, entry_key='file'):
|
def __init__(self, entry_key="file"):
|
||||||
self.entry_key = entry_key
|
self.entry_key = entry_key
|
||||||
self.file_to_entry_map = defaultdict(set)
|
self.file_to_entry_map = defaultdict(set)
|
||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
@ -40,13 +40,13 @@ class FileFilter(BaseFilter):
|
||||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||||
files_to_search = []
|
files_to_search = []
|
||||||
for file in sorted(raw_files_to_search):
|
for file in sorted(raw_files_to_search):
|
||||||
if '/' not in file and '\\' not in file and '*' not in file:
|
if "/" not in file and "\\" not in file and "*" not in file:
|
||||||
files_to_search += [f'*{file}']
|
files_to_search += [f"*{file}"]
|
||||||
else:
|
else:
|
||||||
files_to_search += [file]
|
files_to_search += [file]
|
||||||
|
|
||||||
# Return item from cache if exists
|
# Return item from cache if exists
|
||||||
query = re.sub(self.file_filter_regex, '', query).strip()
|
query = re.sub(self.file_filter_regex, "", query).strip()
|
||||||
cache_key = tuple(files_to_search)
|
cache_key = tuple(files_to_search)
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
logger.info(f"Return file filter results from cache")
|
logger.info(f"Return file filter results from cache")
|
||||||
|
@ -58,10 +58,15 @@ class FileFilter(BaseFilter):
|
||||||
|
|
||||||
# Mark entries that contain any blocked_words for exclusion
|
# Mark entries that contain any blocked_words for exclusion
|
||||||
with timer("Mark entries satisfying filter", logger):
|
with timer("Mark entries satisfying filter", logger):
|
||||||
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
|
included_entry_indices = set.union(
|
||||||
|
*[
|
||||||
|
self.file_to_entry_map[entry_file]
|
||||||
for entry_file in self.file_to_entry_map.keys()
|
for entry_file in self.file_to_entry_map.keys()
|
||||||
for search_file in files_to_search
|
for search_file in files_to_search
|
||||||
if fnmatch.fnmatch(entry_file, search_file)], set())
|
if fnmatch.fnmatch(entry_file, search_file)
|
||||||
|
],
|
||||||
|
set(),
|
||||||
|
)
|
||||||
if not included_entry_indices:
|
if not included_entry_indices:
|
||||||
return query, {}
|
return query, {}
|
||||||
|
|
||||||
|
|
|
@ -17,26 +17,26 @@ class WordFilter(BaseFilter):
|
||||||
required_regex = r'\+"([a-zA-Z0-9_-]+)" ?'
|
required_regex = r'\+"([a-zA-Z0-9_-]+)" ?'
|
||||||
blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?'
|
blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?'
|
||||||
|
|
||||||
def __init__(self, entry_key='raw'):
|
def __init__(self, entry_key="raw"):
|
||||||
self.entry_key = entry_key
|
self.entry_key = entry_key
|
||||||
self.word_to_entry_index = defaultdict(set)
|
self.word_to_entry_index = defaultdict(set)
|
||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
|
||||||
|
|
||||||
def load(self, entries, *args, **kwargs):
|
def load(self, entries, *args, **kwargs):
|
||||||
with timer("Created word filter index", logger):
|
with timer("Created word filter index", logger):
|
||||||
self.cache = {} # Clear cache on filter (re-)load
|
self.cache = {} # Clear cache on filter (re-)load
|
||||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
entry_splitter = (
|
||||||
|
r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
|
||||||
|
)
|
||||||
# Create map of words to entries they exist in
|
# Create map of words to entries they exist in
|
||||||
for entry_index, entry in enumerate(entries):
|
for entry_index, entry in enumerate(entries):
|
||||||
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
||||||
if word == '':
|
if word == "":
|
||||||
continue
|
continue
|
||||||
self.word_to_entry_index[word].add(entry_index)
|
self.word_to_entry_index[word].add(entry_index)
|
||||||
|
|
||||||
return self.word_to_entry_index
|
return self.word_to_entry_index
|
||||||
|
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
"Check if query contains word filters"
|
"Check if query contains word filters"
|
||||||
required_words = re.findall(self.required_regex, raw_query)
|
required_words = re.findall(self.required_regex, raw_query)
|
||||||
|
@ -44,14 +44,13 @@ class WordFilter(BaseFilter):
|
||||||
|
|
||||||
return len(required_words) != 0 or len(blocked_words) != 0
|
return len(required_words) != 0 or len(blocked_words) != 0
|
||||||
|
|
||||||
|
|
||||||
def apply(self, query, entries):
|
def apply(self, query, entries):
|
||||||
"Find entries containing required and not blocked words specified in query"
|
"Find entries containing required and not blocked words specified in query"
|
||||||
# Separate natural query from required, blocked words filters
|
# Separate natural query from required, blocked words filters
|
||||||
with timer("Extract required, blocked filters from query", logger):
|
with timer("Extract required, blocked filters from query", logger):
|
||||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
|
query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||||
return query, set(range(len(entries)))
|
return query, set(range(len(entries)))
|
||||||
|
@ -70,12 +69,16 @@ class WordFilter(BaseFilter):
|
||||||
with timer("Mark entries satisfying filter", logger):
|
with timer("Mark entries satisfying filter", logger):
|
||||||
entries_with_all_required_words = set(range(len(entries)))
|
entries_with_all_required_words = set(range(len(entries)))
|
||||||
if len(required_words) > 0:
|
if len(required_words) > 0:
|
||||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
entries_with_all_required_words = set.intersection(
|
||||||
|
*[self.word_to_entry_index.get(word, set()) for word in required_words]
|
||||||
|
)
|
||||||
|
|
||||||
# mark entries that contain any blocked_words for exclusion
|
# mark entries that contain any blocked_words for exclusion
|
||||||
entries_with_any_blocked_words = set()
|
entries_with_any_blocked_words = set()
|
||||||
if len(blocked_words) > 0:
|
if len(blocked_words) > 0:
|
||||||
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
|
entries_with_any_blocked_words = set.union(
|
||||||
|
*[self.word_to_entry_index.get(word, set()) for word in blocked_words]
|
||||||
|
)
|
||||||
|
|
||||||
# get entries satisfying inclusion and exclusion filters
|
# get entries satisfying inclusion and exclusion filters
|
||||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
||||||
|
|
|
@ -37,7 +37,8 @@ def initialize_model(search_config: ImageSearchConfig):
|
||||||
encoder = load_model(
|
encoder = load_model(
|
||||||
model_dir=search_config.model_directory,
|
model_dir=search_config.model_directory,
|
||||||
model_name=search_config.encoder,
|
model_name=search_config.encoder,
|
||||||
model_type = search_config.encoder_type or SentenceTransformer)
|
model_type=search_config.encoder_type or SentenceTransformer,
|
||||||
|
)
|
||||||
|
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -46,12 +47,12 @@ def extract_entries(image_directories):
|
||||||
image_names = []
|
image_names = []
|
||||||
for image_directory in image_directories:
|
for image_directory in image_directories:
|
||||||
image_directory = resolve_absolute_path(image_directory, strict=True)
|
image_directory = resolve_absolute_path(image_directory, strict=True)
|
||||||
image_names.extend(list(image_directory.glob('*.jpg')))
|
image_names.extend(list(image_directory.glob("*.jpg")))
|
||||||
image_names.extend(list(image_directory.glob('*.jpeg')))
|
image_names.extend(list(image_directory.glob("*.jpeg")))
|
||||||
|
|
||||||
if logger.level >= logging.INFO:
|
if logger.level >= logging.INFO:
|
||||||
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories])
|
image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories])
|
||||||
logger.info(f'Found {len(image_names)} images in {image_directory_names}')
|
logger.info(f"Found {len(image_names)} images in {image_directory_names}")
|
||||||
return sorted(image_names)
|
return sorted(image_names)
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +60,9 @@ def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
|
|
||||||
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
|
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
|
||||||
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate)
|
image_metadata_embeddings = compute_metadata_embeddings(
|
||||||
|
image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate
|
||||||
|
)
|
||||||
|
|
||||||
return image_embeddings, image_metadata_embeddings
|
return image_embeddings, image_metadata_embeddings
|
||||||
|
|
||||||
|
@ -79,10 +82,7 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
|
||||||
# Resize images to max width of 640px for faster processing
|
# Resize images to max width of 640px for faster processing
|
||||||
image.thumbnail((640, image.height))
|
image.thumbnail((640, image.height))
|
||||||
images += [image]
|
images += [image]
|
||||||
image_embeddings += encoder.encode(
|
image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size))
|
||||||
images,
|
|
||||||
convert_to_tensor=True,
|
|
||||||
batch_size=min(len(images), batch_size))
|
|
||||||
|
|
||||||
# Create directory for embeddings file, if it doesn't exist
|
# Create directory for embeddings file, if it doesn't exist
|
||||||
embeddings_file.parent.mkdir(parents=True, exist_ok=True)
|
embeddings_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -94,7 +94,9 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
|
||||||
return image_embeddings
|
return image_embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0):
|
def compute_metadata_embeddings(
|
||||||
|
image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0
|
||||||
|
):
|
||||||
image_metadata_embeddings = None
|
image_metadata_embeddings = None
|
||||||
|
|
||||||
# Load pre-computed image metadata embedding file if exists
|
# Load pre-computed image metadata embedding file if exists
|
||||||
|
@ -106,14 +108,17 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
|
||||||
if use_xmp_metadata and image_metadata_embeddings is None:
|
if use_xmp_metadata and image_metadata_embeddings is None:
|
||||||
image_metadata_embeddings = []
|
image_metadata_embeddings = []
|
||||||
for index in trange(0, len(image_names), batch_size):
|
for index in trange(0, len(image_names), batch_size):
|
||||||
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]]
|
image_metadata = [
|
||||||
|
extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size]
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
image_metadata_embeddings += encoder.encode(
|
image_metadata_embeddings += encoder.encode(
|
||||||
image_metadata,
|
image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size)
|
||||||
convert_to_tensor=True,
|
)
|
||||||
batch_size=min(len(image_metadata), batch_size))
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
|
logger.error(
|
||||||
|
f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
||||||
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||||
|
@ -123,8 +128,10 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
|
||||||
|
|
||||||
def extract_metadata(image_name):
|
def extract_metadata(image_name):
|
||||||
image_xmp_metadata = Image.open(image_name).getxmp()
|
image_xmp_metadata = Image.open(image_name).getxmp()
|
||||||
image_description = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'description', 'Alt', 'li', 'text')
|
image_description = get_from_dict(
|
||||||
image_subjects = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'subject', 'Bag', 'li')
|
image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text"
|
||||||
|
)
|
||||||
|
image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li")
|
||||||
image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject])
|
image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject])
|
||||||
|
|
||||||
image_processed_metadata = image_description
|
image_processed_metadata = image_description
|
||||||
|
@ -155,36 +162,42 @@ def query(raw_query, count, model: ImageSearchModel):
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||||
with timer("Search Time", logger):
|
with timer("Search Time", logger):
|
||||||
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
|
image_hits = {
|
||||||
for result
|
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
||||||
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
|
||||||
|
}
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||||
if model.image_metadata_embeddings:
|
if model.image_metadata_embeddings:
|
||||||
with timer("Metadata Search Time", logger):
|
with timer("Metadata Search Time", logger):
|
||||||
metadata_hits = {result['corpus_id']: result['score']
|
metadata_hits = {
|
||||||
for result
|
result["corpus_id"]: result["score"]
|
||||||
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
|
||||||
|
}
|
||||||
|
|
||||||
# Sum metadata, image scores of the highest ranked images
|
# Sum metadata, image scores of the highest ranked images
|
||||||
for corpus_id, score in metadata_hits.items():
|
for corpus_id, score in metadata_hits.items():
|
||||||
scaling_factor = 0.33
|
scaling_factor = 0.33
|
||||||
if 'corpus_id' in image_hits:
|
if "corpus_id" in image_hits:
|
||||||
image_hits[corpus_id].update({
|
image_hits[corpus_id].update(
|
||||||
'metadata_score': score,
|
{
|
||||||
'score': image_hits[corpus_id].get('score', 0) + scaling_factor*score,
|
"metadata_score": score,
|
||||||
})
|
"score": image_hits[corpus_id].get("score", 0) + scaling_factor * score,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
image_hits[corpus_id] = {'metadata_score': score, 'score': scaling_factor*score}
|
image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score}
|
||||||
|
|
||||||
# Reformat results in original form from sentence transformer semantic_search()
|
# Reformat results in original form from sentence transformer semantic_search()
|
||||||
hits = [
|
hits = [
|
||||||
{
|
{
|
||||||
'corpus_id': corpus_id,
|
"corpus_id": corpus_id,
|
||||||
'score': scores['score'],
|
"score": scores["score"],
|
||||||
'image_score': scores.get('image_score', 0),
|
"image_score": scores.get("image_score", 0),
|
||||||
'metadata_score': scores.get('metadata_score', 0),
|
"metadata_score": scores.get("metadata_score", 0),
|
||||||
} for corpus_id, scores in image_hits.items()]
|
}
|
||||||
|
for corpus_id, scores in image_hits.items()
|
||||||
|
]
|
||||||
|
|
||||||
# Sort the images based on their combined metadata, image scores
|
# Sort the images based on their combined metadata, image scores
|
||||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||||
|
@ -194,7 +207,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||||
results: List[SearchResponse] = []
|
results: List[SearchResponse] = []
|
||||||
|
|
||||||
for index, hit in enumerate(hits[:count]):
|
for index, hit in enumerate(hits[:count]):
|
||||||
source_path = image_names[hit['corpus_id']]
|
source_path = image_names[hit["corpus_id"]]
|
||||||
|
|
||||||
target_image_name = f"{index}{source_path.suffix}"
|
target_image_name = f"{index}{source_path.suffix}"
|
||||||
target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}")
|
target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}")
|
||||||
|
@ -207,17 +220,18 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||||
shutil.copy(source_path, target_path)
|
shutil.copy(source_path, target_path)
|
||||||
|
|
||||||
# Add the image metadata to the results
|
# Add the image metadata to the results
|
||||||
results += [SearchResponse.parse_obj(
|
results += [
|
||||||
|
SearchResponse.parse_obj(
|
||||||
{
|
{
|
||||||
"entry": f'{image_files_url}/{target_image_name}',
|
"entry": f"{image_files_url}/{target_image_name}",
|
||||||
"score": f"{hit['score']:.9f}",
|
"score": f"{hit['score']:.9f}",
|
||||||
"additional":
|
"additional": {
|
||||||
{
|
|
||||||
"image_score": f"{hit['image_score']:.9f}",
|
"image_score": f"{hit['image_score']:.9f}",
|
||||||
"metadata_score": f"{hit['metadata_score']:.9f}",
|
"metadata_score": f"{hit['metadata_score']:.9f}",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
)]
|
]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -248,9 +262,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
||||||
embeddings_file,
|
embeddings_file,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
use_xmp_metadata=config.use_xmp_metadata)
|
use_xmp_metadata=config.use_xmp_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
return ImageSearchModel(all_image_files,
|
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
|
||||||
image_embeddings,
|
|
||||||
image_metadata_embeddings,
|
|
||||||
encoder)
|
|
||||||
|
|
|
@ -41,14 +41,16 @@ def initialize_model(search_config: TextSearchConfig):
|
||||||
model_dir=search_config.model_directory,
|
model_dir=search_config.model_directory,
|
||||||
model_name=search_config.encoder,
|
model_name=search_config.encoder,
|
||||||
model_type=search_config.encoder_type or SentenceTransformer,
|
model_type=search_config.encoder_type or SentenceTransformer,
|
||||||
device=f'{state.device}')
|
device=f"{state.device}",
|
||||||
|
)
|
||||||
|
|
||||||
# The cross-encoder re-ranks the results to improve quality
|
# The cross-encoder re-ranks the results to improve quality
|
||||||
cross_encoder = load_model(
|
cross_encoder = load_model(
|
||||||
model_dir=search_config.model_directory,
|
model_dir=search_config.model_directory,
|
||||||
model_name=search_config.cross_encoder,
|
model_name=search_config.cross_encoder,
|
||||||
model_type=CrossEncoder,
|
model_type=CrossEncoder,
|
||||||
device=f'{state.device}')
|
device=f"{state.device}",
|
||||||
|
)
|
||||||
|
|
||||||
return bi_encoder, cross_encoder, top_k
|
return bi_encoder, cross_encoder, top_k
|
||||||
|
|
||||||
|
@ -58,7 +60,9 @@ def extract_entries(jsonl_file) -> List[Entry]:
|
||||||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False):
|
def compute_embeddings(
|
||||||
|
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
|
||||||
|
):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
new_entries = []
|
new_entries = []
|
||||||
# Load pre-computed embeddings from file if exists and update them if required
|
# Load pre-computed embeddings from file if exists and update them if required
|
||||||
|
@ -69,17 +73,23 @@ def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: Ba
|
||||||
# Encode any new entries in the corpus and update corpus embeddings
|
# Encode any new entries in the corpus and update corpus embeddings
|
||||||
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
||||||
if new_entries:
|
if new_entries:
|
||||||
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
new_embeddings = bi_encoder.encode(
|
||||||
|
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
||||||
|
)
|
||||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||||
if existing_entry_ids:
|
if existing_entry_ids:
|
||||||
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device))
|
existing_embeddings = torch.index_select(
|
||||||
|
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
existing_embeddings = torch.tensor([], device=state.device)
|
existing_embeddings = torch.tensor([], device=state.device)
|
||||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||||
# Else compute the corpus embeddings from scratch
|
# Else compute the corpus embeddings from scratch
|
||||||
else:
|
else:
|
||||||
new_entries = [entry.compiled for _, entry in entries_with_ids]
|
new_entries = [entry.compiled for _, entry in entries_with_ids]
|
||||||
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
corpus_embeddings = bi_encoder.encode(
|
||||||
|
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
||||||
|
)
|
||||||
|
|
||||||
# Save regenerated or updated embeddings to file
|
# Save regenerated or updated embeddings to file
|
||||||
if new_entries:
|
if new_entries:
|
||||||
|
@ -112,7 +122,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
with timer("Search Time", logger, state.device):
|
with timer("Search Time", logger, state.device):
|
||||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
hits = util.semantic_search(
|
||||||
|
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
|
||||||
|
)[0]
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results:
|
if rank_results:
|
||||||
|
@ -128,26 +140,33 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
||||||
|
|
||||||
|
|
||||||
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
|
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
|
||||||
return [SearchResponse.parse_obj(
|
return [
|
||||||
|
SearchResponse.parse_obj(
|
||||||
{
|
{
|
||||||
"entry": entries[hit['corpus_id']].raw,
|
"entry": entries[hit["corpus_id"]].raw,
|
||||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
|
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
|
||||||
"additional": {
|
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
|
||||||
"file": entries[hit['corpus_id']].file,
|
|
||||||
"compiled": entries[hit['corpus_id']].compiled
|
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
for hit
|
for hit in hits[0:count]
|
||||||
in hits[0:count]]
|
]
|
||||||
|
|
||||||
|
|
||||||
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = []) -> TextSearchModel:
|
def setup(
|
||||||
|
text_to_jsonl: Type[TextToJsonl],
|
||||||
|
config: TextContentConfig,
|
||||||
|
search_config: TextSearchConfig,
|
||||||
|
regenerate: bool,
|
||||||
|
filters: List[BaseFilter] = [],
|
||||||
|
) -> TextSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||||
|
|
||||||
# Map notes in text files to (compressed) JSONL formatted file
|
# Map notes in text files to (compressed) JSONL formatted file
|
||||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||||
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
previous_entries = (
|
||||||
|
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
||||||
|
)
|
||||||
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
||||||
|
|
||||||
# Extract Updated Entries
|
# Extract Updated Entries
|
||||||
|
@ -158,7 +177,9 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||||
corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
corpus_embeddings = compute_embeddings(
|
||||||
|
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
|
||||||
|
)
|
||||||
|
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
filter.load(entries, regenerate=regenerate)
|
filter.load(entries, regenerate=regenerate)
|
||||||
|
@ -166,8 +187,10 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
|
||||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||||
|
|
||||||
|
|
||||||
def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]) -> Tuple[str, List[Entry], torch.Tensor]:
|
def apply_filters(
|
||||||
'''Filter query, entries and embeddings before semantic search'''
|
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
|
||||||
|
) -> Tuple[str, List[Entry], torch.Tensor]:
|
||||||
|
"""Filter query, entries and embeddings before semantic search"""
|
||||||
|
|
||||||
with timer("Total Filter Time", logger, state.device):
|
with timer("Total Filter Time", logger, state.device):
|
||||||
included_entry_indices = set(range(len(entries)))
|
included_entry_indices = set(range(len(entries)))
|
||||||
|
@ -178,45 +201,50 @@ def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Ten
|
||||||
|
|
||||||
# Get entries (and associated embeddings) satisfying all filters
|
# Get entries (and associated embeddings) satisfying all filters
|
||||||
if not included_entry_indices:
|
if not included_entry_indices:
|
||||||
return '', [], torch.tensor([], device=state.device)
|
return "", [], torch.tensor([], device=state.device)
|
||||||
else:
|
else:
|
||||||
entries = [entries[id] for id in included_entry_indices]
|
entries = [entries[id] for id in included_entry_indices]
|
||||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
corpus_embeddings = torch.index_select(
|
||||||
|
corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
|
||||||
|
)
|
||||||
|
|
||||||
return query, entries, corpus_embeddings
|
return query, entries, corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
|
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
|
||||||
'''Score all retrieved entries using the cross-encoder'''
|
"""Score all retrieved entries using the cross-encoder"""
|
||||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits]
|
||||||
cross_scores = cross_encoder.predict(cross_inp)
|
cross_scores = cross_encoder.predict(cross_inp)
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
hits[idx]['cross-score'] = cross_scores[idx]
|
hits[idx]["cross-score"] = cross_scores[idx]
|
||||||
|
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
|
|
||||||
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
||||||
'''Order results by cross-encoder score followed by bi-encoder score'''
|
"""Order results by cross-encoder score followed by bi-encoder score"""
|
||||||
with timer("Rank Time", logger, state.device):
|
with timer("Rank Time", logger, state.device):
|
||||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
||||||
if rank_results:
|
if rank_results:
|
||||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
|
|
||||||
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
|
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
|
||||||
'''Deduplicate entries by raw entry text before showing to users
|
"""Deduplicate entries by raw entry text before showing to users
|
||||||
Compiled entries are split by max tokens supported by ML models.
|
Compiled entries are split by max tokens supported by ML models.
|
||||||
This can result in duplicate hits, entries shown to user.'''
|
This can result in duplicate hits, entries shown to user."""
|
||||||
|
|
||||||
with timer("Deduplication Time", logger, state.device):
|
with timer("Deduplication Time", logger, state.device):
|
||||||
seen, original_hits_count = set(), len(hits)
|
seen, original_hits_count = set(), len(hits)
|
||||||
hits = [hit for hit in hits
|
hits = [
|
||||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] # type: ignore[func-returns-value]
|
hit
|
||||||
|
for hit in hits
|
||||||
|
if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
|
||||||
|
]
|
||||||
duplicate_hits = original_hits_count - len(hits)
|
duplicate_hits = original_hits_count - len(hits)
|
||||||
|
|
||||||
logger.debug(f"Removed {duplicate_hits} duplicates")
|
logger.debug(f"Removed {duplicate_hits} duplicates")
|
||||||
|
|
|
@ -10,21 +10,36 @@ from khoj.utils.yaml import parse_config_from_file
|
||||||
|
|
||||||
def cli(args=None):
|
def cli(args=None):
|
||||||
# Setup Argument Parser for the Commandline Interface
|
# Setup Argument Parser for the Commandline Interface
|
||||||
parser = argparse.ArgumentParser(description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--config-file', '-c', default='~/.khoj/khoj.yml', type=pathlib.Path, help="YAML file to configure Khoj")
|
description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos"
|
||||||
parser.add_argument('--no-gui', action='store_true', default=False, help="Do not show native desktop GUI. Default: false")
|
)
|
||||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false")
|
parser.add_argument(
|
||||||
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
"--config-file", "-c", default="~/.khoj/khoj.yml", type=pathlib.Path, help="YAML file to configure Khoj"
|
||||||
parser.add_argument('--host', type=str, default='127.0.0.1', help="Host address of the server. Default: 127.0.0.1")
|
)
|
||||||
parser.add_argument('--port', '-p', type=int, default=8000, help="Port of the server. Default: 8000")
|
parser.add_argument(
|
||||||
parser.add_argument('--socket', type=pathlib.Path, help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock")
|
"--no-gui", action="store_true", default=False, help="Do not show native desktop GUI. Default: false"
|
||||||
parser.add_argument('--version', '-V', action='store_true', help="Print the installed Khoj version and exit")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--regenerate",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Regenerate model embeddings from source files. Default: false",
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", "-v", action="count", default=0, help="Show verbose conversion logs. Default: 0")
|
||||||
|
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host address of the server. Default: 127.0.0.1")
|
||||||
|
parser.add_argument("--port", "-p", type=int, default=8000, help="Port of the server. Default: 8000")
|
||||||
|
parser.add_argument(
|
||||||
|
"--socket",
|
||||||
|
type=pathlib.Path,
|
||||||
|
help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock",
|
||||||
|
)
|
||||||
|
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
|
||||||
|
|
||||||
args = parser.parse_args(args)
|
args = parser.parse_args(args)
|
||||||
|
|
||||||
if args.version:
|
if args.version:
|
||||||
# Show version of khoj installed and exit
|
# Show version of khoj installed and exit
|
||||||
print(version('khoj-assistant'))
|
print(version("khoj-assistant"))
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# Normalize config_file path to absolute path
|
# Normalize config_file path to absolute path
|
||||||
|
|
|
@ -28,8 +28,16 @@ class ProcessorType(str, Enum):
|
||||||
Conversation = "conversation"
|
Conversation = "conversation"
|
||||||
|
|
||||||
|
|
||||||
class TextSearchModel():
|
class TextSearchModel:
|
||||||
def __init__(self, entries: List[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: List[BaseFilter], top_k):
|
def __init__(
|
||||||
|
self,
|
||||||
|
entries: List[Entry],
|
||||||
|
corpus_embeddings: torch.Tensor,
|
||||||
|
bi_encoder: BaseEncoder,
|
||||||
|
cross_encoder: CrossEncoder,
|
||||||
|
filters: List[BaseFilter],
|
||||||
|
top_k,
|
||||||
|
):
|
||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.corpus_embeddings = corpus_embeddings
|
self.corpus_embeddings = corpus_embeddings
|
||||||
self.bi_encoder = bi_encoder
|
self.bi_encoder = bi_encoder
|
||||||
|
@ -38,7 +46,7 @@ class TextSearchModel():
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
|
|
||||||
class ImageSearchModel():
|
class ImageSearchModel:
|
||||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
|
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
|
||||||
self.image_encoder = image_encoder
|
self.image_encoder = image_encoder
|
||||||
self.image_names = image_names
|
self.image_names = image_names
|
||||||
|
@ -48,7 +56,7 @@ class ImageSearchModel():
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchModels():
|
class SearchModels:
|
||||||
orgmode_search: TextSearchModel = None
|
orgmode_search: TextSearchModel = None
|
||||||
ledger_search: TextSearchModel = None
|
ledger_search: TextSearchModel = None
|
||||||
music_search: TextSearchModel = None
|
music_search: TextSearchModel = None
|
||||||
|
@ -56,15 +64,15 @@ class SearchModels():
|
||||||
image_search: ImageSearchModel = None
|
image_search: ImageSearchModel = None
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfigModel():
|
class ConversationProcessorConfigModel:
|
||||||
def __init__(self, processor_config: ConversationProcessorConfig):
|
def __init__(self, processor_config: ConversationProcessorConfig):
|
||||||
self.openai_api_key = processor_config.openai_api_key
|
self.openai_api_key = processor_config.openai_api_key
|
||||||
self.model = processor_config.model
|
self.model = processor_config.model
|
||||||
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
||||||
self.chat_session = ''
|
self.chat_session = ""
|
||||||
self.meta_log: dict = {}
|
self.meta_log: dict = {}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProcessorConfigModel():
|
class ProcessorConfigModel:
|
||||||
conversation: ConversationProcessorConfigModel = None
|
conversation: ConversationProcessorConfigModel = None
|
||||||
|
|
|
@ -1,65 +1,62 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
app_root_directory = Path(__file__).parent.parent.parent
|
app_root_directory = Path(__file__).parent.parent.parent
|
||||||
web_directory = app_root_directory / 'khoj/interface/web/'
|
web_directory = app_root_directory / "khoj/interface/web/"
|
||||||
empty_escape_sequences = '\n|\r|\t| '
|
empty_escape_sequences = "\n|\r|\t| "
|
||||||
|
|
||||||
# default app config to use
|
# default app config to use
|
||||||
default_config = {
|
default_config = {
|
||||||
'content-type': {
|
"content-type": {
|
||||||
'org': {
|
"org": {
|
||||||
'input-files': None,
|
"input-files": None,
|
||||||
'input-filter': None,
|
"input-filter": None,
|
||||||
'compressed-jsonl': '~/.khoj/content/org/org.jsonl.gz',
|
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
|
||||||
'embeddings-file': '~/.khoj/content/org/org_embeddings.pt',
|
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
|
||||||
'index_heading_entries': False
|
"index_heading_entries": False,
|
||||||
},
|
},
|
||||||
'markdown': {
|
"markdown": {
|
||||||
'input-files': None,
|
"input-files": None,
|
||||||
'input-filter': None,
|
"input-filter": None,
|
||||||
'compressed-jsonl': '~/.khoj/content/markdown/markdown.jsonl.gz',
|
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
|
||||||
'embeddings-file': '~/.khoj/content/markdown/markdown_embeddings.pt'
|
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
|
||||||
},
|
},
|
||||||
'ledger': {
|
"ledger": {
|
||||||
'input-files': None,
|
"input-files": None,
|
||||||
'input-filter': None,
|
"input-filter": None,
|
||||||
'compressed-jsonl': '~/.khoj/content/ledger/ledger.jsonl.gz',
|
"compressed-jsonl": "~/.khoj/content/ledger/ledger.jsonl.gz",
|
||||||
'embeddings-file': '~/.khoj/content/ledger/ledger_embeddings.pt'
|
"embeddings-file": "~/.khoj/content/ledger/ledger_embeddings.pt",
|
||||||
},
|
},
|
||||||
'image': {
|
"image": {
|
||||||
'input-directories': None,
|
"input-directories": None,
|
||||||
'input-filter': None,
|
"input-filter": None,
|
||||||
'embeddings-file': '~/.khoj/content/image/image_embeddings.pt',
|
"embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
|
||||||
'batch-size': 50,
|
"batch-size": 50,
|
||||||
'use-xmp-metadata': False
|
"use-xmp-metadata": False,
|
||||||
},
|
},
|
||||||
'music': {
|
"music": {
|
||||||
'input-files': None,
|
"input-files": None,
|
||||||
'input-filter': None,
|
"input-filter": None,
|
||||||
'compressed-jsonl': '~/.khoj/content/music/music.jsonl.gz',
|
"compressed-jsonl": "~/.khoj/content/music/music.jsonl.gz",
|
||||||
'embeddings-file': '~/.khoj/content/music/music_embeddings.pt'
|
"embeddings-file": "~/.khoj/content/music/music_embeddings.pt",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"search-type": {
|
||||||
|
"symmetric": {
|
||||||
|
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
"model_directory": "~/.khoj/search/symmetric/",
|
||||||
|
},
|
||||||
|
"asymmetric": {
|
||||||
|
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||||
|
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
"model_directory": "~/.khoj/search/asymmetric/",
|
||||||
|
},
|
||||||
|
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
||||||
|
},
|
||||||
|
"processor": {
|
||||||
|
"conversation": {
|
||||||
|
"openai-api-key": None,
|
||||||
|
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'search-type': {
|
|
||||||
'symmetric': {
|
|
||||||
'encoder': 'sentence-transformers/all-MiniLM-L6-v2',
|
|
||||||
'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
|
||||||
'model_directory': '~/.khoj/search/symmetric/'
|
|
||||||
},
|
|
||||||
'asymmetric': {
|
|
||||||
'encoder': 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1',
|
|
||||||
'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
|
||||||
'model_directory': '~/.khoj/search/asymmetric/'
|
|
||||||
},
|
|
||||||
'image': {
|
|
||||||
'encoder': 'sentence-transformers/clip-ViT-B-32',
|
|
||||||
'model_directory': '~/.khoj/search/image/'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'processor': {
|
|
||||||
'conversation': {
|
|
||||||
'openai-api-key': None,
|
|
||||||
'conversation-logfile': '~/.khoj/processor/conversation/conversation_logs.json'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -13,16 +13,17 @@ from typing import Optional, Union, TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# External Packages
|
# External Packages
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
def is_none_or_empty(item):
|
def is_none_or_empty(item):
|
||||||
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
|
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
|
||||||
|
|
||||||
|
|
||||||
def to_snake_case_from_dash(item: str):
|
def to_snake_case_from_dash(item: str):
|
||||||
return item.replace('_', '-')
|
return item.replace("_", "-")
|
||||||
|
|
||||||
|
|
||||||
def get_absolute_path(filepath: Union[str, Path]) -> str:
|
def get_absolute_path(filepath: Union[str, Path]) -> str:
|
||||||
|
@ -34,11 +35,11 @@ def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) ->
|
||||||
|
|
||||||
|
|
||||||
def get_from_dict(dictionary, *args):
|
def get_from_dict(dictionary, *args):
|
||||||
'''null-aware get from a nested dictionary
|
"""null-aware get from a nested dictionary
|
||||||
Returns: dictionary[args[0]][args[1]]... or None if any keys missing'''
|
Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
|
||||||
current = dictionary
|
current = dictionary
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if not hasattr(current, '__iter__') or not arg in current:
|
if not hasattr(current, "__iter__") or not arg in current:
|
||||||
return None
|
return None
|
||||||
current = current[arg]
|
current = current[arg]
|
||||||
return current
|
return current
|
||||||
|
@ -74,17 +75,18 @@ def load_model(model_name: str, model_type, model_dir=None, device:str=None) ->
|
||||||
|
|
||||||
def is_pyinstaller_app():
|
def is_pyinstaller_app():
|
||||||
"Returns true if the app is running from Native GUI created by PyInstaller"
|
"Returns true if the app is running from Native GUI created by PyInstaller"
|
||||||
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
|
return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
|
||||||
|
|
||||||
|
|
||||||
def get_class_by_name(name: str) -> object:
|
def get_class_by_name(name: str) -> object:
|
||||||
"Returns the class object from name string"
|
"Returns the class object from name string"
|
||||||
module_name, class_name = name.rsplit('.', 1)
|
module_name, class_name = name.rsplit(".", 1)
|
||||||
return getattr(import_module(module_name), class_name)
|
return getattr(import_module(module_name), class_name)
|
||||||
|
|
||||||
|
|
||||||
class timer:
|
class timer:
|
||||||
'''Context manager to log time taken for a block of code to run'''
|
"""Context manager to log time taken for a block of code to run"""
|
||||||
|
|
||||||
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
|
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
|
||||||
self.message = message
|
self.message = message
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
|
@ -19,9 +19,9 @@ def load_jsonl(input_path):
|
||||||
|
|
||||||
# Open JSONL file
|
# Open JSONL file
|
||||||
if input_path.suffix == ".gz":
|
if input_path.suffix == ".gz":
|
||||||
jsonl_file = gzip.open(get_absolute_path(input_path), 'rt', encoding='utf-8')
|
jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
|
||||||
elif input_path.suffix == ".jsonl":
|
elif input_path.suffix == ".jsonl":
|
||||||
jsonl_file = open(get_absolute_path(input_path), 'r', encoding='utf-8')
|
jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
|
||||||
|
|
||||||
# Read JSONL file
|
# Read JSONL file
|
||||||
for line in jsonl_file:
|
for line in jsonl_file:
|
||||||
|
@ -31,7 +31,7 @@ def load_jsonl(input_path):
|
||||||
jsonl_file.close()
|
jsonl_file.close()
|
||||||
|
|
||||||
# Log JSONL entries loaded
|
# Log JSONL entries loaded
|
||||||
logger.info(f'Loaded {len(data)} records from {input_path}')
|
logger.info(f"Loaded {len(data)} records from {input_path}")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -41,17 +41,17 @@ def dump_jsonl(jsonl_data, output_path):
|
||||||
# Create output directory, if it doesn't exist
|
# Create output directory, if it doesn't exist
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
f.write(jsonl_data)
|
f.write(jsonl_data)
|
||||||
|
|
||||||
logger.info(f'Wrote jsonl data to {output_path}')
|
logger.info(f"Wrote jsonl data to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
def compress_jsonl_data(jsonl_data, output_path):
|
def compress_jsonl_data(jsonl_data, output_path):
|
||||||
# Create output directory, if it doesn't exist
|
# Create output directory, if it doesn't exist
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with gzip.open(output_path, 'wt', encoding='utf-8') as gzip_file:
|
with gzip.open(output_path, "wt", encoding="utf-8") as gzip_file:
|
||||||
gzip_file.write(jsonl_data)
|
gzip_file.write(jsonl_data)
|
||||||
|
|
||||||
logger.info(f'Wrote jsonl data to gzip compressed jsonl at {output_path}')
|
logger.info(f"Wrote jsonl data to gzip compressed jsonl at {output_path}")
|
||||||
|
|
|
@ -13,17 +13,25 @@ from khoj.utils.state import processor_config, config_file
|
||||||
|
|
||||||
class BaseEncoder(ABC):
|
class BaseEncoder(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
|
def __init__(self, model_name: str, device: torch.device = None, **kwargs):
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def encode(self, entries: List[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
|
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseEncoder):
|
class OpenAI(BaseEncoder):
|
||||||
def __init__(self, model_name, device=None):
|
def __init__(self, model_name, device=None):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
|
if (
|
||||||
raise Exception(f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}")
|
not processor_config
|
||||||
|
or not processor_config.conversation
|
||||||
|
or not processor_config.conversation.openai_api_key
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}"
|
||||||
|
)
|
||||||
openai.api_key = processor_config.conversation.openai_api_key
|
openai.api_key = processor_config.conversation.openai_api_key
|
||||||
self.embedding_dimensions = None
|
self.embedding_dimensions = None
|
||||||
|
|
||||||
|
@ -32,7 +40,7 @@ class OpenAI(BaseEncoder):
|
||||||
|
|
||||||
for index in trange(0, len(entries)):
|
for index in trange(0, len(entries)):
|
||||||
# OpenAI models create better embeddings for entries without newlines
|
# OpenAI models create better embeddings for entries without newlines
|
||||||
processed_entry = entries[index].replace('\n', ' ')
|
processed_entry = entries[index].replace("\n", " ")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
|
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
|
||||||
|
@ -41,7 +49,9 @@ class OpenAI(BaseEncoder):
|
||||||
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
||||||
self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536
|
self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}")
|
print(
|
||||||
|
f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}"
|
||||||
|
)
|
||||||
# Use zero embedding vector for entries with failed embeddings
|
# Use zero embedding vector for entries with failed embeddings
|
||||||
# This ensures entry embeddings match the order of the source entries
|
# This ensures entry embeddings match the order of the source entries
|
||||||
# And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector)
|
# And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector)
|
||||||
|
|
|
@ -9,11 +9,13 @@ from pydantic import BaseModel, validator
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
|
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
|
||||||
|
|
||||||
|
|
||||||
class ConfigBase(BaseModel):
|
class ConfigBase(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
alias_generator = to_snake_case_from_dash
|
alias_generator = to_snake_case_from_dash
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
|
||||||
class TextContentConfig(ConfigBase):
|
class TextContentConfig(ConfigBase):
|
||||||
input_files: Optional[List[Path]]
|
input_files: Optional[List[Path]]
|
||||||
input_filter: Optional[List[str]]
|
input_filter: Optional[List[str]]
|
||||||
|
@ -21,12 +23,15 @@ class TextContentConfig(ConfigBase):
|
||||||
embeddings_file: Path
|
embeddings_file: Path
|
||||||
index_heading_entries: Optional[bool] = False
|
index_heading_entries: Optional[bool] = False
|
||||||
|
|
||||||
@validator('input_filter')
|
@validator("input_filter")
|
||||||
def input_filter_or_files_required(cls, input_filter, values, **kwargs):
|
def input_filter_or_files_required(cls, input_filter, values, **kwargs):
|
||||||
if is_none_or_empty(input_filter) and ('input_files' not in values or values["input_files"] is None):
|
if is_none_or_empty(input_filter) and ("input_files" not in values or values["input_files"] is None):
|
||||||
raise ValueError("Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file")
|
raise ValueError(
|
||||||
|
"Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file"
|
||||||
|
)
|
||||||
return input_filter
|
return input_filter
|
||||||
|
|
||||||
|
|
||||||
class ImageContentConfig(ConfigBase):
|
class ImageContentConfig(ConfigBase):
|
||||||
input_directories: Optional[List[Path]]
|
input_directories: Optional[List[Path]]
|
||||||
input_filter: Optional[List[str]]
|
input_filter: Optional[List[str]]
|
||||||
|
@ -34,12 +39,17 @@ class ImageContentConfig(ConfigBase):
|
||||||
use_xmp_metadata: bool
|
use_xmp_metadata: bool
|
||||||
batch_size: int
|
batch_size: int
|
||||||
|
|
||||||
@validator('input_filter')
|
@validator("input_filter")
|
||||||
def input_filter_or_directories_required(cls, input_filter, values, **kwargs):
|
def input_filter_or_directories_required(cls, input_filter, values, **kwargs):
|
||||||
if is_none_or_empty(input_filter) and ('input_directories' not in values or values["input_directories"] is None):
|
if is_none_or_empty(input_filter) and (
|
||||||
raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file")
|
"input_directories" not in values or values["input_directories"] is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Either input_filter or input_directories required in all content-type.image section of Khoj config file"
|
||||||
|
)
|
||||||
return input_filter
|
return input_filter
|
||||||
|
|
||||||
|
|
||||||
class ContentConfig(ConfigBase):
|
class ContentConfig(ConfigBase):
|
||||||
org: Optional[TextContentConfig]
|
org: Optional[TextContentConfig]
|
||||||
ledger: Optional[TextContentConfig]
|
ledger: Optional[TextContentConfig]
|
||||||
|
@ -47,41 +57,49 @@ class ContentConfig(ConfigBase):
|
||||||
music: Optional[TextContentConfig]
|
music: Optional[TextContentConfig]
|
||||||
markdown: Optional[TextContentConfig]
|
markdown: Optional[TextContentConfig]
|
||||||
|
|
||||||
|
|
||||||
class TextSearchConfig(ConfigBase):
|
class TextSearchConfig(ConfigBase):
|
||||||
encoder: str
|
encoder: str
|
||||||
cross_encoder: str
|
cross_encoder: str
|
||||||
encoder_type: Optional[str]
|
encoder_type: Optional[str]
|
||||||
model_directory: Optional[Path]
|
model_directory: Optional[Path]
|
||||||
|
|
||||||
|
|
||||||
class ImageSearchConfig(ConfigBase):
|
class ImageSearchConfig(ConfigBase):
|
||||||
encoder: str
|
encoder: str
|
||||||
encoder_type: Optional[str]
|
encoder_type: Optional[str]
|
||||||
model_directory: Optional[Path]
|
model_directory: Optional[Path]
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(ConfigBase):
|
class SearchConfig(ConfigBase):
|
||||||
asymmetric: Optional[TextSearchConfig]
|
asymmetric: Optional[TextSearchConfig]
|
||||||
symmetric: Optional[TextSearchConfig]
|
symmetric: Optional[TextSearchConfig]
|
||||||
image: Optional[ImageSearchConfig]
|
image: Optional[ImageSearchConfig]
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfig(ConfigBase):
|
class ConversationProcessorConfig(ConfigBase):
|
||||||
openai_api_key: str
|
openai_api_key: str
|
||||||
conversation_logfile: Path
|
conversation_logfile: Path
|
||||||
model: Optional[str] = "text-davinci-003"
|
model: Optional[str] = "text-davinci-003"
|
||||||
|
|
||||||
|
|
||||||
class ProcessorConfig(ConfigBase):
|
class ProcessorConfig(ConfigBase):
|
||||||
conversation: Optional[ConversationProcessorConfig]
|
conversation: Optional[ConversationProcessorConfig]
|
||||||
|
|
||||||
|
|
||||||
class FullConfig(ConfigBase):
|
class FullConfig(ConfigBase):
|
||||||
content_type: Optional[ContentConfig]
|
content_type: Optional[ContentConfig]
|
||||||
search_type: Optional[SearchConfig]
|
search_type: Optional[SearchConfig]
|
||||||
processor: Optional[ProcessorConfig]
|
processor: Optional[ProcessorConfig]
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(ConfigBase):
|
class SearchResponse(ConfigBase):
|
||||||
entry: str
|
entry: str
|
||||||
score: str
|
score: str
|
||||||
additional: Optional[dict]
|
additional: Optional[dict]
|
||||||
|
|
||||||
class Entry():
|
|
||||||
|
class Entry:
|
||||||
raw: str
|
raw: str
|
||||||
compiled: str
|
compiled: str
|
||||||
file: Optional[str]
|
file: Optional[str]
|
||||||
|
@ -99,8 +117,4 @@ class Entry():
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, dictionary: dict):
|
def from_dict(cls, dictionary: dict):
|
||||||
return cls(
|
return cls(raw=dictionary["raw"], compiled=dictionary["compiled"], file=dictionary.get("file", None))
|
||||||
raw=dictionary['raw'],
|
|
||||||
compiled=dictionary['compiled'],
|
|
||||||
file=dictionary.get('file', None)
|
|
||||||
)
|
|
||||||
|
|
|
@ -17,14 +17,14 @@ def save_config_to_file(yaml_config: dict, yaml_config_file: Path):
|
||||||
# Create output directory, if it doesn't exist
|
# Create output directory, if it doesn't exist
|
||||||
yaml_config_file.parent.mkdir(parents=True, exist_ok=True)
|
yaml_config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(yaml_config_file, 'w', encoding='utf-8') as config_file:
|
with open(yaml_config_file, "w", encoding="utf-8") as config_file:
|
||||||
yaml.safe_dump(yaml_config, config_file, allow_unicode=True)
|
yaml.safe_dump(yaml_config, config_file, allow_unicode=True)
|
||||||
|
|
||||||
|
|
||||||
def load_config_from_file(yaml_config_file: Path) -> dict:
|
def load_config_from_file(yaml_config_file: Path) -> dict:
|
||||||
"Read config from YML file"
|
"Read config from YML file"
|
||||||
config_from_file = None
|
config_from_file = None
|
||||||
with open(yaml_config_file, 'r', encoding='utf-8') as config_file:
|
with open(yaml_config_file, "r", encoding="utf-8") as config_file:
|
||||||
config_from_file = yaml.safe_load(config_file)
|
config_from_file = yaml.safe_load(config_file)
|
||||||
return config_from_file
|
return config_from_file
|
||||||
|
|
||||||
|
|
|
@ -6,59 +6,67 @@ import pytest
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils.helpers import resolve_absolute_path
|
from khoj.utils.helpers import resolve_absolute_path
|
||||||
from khoj.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
|
from khoj.utils.rawconfig import (
|
||||||
|
ContentConfig,
|
||||||
|
TextContentConfig,
|
||||||
|
ImageContentConfig,
|
||||||
|
SearchConfig,
|
||||||
|
TextSearchConfig,
|
||||||
|
ImageSearchConfig,
|
||||||
|
)
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope="session")
|
||||||
def search_config() -> SearchConfig:
|
def search_config() -> SearchConfig:
|
||||||
model_dir = resolve_absolute_path('~/.khoj/search')
|
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||||
model_dir.mkdir(parents=True, exist_ok=True)
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
search_config = SearchConfig()
|
search_config = SearchConfig()
|
||||||
|
|
||||||
search_config.symmetric = TextSearchConfig(
|
search_config.symmetric = TextSearchConfig(
|
||||||
encoder="sentence-transformers/all-MiniLM-L6-v2",
|
encoder="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
model_directory = model_dir / 'symmetric/'
|
model_directory=model_dir / "symmetric/",
|
||||||
)
|
)
|
||||||
|
|
||||||
search_config.asymmetric = TextSearchConfig(
|
search_config.asymmetric = TextSearchConfig(
|
||||||
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
model_directory = model_dir / 'asymmetric/'
|
model_directory=model_dir / "asymmetric/",
|
||||||
)
|
)
|
||||||
|
|
||||||
search_config.image = ImageSearchConfig(
|
search_config.image = ImageSearchConfig(
|
||||||
encoder = "sentence-transformers/clip-ViT-B-32",
|
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/"
|
||||||
model_directory = model_dir / 'image/'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return search_config
|
return search_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope="session")
|
||||||
def content_config(tmp_path_factory, search_config: SearchConfig):
|
def content_config(tmp_path_factory, search_config: SearchConfig):
|
||||||
content_dir = tmp_path_factory.mktemp('content')
|
content_dir = tmp_path_factory.mktemp("content")
|
||||||
|
|
||||||
# Generate Image Embeddings from Test Images
|
# Generate Image Embeddings from Test Images
|
||||||
content_config = ContentConfig()
|
content_config = ContentConfig()
|
||||||
content_config.image = ImageContentConfig(
|
content_config.image = ImageContentConfig(
|
||||||
input_directories = ['tests/data/images'],
|
input_directories=["tests/data/images"],
|
||||||
embeddings_file = content_dir.joinpath('image_embeddings.pt'),
|
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
use_xmp_metadata = False)
|
use_xmp_metadata=False,
|
||||||
|
)
|
||||||
|
|
||||||
image_search.setup(content_config.image, search_config.image, regenerate=False)
|
image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||||
|
|
||||||
# Generate Notes Embeddings from Test Notes
|
# Generate Notes Embeddings from Test Notes
|
||||||
content_config.org = TextContentConfig(
|
content_config.org = TextContentConfig(
|
||||||
input_files=None,
|
input_files=None,
|
||||||
input_filter = ['tests/data/org/*.org'],
|
input_filter=["tests/data/org/*.org"],
|
||||||
compressed_jsonl = content_dir.joinpath('notes.jsonl.gz'),
|
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
|
||||||
embeddings_file = content_dir.joinpath('note_embeddings.pt'))
|
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||||
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||||
|
@ -66,7 +74,7 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
|
||||||
return content_config
|
return content_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='function')
|
@pytest.fixture(scope="function")
|
||||||
def new_org_file(content_config: ContentConfig):
|
def new_org_file(content_config: ContentConfig):
|
||||||
# Setup
|
# Setup
|
||||||
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
|
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
|
||||||
|
@ -79,9 +87,9 @@ def new_org_file(content_config: ContentConfig):
|
||||||
new_org_file.unlink()
|
new_org_file.unlink()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='function')
|
@pytest.fixture(scope="function")
|
||||||
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
|
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
|
||||||
new_org_config = deepcopy(content_config.org)
|
new_org_config = deepcopy(content_config.org)
|
||||||
new_org_config.input_files = [f'{new_org_file}']
|
new_org_config.input_files = [f"{new_org_file}"]
|
||||||
new_org_config.input_filter = None
|
new_org_config.input_filter = None
|
||||||
return new_org_config
|
return new_org_config
|
|
@ -8,10 +8,10 @@ from khoj.processor.ledger.beancount_to_jsonl import BeancountToJsonl
|
||||||
def test_no_transactions_in_file(tmp_path):
|
def test_no_transactions_in_file(tmp_path):
|
||||||
"Handle file with no transactions."
|
"Handle file with no transactions."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
- Bullet point 1
|
- Bullet point 1
|
||||||
- Bullet point 2
|
- Bullet point 2
|
||||||
'''
|
"""
|
||||||
beancount_file = create_file(tmp_path, entry)
|
beancount_file = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -20,7 +20,8 @@ def test_no_transactions_in_file(tmp_path):
|
||||||
|
|
||||||
# Process Each Entry from All Beancount Files
|
# Process Each Entry from All Beancount Files
|
||||||
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
|
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
|
||||||
BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries))
|
BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries)
|
||||||
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -30,11 +31,11 @@ def test_no_transactions_in_file(tmp_path):
|
||||||
def test_single_beancount_transaction_to_jsonl(tmp_path):
|
def test_single_beancount_transaction_to_jsonl(tmp_path):
|
||||||
"Convert transaction from single file to jsonl."
|
"Convert transaction from single file to jsonl."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
1984-04-01 * "Payee" "Narration"
|
1984-04-01 * "Payee" "Narration"
|
||||||
Expenses:Test:Test 1.00 KES
|
Expenses:Test:Test 1.00 KES
|
||||||
Assets:Test:Test -1.00 KES
|
Assets:Test:Test -1.00 KES
|
||||||
'''
|
"""
|
||||||
beancount_file = create_file(tmp_path, entry)
|
beancount_file = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -43,7 +44,8 @@ Assets:Test:Test -1.00 KES
|
||||||
|
|
||||||
# Process Each Entry from All Beancount Files
|
# Process Each Entry from All Beancount Files
|
||||||
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
|
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
|
||||||
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map))
|
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)
|
||||||
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -53,7 +55,7 @@ Assets:Test:Test -1.00 KES
|
||||||
def test_multiple_transactions_to_jsonl(tmp_path):
|
def test_multiple_transactions_to_jsonl(tmp_path):
|
||||||
"Convert multiple transactions from single file to jsonl."
|
"Convert multiple transactions from single file to jsonl."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
1984-04-01 * "Payee" "Narration"
|
1984-04-01 * "Payee" "Narration"
|
||||||
Expenses:Test:Test 1.00 KES
|
Expenses:Test:Test 1.00 KES
|
||||||
Assets:Test:Test -1.00 KES
|
Assets:Test:Test -1.00 KES
|
||||||
|
@ -61,7 +63,7 @@ Assets:Test:Test -1.00 KES
|
||||||
1984-04-01 * "Payee" "Narration"
|
1984-04-01 * "Payee" "Narration"
|
||||||
Expenses:Test:Test 1.00 KES
|
Expenses:Test:Test 1.00 KES
|
||||||
Assets:Test:Test -1.00 KES
|
Assets:Test:Test -1.00 KES
|
||||||
'''
|
"""
|
||||||
|
|
||||||
beancount_file = create_file(tmp_path, entry)
|
beancount_file = create_file(tmp_path, entry)
|
||||||
|
|
||||||
|
@ -71,7 +73,8 @@ Assets:Test:Test -1.00 KES
|
||||||
|
|
||||||
# Process Each Entry from All Beancount Files
|
# Process Each Entry from All Beancount Files
|
||||||
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
|
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
|
||||||
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map))
|
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)
|
||||||
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -95,8 +98,8 @@ def test_get_beancount_files(tmp_path):
|
||||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||||
|
|
||||||
# Setup input-files, input-filters
|
# Setup input-files, input-filters
|
||||||
input_files = [tmp_path / 'ledger.bean']
|
input_files = [tmp_path / "ledger.bean"]
|
||||||
input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount']
|
input_filter = [tmp_path / "group1*.bean", tmp_path / "group2*.beancount"]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter)
|
extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from khoj.processor.conversation.gpt import converse, understand, message_to_pro
|
||||||
|
|
||||||
|
|
||||||
# Initialize variables for tests
|
# Initialize variables for tests
|
||||||
model = 'text-davinci-003'
|
model = "text-davinci-003"
|
||||||
api_key = None # Input your OpenAI API key to run the tests below
|
api_key = None # Input your OpenAI API key to run the tests below
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,19 +14,22 @@ api_key = None # Input your OpenAI API key to run the tests below
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_message_to_understand_prompt():
|
def test_message_to_understand_prompt():
|
||||||
# Arrange
|
# Arrange
|
||||||
understand_primer = "Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=[\"companion\", \"notes\", \"ledger\", \"image\", \"music\"]\nsearch(search-type, data);\nsearch-type=[\"google\", \"youtube\"]\ngenerate(activity);\nactivity=[\"paint\",\"write\", \"chat\"]\ntrigger-emotion(emotion);\nemotion=[\"happy\",\"confidence\",\"fear\",\"surprise\",\"sadness\",\"disgust\",\"anger\", \"curiosity\", \"calm\"]\n\nQ: How are you doing?\nA: activity(\"chat\"); trigger-emotion(\"surprise\")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember(\"notes\", \"Brother Antoine when we were at the beach\"); trigger-emotion(\"curiosity\");\nQ: what did we talk about last time?\nA: remember(\"notes\", \"talk last time\"); trigger-emotion(\"curiosity\");\nQ: Let's make some drawings!\nA: generate(\"paint\"); trigger-emotion(\"happy\");\nQ: Do you know anything about Lebanon?\nA: search(\"google\", \"lebanon\"); trigger-emotion(\"confidence\");\nQ: Find a video about a panda rolling in the grass\nA: search(\"youtube\",\"panda rolling in the grass\"); trigger-emotion(\"happy\"); \nQ: Tell me a scary story\nA: generate(\"write\" \"A story about some adventure\"); trigger-emotion(\"fear\");\nQ: What fiction book was I reading last week about AI starship?\nA: remember(\"notes\", \"read fiction book about AI starship last week\"); trigger-emotion(\"curiosity\");\nQ: How much did I spend at Subway for dinner last time?\nA: remember(\"ledger\", \"last Subway dinner\"); trigger-emotion(\"curiosity\");\nQ: I'm feeling sleepy\nA: activity(\"chat\"); trigger-emotion(\"calm\")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember(\"music\", \"popular Sri lankan song that Alex showed recently\"); trigger-emotion(\"curiosity\"); \nQ: You're pretty funny!\nA: activity(\"chat\"); trigger-emotion(\"pride\")"
|
understand_primer = 'Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=["companion", "notes", "ledger", "image", "music"]\nsearch(search-type, data);\nsearch-type=["google", "youtube"]\ngenerate(activity);\nactivity=["paint","write", "chat"]\ntrigger-emotion(emotion);\nemotion=["happy","confidence","fear","surprise","sadness","disgust","anger", "curiosity", "calm"]\n\nQ: How are you doing?\nA: activity("chat"); trigger-emotion("surprise")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember("notes", "Brother Antoine when we were at the beach"); trigger-emotion("curiosity");\nQ: what did we talk about last time?\nA: remember("notes", "talk last time"); trigger-emotion("curiosity");\nQ: Let\'s make some drawings!\nA: generate("paint"); trigger-emotion("happy");\nQ: Do you know anything about Lebanon?\nA: search("google", "lebanon"); trigger-emotion("confidence");\nQ: Find a video about a panda rolling in the grass\nA: search("youtube","panda rolling in the grass"); trigger-emotion("happy"); \nQ: Tell me a scary story\nA: generate("write" "A story about some adventure"); trigger-emotion("fear");\nQ: What fiction book was I reading last week about AI starship?\nA: remember("notes", "read fiction book about AI starship last week"); trigger-emotion("curiosity");\nQ: How much did I spend at Subway for dinner last time?\nA: remember("ledger", "last Subway dinner"); trigger-emotion("curiosity");\nQ: I\'m feeling sleepy\nA: activity("chat"); trigger-emotion("calm")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember("music", "popular Sri lankan song that Alex showed recently"); trigger-emotion("curiosity"); \nQ: You\'re pretty funny!\nA: activity("chat"); trigger-emotion("pride")'
|
||||||
expected_response = "Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=[\"companion\", \"notes\", \"ledger\", \"image\", \"music\"]\nsearch(search-type, data);\nsearch-type=[\"google\", \"youtube\"]\ngenerate(activity);\nactivity=[\"paint\",\"write\", \"chat\"]\ntrigger-emotion(emotion);\nemotion=[\"happy\",\"confidence\",\"fear\",\"surprise\",\"sadness\",\"disgust\",\"anger\", \"curiosity\", \"calm\"]\n\nQ: How are you doing?\nA: activity(\"chat\"); trigger-emotion(\"surprise\")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember(\"notes\", \"Brother Antoine when we were at the beach\"); trigger-emotion(\"curiosity\");\nQ: what did we talk about last time?\nA: remember(\"notes\", \"talk last time\"); trigger-emotion(\"curiosity\");\nQ: Let's make some drawings!\nA: generate(\"paint\"); trigger-emotion(\"happy\");\nQ: Do you know anything about Lebanon?\nA: search(\"google\", \"lebanon\"); trigger-emotion(\"confidence\");\nQ: Find a video about a panda rolling in the grass\nA: search(\"youtube\",\"panda rolling in the grass\"); trigger-emotion(\"happy\"); \nQ: Tell me a scary story\nA: generate(\"write\" \"A story about some adventure\"); trigger-emotion(\"fear\");\nQ: What fiction book was I reading last week about AI starship?\nA: remember(\"notes\", \"read fiction book about AI starship last week\"); trigger-emotion(\"curiosity\");\nQ: How much did I spend at Subway for dinner last time?\nA: remember(\"ledger\", \"last Subway dinner\"); trigger-emotion(\"curiosity\");\nQ: I'm feeling sleepy\nA: activity(\"chat\"); trigger-emotion(\"calm\")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember(\"music\", \"popular Sri lankan song that Alex showed recently\"); trigger-emotion(\"curiosity\"); \nQ: You're pretty funny!\nA: activity(\"chat\"); trigger-emotion(\"pride\")\nQ: When did I last dine at Burger King?\nA:"
|
expected_response = 'Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=["companion", "notes", "ledger", "image", "music"]\nsearch(search-type, data);\nsearch-type=["google", "youtube"]\ngenerate(activity);\nactivity=["paint","write", "chat"]\ntrigger-emotion(emotion);\nemotion=["happy","confidence","fear","surprise","sadness","disgust","anger", "curiosity", "calm"]\n\nQ: How are you doing?\nA: activity("chat"); trigger-emotion("surprise")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember("notes", "Brother Antoine when we were at the beach"); trigger-emotion("curiosity");\nQ: what did we talk about last time?\nA: remember("notes", "talk last time"); trigger-emotion("curiosity");\nQ: Let\'s make some drawings!\nA: generate("paint"); trigger-emotion("happy");\nQ: Do you know anything about Lebanon?\nA: search("google", "lebanon"); trigger-emotion("confidence");\nQ: Find a video about a panda rolling in the grass\nA: search("youtube","panda rolling in the grass"); trigger-emotion("happy"); \nQ: Tell me a scary story\nA: generate("write" "A story about some adventure"); trigger-emotion("fear");\nQ: What fiction book was I reading last week about AI starship?\nA: remember("notes", "read fiction book about AI starship last week"); trigger-emotion("curiosity");\nQ: How much did I spend at Subway for dinner last time?\nA: remember("ledger", "last Subway dinner"); trigger-emotion("curiosity");\nQ: I\'m feeling sleepy\nA: activity("chat"); trigger-emotion("calm")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember("music", "popular Sri lankan song that Alex showed recently"); trigger-emotion("curiosity"); \nQ: You\'re pretty funny!\nA: activity("chat"); trigger-emotion("pride")\nQ: When did I last dine at Burger King?\nA:'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
actual_response = message_to_prompt("When did I last dine at Burger King?", understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
|
actual_response = message_to_prompt(
|
||||||
|
"When did I last dine at Burger King?", understand_primer, start_sequence="\nA:", restart_sequence="\nQ:"
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert actual_response == expected_response
|
assert actual_response == expected_response
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.skipif(api_key is None,
|
@pytest.mark.skipif(
|
||||||
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys")
|
api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
|
||||||
|
)
|
||||||
def test_minimal_chat_with_gpt():
|
def test_minimal_chat_with_gpt():
|
||||||
# Act
|
# Act
|
||||||
response = converse("What will happen when the stars go out?", model=model, api_key=api_key)
|
response = converse("What will happen when the stars go out?", model=model, api_key=api_key)
|
||||||
|
@ -36,21 +39,29 @@ def test_minimal_chat_with_gpt():
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.skipif(api_key is None,
|
@pytest.mark.skipif(
|
||||||
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys")
|
api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
|
||||||
|
)
|
||||||
def test_chat_with_history():
|
def test_chat_with_history():
|
||||||
# Arrange
|
# Arrange
|
||||||
ai_prompt = "AI:"
|
ai_prompt = "AI:"
|
||||||
human_prompt = "Human:"
|
human_prompt = "Human:"
|
||||||
|
|
||||||
conversation_primer = f'''
|
conversation_primer = f"""
|
||||||
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly companion.
|
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly companion.
|
||||||
|
|
||||||
{human_prompt} Hello, I am Testatron. Who are you?
|
{human_prompt} Hello, I am Testatron. Who are you?
|
||||||
{ai_prompt} Hi, I am Khoj, an AI conversational companion created by OpenAI. How can I help you today?'''
|
{ai_prompt} Hi, I am Khoj, an AI conversational companion created by OpenAI. How can I help you today?"""
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = converse("Hi Khoj, What is my name?", model=model, conversation_history=conversation_primer, api_key=api_key, temperature=0, max_tokens=50)
|
response = converse(
|
||||||
|
"Hi Khoj, What is my name?",
|
||||||
|
model=model,
|
||||||
|
conversation_history=conversation_primer,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
|
@ -58,12 +69,13 @@ The following is a conversation with an AI assistant. The assistant is helpful,
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.skipif(api_key is None,
|
@pytest.mark.skipif(
|
||||||
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys")
|
api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
|
||||||
|
)
|
||||||
def test_understand_message_using_gpt():
|
def test_understand_message_using_gpt():
|
||||||
# Act
|
# Act
|
||||||
response = understand("When did I last dine at Subway?", model=model, api_key=api_key)
|
response = understand("When did I last dine at Subway?", model=model, api_key=api_key)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
assert response['intent']['memory-type'] == 'ledger'
|
assert response["intent"]["memory-type"] == "ledger"
|
||||||
|
|
|
@ -14,35 +14,37 @@ def test_cli_minimal_default():
|
||||||
actual_args = cli([])
|
actual_args = cli([])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert actual_args.config_file == resolve_absolute_path(Path('~/.khoj/khoj.yml'))
|
assert actual_args.config_file == resolve_absolute_path(Path("~/.khoj/khoj.yml"))
|
||||||
assert actual_args.regenerate == False
|
assert actual_args.regenerate == False
|
||||||
assert actual_args.no_gui == False
|
assert actual_args.no_gui == False
|
||||||
assert actual_args.verbose == 0
|
assert actual_args.verbose == 0
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_cli_invalid_config_file_path():
|
def test_cli_invalid_config_file_path():
|
||||||
# Arrange
|
# Arrange
|
||||||
non_existent_config_file = f"non-existent-khoj-{random()}.yml"
|
non_existent_config_file = f"non-existent-khoj-{random()}.yml"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
actual_args = cli([f'-c={non_existent_config_file}'])
|
actual_args = cli([f"-c={non_existent_config_file}"])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert actual_args.config_file == resolve_absolute_path(non_existent_config_file)
|
assert actual_args.config_file == resolve_absolute_path(non_existent_config_file)
|
||||||
assert actual_args.config == None
|
assert actual_args.config == None
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_cli_config_from_file():
|
def test_cli_config_from_file():
|
||||||
# Act
|
# Act
|
||||||
actual_args = cli(['-c=tests/data/config.yml',
|
actual_args = cli(["-c=tests/data/config.yml", "--regenerate", "--no-gui", "-vvv"])
|
||||||
'--regenerate',
|
|
||||||
'--no-gui',
|
|
||||||
'-vvv'])
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert actual_args.config_file == resolve_absolute_path(Path('tests/data/config.yml'))
|
assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml"))
|
||||||
assert actual_args.no_gui == True
|
assert actual_args.no_gui == True
|
||||||
assert actual_args.regenerate == True
|
assert actual_args.regenerate == True
|
||||||
assert actual_args.config is not None
|
assert actual_args.config is not None
|
||||||
assert actual_args.config.content_type.org.input_files == [Path('~/first_from_config.org'), Path('~/second_from_config.org')]
|
assert actual_args.config.content_type.org.input_files == [
|
||||||
|
Path("~/first_from_config.org"),
|
||||||
|
Path("~/second_from_config.org"),
|
||||||
|
]
|
||||||
assert actual_args.verbose == 3
|
assert actual_args.verbose == 3
|
||||||
|
|
|
@ -21,6 +21,7 @@ from khoj.search_filter.file_filter import FileFilter
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_search_with_invalid_content_type():
|
def test_search_with_invalid_content_type():
|
||||||
|
@ -98,9 +99,11 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
||||||
config.content_type = content_config
|
config.content_type = content_config
|
||||||
config.search_type = search_config
|
config.search_type = search_config
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||||
query_expected_image_pairs = [("kitten", "kitten_park.jpg"),
|
query_expected_image_pairs = [
|
||||||
|
("kitten", "kitten_park.jpg"),
|
||||||
("a horse and dog on a leash", "horse_dog.jpg"),
|
("a horse and dog on a leash", "horse_dog.jpg"),
|
||||||
("A guinea pig eating grass", "guineapig_grass.jpg")]
|
("A guinea pig eating grass", "guineapig_grass.jpg"),
|
||||||
|
]
|
||||||
|
|
||||||
for query, expected_image_name in query_expected_image_pairs:
|
for query, expected_image_name in query_expected_image_pairs:
|
||||||
# Act
|
# Act
|
||||||
|
@ -135,7 +138,9 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
||||||
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter(), FileFilter()]
|
filters = [WordFilter(), FileFilter()]
|
||||||
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
model.orgmode_search = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
||||||
|
)
|
||||||
user_query = quote('+"Emacs" file:"*.org"')
|
user_query = quote('+"Emacs" file:"*.org"')
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -152,7 +157,9 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co
|
||||||
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter()]
|
filters = [WordFilter()]
|
||||||
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
model.orgmode_search = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
||||||
|
)
|
||||||
user_query = quote('How to git install application? +"Emacs"')
|
user_query = quote('How to git install application? +"Emacs"')
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -169,7 +176,9 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
|
||||||
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
|
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter()]
|
filters = [WordFilter()]
|
||||||
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
model.orgmode_search = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
|
||||||
|
)
|
||||||
user_query = quote('How to git install application? -"clone"')
|
user_query = quote('How to git install application? -"clone"')
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
|
|
@ -10,53 +10,59 @@ from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
def test_date_filter():
|
def test_date_filter():
|
||||||
entries = [
|
entries = [
|
||||||
Entry(compiled='', raw='Entry with no date'),
|
Entry(compiled="", raw="Entry with no date"),
|
||||||
Entry(compiled='', raw='April Fools entry: 1984-04-01'),
|
Entry(compiled="", raw="April Fools entry: 1984-04-01"),
|
||||||
Entry(compiled='', raw='Entry with date:1984-04-02')
|
Entry(compiled="", raw="Entry with date:1984-04-02"),
|
||||||
]
|
]
|
||||||
|
|
||||||
q_with_no_date_filter = 'head tail'
|
q_with_no_date_filter = "head tail"
|
||||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {0, 1, 2}
|
assert entry_indices == {0, 1, 2}
|
||||||
|
|
||||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == set()
|
assert entry_indices == set()
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {2}
|
assert entry_indices == {2}
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {1}
|
assert entry_indices == {1}
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {2}
|
assert entry_indices == {2}
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {1, 2}
|
assert entry_indices == {1, 2}
|
||||||
|
|
||||||
|
|
||||||
def test_extract_date_range():
|
def test_extract_date_range():
|
||||||
assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
|
assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [
|
||||||
|
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
|
||||||
|
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
|
||||||
|
]
|
||||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
||||||
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
|
||||||
|
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
|
||||||
|
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
|
||||||
|
]
|
||||||
|
|
||||||
# Unparseable date filter specified in query
|
# Unparseable date filter specified in query
|
||||||
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
|
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
|
||||||
|
|
||||||
# No date filter specified in query
|
# No date filter specified in query
|
||||||
assert DateFilter().extract_date_range('head tail') == None
|
assert DateFilter().extract_date_range("head tail") == None
|
||||||
|
|
||||||
# Non intersecting date ranges
|
# Non intersecting date ranges
|
||||||
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
|
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
|
||||||
|
@ -66,43 +72,79 @@ def test_parse():
|
||||||
test_now = datetime(1984, 4, 1, 21, 21, 21)
|
test_now = datetime(1984, 4, 1, 21, 21, 21)
|
||||||
|
|
||||||
# day variations
|
# day variations
|
||||||
assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
|
assert DateFilter().parse("today", relative_base=test_now) == (
|
||||||
assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
|
datetime(1984, 4, 1, 0, 0, 0),
|
||||||
assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
|
datetime(1984, 4, 2, 0, 0, 0),
|
||||||
assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
|
)
|
||||||
|
assert DateFilter().parse("tomorrow", relative_base=test_now) == (
|
||||||
|
datetime(1984, 4, 2, 0, 0, 0),
|
||||||
|
datetime(1984, 4, 3, 0, 0, 0),
|
||||||
|
)
|
||||||
|
assert DateFilter().parse("yesterday", relative_base=test_now) == (
|
||||||
|
datetime(1984, 3, 31, 0, 0, 0),
|
||||||
|
datetime(1984, 4, 1, 0, 0, 0),
|
||||||
|
)
|
||||||
|
assert DateFilter().parse("5 days ago", relative_base=test_now) == (
|
||||||
|
datetime(1984, 3, 27, 0, 0, 0),
|
||||||
|
datetime(1984, 3, 28, 0, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
# week variations
|
# week variations
|
||||||
assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
|
assert DateFilter().parse("last week", relative_base=test_now) == (
|
||||||
assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
|
datetime(1984, 3, 18, 0, 0, 0),
|
||||||
|
datetime(1984, 3, 25, 0, 0, 0),
|
||||||
|
)
|
||||||
|
assert DateFilter().parse("2 weeks ago", relative_base=test_now) == (
|
||||||
|
datetime(1984, 3, 11, 0, 0, 0),
|
||||||
|
datetime(1984, 3, 18, 0, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
# month variations
|
# month variations
|
||||||
assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
|
assert DateFilter().parse("next month", relative_base=test_now) == (
|
||||||
assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
|
datetime(1984, 5, 1, 0, 0, 0),
|
||||||
|
datetime(1984, 6, 1, 0, 0, 0),
|
||||||
|
)
|
||||||
|
assert DateFilter().parse("2 months ago", relative_base=test_now) == (
|
||||||
|
datetime(1984, 2, 1, 0, 0, 0),
|
||||||
|
datetime(1984, 3, 1, 0, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
# year variations
|
# year variations
|
||||||
assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
|
assert DateFilter().parse("this year", relative_base=test_now) == (
|
||||||
assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
|
datetime(1984, 1, 1, 0, 0, 0),
|
||||||
|
datetime(1985, 1, 1, 0, 0, 0),
|
||||||
|
)
|
||||||
|
assert DateFilter().parse("20 years later", relative_base=test_now) == (
|
||||||
|
datetime(2004, 1, 1, 0, 0, 0),
|
||||||
|
datetime(2005, 1, 1, 0, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
# specific month/date variation
|
# specific month/date variation
|
||||||
assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
assert DateFilter().parse("in august", relative_base=test_now) == (
|
||||||
assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
datetime(1983, 8, 1, 0, 0, 0),
|
||||||
|
datetime(1983, 8, 2, 0, 0, 0),
|
||||||
|
)
|
||||||
|
assert DateFilter().parse("on 1983-08-01", relative_base=test_now) == (
|
||||||
|
datetime(1983, 8, 1, 0, 0, 0),
|
||||||
|
datetime(1983, 8, 2, 0, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_date_filter_regex():
|
def test_date_filter_regex():
|
||||||
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
|
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
|
||||||
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
|
assert dtrange_match == [(">", "today"), (":", "1984-01-01")]
|
||||||
|
|
||||||
dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
|
dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
|
||||||
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
|
assert dtrange_match == [(">", "today"), (":", "1984-01-01")]
|
||||||
|
|
||||||
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
|
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
|
||||||
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')]
|
assert dtrange_match == [(">=", "today"), ("=", "1984-01-01")]
|
||||||
|
|
||||||
dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
|
dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
|
||||||
assert dtrange_match == [('<', 'multi word date')]
|
assert dtrange_match == [("<", "multi word date")]
|
||||||
|
|
||||||
dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
|
dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
|
||||||
assert dtrange_match == [('<=', 'multi word date')]
|
assert dtrange_match == [("<=", "multi word date")]
|
||||||
|
|
||||||
dtrange_match = re.findall(DateFilter().date_regex, 'head tail')
|
dtrange_match = re.findall(DateFilter().date_regex, "head tail")
|
||||||
assert dtrange_match == []
|
assert dtrange_match == []
|
|
@ -7,7 +7,7 @@ def test_no_file_filter():
|
||||||
# Arrange
|
# Arrange
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
entries = arrange_content()
|
entries = arrange_content()
|
||||||
q_with_no_filter = 'head tail'
|
q_with_no_filter = "head tail"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
|
@ -15,7 +15,7 @@ def test_no_file_filter():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == False
|
assert can_filter == False
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
assert entry_indices == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def test_file_filter_with_non_existent_file():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {}
|
assert entry_indices == {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ def test_single_file_filter():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {0, 2}
|
assert entry_indices == {0, 2}
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ def test_file_filter_with_partial_match():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {0, 2}
|
assert entry_indices == {0, 2}
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ def test_file_filter_with_regex_match():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
assert entry_indices == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,16 +95,16 @@ def test_multiple_file_filter():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
assert entry_indices == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
|
||||||
def arrange_content():
|
def arrange_content():
|
||||||
entries = [
|
entries = [
|
||||||
Entry(compiled='', raw='First Entry', file= 'file 1.org'),
|
Entry(compiled="", raw="First Entry", file="file 1.org"),
|
||||||
Entry(compiled='', raw='Second Entry', file= 'file2.org'),
|
Entry(compiled="", raw="Second Entry", file="file2.org"),
|
||||||
Entry(compiled='', raw='Third Entry', file= 'file 1.org'),
|
Entry(compiled="", raw="Third Entry", file="file 1.org"),
|
||||||
Entry(compiled='', raw='Fourth Entry', file= 'file2.org')
|
Entry(compiled="", raw="Fourth Entry", file="file2.org"),
|
||||||
]
|
]
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from khoj.utils import helpers
|
from khoj.utils import helpers
|
||||||
|
|
||||||
|
|
||||||
def test_get_from_null_dict():
|
def test_get_from_null_dict():
|
||||||
# null handling
|
# null handling
|
||||||
assert helpers.get_from_dict(dict()) == dict()
|
assert helpers.get_from_dict(dict()) == dict()
|
||||||
|
@ -7,39 +8,39 @@ def test_get_from_null_dict():
|
||||||
|
|
||||||
# key present in nested dictionary
|
# key present in nested dictionary
|
||||||
# 1-level dictionary
|
# 1-level dictionary
|
||||||
assert helpers.get_from_dict({'a': 1, 'b': 2}, 'a') == 1
|
assert helpers.get_from_dict({"a": 1, "b": 2}, "a") == 1
|
||||||
assert helpers.get_from_dict({'a': 1, 'b': 2}, 'c') == None
|
assert helpers.get_from_dict({"a": 1, "b": 2}, "c") == None
|
||||||
|
|
||||||
# 2-level dictionary
|
# 2-level dictionary
|
||||||
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'a') == {'a_a': 1}
|
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a") == {"a_a": 1}
|
||||||
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'a', 'a_a') == 1
|
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a", "a_a") == 1
|
||||||
|
|
||||||
# key not present in nested dictionary
|
# key not present in nested dictionary
|
||||||
# 2-level_dictionary
|
# 2-level_dictionary
|
||||||
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'b', 'b_a') == None
|
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "b", "b_a") == None
|
||||||
|
|
||||||
|
|
||||||
def test_merge_dicts():
|
def test_merge_dicts():
|
||||||
# basic merge of dicts with non-overlapping keys
|
# basic merge of dicts with non-overlapping keys
|
||||||
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'b': 2}) == {'a': 1, 'b': 2}
|
assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"b": 2}) == {"a": 1, "b": 2}
|
||||||
|
|
||||||
# use default dict items when not present in priority dict
|
# use default dict items when not present in priority dict
|
||||||
assert helpers.merge_dicts(priority_dict={}, default_dict={'b': 2}) == {'b': 2}
|
assert helpers.merge_dicts(priority_dict={}, default_dict={"b": 2}) == {"b": 2}
|
||||||
|
|
||||||
# do not override existing key in priority_dict with default dict
|
# do not override existing key in priority_dict with default dict
|
||||||
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1}
|
assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"a": 2}) == {"a": 1}
|
||||||
|
|
||||||
|
|
||||||
def test_lru_cache():
|
def test_lru_cache():
|
||||||
# Test initializing cache
|
# Test initializing cache
|
||||||
cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2)
|
cache = helpers.LRU({"a": 1, "b": 2}, capacity=2)
|
||||||
assert cache == {'a': 1, 'b': 2}
|
assert cache == {"a": 1, "b": 2}
|
||||||
|
|
||||||
# Test capacity overflow
|
# Test capacity overflow
|
||||||
cache['c'] = 3
|
cache["c"] = 3
|
||||||
assert cache == {'b': 2, 'c': 3}
|
assert cache == {"b": 2, "c": 3}
|
||||||
|
|
||||||
# Test delete least recently used item from LRU cache on capacity overflow
|
# Test delete least recently used item from LRU cache on capacity overflow
|
||||||
cache['b'] # accessing 'b' makes it the most recently used item
|
cache["b"] # accessing 'b' makes it the most recently used item
|
||||||
cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b'
|
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
||||||
assert cache == {'b': 2, 'd': 4}
|
assert cache == {"b": 2, "d": 4}
|
||||||
|
|
|
@ -30,7 +30,8 @@ def test_image_metadata(content_config: ContentConfig):
|
||||||
expected_metadata_image_name_pairs = [
|
expected_metadata_image_name_pairs = [
|
||||||
(["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"),
|
(["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"),
|
||||||
(["Pasture.", "Horse", "Dog"], "horse_dog.jpg"),
|
(["Pasture.", "Horse", "Dog"], "horse_dog.jpg"),
|
||||||
(["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg")]
|
(["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg"),
|
||||||
|
]
|
||||||
|
|
||||||
test_image_paths = [
|
test_image_paths = [
|
||||||
Path(content_config.image.input_directories[0] / image_name[1])
|
Path(content_config.image.input_directories[0] / image_name[1])
|
||||||
|
@ -51,23 +52,23 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
||||||
# Arrange
|
# Arrange
|
||||||
output_directory = resolve_absolute_path(web_directory)
|
output_directory = resolve_absolute_path(web_directory)
|
||||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||||
query_expected_image_pairs = [("kitten", "kitten_park.jpg"),
|
query_expected_image_pairs = [
|
||||||
|
("kitten", "kitten_park.jpg"),
|
||||||
("horse and dog in a farm", "horse_dog.jpg"),
|
("horse and dog in a farm", "horse_dog.jpg"),
|
||||||
("A guinea pig eating grass", "guineapig_grass.jpg")]
|
("A guinea pig eating grass", "guineapig_grass.jpg"),
|
||||||
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
for query, expected_image_name in query_expected_image_pairs:
|
for query, expected_image_name in query_expected_image_pairs:
|
||||||
hits = image_search.query(
|
hits = image_search.query(query, count=1, model=model.image_search)
|
||||||
query,
|
|
||||||
count = 1,
|
|
||||||
model = model.image_search)
|
|
||||||
|
|
||||||
results = image_search.collate_results(
|
results = image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
model.image_search.image_names,
|
model.image_search.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
image_files_url='/static/images',
|
image_files_url="/static/images",
|
||||||
count=1)
|
count=1,
|
||||||
|
)
|
||||||
|
|
||||||
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
|
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
|
||||||
actual_image = Image.open(actual_image_path)
|
actual_image = Image.open(actual_image_path)
|
||||||
|
@ -92,10 +93,7 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
|
||||||
# Act
|
# Act
|
||||||
try:
|
try:
|
||||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||||
image_search.query(
|
image_search.query(query, count=1, model=model.image_search)
|
||||||
query,
|
|
||||||
count = 1,
|
|
||||||
model = model.image_search)
|
|
||||||
# Assert
|
# Assert
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
||||||
|
@ -115,17 +113,15 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
|
||||||
hits = image_search.query(
|
hits = image_search.query(query, count=1, model=model.image_search)
|
||||||
query,
|
|
||||||
count = 1,
|
|
||||||
model = model.image_search)
|
|
||||||
|
|
||||||
results = image_search.collate_results(
|
results = image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
model.image_search.image_names,
|
model.image_search.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
image_files_url='/static/images',
|
image_files_url="/static/images",
|
||||||
count=1)
|
count=1,
|
||||||
|
)
|
||||||
|
|
||||||
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
|
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
|
||||||
actual_image = Image.open(actual_image_path)
|
actual_image = Image.open(actual_image_path)
|
||||||
|
@ -133,7 +129,9 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Ensure file search triggered instead of query with file path as string
|
# Ensure file search triggered instead of query with file path as string
|
||||||
assert f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text, "File search not triggered"
|
assert (
|
||||||
|
f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text
|
||||||
|
), "File search not triggered"
|
||||||
# Ensure the correct image is returned
|
# Ensure the correct image is returned
|
||||||
assert expected_image == actual_image, "Incorrect image returned by file search"
|
assert expected_image == actual_image, "Incorrect image returned by file search"
|
||||||
|
|
||||||
|
|
|
@ -8,10 +8,10 @@ from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||||
"Convert files with no heading to jsonl."
|
"Convert files with no heading to jsonl."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
- Bullet point 1
|
- Bullet point 1
|
||||||
- Bullet point 2
|
- Bullet point 2
|
||||||
'''
|
"""
|
||||||
markdownfile = create_file(tmp_path, entry)
|
markdownfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -20,7 +20,8 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
||||||
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries))
|
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
|
||||||
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -30,10 +31,10 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||||
def test_single_markdown_entry_to_jsonl(tmp_path):
|
def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||||
"Convert markdown entry from single file to jsonl."
|
"Convert markdown entry from single file to jsonl."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''### Heading
|
entry = f"""### Heading
|
||||||
\t\r
|
\t\r
|
||||||
Body Line 1
|
Body Line 1
|
||||||
'''
|
"""
|
||||||
markdownfile = create_file(tmp_path, entry)
|
markdownfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -42,7 +43,8 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
||||||
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map))
|
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)
|
||||||
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -52,14 +54,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||||
def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
||||||
"Convert multiple markdown entries from single file to jsonl."
|
"Convert multiple markdown entries from single file to jsonl."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
### Heading 1
|
### Heading 1
|
||||||
\t\r
|
\t\r
|
||||||
Heading 1 Body Line 1
|
Heading 1 Body Line 1
|
||||||
### Heading 2
|
### Heading 2
|
||||||
\t\r
|
\t\r
|
||||||
Heading 2 Body Line 2
|
Heading 2 Body Line 2
|
||||||
'''
|
"""
|
||||||
markdownfile = create_file(tmp_path, entry)
|
markdownfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -68,7 +70,8 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
||||||
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map))
|
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)
|
||||||
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -92,8 +95,8 @@ def test_get_markdown_files(tmp_path):
|
||||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||||
|
|
||||||
# Setup input-files, input-filters
|
# Setup input-files, input-filters
|
||||||
input_files = [tmp_path / 'notes.md']
|
input_files = [tmp_path / "notes.md"]
|
||||||
input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown']
|
input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter)
|
extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter)
|
||||||
|
@ -106,10 +109,10 @@ def test_get_markdown_files(tmp_path):
|
||||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||||
"Extract markdown entries with different level headings."
|
"Extract markdown entries with different level headings."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
# Heading 1
|
# Heading 1
|
||||||
## Heading 2
|
## Heading 2
|
||||||
'''
|
"""
|
||||||
markdownfile = create_file(tmp_path, entry)
|
markdownfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
|
|
@ -9,23 +9,25 @@ from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
|
|
||||||
def test_configure_heading_entry_to_jsonl(tmp_path):
|
def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||||
'''Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
|
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
|
||||||
Property drawers not considered Body. Ignore control characters for evaluating if Body empty.'''
|
Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''*** Heading
|
entry = f"""*** Heading
|
||||||
:PROPERTIES:
|
:PROPERTIES:
|
||||||
:ID: 42-42-42
|
:ID: 42-42-42
|
||||||
:END:
|
:END:
|
||||||
\t \r
|
\t \r
|
||||||
'''
|
"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
for index_heading_entries in [True, False]:
|
for index_heading_entries in [True, False]:
|
||||||
# Act
|
# Act
|
||||||
# Extract entries into jsonl from specified Org files
|
# Extract entries into jsonl from specified Org files
|
||||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(
|
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||||
*OrgToJsonl.extract_org_entries(org_files=[orgfile]),
|
OrgToJsonl.convert_org_nodes_to_entries(
|
||||||
index_heading_entries=index_heading_entries))
|
*OrgToJsonl.extract_org_entries(org_files=[orgfile]), index_heading_entries=index_heading_entries
|
||||||
|
)
|
||||||
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -40,10 +42,10 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||||
def test_entry_split_when_exceeds_max_words(tmp_path):
|
def test_entry_split_when_exceeds_max_words(tmp_path):
|
||||||
"Ensure entries with compiled words exceeding max_words are split."
|
"Ensure entries with compiled words exceeding max_words are split."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''*** Heading
|
entry = f"""*** Heading
|
||||||
\t\r
|
\t\r
|
||||||
Body Line 1
|
Body Line 1
|
||||||
'''
|
"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -53,8 +55,8 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
|
||||||
# Split each entry from specified Org files by max words
|
# Split each entry from specified Org files by max words
|
||||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||||
TextToJsonl.split_entries_by_max_tokens(
|
TextToJsonl.split_entries_by_max_tokens(
|
||||||
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map),
|
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=2
|
||||||
max_tokens = 2)
|
)
|
||||||
)
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
|
@ -65,10 +67,10 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
|
||||||
def test_entry_split_drops_large_words(tmp_path):
|
def test_entry_split_drops_large_words(tmp_path):
|
||||||
"Ensure entries drops words larger than specified max word length from compiled version."
|
"Ensure entries drops words larger than specified max word length from compiled version."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry_text = f'''*** Heading
|
entry_text = f"""*** Heading
|
||||||
\t\r
|
\t\r
|
||||||
Body Line 1
|
Body Line 1
|
||||||
'''
|
"""
|
||||||
entry = Entry(raw=entry_text, compiled=entry_text)
|
entry = Entry(raw=entry_text, compiled=entry_text)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -83,13 +85,13 @@ def test_entry_split_drops_large_words(tmp_path):
|
||||||
def test_entry_with_body_to_jsonl(tmp_path):
|
def test_entry_with_body_to_jsonl(tmp_path):
|
||||||
"Ensure entries with valid body text are loaded."
|
"Ensure entries with valid body text are loaded."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''*** Heading
|
entry = f"""*** Heading
|
||||||
:PROPERTIES:
|
:PROPERTIES:
|
||||||
:ID: 42-42-42
|
:ID: 42-42-42
|
||||||
:END:
|
:END:
|
||||||
\t\r
|
\t\r
|
||||||
Body Line 1
|
Body Line 1
|
||||||
'''
|
"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -97,7 +99,9 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
||||||
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
|
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map))
|
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||||
|
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map)
|
||||||
|
)
|
||||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -107,10 +111,10 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
||||||
def test_file_with_no_headings_to_jsonl(tmp_path):
|
def test_file_with_no_headings_to_jsonl(tmp_path):
|
||||||
"Ensure files with no heading, only body text are loaded."
|
"Ensure files with no heading, only body text are loaded."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
- Bullet point 1
|
- Bullet point 1
|
||||||
- Bullet point 2
|
- Bullet point 2
|
||||||
'''
|
"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -143,8 +147,8 @@ def test_get_org_files(tmp_path):
|
||||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]))
|
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]))
|
||||||
|
|
||||||
# Setup input-files, input-filters
|
# Setup input-files, input-filters
|
||||||
input_files = [tmp_path / 'orgfile1.org']
|
input_files = [tmp_path / "orgfile1.org"]
|
||||||
input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org']
|
input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter)
|
extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter)
|
||||||
|
@ -157,10 +161,10 @@ def test_get_org_files(tmp_path):
|
||||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||||
"Extract org entries with different level headings."
|
"Extract org entries with different level headings."
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
* Heading 1
|
* Heading 1
|
||||||
** Heading 2
|
** Heading 2
|
||||||
'''
|
"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -169,8 +173,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 2
|
assert len(entries) == 2
|
||||||
assert f'{entries[0]}'.startswith("* Heading 1")
|
assert f"{entries[0]}".startswith("* Heading 1")
|
||||||
assert f'{entries[1]}'.startswith("** Heading 2")
|
assert f"{entries[1]}".startswith("** Heading 2")
|
||||||
|
|
||||||
|
|
||||||
# Helper Functions
|
# Helper Functions
|
||||||
|
|
|
@ -10,7 +10,7 @@ from khoj.processor.org_mode import orgnode
|
||||||
def test_parse_entry_with_no_headings(tmp_path):
|
def test_parse_entry_with_no_headings(tmp_path):
|
||||||
"Test parsing of entry with minimal fields"
|
"Test parsing of entry with minimal fields"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''Body Line 1'''
|
entry = f"""Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -18,7 +18,7 @@ def test_parse_entry_with_no_headings(tmp_path):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
assert entries[0].heading == f'{orgfile}'
|
assert entries[0].heading == f"{orgfile}"
|
||||||
assert entries[0].tags == list()
|
assert entries[0].tags == list()
|
||||||
assert entries[0].body == "Body Line 1"
|
assert entries[0].body == "Body Line 1"
|
||||||
assert entries[0].priority == ""
|
assert entries[0].priority == ""
|
||||||
|
@ -32,9 +32,9 @@ def test_parse_entry_with_no_headings(tmp_path):
|
||||||
def test_parse_minimal_entry(tmp_path):
|
def test_parse_minimal_entry(tmp_path):
|
||||||
"Test parsing of entry with minimal fields"
|
"Test parsing of entry with minimal fields"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
* Heading
|
* Heading
|
||||||
Body Line 1'''
|
Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -56,7 +56,7 @@ Body Line 1'''
|
||||||
def test_parse_complete_entry(tmp_path):
|
def test_parse_complete_entry(tmp_path):
|
||||||
"Test parsing of entry with all important fields"
|
"Test parsing of entry with all important fields"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
*** DONE [#A] Heading :Tag1:TAG2:tag3:
|
*** DONE [#A] Heading :Tag1:TAG2:tag3:
|
||||||
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
|
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
|
||||||
:PROPERTIES:
|
:PROPERTIES:
|
||||||
|
@ -67,7 +67,7 @@ CLOCK: [1984-04-01 Sun 09:00]--[1984-04-01 Sun 12:00] => 3:00
|
||||||
- Clocked Log 1
|
- Clocked Log 1
|
||||||
:END:
|
:END:
|
||||||
Body Line 1
|
Body Line 1
|
||||||
Body Line 2'''
|
Body Line 2"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -91,35 +91,35 @@ Body Line 2'''
|
||||||
def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
|
def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
|
||||||
"Render heading entry with property drawer"
|
"Render heading entry with property drawer"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry_to_render = f'''
|
entry_to_render = f"""
|
||||||
*** [#A] Heading1 :tag1:
|
*** [#A] Heading1 :tag1:
|
||||||
:PROPERTIES:
|
:PROPERTIES:
|
||||||
:ID: 111-111-111-1111-1111
|
:ID: 111-111-111-1111-1111
|
||||||
:END:
|
:END:
|
||||||
\t\r \n
|
\t\r \n
|
||||||
'''
|
"""
|
||||||
orgfile = create_file(tmp_path, entry_to_render)
|
orgfile = create_file(tmp_path, entry_to_render)
|
||||||
|
|
||||||
expected_entry = f'''*** [#A] Heading1 :tag1:
|
expected_entry = f"""*** [#A] Heading1 :tag1:
|
||||||
:PROPERTIES:
|
:PROPERTIES:
|
||||||
:LINE: file:{orgfile}::2
|
:LINE: file:{orgfile}::2
|
||||||
:ID: id:111-111-111-1111-1111
|
:ID: id:111-111-111-1111-1111
|
||||||
:SOURCE: [[file:{orgfile}::*Heading1]]
|
:SOURCE: [[file:{orgfile}::*Heading1]]
|
||||||
:END:
|
:END:
|
||||||
'''
|
"""
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
parsed_entries = orgnode.makelist(orgfile)
|
parsed_entries = orgnode.makelist(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert f'{parsed_entries[0]}' == expected_entry
|
assert f"{parsed_entries[0]}" == expected_entry
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_all_links_to_entry_rendered(tmp_path):
|
def test_all_links_to_entry_rendered(tmp_path):
|
||||||
"Ensure all links to entry rendered in property drawer from entry"
|
"Ensure all links to entry rendered in property drawer from entry"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
*** [#A] Heading :tag1:
|
*** [#A] Heading :tag1:
|
||||||
:PROPERTIES:
|
:PROPERTIES:
|
||||||
:ID: 123-456-789-4234-1231
|
:ID: 123-456-789-4234-1231
|
||||||
|
@ -127,7 +127,7 @@ def test_all_links_to_entry_rendered(tmp_path):
|
||||||
Body Line 1
|
Body Line 1
|
||||||
*** Heading2
|
*** Heading2
|
||||||
Body Line 2
|
Body Line 2
|
||||||
'''
|
"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -135,23 +135,23 @@ Body Line 2
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# SOURCE link rendered with Heading
|
# SOURCE link rendered with Heading
|
||||||
assert f':SOURCE: [[file:{orgfile}::*{entries[0].heading}]]' in f'{entries[0]}'
|
assert f":SOURCE: [[file:{orgfile}::*{entries[0].heading}]]" in f"{entries[0]}"
|
||||||
# ID link rendered with ID
|
# ID link rendered with ID
|
||||||
assert f':ID: id:123-456-789-4234-1231' in f'{entries[0]}'
|
assert f":ID: id:123-456-789-4234-1231" in f"{entries[0]}"
|
||||||
# LINE link rendered with line number
|
# LINE link rendered with line number
|
||||||
assert f':LINE: file:{orgfile}::2' in f'{entries[0]}'
|
assert f":LINE: file:{orgfile}::2" in f"{entries[0]}"
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_source_link_to_entry_escaped_for_rendering(tmp_path):
|
def test_source_link_to_entry_escaped_for_rendering(tmp_path):
|
||||||
"Test SOURCE link renders with square brackets in filename, heading escaped for org-mode rendering"
|
"Test SOURCE link renders with square brackets in filename, heading escaped for org-mode rendering"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''
|
entry = f"""
|
||||||
*** [#A] Heading[1] :tag1:
|
*** [#A] Heading[1] :tag1:
|
||||||
:PROPERTIES:
|
:PROPERTIES:
|
||||||
:ID: 123-456-789-4234-1231
|
:ID: 123-456-789-4234-1231
|
||||||
:END:
|
:END:
|
||||||
Body Line 1'''
|
Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry, filename="test[1].org")
|
orgfile = create_file(tmp_path, entry, filename="test[1].org")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -162,15 +162,15 @@ Body Line 1'''
|
||||||
# parsed heading from entry
|
# parsed heading from entry
|
||||||
assert entries[0].heading == "Heading[1]"
|
assert entries[0].heading == "Heading[1]"
|
||||||
# ensure SOURCE link has square brackets in filename, heading escaped in rendered entries
|
# ensure SOURCE link has square brackets in filename, heading escaped in rendered entries
|
||||||
escaped_orgfile = f'{orgfile}'.replace("[1]", "\\[1\\]")
|
escaped_orgfile = f"{orgfile}".replace("[1]", "\\[1\\]")
|
||||||
assert f':SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]' in f'{entries[0]}'
|
assert f":SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]" in f"{entries[0]}"
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_parse_multiple_entries(tmp_path):
|
def test_parse_multiple_entries(tmp_path):
|
||||||
"Test parsing of multiple entries"
|
"Test parsing of multiple entries"
|
||||||
# Arrange
|
# Arrange
|
||||||
content = f'''
|
content = f"""
|
||||||
*** FAILED [#A] Heading1 :tag1:
|
*** FAILED [#A] Heading1 :tag1:
|
||||||
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
|
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
|
||||||
:PROPERTIES:
|
:PROPERTIES:
|
||||||
|
@ -193,7 +193,7 @@ CLOCK: [1984-04-02 Mon 09:00]--[1984-04-02 Mon 12:00] => 3:00
|
||||||
:END:
|
:END:
|
||||||
Body 2
|
Body 2
|
||||||
|
|
||||||
'''
|
"""
|
||||||
orgfile = create_file(tmp_path, content)
|
orgfile = create_file(tmp_path, content)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -211,15 +211,17 @@ Body 2
|
||||||
assert entry.closed == datetime.date(1984, 4, index + 1)
|
assert entry.closed == datetime.date(1984, 4, index + 1)
|
||||||
assert entry.scheduled == datetime.date(1984, 4, index + 1)
|
assert entry.scheduled == datetime.date(1984, 4, index + 1)
|
||||||
assert entry.deadline == datetime.date(1984, 4, index + 1)
|
assert entry.deadline == datetime.date(1984, 4, index + 1)
|
||||||
assert entry.logbook == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))]
|
assert entry.logbook == [
|
||||||
|
(datetime.datetime(1984, 4, index + 1, 9, 0, 0), datetime.datetime(1984, 4, index + 1, 12, 0, 0))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_parse_entry_with_empty_title(tmp_path):
|
def test_parse_entry_with_empty_title(tmp_path):
|
||||||
"Test parsing of entry with minimal fields"
|
"Test parsing of entry with minimal fields"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''#+TITLE:
|
entry = f"""#+TITLE:
|
||||||
Body Line 1'''
|
Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -227,7 +229,7 @@ Body Line 1'''
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
assert entries[0].heading == f'{orgfile}'
|
assert entries[0].heading == f"{orgfile}"
|
||||||
assert entries[0].tags == list()
|
assert entries[0].tags == list()
|
||||||
assert entries[0].body == "Body Line 1"
|
assert entries[0].body == "Body Line 1"
|
||||||
assert entries[0].priority == ""
|
assert entries[0].priority == ""
|
||||||
|
@ -241,8 +243,8 @@ Body Line 1'''
|
||||||
def test_parse_entry_with_title_and_no_headings(tmp_path):
|
def test_parse_entry_with_title_and_no_headings(tmp_path):
|
||||||
"Test parsing of entry with minimal fields"
|
"Test parsing of entry with minimal fields"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''#+TITLE: test
|
entry = f"""#+TITLE: test
|
||||||
Body Line 1'''
|
Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -250,7 +252,7 @@ Body Line 1'''
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
assert entries[0].heading == 'test'
|
assert entries[0].heading == "test"
|
||||||
assert entries[0].tags == list()
|
assert entries[0].tags == list()
|
||||||
assert entries[0].body == "Body Line 1"
|
assert entries[0].body == "Body Line 1"
|
||||||
assert entries[0].priority == ""
|
assert entries[0].priority == ""
|
||||||
|
@ -264,9 +266,9 @@ Body Line 1'''
|
||||||
def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path):
|
def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path):
|
||||||
"Test parsing of entry with minimal fields"
|
"Test parsing of entry with minimal fields"
|
||||||
# Arrange
|
# Arrange
|
||||||
entry = f'''#+TITLE: title1
|
entry = f"""#+TITLE: title1
|
||||||
Body Line 1
|
Body Line 1
|
||||||
#+TITLE: title2 '''
|
#+TITLE: title2 """
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -274,7 +276,7 @@ Body Line 1
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
assert entries[0].heading == 'title1 title2'
|
assert entries[0].heading == "title1 title2"
|
||||||
assert entries[0].tags == list()
|
assert entries[0].tags == list()
|
||||||
assert entries[0].body == "Body Line 1\n"
|
assert entries[0].body == "Body Line 1\n"
|
||||||
assert entries[0].priority == ""
|
assert entries[0].priority == ""
|
||||||
|
|
|
@ -14,7 +14,9 @@ from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
|
def test_asymmetric_setup_with_missing_file_raises_error(
|
||||||
|
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||||
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
# Ensure file mentioned in org.input-files is missing
|
# Ensure file mentioned in org.input-files is missing
|
||||||
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
||||||
|
@ -27,10 +29,12 @@ def test_asymmetric_setup_with_missing_file_raises_error(org_config_with_only_ne
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_setup_with_empty_file_raises_error(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
|
def test_asymmetric_setup_with_empty_file_raises_error(
|
||||||
|
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||||
|
):
|
||||||
# Act
|
# Act
|
||||||
# Generate notes embeddings during asymmetric setup
|
# Generate notes embeddings during asymmetric setup
|
||||||
with pytest.raises(ValueError, match=r'^No valid entries found*'):
|
with pytest.raises(ValueError, match=r"^No valid entries found*"):
|
||||||
text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,15 +56,9 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
||||||
query = "How to git install application?"
|
query = "How to git install application?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
hits, entries = text_search.query(
|
hits, entries = text_search.query(query, model=model.notes_search, rank_results=True)
|
||||||
query,
|
|
||||||
model = model.notes_search,
|
|
||||||
rank_results=True)
|
|
||||||
|
|
||||||
results = text_search.collate_results(
|
results = text_search.collate_results(hits, entries, count=1)
|
||||||
hits,
|
|
||||||
entries,
|
|
||||||
count=1)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Actual_data should contain "Khoj via Emacs" entry
|
# Actual_data should contain "Khoj via Emacs" entry
|
||||||
|
@ -81,7 +79,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# reload embeddings, entries, notes model after adding new org-mode file
|
# reload embeddings, entries, notes model after adding new org-mode file
|
||||||
initial_notes_model = text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False)
|
initial_notes_model = text_search.setup(
|
||||||
|
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify newly added org-mode entry is split by max tokens
|
# verify newly added org-mode entry is split by max tokens
|
||||||
|
@ -98,12 +98,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||||
|
|
||||||
# append org-mode entry to first org input file in config
|
# append org-mode entry to first org input file in config
|
||||||
content_config.org.input_files = [f'{new_org_file}']
|
content_config.org.input_files = [f"{new_org_file}"]
|
||||||
with open(new_org_file, "w") as f:
|
with open(new_org_file, "w") as f:
|
||||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||||
|
|
||||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||||
regenerated_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
regenerated_notes_model = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True
|
||||||
|
)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
||||||
|
@ -137,7 +139,7 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# update embeddings, entries with the newly added note
|
# update embeddings, entries with the newly added note
|
||||||
content_config.org.input_files = [f'{new_org_file}']
|
content_config.org.input_files = [f"{new_org_file}"]
|
||||||
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
|
|
@ -7,7 +7,7 @@ def test_no_word_filter():
|
||||||
# Arrange
|
# Arrange
|
||||||
word_filter = WordFilter()
|
word_filter = WordFilter()
|
||||||
entries = arrange_content()
|
entries = arrange_content()
|
||||||
q_with_no_filter = 'head tail'
|
q_with_no_filter = "head tail"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = word_filter.can_filter(q_with_no_filter)
|
can_filter = word_filter.can_filter(q_with_no_filter)
|
||||||
|
@ -15,7 +15,7 @@ def test_no_word_filter():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == False
|
assert can_filter == False
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
assert entry_indices == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def test_word_exclude_filter():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {0, 2}
|
assert entry_indices == {0, 2}
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ def test_word_include_filter():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {2, 3}
|
assert entry_indices == {2, 3}
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,16 +63,16 @@ def test_word_include_and_exclude_filter():
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == "head tail"
|
||||||
assert entry_indices == {2}
|
assert entry_indices == {2}
|
||||||
|
|
||||||
|
|
||||||
def arrange_content():
|
def arrange_content():
|
||||||
entries = [
|
entries = [
|
||||||
Entry(compiled='', raw='Minimal Entry'),
|
Entry(compiled="", raw="Minimal Entry"),
|
||||||
Entry(compiled='', raw='Entry with exclude_word'),
|
Entry(compiled="", raw="Entry with exclude_word"),
|
||||||
Entry(compiled='', raw='Entry with include_word'),
|
Entry(compiled="", raw="Entry with include_word"),
|
||||||
Entry(compiled='', raw='Entry with include_word and exclude_word')
|
Entry(compiled="", raw="Entry with include_word and exclude_word"),
|
||||||
]
|
]
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
Loading…
Reference in a new issue