Use Black to format Khoj server code and tests

This commit is contained in:
Debanjum Singh Solanky 2023-02-17 10:04:26 -06:00
parent 6130fddf45
commit 5e83baab21
44 changed files with 1167 additions and 915 deletions

View file

@ -64,7 +64,8 @@ khoj = "khoj.main:run"
[project.optional-dependencies]
test = [
"pytest == 7.1.2",
"pytest >= 7.1.2",
"black >= 23.1.0",
]
dev = ["khoj-assistant[test]"]
@ -88,3 +89,6 @@ exclude = [
"src/khoj/interface/desktop/file_browser.py",
"src/khoj/interface/desktop/system_tray.py",
]
[tool.black]
line-length = 120

View file

@ -26,10 +26,12 @@ logger = logging.getLogger(__name__)
def configure_server(args, required=False):
if args.config is None:
if required:
logger.error(f'Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.')
logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.")
sys.exit(1)
else:
logger.warn(f'Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}.')
logger.warn(
f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}."
)
return
else:
state.config = args.config
@ -60,7 +62,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.org,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music:
@ -70,7 +73,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.music,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter()])
filters=[DateFilter(), WordFilter()],
)
# Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
@ -80,7 +84,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.markdown,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
@ -90,15 +95,15 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.ledger,
search_config=config.search_type.symmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(
config.content_type.image,
search_config=config.search_type.image,
regenerate=regenerate)
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
)
# Invalidate Query Cache
state.query_cache = LRU()
@ -125,9 +130,9 @@ def configure_conversation_processor(conversation_processor_config):
if conversation_logfile.is_file():
# Load Metadata Logs from Conversation Logfile
with conversation_logfile.open('r') as f:
with conversation_logfile.open("r") as f:
conversation_processor.meta_log = json.load(f)
logger.info('Conversation logs loaded from disk.')
logger.info("Conversation logs loaded from disk.")
else:
# Initialize Conversation Logs
conversation_processor.meta_log = {}

View file

@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty
class FileBrowser(QtWidgets.QWidget):
def __init__(self, title, search_type: SearchType=None, default_files:list=[]):
def __init__(self, title, search_type: SearchType = None, default_files: list = []):
QtWidgets.QWidget.__init__(self)
layout = QtWidgets.QHBoxLayout()
self.setLayout(layout)
@ -26,47 +26,50 @@ class FileBrowser(QtWidgets.QWidget):
self.lineEdit = QtWidgets.QPlainTextEdit(self)
self.lineEdit.setFixedWidth(330)
self.setFiles(default_files)
self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90))
self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90))
self.lineEdit.textChanged.connect(self.updateFieldHeight)
layout.addWidget(self.lineEdit)
self.button = QtWidgets.QPushButton('Add')
self.button = QtWidgets.QPushButton("Add")
self.button.clicked.connect(self.storeFilesSelectedInFileDialog)
layout.addWidget(self.button)
layout.addStretch()
def getFileFilter(self, search_type):
if search_type == SearchType.Org:
return 'Org-Mode Files (*.org)'
return "Org-Mode Files (*.org)"
elif search_type == SearchType.Ledger:
return 'Beancount Files (*.bean *.beancount)'
return "Beancount Files (*.bean *.beancount)"
elif search_type == SearchType.Markdown:
return 'Markdown Files (*.md *.markdown)'
return "Markdown Files (*.md *.markdown)"
elif search_type == SearchType.Music:
return 'Org-Music Files (*.org)'
return "Org-Music Files (*.org)"
elif search_type == SearchType.Image:
return 'Images (*.jp[e]g)'
return "Images (*.jp[e]g)"
def storeFilesSelectedInFileDialog(self):
filepaths = self.getPaths()
if self.search_type == SearchType.Image:
filepaths.append(QtWidgets.QFileDialog.getExistingDirectory(self, caption='Choose Folder',
directory=self.dirpath))
filepaths.append(
QtWidgets.QFileDialog.getExistingDirectory(self, caption="Choose Folder", directory=self.dirpath)
)
else:
filepaths.extend(QtWidgets.QFileDialog.getOpenFileNames(self, caption='Choose Files',
directory=self.dirpath,
filter=self.filter_name)[0])
filepaths.extend(
QtWidgets.QFileDialog.getOpenFileNames(
self, caption="Choose Files", directory=self.dirpath, filter=self.filter_name
)[0]
)
self.setFiles(filepaths)
def setFiles(self, paths:list):
def setFiles(self, paths: list):
self.filepaths = [path for path in paths if not is_none_or_empty(path)]
self.lineEdit.setPlainText("\n".join(self.filepaths))
def getPaths(self) -> list:
if self.lineEdit.toPlainText() == '':
if self.lineEdit.toPlainText() == "":
return []
else:
return self.lineEdit.toPlainText().split('\n')
return self.lineEdit.toPlainText().split("\n")
def updateFieldHeight(self):
self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90))
self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90))

View file

@ -6,7 +6,7 @@ from khoj.utils.config import ProcessorType
class LabelledTextField(QtWidgets.QWidget):
def __init__(self, title, processor_type: ProcessorType=None, default_value: str=None):
def __init__(self, title, processor_type: ProcessorType = None, default_value: str = None):
QtWidgets.QWidget.__init__(self)
layout = QtWidgets.QHBoxLayout()
self.setLayout(layout)

View file

@ -31,9 +31,9 @@ class MainWindow(QtWidgets.QMainWindow):
self.config_file = config_file
# Set regenerate flag to regenerate embeddings everytime user clicks configure
if state.cli_args:
state.cli_args += ['--regenerate']
state.cli_args += ["--regenerate"]
else:
state.cli_args = ['--regenerate']
state.cli_args = ["--regenerate"]
# Load config from existing config, if exists, else load from default config
if resolve_absolute_path(self.config_file).exists():
@ -49,8 +49,8 @@ class MainWindow(QtWidgets.QMainWindow):
self.setFixedWidth(600)
# Set Window Icon
icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png'
self.setWindowIcon(QtGui.QIcon(f'{icon_path.absolute()}'))
icon_path = constants.web_directory / "assets/icons/favicon-144x144.png"
self.setWindowIcon(QtGui.QIcon(f"{icon_path.absolute()}"))
# Initialize Configure Window Layout
self.layout = QtWidgets.QVBoxLayout()
@ -58,13 +58,13 @@ class MainWindow(QtWidgets.QMainWindow):
# Add Settings Panels for each Search Type to Configure Window Layout
self.search_settings_panels = []
for search_type in SearchType:
current_content_config = self.current_config['content-type'].get(search_type, {})
current_content_config = self.current_config["content-type"].get(search_type, {})
self.search_settings_panels += [self.add_settings_panel(current_content_config, search_type)]
# Add Conversation Processor Panel to Configure Screen
self.processor_settings_panels = []
conversation_type = ProcessorType.Conversation
current_conversation_config = self.current_config['processor'].get(conversation_type, {})
current_conversation_config = self.current_config["processor"].get(conversation_type, {})
self.processor_settings_panels += [self.add_processor_panel(current_conversation_config, conversation_type)]
# Add Action Buttons Panel
@ -81,11 +81,11 @@ class MainWindow(QtWidgets.QMainWindow):
"Add Settings Panel for specified Search Type. Toggle Editable Search Types"
# Get current files from config for given search type
if search_type == SearchType.Image:
current_content_files = current_content_config.get('input-directories', [])
file_input_text = f'{search_type.name} Folders'
current_content_files = current_content_config.get("input-directories", [])
file_input_text = f"{search_type.name} Folders"
else:
current_content_files = current_content_config.get('input-files', [])
file_input_text = f'{search_type.name} Files'
current_content_files = current_content_config.get("input-files", [])
file_input_text = f"{search_type.name} Files"
# Create widgets to display settings for given search type
search_type_settings = QtWidgets.QWidget()
@ -109,7 +109,7 @@ class MainWindow(QtWidgets.QMainWindow):
def add_processor_panel(self, current_conversation_config: dict, processor_type: ProcessorType):
"Add Conversation Processor Panel"
# Get current settings from config for given processor type
current_openai_api_key = current_conversation_config.get('openai-api-key', None)
current_openai_api_key = current_conversation_config.get("openai-api-key", None)
# Create widgets to display settings for given processor type
processor_type_settings = QtWidgets.QWidget()
@ -137,20 +137,22 @@ class MainWindow(QtWidgets.QMainWindow):
action_bar_layout = QtWidgets.QHBoxLayout(action_bar)
self.configure_button = QtWidgets.QPushButton("Configure", clicked=self.configure_app)
self.search_button = QtWidgets.QPushButton("Search", clicked=lambda: webbrowser.open(f'http://{state.host}:{state.port}/'))
self.search_button = QtWidgets.QPushButton(
"Search", clicked=lambda: webbrowser.open(f"http://{state.host}:{state.port}/")
)
self.search_button.setEnabled(not self.first_run)
action_bar_layout.addWidget(self.configure_button)
action_bar_layout.addWidget(self.search_button)
self.layout.addWidget(action_bar)
def get_default_config(self, search_type:SearchType=None, processor_type:ProcessorType=None):
def get_default_config(self, search_type: SearchType = None, processor_type: ProcessorType = None):
"Get default config"
config = constants.default_config
if search_type:
return config['content-type'][search_type]
return config["content-type"][search_type]
elif processor_type:
return config['processor'][processor_type]
return config["processor"][processor_type]
else:
return config
@ -160,7 +162,9 @@ class MainWindow(QtWidgets.QMainWindow):
for message_prefix in ErrorType:
for i in reversed(range(self.layout.count())):
current_widget = self.layout.itemAt(i).widget()
if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(message_prefix.value):
if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(
message_prefix.value
):
self.layout.removeWidget(current_widget)
current_widget.deleteLater()
@ -180,18 +184,24 @@ class MainWindow(QtWidgets.QMainWindow):
continue
if isinstance(child, SearchCheckBox):
# Search Type Disabled
if not child.isChecked() and child.search_type in self.new_config['content-type']:
del self.new_config['content-type'][child.search_type]
if not child.isChecked() and child.search_type in self.new_config["content-type"]:
del self.new_config["content-type"][child.search_type]
# Search Type (re)-Enabled
if child.isChecked():
current_search_config = self.current_config['content-type'].get(child.search_type, {})
default_search_config = self.get_default_config(search_type = child.search_type)
self.new_config['content-type'][child.search_type.value] = merge_dicts(current_search_config, default_search_config)
elif isinstance(child, FileBrowser) and child.search_type in self.new_config['content-type']:
current_search_config = self.current_config["content-type"].get(child.search_type, {})
default_search_config = self.get_default_config(search_type=child.search_type)
self.new_config["content-type"][child.search_type.value] = merge_dicts(
current_search_config, default_search_config
)
elif isinstance(child, FileBrowser) and child.search_type in self.new_config["content-type"]:
if child.search_type.value == SearchType.Image:
self.new_config['content-type'][child.search_type.value]['input-directories'] = child.getPaths() if child.getPaths() != [] else None
self.new_config["content-type"][child.search_type.value]["input-directories"] = (
child.getPaths() if child.getPaths() != [] else None
)
else:
self.new_config['content-type'][child.search_type.value]['input-files'] = child.getPaths() if child.getPaths() != [] else None
self.new_config["content-type"][child.search_type.value]["input-files"] = (
child.getPaths() if child.getPaths() != [] else None
)
def update_processor_settings(self):
"Update config with conversation settings from UI"
@ -201,16 +211,20 @@ class MainWindow(QtWidgets.QMainWindow):
continue
if isinstance(child, ProcessorCheckBox):
# Processor Type Disabled
if not child.isChecked() and child.processor_type in self.new_config['processor']:
del self.new_config['processor'][child.processor_type]
if not child.isChecked() and child.processor_type in self.new_config["processor"]:
del self.new_config["processor"][child.processor_type]
# Processor Type (re)-Enabled
if child.isChecked():
current_processor_config = self.current_config['processor'].get(child.processor_type, {})
default_processor_config = self.get_default_config(processor_type = child.processor_type)
self.new_config['processor'][child.processor_type.value] = merge_dicts(current_processor_config, default_processor_config)
elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config['processor']:
current_processor_config = self.current_config["processor"].get(child.processor_type, {})
default_processor_config = self.get_default_config(processor_type=child.processor_type)
self.new_config["processor"][child.processor_type.value] = merge_dicts(
current_processor_config, default_processor_config
)
elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config["processor"]:
if child.processor_type == ProcessorType.Conversation:
self.new_config['processor'][child.processor_type.value]['openai-api-key'] = child.input_field.toPlainText() if child.input_field.toPlainText() != '' else None
self.new_config["processor"][child.processor_type.value]["openai-api-key"] = (
child.input_field.toPlainText() if child.input_field.toPlainText() != "" else None
)
def save_settings_to_file(self) -> bool:
"Save validated settings to file"
@ -312,6 +326,7 @@ class ProcessorCheckBox(QtWidgets.QCheckBox):
self.processor_type = processor_type
super(ProcessorCheckBox, self).__init__(text, parent=parent)
class ErrorType(Enum):
"Error Types"
ConfigLoadingError = "Config Loading Error"

View file

@ -17,17 +17,17 @@ def create_system_tray(gui: QtWidgets.QApplication, main_window: MainWindow):
"""
# Create the system tray with icon
icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png'
icon = QtGui.QIcon(f'{icon_path.absolute()}')
icon_path = constants.web_directory / "assets/icons/favicon-144x144.png"
icon = QtGui.QIcon(f"{icon_path.absolute()}")
tray = QtWidgets.QSystemTrayIcon(icon)
tray.setVisible(True)
# Create the menu and menu actions
menu = QtWidgets.QMenu()
menu_actions = [
('Search', lambda: webbrowser.open(f'http://{state.host}:{state.port}/')),
('Configure', main_window.show_on_top),
('Quit', gui.quit),
("Search", lambda: webbrowser.open(f"http://{state.host}:{state.port}/")),
("Configure", main_window.show_on_top),
("Quit", gui.quit),
]
# Add the menu actions to the menu

View file

@ -8,8 +8,8 @@ import warnings
from platform import system
# Ignore non-actionable warnings
warnings.filterwarnings("ignore", message=r'snapshot_download.py has been made private', category=FutureWarning)
warnings.filterwarnings("ignore", message=r'legacy way to download files from the HF hub,', category=FutureWarning)
warnings.filterwarnings("ignore", message=r"snapshot_download.py has been made private", category=FutureWarning)
warnings.filterwarnings("ignore", message=r"legacy way to download files from the HF hub,", category=FutureWarning)
# External Packages
import uvicorn
@ -43,11 +43,12 @@ rich_handler = RichHandler(rich_tracebacks=True)
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
logging.basicConfig(handlers=[rich_handler])
logger = logging.getLogger('khoj')
logger = logging.getLogger("khoj")
def run():
# Turn Tokenizers Parallelism Off. App does not support it.
os.environ["TOKENIZERS_PARALLELISM"] = 'false'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Load config from CLI
state.cli_args = sys.argv[1:]
@ -66,7 +67,7 @@ def run():
logger.setLevel(logging.DEBUG)
# Set Log File
fh = logging.FileHandler(state.config_file.parent / 'khoj.log')
fh = logging.FileHandler(state.config_file.parent / "khoj.log")
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
@ -87,7 +88,7 @@ def run():
# On Linux (Gnome) the System tray is not supported.
# Since only the Main Window is available
# Quitting it should quit the application
if system() in ['Windows', 'Darwin']:
if system() in ["Windows", "Darwin"]:
gui.setQuitOnLastWindowClosed(False)
tray = create_system_tray(gui, main_window)
tray.show()
@ -97,7 +98,7 @@ def run():
server = ServerThread(app, args.host, args.port, args.socket)
# Show Main Window on First Run Experience or if on Linux
if args.config is None or system() not in ['Windows', 'Darwin']:
if args.config is None or system() not in ["Windows", "Darwin"]:
main_window.show()
# Setup Signal Handlers
@ -112,9 +113,10 @@ def run():
gui.aboutToQuit.connect(server.terminate)
# Close Splash Screen if still open
if system() != 'Darwin':
if system() != "Darwin":
try:
import pyi_splash
# Update the text on the splash screen
pyi_splash.update_text("Khoj setup complete")
# Close Splash Screen
@ -167,5 +169,5 @@ class ServerThread(QThread):
start_server(self.app, self.host, self.port, self.socket)
if __name__ == '__main__':
if __name__ == "__main__":
run()

View file

@ -19,31 +19,27 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
# Setup Prompt based on Summary Type
if summary_type == "chat":
prompt = f'''
prompt = f"""
You are an AI. Summarize the conversation below from your perspective:
{text}
Summarize the conversation from the AI's first-person perspective:'''
Summarize the conversation from the AI's first-person perspective:"""
elif summary_type == "notes":
prompt = f'''
prompt = f"""
Summarize the below notes about {user_query}:
{text}
Summarize the notes in second person perspective:'''
Summarize the notes in second person perspective:"""
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop="\"\"\"")
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
)
# Extract, Clean Message from GPT's Response
story = response['choices'][0]['text']
story = response["choices"][0]["text"]
return str(story).replace("\n\n", "")
@ -53,7 +49,7 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
understand_primer = '''
understand_primer = """
Objective: Extract search type from user query and return information as JSON
Allowed search types are listed below:
@ -73,7 +69,7 @@ A:{ "search-type": "notes" }
Q: When did I buy Groceries last?
A:{ "search-type": "ledger" }
Q:When did I go surfing last?
A:{ "search-type": "notes" }'''
A:{ "search-type": "notes" }"""
# Setup Prompt with Understand Primer
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
@ -82,15 +78,11 @@ A:{ "search-type": "notes" }'''
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop=["\n"])
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
)
# Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text'])
story = str(response["choices"][0]["text"])
return json.loads(story.strip(empty_escape_sequences))
@ -100,7 +92,7 @@ def understand(text, model, api_key=None, temperature=0.5, max_tokens=100, verbo
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
understand_primer = '''
understand_primer = """
Objective: Extract intent and trigger emotion information as JSON from each chat message
Potential intent types and valid argument values are listed below:
@ -142,7 +134,7 @@ A: { "intent": {"type": "remember", "memory-type": "notes", "query": "recommend
Q: When did I go surfing last?
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" }
Q: Can you dance for me?
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }'''
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }"""
# Setup Prompt with Understand Primer
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
@ -151,15 +143,11 @@ A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop=["\n"])
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
)
# Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text'])
story = str(response["choices"][0]["text"])
return json.loads(story.strip(empty_escape_sequences))
@ -171,15 +159,15 @@ def converse(text, model, conversation_history=None, api_key=None, temperature=0
max_words = 500
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
conversation_primer = f'''
conversation_primer = f"""
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion.
Human: Hello, who are you?
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?'''
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?"""
# Setup Prompt with Primer or Conversation History
prompt = message_to_prompt(text, conversation_history or conversation_primer)
prompt = ' '.join(prompt.split()[:max_words])
prompt = " ".join(prompt.split()[:max_words])
# Get Response from GPT
response = openai.Completion.create(
@ -188,14 +176,17 @@ AI: Hi, I am an AI conversational companion created by OpenAI. How can I help yo
temperature=temperature,
max_tokens=max_tokens,
presence_penalty=0.6,
stop=["\n", "Human:", "AI:"])
stop=["\n", "Human:", "AI:"],
)
# Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text'])
story = str(response["choices"][0]["text"])
return story.strip(empty_escape_sequences)
def message_to_prompt(user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"):
def message_to_prompt(
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
):
"""Create prompt for GPT from messages and conversation history"""
gpt_message = f" {gpt_message}" if gpt_message else ""
@ -205,12 +196,8 @@ def message_to_prompt(user_message, conversation_history="", gpt_message=None, s
def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]):
"""Create json logs from messages, metadata for conversation log"""
default_user_message_metadata = {
"intent": {
"type": "remember",
"memory-type": "notes",
"query": user_message
},
"trigger-emotion": "calm"
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
"trigger-emotion": "calm",
}
current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -229,5 +216,4 @@ def message_to_log(user_message, gpt_message, user_message_metadata={}, conversa
def extract_summaries(metadata):
"""Extract summaries from metadata"""
return ''.join(
[f'\n{session["summary"]}' for session in metadata])
return "".join([f'\n{session["summary"]}' for session in metadata])

View file

@ -19,7 +19,11 @@ class BeancountToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=None):
# Extract required fields from config
beancount_files, beancount_file_filter, output_file = self.config.input_files, self.config.input_filter,self.config.compressed_jsonl
beancount_files, beancount_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
@ -31,7 +35,9 @@ class BeancountToJsonl(TextToJsonl):
# Extract Entries from specified Beancount files
with timer("Parse transactions from Beancount files into dictionaries", logger):
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
current_entries = BeancountToJsonl.convert_transactions_to_maps(
*BeancountToJsonl.extract_beancount_transactions(beancount_files)
)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -42,7 +48,9 @@ class BeancountToJsonl(TextToJsonl):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write transactions to JSONL file", logger):
# Process Each Entry from All Notes Files
@ -62,9 +70,7 @@ class BeancountToJsonl(TextToJsonl):
"Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files:
absolute_beancount_files = {get_absolute_path(beancount_file)
for beancount_file
in beancount_files}
absolute_beancount_files = {get_absolute_path(beancount_file) for beancount_file in beancount_files}
if beancount_file_filters:
filtered_beancount_files = {
filtered_file
@ -76,14 +82,13 @@ class BeancountToJsonl(TextToJsonl):
files_with_non_beancount_extensions = {
beancount_file
for beancount_file
in all_beancount_files
for beancount_file in all_beancount_files
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
}
if any(files_with_non_beancount_extensions):
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
logger.info(f'Processing files: {all_beancount_files}')
logger.info(f"Processing files: {all_beancount_files}")
return all_beancount_files
@ -92,19 +97,20 @@ class BeancountToJsonl(TextToJsonl):
"Extract entries from specified Beancount files"
# Initialize Regex for extracting Beancount Entries
transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] '
empty_newline = f'^[\n\r\t\ ]*$'
transaction_regex = r"^\n?\d{4}-\d{2}-\d{2} [\*|\!] "
empty_newline = f"^[\n\r\t\ ]*$"
entries = []
transaction_to_file_map = []
for beancount_file in beancount_files:
with open(beancount_file) as f:
ledger_content = f.read()
transactions_per_file = [entry.strip(empty_escape_sequences)
for entry
in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
if re.match(transaction_regex, entry)]
transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file))
transactions_per_file = [
entry.strip(empty_escape_sequences)
for entry in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
if re.match(transaction_regex, entry)
]
transaction_to_file_map += zip(transactions_per_file, [beancount_file] * len(transactions_per_file))
entries.extend(transactions_per_file)
return entries, dict(transaction_to_file_map)
@ -113,7 +119,9 @@ class BeancountToJsonl(TextToJsonl):
"Convert each parsed Beancount transaction into a Entry"
entries = []
for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}'))
entries.append(
Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{transaction_to_file_map[parsed_entry]}")
)
logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
@ -122,4 +130,4 @@ class BeancountToJsonl(TextToJsonl):
@staticmethod
def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str:
"Convert each Beancount transaction entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries])
return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -20,7 +20,11 @@ class MarkdownToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=None):
# Extract required fields from config
markdown_files, markdown_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
markdown_files, markdown_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
@ -32,7 +36,9 @@ class MarkdownToJsonl(TextToJsonl):
# Extract Entries from specified Markdown files
with timer("Parse entries from Markdown files into dictionaries", logger):
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(
*MarkdownToJsonl.extract_markdown_entries(markdown_files)
)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@ -43,7 +49,9 @@ class MarkdownToJsonl(TextToJsonl):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write markdown entries to JSONL file", logger):
# Process Each Entry from All Notes Files
@ -75,15 +83,16 @@ class MarkdownToJsonl(TextToJsonl):
files_with_non_markdown_extensions = {
md_file
for md_file
in all_markdown_files
if not md_file.endswith(".md") and not md_file.endswith('.markdown')
for md_file in all_markdown_files
if not md_file.endswith(".md") and not md_file.endswith(".markdown")
}
if any(files_with_non_markdown_extensions):
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
logger.warn(
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
)
logger.info(f'Processing files: {all_markdown_files}')
logger.info(f"Processing files: {all_markdown_files}")
return all_markdown_files
@ -92,20 +101,20 @@ class MarkdownToJsonl(TextToJsonl):
"Extract entries by heading from specified Markdown files"
# Regex to extract Markdown Entries by Heading
markdown_heading_regex = r'^#'
markdown_heading_regex = r"^#"
entries = []
entry_to_file_map = []
for markdown_file in markdown_files:
with open(markdown_file, 'r', encoding='utf8') as f:
with open(markdown_file, "r", encoding="utf8") as f:
markdown_content = f.read()
markdown_entries_per_file = []
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
prefix = '#' if entry.startswith('#') else '# '
if entry.strip(empty_escape_sequences) != '':
markdown_entries_per_file.append(f'{prefix}{entry.strip(empty_escape_sequences)}')
prefix = "#" if entry.startswith("#") else "# "
if entry.strip(empty_escape_sequences) != "":
markdown_entries_per_file.append(f"{prefix}{entry.strip(empty_escape_sequences)}")
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file))
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
entries.extend(markdown_entries_per_file)
return entries, dict(entry_to_file_map)
@ -115,7 +124,7 @@ class MarkdownToJsonl(TextToJsonl):
"Convert each Markdown entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}'))
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{entry_to_file_map[parsed_entry]}"))
logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
@ -124,4 +133,4 @@ class MarkdownToJsonl(TextToJsonl):
@staticmethod
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
"Convert each Markdown entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries])
return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -18,9 +18,13 @@ logger = logging.getLogger(__name__)
class OrgToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries: List[Entry]=None):
def process(self, previous_entries: List[Entry] = None):
# Extract required fields from config
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
org_files, org_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
index_heading_entries = self.config.index_heading_entries
# Input Validation
@ -46,7 +50,9 @@ class OrgToJsonl(TextToJsonl):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
# Process Each Entry from All Notes Files
with timer("Write org entries to JSONL file", logger):
@ -66,11 +72,7 @@ class OrgToJsonl(TextToJsonl):
"Get Org files to process"
absolute_org_files, filtered_org_files = set(), set()
if org_files:
absolute_org_files = {
get_absolute_path(org_file)
for org_file
in org_files
}
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
if org_file_filters:
filtered_org_files = {
filtered_file
@ -84,7 +86,7 @@ class OrgToJsonl(TextToJsonl):
if any(files_with_non_org_extensions):
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.info(f'Processing files: {all_org_files}')
logger.info(f"Processing files: {all_org_files}")
return all_org_files
@ -95,13 +97,15 @@ class OrgToJsonl(TextToJsonl):
entry_to_file_map = []
for org_file in org_files:
org_file_entries = orgnode.makelist(str(org_file))
entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries))
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries)
return entries, dict(entry_to_file_map)
@staticmethod
def convert_org_nodes_to_entries(parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> List[Entry]:
def convert_org_nodes_to_entries(
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
) -> List[Entry]:
"Convert Org-Mode nodes into list of Entry objects"
entries: List[Entry] = []
for parsed_entry in parsed_entries:
@ -109,13 +113,13 @@ class OrgToJsonl(TextToJsonl):
# Ignore title notes i.e notes with just headings and empty body
continue
compiled = f'{parsed_entry.heading}.'
compiled = f"{parsed_entry.heading}."
if state.verbose > 2:
logger.debug(f"Title: {parsed_entry.heading}")
if parsed_entry.tags:
tags_str = " ".join(parsed_entry.tags)
compiled += f'\t {tags_str}.'
compiled += f"\t {tags_str}."
if state.verbose > 2:
logger.debug(f"Tags: {tags_str}")
@ -130,19 +134,16 @@ class OrgToJsonl(TextToJsonl):
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
if parsed_entry.hasBody:
compiled += f'\n {parsed_entry.body}'
compiled += f"\n {parsed_entry.body}"
if state.verbose > 2:
logger.debug(f"Body: {parsed_entry.body}")
if compiled:
entries += [Entry(
compiled=compiled,
raw=f'{parsed_entry}',
file=f'{entry_to_file_map[parsed_entry]}')]
entries += [Entry(compiled=compiled, raw=f"{parsed_entry}", file=f"{entry_to_file_map[parsed_entry]}")]
return entries
@staticmethod
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL"
return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries])
return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries])

View file

@ -39,18 +39,20 @@ from pathlib import Path
from os.path import relpath
from typing import List
indent_regex = re.compile(r'^ *')
indent_regex = re.compile(r"^ *")
def normalize_filename(filename):
"Normalize and escape filename for rendering"
if not Path(filename).is_absolute():
# Normalize relative filename to be relative to current directory
normalized_filename = f'~/{relpath(filename, start=Path.home())}'
normalized_filename = f"~/{relpath(filename, start=Path.home())}"
else:
normalized_filename = filename
escaped_filename = f'{normalized_filename}'.replace("[","\[").replace("]","\]")
escaped_filename = f"{normalized_filename}".replace("[", "\[").replace("]", "\]")
return escaped_filename
def makelist(filename):
"""
Read an org-mode file and return a list of Orgnode objects
@ -58,124 +60,136 @@ def makelist(filename):
"""
ctr = 0
f = open(filename, 'r')
f = open(filename, "r")
todos = { "TODO": "", "WAITING": "", "ACTIVE": "",
"DONE": "", "CANCELLED": "", "FAILED": ""} # populated from #+SEQ_TODO line
todos = {
"TODO": "",
"WAITING": "",
"ACTIVE": "",
"DONE": "",
"CANCELLED": "",
"FAILED": "",
} # populated from #+SEQ_TODO line
level = ""
heading = ""
bodytext = ""
tags = list() # set of all tags in headline
closed_date = ''
sched_date = ''
deadline_date = ''
closed_date = ""
sched_date = ""
deadline_date = ""
logbook = list()
nodelist: List[Orgnode] = list()
property_map = dict()
in_properties_drawer = False
in_logbook_drawer = False
file_title = f'{filename}'
file_title = f"{filename}"
for line in f:
ctr += 1
heading_search = re.search(r'^(\*+)\s(.*?)\s*$', line)
heading_search = re.search(r"^(\*+)\s(.*?)\s*$", line)
if heading_search: # we are processing a heading line
if heading: # if we have are on second heading, append first heading to headings list
thisNode = Orgnode(level, heading, bodytext, tags)
if closed_date:
thisNode.closed = closed_date
closed_date = ''
closed_date = ""
if sched_date:
thisNode.scheduled = sched_date
sched_date = ""
if deadline_date:
thisNode.deadline = deadline_date
deadline_date = ''
deadline_date = ""
if logbook:
thisNode.logbook = logbook
logbook = list()
thisNode.properties = property_map
nodelist.append( thisNode )
property_map = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'}
nodelist.append(thisNode)
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
level = heading_search.group(1)
heading = heading_search.group(2)
bodytext = ""
tags = list() # set of all tags in headline
tag_search = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading)
tag_search = re.search(r"(.*?)\s*:([a-zA-Z0-9].*?):$", heading)
if tag_search:
heading = tag_search.group(1)
parsedtags = tag_search.group(2)
if parsedtags:
for parsedtag in parsedtags.split(':'):
if parsedtag != '': tags.append(parsedtag)
for parsedtag in parsedtags.split(":"):
if parsedtag != "":
tags.append(parsedtag)
else: # we are processing a non-heading line
if line[:10] == '#+SEQ_TODO':
kwlist = re.findall(r'([A-Z]+)\(', line)
for kw in kwlist: todos[kw] = ""
if line[:10] == "#+SEQ_TODO":
kwlist = re.findall(r"([A-Z]+)\(", line)
for kw in kwlist:
todos[kw] = ""
# Set file title to TITLE property, if it exists
title_search = re.search(r'^#\+TITLE:\s*(.*)$', line)
if title_search and title_search.group(1).strip() != '':
title_search = re.search(r"^#\+TITLE:\s*(.*)$", line)
if title_search and title_search.group(1).strip() != "":
title_text = title_search.group(1).strip()
if file_title == f'{filename}':
if file_title == f"{filename}":
file_title = title_text
else:
file_title += f' {title_text}'
file_title += f" {title_text}"
continue
# Ignore Properties Drawers Completely
if re.search(':PROPERTIES:', line):
in_properties_drawer=True
if re.search(":PROPERTIES:", line):
in_properties_drawer = True
continue
if in_properties_drawer and re.search(':END:', line):
in_properties_drawer=False
if in_properties_drawer and re.search(":END:", line):
in_properties_drawer = False
continue
# Ignore Logbook Drawer Start, End Lines
if re.search(':LOGBOOK:', line):
in_logbook_drawer=True
if re.search(":LOGBOOK:", line):
in_logbook_drawer = True
continue
if in_logbook_drawer and re.search(':END:', line):
in_logbook_drawer=False
if in_logbook_drawer and re.search(":END:", line):
in_logbook_drawer = False
continue
# Extract Clocking Lines
clocked_re = re.search(r'CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]', line)
clocked_re = re.search(
r"CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]",
line,
)
if clocked_re:
# convert clock in, clock out strings to datetime objects
clocked_in = datetime.datetime.strptime(clocked_re.group(1), '%Y-%m-%d %a %H:%M')
clocked_out = datetime.datetime.strptime(clocked_re.group(2), '%Y-%m-%d %a %H:%M')
clocked_in = datetime.datetime.strptime(clocked_re.group(1), "%Y-%m-%d %a %H:%M")
clocked_out = datetime.datetime.strptime(clocked_re.group(2), "%Y-%m-%d %a %H:%M")
# add clocked time to the entries logbook list
logbook += [(clocked_in, clocked_out)]
line = ""
property_search = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line)
property_search = re.search(r"^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$", line)
if property_search:
# Set ID property to an id based org-mode link to the entry
if property_search.group(1) == 'ID':
property_map['ID'] = f'id:{property_search.group(2)}'
if property_search.group(1) == "ID":
property_map["ID"] = f"id:{property_search.group(2)}"
else:
property_map[property_search.group(1)] = property_search.group(2)
continue
cd_re = re.search(r'CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})', line)
cd_re = re.search(r"CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})", line)
if cd_re:
closed_date = datetime.date(int(cd_re.group(1)),
int(cd_re.group(2)),
int(cd_re.group(3)) )
sd_re = re.search(r'SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)', line)
closed_date = datetime.date(int(cd_re.group(1)), int(cd_re.group(2)), int(cd_re.group(3)))
sd_re = re.search(r"SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)", line)
if sd_re:
sched_date = datetime.date(int(sd_re.group(1)),
int(sd_re.group(2)),
int(sd_re.group(3)) )
dd_re = re.search(r'DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)', line)
sched_date = datetime.date(int(sd_re.group(1)), int(sd_re.group(2)), int(sd_re.group(3)))
dd_re = re.search(r"DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)", line)
if dd_re:
deadline_date = datetime.date(int(dd_re.group(1)),
int(dd_re.group(2)),
int(dd_re.group(3)) )
deadline_date = datetime.date(int(dd_re.group(1)), int(dd_re.group(2)), int(dd_re.group(3)))
# Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body
if not in_properties_drawer and not cd_re and not sd_re and not dd_re and not clocked_re and line[:1] != '#':
if (
not in_properties_drawer
and not cd_re
and not sd_re
and not dd_re
and not clocked_re
and line[:1] != "#"
):
bodytext = bodytext + line
# write out last node
@ -189,39 +203,41 @@ def makelist(filename):
thisNode.closed = closed_date
if logbook:
thisNode.logbook = logbook
nodelist.append( thisNode )
nodelist.append(thisNode)
# using the list of TODO keywords found in the file
# process the headings searching for TODO keywords
for n in nodelist:
todo_search = re.search(r'([A-Z]+)\s(.*?)$', n.heading)
todo_search = re.search(r"([A-Z]+)\s(.*?)$", n.heading)
if todo_search:
if todo_search.group(1) in todos:
n.heading = todo_search.group(2)
n.todo = todo_search.group(1)
# extract, set priority from heading, update heading if necessary
priority_search = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.heading)
priority_search = re.search(r"^\[\#(A|B|C)\] (.*?)$", n.heading)
if priority_search:
n.priority = priority_search.group(1)
n.heading = priority_search.group(2)
# Set SOURCE property to a file+heading based org-mode link to the entry
if n.level == 0:
n.properties['LINE'] = f'file:{normalize_filename(filename)}::0'
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}]]'
n.properties["LINE"] = f"file:{normalize_filename(filename)}::0"
n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}]]"
else:
escaped_heading = n.heading.replace("[","\\[").replace("]","\\]")
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]'
escaped_heading = n.heading.replace("[", "\\[").replace("]", "\\]")
n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}::*{escaped_heading}]]"
return nodelist
######################
class Orgnode(object):
"""
Orgnode class represents a headline, tags and text associated
with the headline.
"""
def __init__(self, level, headline, body, tags):
"""
Create an Orgnode object given the parameters of level (as the
@ -270,7 +286,7 @@ class Orgnode(object):
"""
Returns True if node has non empty body, else False
"""
return self._body and re.sub(r'\n|\t|\r| ', '', self._body) != ''
return self._body and re.sub(r"\n|\t|\r| ", "", self._body) != ""
@property
def level(self):
@ -417,20 +433,20 @@ class Orgnode(object):
text as used to construct the node.
"""
# Output heading line
n = ''
n = ""
for _ in range(0, self._level):
n = n + '*'
n = n + ' '
n = n + "*"
n = n + " "
if self._todo:
n = n + self._todo + ' '
n = n + self._todo + " "
if self._priority:
n = n + '[#' + self._priority + '] '
n = n + "[#" + self._priority + "] "
n = n + self._heading
n = "%-60s " % n # hack - tags will start in column 62
closecolon = ''
closecolon = ""
for t in self._tags:
n = n + ':' + t
closecolon = ':'
n = n + ":" + t
closecolon = ":"
n = n + closecolon
n = n + "\n"
@ -447,7 +463,7 @@ class Orgnode(object):
if self._deadline:
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
if self._closed or self._scheduled or self._deadline:
n = n + '\n'
n = n + "\n"
# Ouput Property Drawer
n = n + indent + ":PROPERTIES:\n"

View file

@ -17,14 +17,17 @@ class TextToJsonl(ABC):
self.config = config
@abstractmethod
def process(self, previous_entries: List[Entry]=None) -> List[Tuple[int, Entry]]: ...
def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]:
...
@staticmethod
def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
@staticmethod
def split_entries_by_max_tokens(entries: List[Entry], max_tokens: int=256, max_word_length: int=500) -> List[Entry]:
def split_entries_by_max_tokens(
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
) -> List[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: List[Entry] = []
for entry in entries:
@ -32,13 +35,15 @@ class TextToJsonl(ABC):
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens]
compiled_entry_chunk = ' '.join(compiled_entry_words_chunk)
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file)
chunked_entries.append(entry_chunk)
return chunked_entries
def mark_entries_for_update(self, current_entries: List[Entry], previous_entries: List[Entry], key='compiled', logger=None) -> List[Tuple[int, Entry]]:
def mark_entries_for_update(
self, current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None
) -> List[Tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries
with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
@ -54,10 +59,7 @@ class TextToJsonl(ABC):
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with -1 id to flag for later embeddings generation
new_entries = [
(-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])

View file

@ -22,27 +22,30 @@ logger = logging.getLogger(__name__)
# Create Routes
@api.get('/config/data/default')
@api.get("/config/data/default")
def get_default_config_data():
return constants.default_config
@api.get('/config/data', response_model=FullConfig)
@api.get("/config/data", response_model=FullConfig)
def get_config_data():
return state.config
@api.post('/config/data')
@api.post("/config/data")
async def set_config_data(updated_config: FullConfig):
state.config = updated_config
with open(state.config_file, 'w') as outfile:
with open(state.config_file, "w") as outfile:
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close()
return state.config
@api.get('/search', response_model=List[SearchResponse])
@api.get("/search", response_model=List[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
results: List[SearchResponse] = []
if q is None or q == '':
logger.info(f'No query param (q) passed in API call to initiate search')
if q is None or q == "":
logger.info(f"No query param (q) passed in API call to initiate search")
return results
# initialize variables
@ -50,9 +53,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
results_count = n
# return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}'
query_cache_key = f"{user_query}-{n}-{t}-{r}"
if query_cache_key in state.query_cache:
logger.info(f'Return response from query cache')
logger.info(f"Return response from query cache")
return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
@ -95,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# query images
with timer("Query took", logger):
hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images'
output_directory = constants.web_directory / "images"
# collate and return results
with timer("Collating results took", logger):
@ -103,8 +106,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=results_count)
image_files_url="/static/images",
count=results_count,
)
# Cache results
state.query_cache[query_cache_key] = results
@ -112,7 +116,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
return results
@api.get('/update')
@api.get("/update")
def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
try:
state.search_index_lock.acquire()
@ -132,4 +136,4 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
else:
logger.info("Processor reconfigured via API call")
return {'status': 'ok', 'message': 'khoj reloaded'}
return {"status": "ok", "message": "khoj reloaded"}

View file

@ -9,7 +9,14 @@ from fastapi import APIRouter
# Internal Packages
from khoj.routers.api import search
from khoj.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from khoj.processor.conversation.gpt import (
converse,
extract_search_type,
message_to_log,
message_to_prompt,
understand,
summarize,
)
from khoj.utils.config import SearchType
from khoj.utils.helpers import get_from_dict, resolve_absolute_path
from khoj.utils import state
@ -21,7 +28,7 @@ logger = logging.getLogger(__name__)
# Create Routes
@api_beta.get('/search')
@api_beta.get("/search")
def search_beta(q: str, n: Optional[int] = 1):
# Initialize Variables
model = state.processor_config.conversation.model
@ -32,16 +39,16 @@ def search_beta(q: str, n: Optional[int] = 1):
metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type")
except Exception as e:
return {'status': 'error', 'result': [str(e)], 'type': None}
return {"status": "error", "result": [str(e)], "type": None}
# Search
search_results = search(q, n=n, t=SearchType(search_type))
# Return response
return {'status': 'ok', 'result': search_results, 'type': search_type}
return {"status": "ok", "result": search_results, "type": search_type}
@api_beta.get('/summarize')
@api_beta.get("/summarize")
def summarize_beta(q: str):
# Initialize Variables
model = state.processor_config.conversation.model
@ -54,23 +61,25 @@ def summarize_beta(q: str):
# Converse with OpenAI GPT
result_list = search(q, n=1, r=True)
collated_result = "\n".join([item.entry for item in result_list])
logger.debug(f'Semantically Similar Notes:\n{collated_result}')
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
try:
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
status = 'ok'
status = "ok"
except Exception as e:
gpt_response = str(e)
status = 'error'
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, conversation_log=meta_log.get('chat', []))
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, conversation_log=meta_log.get("chat", [])
)
return {'status': status, 'response': gpt_response}
return {"status": status, "response": gpt_response}
@api_beta.get('/chat')
def chat(q: Optional[str]=None):
@api_beta.get("/chat")
def chat(q: Optional[str] = None):
# Initialize Variables
model = state.processor_config.conversation.model
api_key = state.processor_config.conversation.openai_api_key
@ -81,10 +90,10 @@ def chat(q: Optional[str]=None):
# If user query is empty, return chat history
if not q:
if meta_log.get('chat'):
return {'status': 'ok', 'response': meta_log["chat"]}
if meta_log.get("chat"):
return {"status": "ok", "response": meta_log["chat"]}
else:
return {'status': 'ok', 'response': []}
return {"status": "ok", "response": []}
# Converse with OpenAI GPT
metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose)
@ -94,32 +103,39 @@ def chat(q: Optional[str]=None):
query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org, r=True)
collated_result = "\n".join([item.entry for item in result_list])
logger.debug(f'Semantically Similar Notes:\n{collated_result}')
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
try:
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
status = 'ok'
status = "ok"
except Exception as e:
gpt_response = str(e)
status = 'error'
status = "error"
else:
try:
gpt_response = converse(q, model, chat_session, api_key=api_key)
status = 'ok'
status = "ok"
except Exception as e:
gpt_response = str(e)
status = 'error'
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, metadata, meta_log.get('chat', []))
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, metadata, meta_log.get("chat", [])
)
return {'status': status, 'response': gpt_response}
return {"status": status, "response": gpt_response}
@schedule.repeat(schedule.every(5).minutes)
def save_chat_session():
# No need to create empty log file
if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log and state.processor_config.conversation.chat_session):
if not (
state.processor_config
and state.processor_config.conversation
and state.processor_config.conversation.meta_log
and state.processor_config.conversation.chat_session
):
return
# Summarize Conversation Logs for this Session
@ -130,19 +146,19 @@ def save_chat_session():
session = {
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"])
"session-end": len(conversation_log["chat"]),
}
if 'session' in conversation_log:
conversation_log['session'].append(session)
if "session" in conversation_log:
conversation_log["session"].append(session)
else:
conversation_log['session'] = [session]
logger.info('Added new chat session to conversation logs')
conversation_log["session"] = [session]
logger.info("Added new chat session to conversation logs")
# Save Conversation Metadata Logs to Disk
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
with open(conversation_logfile, "w+", encoding='utf-8') as logfile:
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile)
state.processor_config.conversation.chat_session = None
logger.info('Saved updated conversation logs to disk.')
logger.info("Saved updated conversation logs to disk.")

View file

@ -18,9 +18,11 @@ templates = Jinja2Templates(directory=constants.web_directory)
def index():
return FileResponse(constants.web_directory / "index.html")
@web_client.get('/config', response_class=HTMLResponse)
@web_client.get("/config", response_class=HTMLResponse)
def config_page(request: Request):
return templates.TemplateResponse("config.html", context={'request': request})
return templates.TemplateResponse("config.html", context={"request": request})
@web_client.get("/chat", response_class=FileResponse)
def chat_page():

View file

@ -8,10 +8,13 @@ from khoj.utils.rawconfig import Entry
class BaseFilter(ABC):
@abstractmethod
def load(self, entries: List[Entry], *args, **kwargs): ...
def load(self, entries: List[Entry], *args, **kwargs):
...
@abstractmethod
def can_filter(self, raw_query:str) -> bool: ...
def can_filter(self, raw_query: str) -> bool:
...
@abstractmethod
def apply(self, query:str, entries: List[Entry]) -> Tuple[str, Set[int]]: ...
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
...

View file

@ -26,21 +26,19 @@ class DateFilter(BaseFilter):
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def __init__(self, entry_key='raw'):
def __init__(self, entry_key="raw"):
self.entry_key = entry_key
self.date_to_entry_ids = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created date filter index", logger):
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp()
except ValueError:
continue
self.date_to_entry_ids[date_in_entry].add(id)
@ -49,7 +47,6 @@ class DateFilter(BaseFilter):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
@ -61,8 +58,8 @@ class DateFilter(BaseFilter):
return query, set(range(len(entries)))
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
# return results from cache if exists
cache_key = tuple(query_daterange)
@ -87,7 +84,6 @@ class DateFilter(BaseFilter):
return query, entries_to_include
def extract_date_range(self, query):
# find date range filter in query
date_range_matches = re.findall(self.date_regex, query)
@ -98,7 +94,7 @@ class DateFilter(BaseFilter):
# extract, parse natural dates ranges from date range filter passed in query
# e.g today maps to (start_of_day, start_of_tomorrow)
date_ranges_from_filter = []
for (cmp, date_str) in date_range_matches:
for cmp, date_str in date_range_matches:
if self.parse(date_str):
dt_start, dt_end = self.parse(date_str)
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
@ -111,15 +107,15 @@ class DateFilter(BaseFilter):
effective_date_range = [0, inf]
date_range_considering_comparator = []
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
if cmp == '>':
if cmp == ">":
date_range_considering_comparator += [[dtrange_end, inf]]
elif cmp == '>=':
elif cmp == ">=":
date_range_considering_comparator += [[dtrange_start, inf]]
elif cmp == '<':
elif cmp == "<":
date_range_considering_comparator += [[0, dtrange_start]]
elif cmp == '<=':
elif cmp == "<=":
date_range_considering_comparator += [[0, dtrange_end]]
elif cmp == '=' or cmp == ':' or cmp == '==':
elif cmp == "=" or cmp == ":" or cmp == "==":
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
# Combine above intervals (via AND/intersect)
@ -129,48 +125,48 @@ class DateFilter(BaseFilter):
for date_range in date_range_considering_comparator:
effective_date_range = [
max(effective_date_range[0], date_range[0]),
min(effective_date_range[1], date_range[1])]
min(effective_date_range[1], date_range[1]),
]
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return None
else:
return effective_date_range
def parse(self, date_str, relative_base=None):
"Parse date string passed in date filter of query to datetime object"
# clean date string to handle future date parsing by date parser
future_strings = ['later', 'from now', 'from today']
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
future_strings = ["later", "from now", "from today"]
prefer_dates_from = {True: "future", False: "past"}[any([True for fstr in future_strings if fstr in date_str])]
clean_date_str = re.sub("|".join(future_strings), "", date_str)
# parse date passed in query date filter
parsed_date = dtparse.parse(
clean_date_str,
settings= {
'RELATIVE_BASE': relative_base or datetime.now(),
'PREFER_DAY_OF_MONTH': 'first',
'PREFER_DATES_FROM': prefer_dates_from
})
settings={
"RELATIVE_BASE": relative_base or datetime.now(),
"PREFER_DAY_OF_MONTH": "first",
"PREFER_DATES_FROM": prefer_dates_from,
},
)
if parsed_date is None:
return None
return self.date_to_daterange(parsed_date, date_str)
def date_to_daterange(self, parsed_date, date_str):
"Convert parsed date to date ranges at natural granularity (day, week, month or year)"
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
if 'year' in date_str:
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0))
if 'month' in date_str:
if "year" in date_str:
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year + 1, 1, 1, 0, 0, 0))
if "month" in date_str:
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
next_month = start_of_month + relativedelta(months=1)
return (start_of_month, next_month)
if 'week' in date_str:
if "week" in date_str:
# if week in date string, dateparser parses it to next week start
# so today = end of this week
start_of_week = start_of_day - timedelta(days=7)

View file

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
class FileFilter(BaseFilter):
file_filter_regex = r'file:"(.+?)" ?'
def __init__(self, entry_key='file'):
def __init__(self, entry_key="file"):
self.entry_key = entry_key
self.file_to_entry_map = defaultdict(set)
self.cache = LRU()
@ -40,13 +40,13 @@ class FileFilter(BaseFilter):
# e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = []
for file in sorted(raw_files_to_search):
if '/' not in file and '\\' not in file and '*' not in file:
files_to_search += [f'*{file}']
if "/" not in file and "\\" not in file and "*" not in file:
files_to_search += [f"*{file}"]
else:
files_to_search += [file]
# Return item from cache if exists
query = re.sub(self.file_filter_regex, '', query).strip()
query = re.sub(self.file_filter_regex, "", query).strip()
cache_key = tuple(files_to_search)
if cache_key in self.cache:
logger.info(f"Return file filter results from cache")
@ -58,10 +58,15 @@ class FileFilter(BaseFilter):
# Mark entries that contain any blocked_words for exclusion
with timer("Mark entries satisfying filter", logger):
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
included_entry_indices = set.union(
*[
self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)], set())
if fnmatch.fnmatch(entry_file, search_file)
],
set(),
)
if not included_entry_indices:
return query, {}

View file

@ -17,26 +17,26 @@ class WordFilter(BaseFilter):
required_regex = r'\+"([a-zA-Z0-9_-]+)" ?'
blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?'
def __init__(self, entry_key='raw'):
def __init__(self, entry_key="raw"):
self.entry_key = entry_key
self.word_to_entry_index = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created word filter index", logger):
self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
entry_splitter = (
r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
)
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '':
if word == "":
continue
self.word_to_entry_index[word].add(entry_index)
return self.word_to_entry_index
def can_filter(self, raw_query):
"Check if query contains word filters"
required_words = re.findall(self.required_regex, raw_query)
@ -44,14 +44,13 @@ class WordFilter(BaseFilter):
return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries)))
@ -70,12 +69,16 @@ class WordFilter(BaseFilter):
with timer("Mark entries satisfying filter", logger):
entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
entries_with_all_required_words = set.intersection(
*[self.word_to_entry_index.get(word, set()) for word in required_words]
)
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
entries_with_any_blocked_words = set.union(
*[self.word_to_entry_index.get(word, set()) for word in blocked_words]
)
# get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words

View file

@ -35,9 +35,10 @@ def initialize_model(search_config: ImageSearchConfig):
# Load the CLIP model
encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = search_config.encoder_type or SentenceTransformer)
model_dir=search_config.model_directory,
model_name=search_config.encoder,
model_type=search_config.encoder_type or SentenceTransformer,
)
return encoder
@ -46,12 +47,12 @@ def extract_entries(image_directories):
image_names = []
for image_directory in image_directories:
image_directory = resolve_absolute_path(image_directory, strict=True)
image_names.extend(list(image_directory.glob('*.jpg')))
image_names.extend(list(image_directory.glob('*.jpeg')))
image_names.extend(list(image_directory.glob("*.jpg")))
image_names.extend(list(image_directory.glob("*.jpeg")))
if logger.level >= logging.INFO:
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories])
logger.info(f'Found {len(image_names)} images in {image_directory_names}')
image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories])
logger.info(f"Found {len(image_names)} images in {image_directory_names}")
return sorted(image_names)
@ -59,7 +60,9 @@ def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(
image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate
)
return image_embeddings, image_metadata_embeddings
@ -74,15 +77,12 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
image_embeddings = []
for index in trange(0, len(image_names), batch_size):
images = []
for image_name in image_names[index:index+batch_size]:
for image_name in image_names[index : index + batch_size]:
image = Image.open(image_name)
# Resize images to max width of 640px for faster processing
image.thumbnail((640, image.height))
images += [image]
image_embeddings += encoder.encode(
images,
convert_to_tensor=True,
batch_size=min(len(images), batch_size))
image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size))
# Create directory for embeddings file, if it doesn't exist
embeddings_file.parent.mkdir(parents=True, exist_ok=True)
@ -94,7 +94,9 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
return image_embeddings
def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0):
def compute_metadata_embeddings(
image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0
):
image_metadata_embeddings = None
# Load pre-computed image metadata embedding file if exists
@ -106,14 +108,17 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
if use_xmp_metadata and image_metadata_embeddings is None:
image_metadata_embeddings = []
for index in trange(0, len(image_names), batch_size):
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]]
image_metadata = [
extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size]
]
try:
image_metadata_embeddings += encoder.encode(
image_metadata,
convert_to_tensor=True,
batch_size=min(len(image_metadata), batch_size))
image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size)
)
except RuntimeError as e:
logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
logger.error(
f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}"
)
continue
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
@ -123,8 +128,10 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
def extract_metadata(image_name):
image_xmp_metadata = Image.open(image_name).getxmp()
image_description = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'description', 'Alt', 'li', 'text')
image_subjects = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'subject', 'Bag', 'li')
image_description = get_from_dict(
image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text"
)
image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li")
image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject])
image_processed_metadata = image_description
@ -155,36 +162,42 @@ def query(raw_query, count, model: ImageSearchModel):
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
with timer("Metadata Search Time", logger):
metadata_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
metadata_hits = {
result["corpus_id"]: result["score"]
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
}
# Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items():
scaling_factor = 0.33
if 'corpus_id' in image_hits:
image_hits[corpus_id].update({
'metadata_score': score,
'score': image_hits[corpus_id].get('score', 0) + scaling_factor*score,
})
if "corpus_id" in image_hits:
image_hits[corpus_id].update(
{
"metadata_score": score,
"score": image_hits[corpus_id].get("score", 0) + scaling_factor * score,
}
)
else:
image_hits[corpus_id] = {'metadata_score': score, 'score': scaling_factor*score}
image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score}
# Reformat results in original form from sentence transformer semantic_search()
hits = [
{
'corpus_id': corpus_id,
'score': scores['score'],
'image_score': scores.get('image_score', 0),
'metadata_score': scores.get('metadata_score', 0),
} for corpus_id, scores in image_hits.items()]
"corpus_id": corpus_id,
"score": scores["score"],
"image_score": scores.get("image_score", 0),
"metadata_score": scores.get("metadata_score", 0),
}
for corpus_id, scores in image_hits.items()
]
# Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
@ -194,7 +207,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
results: List[SearchResponse] = []
for index, hit in enumerate(hits[:count]):
source_path = image_names[hit['corpus_id']]
source_path = image_names[hit["corpus_id"]]
target_image_name = f"{index}{source_path.suffix}"
target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}")
@ -207,17 +220,18 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
shutil.copy(source_path, target_path)
# Add the image metadata to the results
results += [SearchResponse.parse_obj(
results += [
SearchResponse.parse_obj(
{
"entry": f'{image_files_url}/{target_image_name}',
"entry": f"{image_files_url}/{target_image_name}",
"score": f"{hit['score']:.9f}",
"additional":
{
"additional": {
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
},
}
}
)]
)
]
return results
@ -248,9 +262,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
embeddings_file,
batch_size=config.batch_size,
regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata)
use_xmp_metadata=config.use_xmp_metadata,
)
return ImageSearchModel(all_image_files,
image_embeddings,
image_metadata_embeddings,
encoder)
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)

View file

@ -38,17 +38,19 @@ def initialize_model(search_config: TextSearchConfig):
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = search_config.encoder_type or SentenceTransformer,
device=f'{state.device}')
model_dir=search_config.model_directory,
model_name=search_config.encoder,
model_type=search_config.encoder_type or SentenceTransformer,
device=f"{state.device}",
)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.cross_encoder,
model_type = CrossEncoder,
device=f'{state.device}')
model_dir=search_config.model_directory,
model_name=search_config.cross_encoder,
model_type=CrossEncoder,
device=f"{state.device}",
)
return bi_encoder, cross_encoder, top_k
@ -58,7 +60,9 @@ def extract_entries(jsonl_file) -> List[Entry]:
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False):
def compute_embeddings(
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
@ -69,17 +73,23 @@ def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: Ba
# Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries:
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
new_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
if existing_entry_ids:
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device))
existing_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
)
else:
existing_embeddings = torch.tensor([], device=state.device)
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch
else:
new_entries = [entry.compiled for _, entry in entries_with_ids]
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
corpus_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
# Save regenerated or updated embeddings to file
if new_entries:
@ -112,7 +122,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
# Find relevant entries for the query
with timer("Search Time", logger, state.device):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
hits = util.semantic_search(
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
)[0]
# Score all retrieved entries using the cross-encoder
if rank_results:
@ -128,26 +140,33 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
return [SearchResponse.parse_obj(
return [
SearchResponse.parse_obj(
{
"entry": entries[hit['corpus_id']].raw,
"entry": entries[hit["corpus_id"]].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
"additional": {
"file": entries[hit['corpus_id']].file,
"compiled": entries[hit['corpus_id']].compiled
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
}
})
for hit
in hits[0:count]]
)
for hit in hits[0:count]
]
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = []) -> TextSearchModel:
def setup(
text_to_jsonl: Type[TextToJsonl],
config: TextContentConfig,
search_config: TextSearchConfig,
regenerate: bool,
filters: List[BaseFilter] = [],
) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
previous_entries = (
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
)
entries_with_indices = text_to_jsonl(config).process(previous_entries)
# Extract Updated Entries
@ -158,7 +177,9 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate)
corpus_embeddings = compute_embeddings(
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
)
for filter in filters:
filter.load(entries, regenerate=regenerate)
@ -166,8 +187,10 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]) -> Tuple[str, List[Entry], torch.Tensor]:
'''Filter query, entries and embeddings before semantic search'''
def apply_filters(
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
) -> Tuple[str, List[Entry], torch.Tensor]:
"""Filter query, entries and embeddings before semantic search"""
with timer("Total Filter Time", logger, state.device):
included_entry_indices = set(range(len(entries)))
@ -178,45 +201,50 @@ def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Ten
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return '', [], torch.tensor([], device=state.device)
return "", [], torch.tensor([], device=state.device)
else:
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
corpus_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
)
return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
'''Score all retrieved entries using the cross-encoder'''
"""Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hits[idx]["cross-score"] = cross_scores[idx]
return hits
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
'''Order results by cross-encoder score followed by bi-encoder score'''
"""Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score
return hits
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
'''Deduplicate entries by raw entry text before showing to users
"""Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user.'''
This can result in duplicate hits, entries shown to user."""
with timer("Deduplication Time", logger, state.device):
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] # type: ignore[func-returns-value]
hits = [
hit
for hit in hits
if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
]
duplicate_hits = original_hits_count - len(hits)
logger.debug(f"Removed {duplicate_hits} duplicates")

View file

@ -10,21 +10,36 @@ from khoj.utils.yaml import parse_config_from_file
def cli(args=None):
# Setup Argument Parser for the Commandline Interface
parser = argparse.ArgumentParser(description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos")
parser.add_argument('--config-file', '-c', default='~/.khoj/khoj.yml', type=pathlib.Path, help="YAML file to configure Khoj")
parser.add_argument('--no-gui', action='store_true', default=False, help="Do not show native desktop GUI. Default: false")
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false")
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0")
parser.add_argument('--host', type=str, default='127.0.0.1', help="Host address of the server. Default: 127.0.0.1")
parser.add_argument('--port', '-p', type=int, default=8000, help="Port of the server. Default: 8000")
parser.add_argument('--socket', type=pathlib.Path, help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock")
parser.add_argument('--version', '-V', action='store_true', help="Print the installed Khoj version and exit")
parser = argparse.ArgumentParser(
description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos"
)
parser.add_argument(
"--config-file", "-c", default="~/.khoj/khoj.yml", type=pathlib.Path, help="YAML file to configure Khoj"
)
parser.add_argument(
"--no-gui", action="store_true", default=False, help="Do not show native desktop GUI. Default: false"
)
parser.add_argument(
"--regenerate",
action="store_true",
default=False,
help="Regenerate model embeddings from source files. Default: false",
)
parser.add_argument("--verbose", "-v", action="count", default=0, help="Show verbose conversion logs. Default: 0")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host address of the server. Default: 127.0.0.1")
parser.add_argument("--port", "-p", type=int, default=8000, help="Port of the server. Default: 8000")
parser.add_argument(
"--socket",
type=pathlib.Path,
help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock",
)
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
args = parser.parse_args(args)
if args.version:
# Show version of khoj installed and exit
print(version('khoj-assistant'))
print(version("khoj-assistant"))
exit(0)
# Normalize config_file path to absolute path

View file

@ -28,8 +28,16 @@ class ProcessorType(str, Enum):
Conversation = "conversation"
class TextSearchModel():
def __init__(self, entries: List[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: List[BaseFilter], top_k):
class TextSearchModel:
def __init__(
self,
entries: List[Entry],
corpus_embeddings: torch.Tensor,
bi_encoder: BaseEncoder,
cross_encoder: CrossEncoder,
filters: List[BaseFilter],
top_k,
):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
@ -38,7 +46,7 @@ class TextSearchModel():
self.top_k = top_k
class ImageSearchModel():
class ImageSearchModel:
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder
self.image_names = image_names
@ -48,7 +56,7 @@ class ImageSearchModel():
@dataclass
class SearchModels():
class SearchModels:
orgmode_search: TextSearchModel = None
ledger_search: TextSearchModel = None
music_search: TextSearchModel = None
@ -56,15 +64,15 @@ class SearchModels():
image_search: ImageSearchModel = None
class ConversationProcessorConfigModel():
class ConversationProcessorConfigModel:
def __init__(self, processor_config: ConversationProcessorConfig):
self.openai_api_key = processor_config.openai_api_key
self.model = processor_config.model
self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_session = ''
self.chat_session = ""
self.meta_log: dict = {}
@dataclass
class ProcessorConfigModel():
class ProcessorConfigModel:
conversation: ConversationProcessorConfigModel = None

View file

@ -1,65 +1,62 @@
from pathlib import Path
app_root_directory = Path(__file__).parent.parent.parent
web_directory = app_root_directory / 'khoj/interface/web/'
empty_escape_sequences = '\n|\r|\t| '
web_directory = app_root_directory / "khoj/interface/web/"
empty_escape_sequences = "\n|\r|\t| "
# default app config to use
default_config = {
'content-type': {
'org': {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/org/org.jsonl.gz',
'embeddings-file': '~/.khoj/content/org/org_embeddings.pt',
'index_heading_entries': False
"content-type": {
"org": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
"index_heading_entries": False,
},
'markdown': {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/markdown/markdown.jsonl.gz',
'embeddings-file': '~/.khoj/content/markdown/markdown_embeddings.pt'
"markdown": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
},
'ledger': {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/ledger/ledger.jsonl.gz',
'embeddings-file': '~/.khoj/content/ledger/ledger_embeddings.pt'
"ledger": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/ledger/ledger.jsonl.gz",
"embeddings-file": "~/.khoj/content/ledger/ledger_embeddings.pt",
},
'image': {
'input-directories': None,
'input-filter': None,
'embeddings-file': '~/.khoj/content/image/image_embeddings.pt',
'batch-size': 50,
'use-xmp-metadata': False
"image": {
"input-directories": None,
"input-filter": None,
"embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
"batch-size": 50,
"use-xmp-metadata": False,
},
'music': {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/music/music.jsonl.gz',
'embeddings-file': '~/.khoj/content/music/music_embeddings.pt'
"music": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/music/music.jsonl.gz",
"embeddings-file": "~/.khoj/content/music/music_embeddings.pt",
},
},
"search-type": {
"symmetric": {
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/symmetric/",
},
"asymmetric": {
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/asymmetric/",
},
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
"processor": {
"conversation": {
"openai-api-key": None,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
}
},
'search-type': {
'symmetric': {
'encoder': 'sentence-transformers/all-MiniLM-L6-v2',
'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
'model_directory': '~/.khoj/search/symmetric/'
},
'asymmetric': {
'encoder': 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1',
'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
'model_directory': '~/.khoj/search/asymmetric/'
},
'image': {
'encoder': 'sentence-transformers/clip-ViT-B-32',
'model_directory': '~/.khoj/search/image/'
}
},
'processor': {
'conversation': {
'openai-api-key': None,
'conversation-logfile': '~/.khoj/processor/conversation/conversation_logs.json'
}
}
}

View file

@ -13,16 +13,17 @@ from typing import Optional, Union, TYPE_CHECKING
if TYPE_CHECKING:
# External Packages
from sentence_transformers import CrossEncoder
# Internal Packages
from khoj.utils.models import BaseEncoder
def is_none_or_empty(item):
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
def to_snake_case_from_dash(item: str):
return item.replace('_', '-')
return item.replace("_", "-")
def get_absolute_path(filepath: Union[str, Path]) -> str:
@ -34,11 +35,11 @@ def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) ->
def get_from_dict(dictionary, *args):
'''null-aware get from a nested dictionary
Returns: dictionary[args[0]][args[1]]... or None if any keys missing'''
"""null-aware get from a nested dictionary
Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
current = dictionary
for arg in args:
if not hasattr(current, '__iter__') or not arg in current:
if not hasattr(current, "__iter__") or not arg in current:
return None
current = current[arg]
return current
@ -54,7 +55,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]:
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
"Load model from disk or huggingface"
# Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
@ -74,17 +75,18 @@ def load_model(model_name: str, model_type, model_dir=None, device:str=None) ->
def is_pyinstaller_app():
"Returns true if the app is running from Native GUI created by PyInstaller"
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
def get_class_by_name(name: str) -> object:
"Returns the class object from name string"
module_name, class_name = name.rsplit('.', 1)
module_name, class_name = name.rsplit(".", 1)
return getattr(import_module(module_name), class_name)
class timer:
'''Context manager to log time taken for a block of code to run'''
"""Context manager to log time taken for a block of code to run"""
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
self.message = message
self.logger = logger

View file

@ -19,9 +19,9 @@ def load_jsonl(input_path):
# Open JSONL file
if input_path.suffix == ".gz":
jsonl_file = gzip.open(get_absolute_path(input_path), 'rt', encoding='utf-8')
jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
elif input_path.suffix == ".jsonl":
jsonl_file = open(get_absolute_path(input_path), 'r', encoding='utf-8')
jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
# Read JSONL file
for line in jsonl_file:
@ -31,7 +31,7 @@ def load_jsonl(input_path):
jsonl_file.close()
# Log JSONL entries loaded
logger.info(f'Loaded {len(data)} records from {input_path}')
logger.info(f"Loaded {len(data)} records from {input_path}")
return data
@ -41,17 +41,17 @@ def dump_jsonl(jsonl_data, output_path):
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
with open(output_path, "w", encoding="utf-8") as f:
f.write(jsonl_data)
logger.info(f'Wrote jsonl data to {output_path}')
logger.info(f"Wrote jsonl data to {output_path}")
def compress_jsonl_data(jsonl_data, output_path):
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
with gzip.open(output_path, 'wt', encoding='utf-8') as gzip_file:
with gzip.open(output_path, "wt", encoding="utf-8") as gzip_file:
gzip_file.write(jsonl_data)
logger.info(f'Wrote jsonl data to gzip compressed jsonl at {output_path}')
logger.info(f"Wrote jsonl data to gzip compressed jsonl at {output_path}")

View file

@ -13,17 +13,25 @@ from khoj.utils.state import processor_config, config_file
class BaseEncoder(ABC):
@abstractmethod
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
def __init__(self, model_name: str, device: torch.device = None, **kwargs):
...
@abstractmethod
def encode(self, entries: List[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor:
...
class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None):
self.model_name = model_name
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
raise Exception(f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}")
if (
not processor_config
or not processor_config.conversation
or not processor_config.conversation.openai_api_key
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}"
)
openai.api_key = processor_config.conversation.openai_api_key
self.embedding_dimensions = None
@ -32,7 +40,7 @@ class OpenAI(BaseEncoder):
for index in trange(0, len(entries)):
# OpenAI models create better embeddings for entries without newlines
processed_entry = entries[index].replace('\n', ' ')
processed_entry = entries[index].replace("\n", " ")
try:
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
@ -41,7 +49,9 @@ class OpenAI(BaseEncoder):
# Else default to embedding dimensions of the text-embedding-ada-002 model
self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536
except Exception as e:
print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}")
print(
f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}"
)
# Use zero embedding vector for entries with failed embeddings
# This ensures entry embeddings match the order of the source entries
# And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector)

View file

@ -9,11 +9,13 @@ from pydantic import BaseModel, validator
# Internal Packages
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
class ConfigBase(BaseModel):
class Config:
alias_generator = to_snake_case_from_dash
allow_population_by_field_name = True
class TextContentConfig(ConfigBase):
input_files: Optional[List[Path]]
input_filter: Optional[List[str]]
@ -21,12 +23,15 @@ class TextContentConfig(ConfigBase):
embeddings_file: Path
index_heading_entries: Optional[bool] = False
@validator('input_filter')
@validator("input_filter")
def input_filter_or_files_required(cls, input_filter, values, **kwargs):
if is_none_or_empty(input_filter) and ('input_files' not in values or values["input_files"] is None):
raise ValueError("Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file")
if is_none_or_empty(input_filter) and ("input_files" not in values or values["input_files"] is None):
raise ValueError(
"Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file"
)
return input_filter
class ImageContentConfig(ConfigBase):
input_directories: Optional[List[Path]]
input_filter: Optional[List[str]]
@ -34,12 +39,17 @@ class ImageContentConfig(ConfigBase):
use_xmp_metadata: bool
batch_size: int
@validator('input_filter')
@validator("input_filter")
def input_filter_or_directories_required(cls, input_filter, values, **kwargs):
if is_none_or_empty(input_filter) and ('input_directories' not in values or values["input_directories"] is None):
raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file")
if is_none_or_empty(input_filter) and (
"input_directories" not in values or values["input_directories"] is None
):
raise ValueError(
"Either input_filter or input_directories required in all content-type.image section of Khoj config file"
)
return input_filter
class ContentConfig(ConfigBase):
org: Optional[TextContentConfig]
ledger: Optional[TextContentConfig]
@ -47,41 +57,49 @@ class ContentConfig(ConfigBase):
music: Optional[TextContentConfig]
markdown: Optional[TextContentConfig]
class TextSearchConfig(ConfigBase):
encoder: str
cross_encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path]
class ImageSearchConfig(ConfigBase):
encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path]
class SearchConfig(ConfigBase):
asymmetric: Optional[TextSearchConfig]
symmetric: Optional[TextSearchConfig]
image: Optional[ImageSearchConfig]
class ConversationProcessorConfig(ConfigBase):
openai_api_key: str
conversation_logfile: Path
model: Optional[str] = "text-davinci-003"
class ProcessorConfig(ConfigBase):
conversation: Optional[ConversationProcessorConfig]
class FullConfig(ConfigBase):
content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig]
class SearchResponse(ConfigBase):
entry: str
score: str
additional: Optional[dict]
class Entry():
class Entry:
raw: str
compiled: str
file: Optional[str]
@ -99,8 +117,4 @@ class Entry():
@classmethod
def from_dict(cls, dictionary: dict):
return cls(
raw=dictionary['raw'],
compiled=dictionary['compiled'],
file=dictionary.get('file', None)
)
return cls(raw=dictionary["raw"], compiled=dictionary["compiled"], file=dictionary.get("file", None))

View file

@ -17,14 +17,14 @@ def save_config_to_file(yaml_config: dict, yaml_config_file: Path):
# Create output directory, if it doesn't exist
yaml_config_file.parent.mkdir(parents=True, exist_ok=True)
with open(yaml_config_file, 'w', encoding='utf-8') as config_file:
with open(yaml_config_file, "w", encoding="utf-8") as config_file:
yaml.safe_dump(yaml_config, config_file, allow_unicode=True)
def load_config_from_file(yaml_config_file: Path) -> dict:
"Read config from YML file"
config_from_file = None
with open(yaml_config_file, 'r', encoding='utf-8') as config_file:
with open(yaml_config_file, "r", encoding="utf-8") as config_file:
config_from_file = yaml.safe_load(config_file)
return config_from_file

View file

@ -6,59 +6,67 @@ import pytest
# Internal Packages
from khoj.search_type import image_search, text_search
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from khoj.utils.rawconfig import (
ContentConfig,
TextContentConfig,
ImageContentConfig,
SearchConfig,
TextSearchConfig,
ImageSearchConfig,
)
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def search_config() -> SearchConfig:
model_dir = resolve_absolute_path('~/.khoj/search')
model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig()
search_config.symmetric = TextSearchConfig(
encoder = "sentence-transformers/all-MiniLM-L6-v2",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir / 'symmetric/'
encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "symmetric/",
)
search_config.asymmetric = TextSearchConfig(
encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir / 'asymmetric/'
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "asymmetric/",
)
search_config.image = ImageSearchConfig(
encoder = "sentence-transformers/clip-ViT-B-32",
model_directory = model_dir / 'image/'
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/"
)
return search_config
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_config: SearchConfig):
content_dir = tmp_path_factory.mktemp('content')
content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images
content_config = ContentConfig()
content_config.image = ImageContentConfig(
input_directories = ['tests/data/images'],
embeddings_file = content_dir.joinpath('image_embeddings.pt'),
batch_size = 1,
use_xmp_metadata = False)
input_directories=["tests/data/images"],
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
batch_size=1,
use_xmp_metadata=False,
)
image_search.setup(content_config.image, search_config.image, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
input_files = None,
input_filter = ['tests/data/org/*.org'],
compressed_jsonl = content_dir.joinpath('notes.jsonl.gz'),
embeddings_file = content_dir.joinpath('note_embeddings.pt'))
input_files=None,
input_filter=["tests/data/org/*.org"],
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
@ -66,7 +74,7 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
return content_config
@pytest.fixture(scope='function')
@pytest.fixture(scope="function")
def new_org_file(content_config: ContentConfig):
# Setup
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
@ -79,9 +87,9 @@ def new_org_file(content_config: ContentConfig):
new_org_file.unlink()
@pytest.fixture(scope='function')
@pytest.fixture(scope="function")
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
new_org_config = deepcopy(content_config.org)
new_org_config.input_files = [f'{new_org_file}']
new_org_config.input_files = [f"{new_org_file}"]
new_org_config.input_filter = None
return new_org_config

View file

@ -8,10 +8,10 @@ from khoj.processor.ledger.beancount_to_jsonl import BeancountToJsonl
def test_no_transactions_in_file(tmp_path):
"Handle file with no transactions."
# Arrange
entry = f'''
entry = f"""
- Bullet point 1
- Bullet point 2
'''
"""
beancount_file = create_file(tmp_path, entry)
# Act
@ -20,7 +20,8 @@ def test_no_transactions_in_file(tmp_path):
# Process Each Entry from All Beancount Files
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries))
BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -30,11 +31,11 @@ def test_no_transactions_in_file(tmp_path):
def test_single_beancount_transaction_to_jsonl(tmp_path):
"Convert transaction from single file to jsonl."
# Arrange
entry = f'''
entry = f"""
1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES
'''
"""
beancount_file = create_file(tmp_path, entry)
# Act
@ -43,7 +44,8 @@ Assets:Test:Test -1.00 KES
# Process Each Entry from All Beancount Files
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map))
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -53,7 +55,7 @@ Assets:Test:Test -1.00 KES
def test_multiple_transactions_to_jsonl(tmp_path):
"Convert multiple transactions from single file to jsonl."
# Arrange
entry = f'''
entry = f"""
1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES
@ -61,7 +63,7 @@ Assets:Test:Test -1.00 KES
1984-04-01 * "Payee" "Narration"
Expenses:Test:Test 1.00 KES
Assets:Test:Test -1.00 KES
'''
"""
beancount_file = create_file(tmp_path, entry)
@ -71,7 +73,8 @@ Assets:Test:Test -1.00 KES
# Process Each Entry from All Beancount Files
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map))
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -95,8 +98,8 @@ def test_get_beancount_files(tmp_path):
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters
input_files = [tmp_path / 'ledger.bean']
input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount']
input_files = [tmp_path / "ledger.bean"]
input_filter = [tmp_path / "group1*.bean", tmp_path / "group2*.beancount"]
# Act
extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter)

View file

@ -6,7 +6,7 @@ from khoj.processor.conversation.gpt import converse, understand, message_to_pro
# Initialize variables for tests
model = 'text-davinci-003'
model = "text-davinci-003"
api_key = None # Input your OpenAI API key to run the tests below
@ -14,19 +14,22 @@ api_key = None # Input your OpenAI API key to run the tests below
# ----------------------------------------------------------------------------------------------------
def test_message_to_understand_prompt():
# Arrange
understand_primer = "Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=[\"companion\", \"notes\", \"ledger\", \"image\", \"music\"]\nsearch(search-type, data);\nsearch-type=[\"google\", \"youtube\"]\ngenerate(activity);\nactivity=[\"paint\",\"write\", \"chat\"]\ntrigger-emotion(emotion);\nemotion=[\"happy\",\"confidence\",\"fear\",\"surprise\",\"sadness\",\"disgust\",\"anger\", \"curiosity\", \"calm\"]\n\nQ: How are you doing?\nA: activity(\"chat\"); trigger-emotion(\"surprise\")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember(\"notes\", \"Brother Antoine when we were at the beach\"); trigger-emotion(\"curiosity\");\nQ: what did we talk about last time?\nA: remember(\"notes\", \"talk last time\"); trigger-emotion(\"curiosity\");\nQ: Let's make some drawings!\nA: generate(\"paint\"); trigger-emotion(\"happy\");\nQ: Do you know anything about Lebanon?\nA: search(\"google\", \"lebanon\"); trigger-emotion(\"confidence\");\nQ: Find a video about a panda rolling in the grass\nA: search(\"youtube\",\"panda rolling in the grass\"); trigger-emotion(\"happy\"); \nQ: Tell me a scary story\nA: generate(\"write\" \"A story about some adventure\"); trigger-emotion(\"fear\");\nQ: What fiction book was I reading last week about AI starship?\nA: remember(\"notes\", \"read fiction book about AI starship last week\"); trigger-emotion(\"curiosity\");\nQ: How much did I spend at Subway for dinner last time?\nA: remember(\"ledger\", \"last Subway dinner\"); trigger-emotion(\"curiosity\");\nQ: I'm feeling sleepy\nA: activity(\"chat\"); trigger-emotion(\"calm\")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember(\"music\", \"popular Sri lankan song that Alex showed recently\"); trigger-emotion(\"curiosity\"); \nQ: You're pretty funny!\nA: activity(\"chat\"); trigger-emotion(\"pride\")"
expected_response = "Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=[\"companion\", \"notes\", \"ledger\", \"image\", \"music\"]\nsearch(search-type, data);\nsearch-type=[\"google\", \"youtube\"]\ngenerate(activity);\nactivity=[\"paint\",\"write\", \"chat\"]\ntrigger-emotion(emotion);\nemotion=[\"happy\",\"confidence\",\"fear\",\"surprise\",\"sadness\",\"disgust\",\"anger\", \"curiosity\", \"calm\"]\n\nQ: How are you doing?\nA: activity(\"chat\"); trigger-emotion(\"surprise\")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember(\"notes\", \"Brother Antoine when we were at the beach\"); trigger-emotion(\"curiosity\");\nQ: what did we talk about last time?\nA: remember(\"notes\", \"talk last time\"); trigger-emotion(\"curiosity\");\nQ: Let's make some drawings!\nA: generate(\"paint\"); trigger-emotion(\"happy\");\nQ: Do you know anything about Lebanon?\nA: search(\"google\", \"lebanon\"); trigger-emotion(\"confidence\");\nQ: Find a video about a panda rolling in the grass\nA: search(\"youtube\",\"panda rolling in the grass\"); trigger-emotion(\"happy\"); \nQ: Tell me a scary story\nA: generate(\"write\" \"A story about some adventure\"); trigger-emotion(\"fear\");\nQ: What fiction book was I reading last week about AI starship?\nA: remember(\"notes\", \"read fiction book about AI starship last week\"); trigger-emotion(\"curiosity\");\nQ: How much did I spend at Subway for dinner last time?\nA: remember(\"ledger\", \"last Subway dinner\"); trigger-emotion(\"curiosity\");\nQ: I'm feeling sleepy\nA: activity(\"chat\"); trigger-emotion(\"calm\")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember(\"music\", \"popular Sri lankan song that Alex showed recently\"); trigger-emotion(\"curiosity\"); \nQ: You're pretty funny!\nA: activity(\"chat\"); trigger-emotion(\"pride\")\nQ: When did I last dine at Burger King?\nA:"
understand_primer = 'Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=["companion", "notes", "ledger", "image", "music"]\nsearch(search-type, data);\nsearch-type=["google", "youtube"]\ngenerate(activity);\nactivity=["paint","write", "chat"]\ntrigger-emotion(emotion);\nemotion=["happy","confidence","fear","surprise","sadness","disgust","anger", "curiosity", "calm"]\n\nQ: How are you doing?\nA: activity("chat"); trigger-emotion("surprise")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember("notes", "Brother Antoine when we were at the beach"); trigger-emotion("curiosity");\nQ: what did we talk about last time?\nA: remember("notes", "talk last time"); trigger-emotion("curiosity");\nQ: Let\'s make some drawings!\nA: generate("paint"); trigger-emotion("happy");\nQ: Do you know anything about Lebanon?\nA: search("google", "lebanon"); trigger-emotion("confidence");\nQ: Find a video about a panda rolling in the grass\nA: search("youtube","panda rolling in the grass"); trigger-emotion("happy"); \nQ: Tell me a scary story\nA: generate("write" "A story about some adventure"); trigger-emotion("fear");\nQ: What fiction book was I reading last week about AI starship?\nA: remember("notes", "read fiction book about AI starship last week"); trigger-emotion("curiosity");\nQ: How much did I spend at Subway for dinner last time?\nA: remember("ledger", "last Subway dinner"); trigger-emotion("curiosity");\nQ: I\'m feeling sleepy\nA: activity("chat"); trigger-emotion("calm")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember("music", "popular Sri lankan song that Alex showed recently"); trigger-emotion("curiosity"); \nQ: You\'re pretty funny!\nA: activity("chat"); trigger-emotion("pride")'
expected_response = 'Extract information from each chat message\n\nremember(memory-type, data);\nmemory-type=["companion", "notes", "ledger", "image", "music"]\nsearch(search-type, data);\nsearch-type=["google", "youtube"]\ngenerate(activity);\nactivity=["paint","write", "chat"]\ntrigger-emotion(emotion);\nemotion=["happy","confidence","fear","surprise","sadness","disgust","anger", "curiosity", "calm"]\n\nQ: How are you doing?\nA: activity("chat"); trigger-emotion("surprise")\nQ: Do you remember what I told you about my brother Antoine when we were at the beach?\nA: remember("notes", "Brother Antoine when we were at the beach"); trigger-emotion("curiosity");\nQ: what did we talk about last time?\nA: remember("notes", "talk last time"); trigger-emotion("curiosity");\nQ: Let\'s make some drawings!\nA: generate("paint"); trigger-emotion("happy");\nQ: Do you know anything about Lebanon?\nA: search("google", "lebanon"); trigger-emotion("confidence");\nQ: Find a video about a panda rolling in the grass\nA: search("youtube","panda rolling in the grass"); trigger-emotion("happy"); \nQ: Tell me a scary story\nA: generate("write" "A story about some adventure"); trigger-emotion("fear");\nQ: What fiction book was I reading last week about AI starship?\nA: remember("notes", "read fiction book about AI starship last week"); trigger-emotion("curiosity");\nQ: How much did I spend at Subway for dinner last time?\nA: remember("ledger", "last Subway dinner"); trigger-emotion("curiosity");\nQ: I\'m feeling sleepy\nA: activity("chat"); trigger-emotion("calm")\nQ: What was that popular Sri lankan song that Alex showed me recently?\nA: remember("music", "popular Sri lankan song that Alex showed recently"); trigger-emotion("curiosity"); \nQ: You\'re pretty funny!\nA: activity("chat"); trigger-emotion("pride")\nQ: When did I last dine at Burger King?\nA:'
# Act
actual_response = message_to_prompt("When did I last dine at Burger King?", understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
actual_response = message_to_prompt(
"When did I last dine at Burger King?", understand_primer, start_sequence="\nA:", restart_sequence="\nQ:"
)
# Assert
assert actual_response == expected_response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(api_key is None,
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys")
@pytest.mark.skipif(
api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
)
def test_minimal_chat_with_gpt():
# Act
response = converse("What will happen when the stars go out?", model=model, api_key=api_key)
@ -36,21 +39,29 @@ def test_minimal_chat_with_gpt():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(api_key is None,
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys")
@pytest.mark.skipif(
api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
)
def test_chat_with_history():
# Arrange
ai_prompt="AI:"
human_prompt="Human:"
ai_prompt = "AI:"
human_prompt = "Human:"
conversation_primer = f'''
conversation_primer = f"""
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly companion.
{human_prompt} Hello, I am Testatron. Who are you?
{ai_prompt} Hi, I am Khoj, an AI conversational companion created by OpenAI. How can I help you today?'''
{ai_prompt} Hi, I am Khoj, an AI conversational companion created by OpenAI. How can I help you today?"""
# Act
response = converse("Hi Khoj, What is my name?", model=model, conversation_history=conversation_primer, api_key=api_key, temperature=0, max_tokens=50)
response = converse(
"Hi Khoj, What is my name?",
model=model,
conversation_history=conversation_primer,
api_key=api_key,
temperature=0,
max_tokens=50,
)
# Assert
assert len(response) > 0
@ -58,12 +69,13 @@ The following is a conversation with an AI assistant. The assistant is helpful,
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(api_key is None,
reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys")
@pytest.mark.skipif(
api_key is None, reason="Set api_key variable to your OpenAI API key from https://beta.openai.com/account/api-keys"
)
def test_understand_message_using_gpt():
# Act
response = understand("When did I last dine at Subway?", model=model, api_key=api_key)
# Assert
assert len(response) > 0
assert response['intent']['memory-type'] == 'ledger'
assert response["intent"]["memory-type"] == "ledger"

View file

@ -14,35 +14,37 @@ def test_cli_minimal_default():
actual_args = cli([])
# Assert
assert actual_args.config_file == resolve_absolute_path(Path('~/.khoj/khoj.yml'))
assert actual_args.config_file == resolve_absolute_path(Path("~/.khoj/khoj.yml"))
assert actual_args.regenerate == False
assert actual_args.no_gui == False
assert actual_args.verbose == 0
# ----------------------------------------------------------------------------------------------------
def test_cli_invalid_config_file_path():
# Arrange
non_existent_config_file = f"non-existent-khoj-{random()}.yml"
# Act
actual_args = cli([f'-c={non_existent_config_file}'])
actual_args = cli([f"-c={non_existent_config_file}"])
# Assert
assert actual_args.config_file == resolve_absolute_path(non_existent_config_file)
assert actual_args.config == None
# ----------------------------------------------------------------------------------------------------
def test_cli_config_from_file():
# Act
actual_args = cli(['-c=tests/data/config.yml',
'--regenerate',
'--no-gui',
'-vvv'])
actual_args = cli(["-c=tests/data/config.yml", "--regenerate", "--no-gui", "-vvv"])
# Assert
assert actual_args.config_file == resolve_absolute_path(Path('tests/data/config.yml'))
assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml"))
assert actual_args.no_gui == True
assert actual_args.regenerate == True
assert actual_args.config is not None
assert actual_args.config.content_type.org.input_files == [Path('~/first_from_config.org'), Path('~/second_from_config.org')]
assert actual_args.config.content_type.org.input_files == [
Path("~/first_from_config.org"),
Path("~/second_from_config.org"),
]
assert actual_args.verbose == 3

View file

@ -21,6 +21,7 @@ from khoj.search_filter.file_filter import FileFilter
# ----------------------------------------------------------------------------------------------------
client = TestClient(app)
# Test
# ----------------------------------------------------------------------------------------------------
def test_search_with_invalid_content_type():
@ -98,9 +99,11 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
config.content_type = content_config
config.search_type = search_config
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [("kitten", "kitten_park.jpg"),
query_expected_image_pairs = [
("kitten", "kitten_park.jpg"),
("a horse and dog on a leash", "horse_dog.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg")]
("A guinea pig eating grass", "guineapig_grass.jpg"),
]
for query, expected_image_name in query_expected_image_pairs:
# Act
@ -135,7 +138,9 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter(), FileFilter()]
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
model.orgmode_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
)
user_query = quote('+"Emacs" file:"*.org"')
# Act
@ -152,7 +157,9 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
model.orgmode_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
)
user_query = quote('How to git install application? +"Emacs"')
# Act
@ -169,7 +176,9 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
model.orgmode_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
)
user_query = quote('How to git install application? -"clone"')
# Act

View file

@ -10,53 +10,59 @@ from khoj.utils.rawconfig import Entry
def test_date_filter():
entries = [
Entry(compiled='', raw='Entry with no date'),
Entry(compiled='', raw='April Fools entry: 1984-04-01'),
Entry(compiled='', raw='Entry with date:1984-04-02')
Entry(compiled="", raw="Entry with no date"),
Entry(compiled="", raw="April Fools entry: 1984-04-01"),
Entry(compiled="", raw="Entry with date:1984-04-02"),
]
q_with_no_date_filter = 'head tail'
q_with_no_date_filter = "head tail"
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2}
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == set()
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {1}
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {1, 2}
def test_extract_date_range():
assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
]
# Unparseable date filter specified in query
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
# No date filter specified in query
assert DateFilter().extract_date_range('head tail') == None
assert DateFilter().extract_date_range("head tail") == None
# Non intersecting date ranges
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
@ -66,43 +72,79 @@ def test_parse():
test_now = datetime(1984, 4, 1, 21, 21, 21)
# day variations
assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
assert DateFilter().parse("today", relative_base=test_now) == (
datetime(1984, 4, 1, 0, 0, 0),
datetime(1984, 4, 2, 0, 0, 0),
)
assert DateFilter().parse("tomorrow", relative_base=test_now) == (
datetime(1984, 4, 2, 0, 0, 0),
datetime(1984, 4, 3, 0, 0, 0),
)
assert DateFilter().parse("yesterday", relative_base=test_now) == (
datetime(1984, 3, 31, 0, 0, 0),
datetime(1984, 4, 1, 0, 0, 0),
)
assert DateFilter().parse("5 days ago", relative_base=test_now) == (
datetime(1984, 3, 27, 0, 0, 0),
datetime(1984, 3, 28, 0, 0, 0),
)
# week variations
assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
assert DateFilter().parse("last week", relative_base=test_now) == (
datetime(1984, 3, 18, 0, 0, 0),
datetime(1984, 3, 25, 0, 0, 0),
)
assert DateFilter().parse("2 weeks ago", relative_base=test_now) == (
datetime(1984, 3, 11, 0, 0, 0),
datetime(1984, 3, 18, 0, 0, 0),
)
# month variations
assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
assert DateFilter().parse("next month", relative_base=test_now) == (
datetime(1984, 5, 1, 0, 0, 0),
datetime(1984, 6, 1, 0, 0, 0),
)
assert DateFilter().parse("2 months ago", relative_base=test_now) == (
datetime(1984, 2, 1, 0, 0, 0),
datetime(1984, 3, 1, 0, 0, 0),
)
# year variations
assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
assert DateFilter().parse("this year", relative_base=test_now) == (
datetime(1984, 1, 1, 0, 0, 0),
datetime(1985, 1, 1, 0, 0, 0),
)
assert DateFilter().parse("20 years later", relative_base=test_now) == (
datetime(2004, 1, 1, 0, 0, 0),
datetime(2005, 1, 1, 0, 0, 0),
)
# specific month/date variation
assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
assert DateFilter().parse("in august", relative_base=test_now) == (
datetime(1983, 8, 1, 0, 0, 0),
datetime(1983, 8, 2, 0, 0, 0),
)
assert DateFilter().parse("on 1983-08-01", relative_base=test_now) == (
datetime(1983, 8, 1, 0, 0, 0),
datetime(1983, 8, 2, 0, 0, 0),
)
def test_date_filter_regex():
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
assert dtrange_match == [(">", "today"), (":", "1984-01-01")]
dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
assert dtrange_match == [(">", "today"), (":", "1984-01-01")]
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')]
assert dtrange_match == [(">=", "today"), ("=", "1984-01-01")]
dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
assert dtrange_match == [('<', 'multi word date')]
assert dtrange_match == [("<", "multi word date")]
dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
assert dtrange_match == [('<=', 'multi word date')]
assert dtrange_match == [("<=", "multi word date")]
dtrange_match = re.findall(DateFilter().date_regex, 'head tail')
dtrange_match = re.findall(DateFilter().date_regex, "head tail")
assert dtrange_match == []

View file

@ -7,7 +7,7 @@ def test_no_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head tail'
q_with_no_filter = "head tail"
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
@ -15,7 +15,7 @@ def test_no_file_filter():
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
@ -31,7 +31,7 @@ def test_file_filter_with_non_existent_file():
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {}
@ -47,7 +47,7 @@ def test_single_file_filter():
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {0, 2}
@ -63,7 +63,7 @@ def test_file_filter_with_partial_match():
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {0, 2}
@ -79,7 +79,7 @@ def test_file_filter_with_regex_match():
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
@ -95,16 +95,16 @@ def test_multiple_file_filter():
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def arrange_content():
entries = [
Entry(compiled='', raw='First Entry', file= 'file 1.org'),
Entry(compiled='', raw='Second Entry', file= 'file2.org'),
Entry(compiled='', raw='Third Entry', file= 'file 1.org'),
Entry(compiled='', raw='Fourth Entry', file= 'file2.org')
Entry(compiled="", raw="First Entry", file="file 1.org"),
Entry(compiled="", raw="Second Entry", file="file2.org"),
Entry(compiled="", raw="Third Entry", file="file 1.org"),
Entry(compiled="", raw="Fourth Entry", file="file2.org"),
]
return entries

View file

@ -1,5 +1,6 @@
from khoj.utils import helpers
def test_get_from_null_dict():
# null handling
assert helpers.get_from_dict(dict()) == dict()
@ -7,39 +8,39 @@ def test_get_from_null_dict():
# key present in nested dictionary
# 1-level dictionary
assert helpers.get_from_dict({'a': 1, 'b': 2}, 'a') == 1
assert helpers.get_from_dict({'a': 1, 'b': 2}, 'c') == None
assert helpers.get_from_dict({"a": 1, "b": 2}, "a") == 1
assert helpers.get_from_dict({"a": 1, "b": 2}, "c") == None
# 2-level dictionary
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'a') == {'a_a': 1}
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'a', 'a_a') == 1
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a") == {"a_a": 1}
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a", "a_a") == 1
# key not present in nested dictionary
# 2-level_dictionary
assert helpers.get_from_dict({'a': {'a_a': 1}, 'b': 2}, 'b', 'b_a') == None
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "b", "b_a") == None
def test_merge_dicts():
# basic merge of dicts with non-overlapping keys
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'b': 2}) == {'a': 1, 'b': 2}
assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"b": 2}) == {"a": 1, "b": 2}
# use default dict items when not present in priority dict
assert helpers.merge_dicts(priority_dict={}, default_dict={'b': 2}) == {'b': 2}
assert helpers.merge_dicts(priority_dict={}, default_dict={"b": 2}) == {"b": 2}
# do not override existing key in priority_dict with default dict
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1}
assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"a": 2}) == {"a": 1}
def test_lru_cache():
# Test initializing cache
cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2)
assert cache == {'a': 1, 'b': 2}
cache = helpers.LRU({"a": 1, "b": 2}, capacity=2)
assert cache == {"a": 1, "b": 2}
# Test capacity overflow
cache['c'] = 3
assert cache == {'b': 2, 'c': 3}
cache["c"] = 3
assert cache == {"b": 2, "c": 3}
# Test delete least recently used item from LRU cache on capacity overflow
cache['b'] # accessing 'b' makes it the most recently used item
cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {'b': 2, 'd': 4}
cache["b"] # accessing 'b' makes it the most recently used item
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {"b": 2, "d": 4}

View file

@ -30,7 +30,8 @@ def test_image_metadata(content_config: ContentConfig):
expected_metadata_image_name_pairs = [
(["Billi Ka Bacha.", "Cat", "Grass"], "kitten_park.jpg"),
(["Pasture.", "Horse", "Dog"], "horse_dog.jpg"),
(["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg")]
(["Guinea Pig Eating Celery.", "Rodent", "Whiskers"], "guineapig_grass.jpg"),
]
test_image_paths = [
Path(content_config.image.input_directories[0] / image_name[1])
@ -51,23 +52,23 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# Arrange
output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [("kitten", "kitten_park.jpg"),
query_expected_image_pairs = [
("kitten", "kitten_park.jpg"),
("horse and dog in a farm", "horse_dog.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg")]
("A guinea pig eating grass", "guineapig_grass.jpg"),
]
# Act
for query, expected_image_name in query_expected_image_pairs:
hits = image_search.query(
query,
count = 1,
model = model.image_search)
hits = image_search.query(query, count=1, model=model.image_search)
results = image_search.collate_results(
hits,
model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=1)
image_files_url="/static/images",
count=1,
)
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path)
@ -86,16 +87,13 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
# Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
max_words_supported = 10
query = " ".join(["hello"]*100)
truncated_query = " ".join(["hello"]*max_words_supported)
query = " ".join(["hello"] * 100)
truncated_query = " ".join(["hello"] * max_words_supported)
# Act
try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
image_search.query(
query,
count = 1,
model = model.image_search)
image_search.query(query, count=1, model=model.image_search)
# Assert
except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
@ -115,17 +113,15 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
# Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = image_search.query(
query,
count = 1,
model = model.image_search)
hits = image_search.query(query, count=1, model=model.image_search)
results = image_search.collate_results(
hits,
model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=1)
image_files_url="/static/images",
count=1,
)
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path)
@ -133,7 +129,9 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
# Assert
# Ensure file search triggered instead of query with file path as string
assert f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text, "File search not triggered"
assert (
f"Find Images by Image: {resolve_absolute_path(expected_image_path)}" in caplog.text
), "File search not triggered"
# Ensure the correct image is returned
assert expected_image == actual_image, "Incorrect image returned by file search"

View file

@ -8,10 +8,10 @@ from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
"Convert files with no heading to jsonl."
# Arrange
entry = f'''
entry = f"""
- Bullet point 1
- Bullet point 2
'''
"""
markdownfile = create_file(tmp_path, entry)
# Act
@ -20,7 +20,8 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries))
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -30,10 +31,10 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
def test_single_markdown_entry_to_jsonl(tmp_path):
"Convert markdown entry from single file to jsonl."
# Arrange
entry = f'''### Heading
entry = f"""### Heading
\t\r
Body Line 1
'''
"""
markdownfile = create_file(tmp_path, entry)
# Act
@ -42,7 +43,8 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map))
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -52,14 +54,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
def test_multiple_markdown_entries_to_jsonl(tmp_path):
"Convert multiple markdown entries from single file to jsonl."
# Arrange
entry = f'''
entry = f"""
### Heading 1
\t\r
Heading 1 Body Line 1
### Heading 2
\t\r
Heading 2 Body Line 2
'''
"""
markdownfile = create_file(tmp_path, entry)
# Act
@ -68,7 +70,8 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map))
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -92,8 +95,8 @@ def test_get_markdown_files(tmp_path):
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters
input_files = [tmp_path / 'notes.md']
input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown']
input_files = [tmp_path / "notes.md"]
input_filter = [tmp_path / "group1*.md", tmp_path / "group2*.markdown"]
# Act
extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter)
@ -106,10 +109,10 @@ def test_get_markdown_files(tmp_path):
def test_extract_entries_with_different_level_headings(tmp_path):
"Extract markdown entries with different level headings."
# Arrange
entry = f'''
entry = f"""
# Heading 1
## Heading 2
'''
"""
markdownfile = create_file(tmp_path, entry)
# Act

View file

@ -9,23 +9,25 @@ from khoj.utils.rawconfig import Entry
def test_configure_heading_entry_to_jsonl(tmp_path):
'''Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
Property drawers not considered Body. Ignore control characters for evaluating if Body empty.'''
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
# Arrange
entry = f'''*** Heading
entry = f"""*** Heading
:PROPERTIES:
:ID: 42-42-42
:END:
\t \r
'''
"""
orgfile = create_file(tmp_path, entry)
for index_heading_entries in [True, False]:
# Act
# Extract entries into jsonl from specified Org files
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(
*OrgToJsonl.extract_org_entries(org_files=[orgfile]),
index_heading_entries=index_heading_entries))
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
OrgToJsonl.convert_org_nodes_to_entries(
*OrgToJsonl.extract_org_entries(org_files=[orgfile]), index_heading_entries=index_heading_entries
)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -40,10 +42,10 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
def test_entry_split_when_exceeds_max_words(tmp_path):
"Ensure entries with compiled words exceeding max_words are split."
# Arrange
entry = f'''*** Heading
entry = f"""*** Heading
\t\r
Body Line 1
'''
"""
orgfile = create_file(tmp_path, entry)
# Act
@ -53,8 +55,8 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
# Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
TextToJsonl.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map),
max_tokens = 2)
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=2
)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@ -65,15 +67,15 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
def test_entry_split_drops_large_words(tmp_path):
"Ensure entries drops words larger than specified max word length from compiled version."
# Arrange
entry_text = f'''*** Heading
entry_text = f"""*** Heading
\t\r
Body Line 1
'''
"""
entry = Entry(raw=entry_text, compiled=entry_text)
# Act
# Split entry by max words and drop words larger than max word length
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length = 5)[0]
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert
# "Heading" dropped from compiled version because its over the set max word limit
@ -83,13 +85,13 @@ def test_entry_split_drops_large_words(tmp_path):
def test_entry_with_body_to_jsonl(tmp_path):
"Ensure entries with valid body text are loaded."
# Arrange
entry = f'''*** Heading
entry = f"""*** Heading
:PROPERTIES:
:ID: 42-42-42
:END:
\t\r
Body Line 1
'''
"""
orgfile = create_file(tmp_path, entry)
# Act
@ -97,7 +99,9 @@ def test_entry_with_body_to_jsonl(tmp_path):
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map))
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@ -107,10 +111,10 @@ def test_entry_with_body_to_jsonl(tmp_path):
def test_file_with_no_headings_to_jsonl(tmp_path):
"Ensure files with no heading, only body text are loaded."
# Arrange
entry = f'''
entry = f"""
- Bullet point 1
- Bullet point 2
'''
"""
orgfile = create_file(tmp_path, entry)
# Act
@ -143,8 +147,8 @@ def test_get_org_files(tmp_path):
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]))
# Setup input-files, input-filters
input_files = [tmp_path / 'orgfile1.org']
input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org']
input_files = [tmp_path / "orgfile1.org"]
input_filter = [tmp_path / "group1*.org", tmp_path / "group2*.org"]
# Act
extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter)
@ -157,10 +161,10 @@ def test_get_org_files(tmp_path):
def test_extract_entries_with_different_level_headings(tmp_path):
"Extract org entries with different level headings."
# Arrange
entry = f'''
entry = f"""
* Heading 1
** Heading 2
'''
"""
orgfile = create_file(tmp_path, entry)
# Act
@ -169,8 +173,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Assert
assert len(entries) == 2
assert f'{entries[0]}'.startswith("* Heading 1")
assert f'{entries[1]}'.startswith("** Heading 2")
assert f"{entries[0]}".startswith("* Heading 1")
assert f"{entries[1]}".startswith("** Heading 2")
# Helper Functions

View file

@ -10,7 +10,7 @@ from khoj.processor.org_mode import orgnode
def test_parse_entry_with_no_headings(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''Body Line 1'''
entry = f"""Body Line 1"""
orgfile = create_file(tmp_path, entry)
# Act
@ -18,7 +18,7 @@ def test_parse_entry_with_no_headings(tmp_path):
# Assert
assert len(entries) == 1
assert entries[0].heading == f'{orgfile}'
assert entries[0].heading == f"{orgfile}"
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1"
assert entries[0].priority == ""
@ -32,9 +32,9 @@ def test_parse_entry_with_no_headings(tmp_path):
def test_parse_minimal_entry(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''
entry = f"""
* Heading
Body Line 1'''
Body Line 1"""
orgfile = create_file(tmp_path, entry)
# Act
@ -56,7 +56,7 @@ Body Line 1'''
def test_parse_complete_entry(tmp_path):
"Test parsing of entry with all important fields"
# Arrange
entry = f'''
entry = f"""
*** DONE [#A] Heading :Tag1:TAG2:tag3:
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
:PROPERTIES:
@ -67,7 +67,7 @@ CLOCK: [1984-04-01 Sun 09:00]--[1984-04-01 Sun 12:00] => 3:00
- Clocked Log 1
:END:
Body Line 1
Body Line 2'''
Body Line 2"""
orgfile = create_file(tmp_path, entry)
# Act
@ -81,45 +81,45 @@ Body Line 2'''
assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2"
assert entries[0].priority == "A"
assert entries[0].Property("ID") == "id:123-456-789-4234-1231"
assert entries[0].closed == datetime.date(1984,4,1)
assert entries[0].scheduled == datetime.date(1984,4,1)
assert entries[0].deadline == datetime.date(1984,4,1)
assert entries[0].logbook == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))]
assert entries[0].closed == datetime.date(1984, 4, 1)
assert entries[0].scheduled == datetime.date(1984, 4, 1)
assert entries[0].deadline == datetime.date(1984, 4, 1)
assert entries[0].logbook == [(datetime.datetime(1984, 4, 1, 9, 0, 0), datetime.datetime(1984, 4, 1, 12, 0, 0))]
# ----------------------------------------------------------------------------------------------------
def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
"Render heading entry with property drawer"
# Arrange
entry_to_render = f'''
entry_to_render = f"""
*** [#A] Heading1 :tag1:
:PROPERTIES:
:ID: 111-111-111-1111-1111
:END:
\t\r \n
'''
"""
orgfile = create_file(tmp_path, entry_to_render)
expected_entry = f'''*** [#A] Heading1 :tag1:
expected_entry = f"""*** [#A] Heading1 :tag1:
:PROPERTIES:
:LINE: file:{orgfile}::2
:ID: id:111-111-111-1111-1111
:SOURCE: [[file:{orgfile}::*Heading1]]
:END:
'''
"""
# Act
parsed_entries = orgnode.makelist(orgfile)
# Assert
assert f'{parsed_entries[0]}' == expected_entry
assert f"{parsed_entries[0]}" == expected_entry
# ----------------------------------------------------------------------------------------------------
def test_all_links_to_entry_rendered(tmp_path):
"Ensure all links to entry rendered in property drawer from entry"
# Arrange
entry = f'''
entry = f"""
*** [#A] Heading :tag1:
:PROPERTIES:
:ID: 123-456-789-4234-1231
@ -127,7 +127,7 @@ def test_all_links_to_entry_rendered(tmp_path):
Body Line 1
*** Heading2
Body Line 2
'''
"""
orgfile = create_file(tmp_path, entry)
# Act
@ -135,23 +135,23 @@ Body Line 2
# Assert
# SOURCE link rendered with Heading
assert f':SOURCE: [[file:{orgfile}::*{entries[0].heading}]]' in f'{entries[0]}'
assert f":SOURCE: [[file:{orgfile}::*{entries[0].heading}]]" in f"{entries[0]}"
# ID link rendered with ID
assert f':ID: id:123-456-789-4234-1231' in f'{entries[0]}'
assert f":ID: id:123-456-789-4234-1231" in f"{entries[0]}"
# LINE link rendered with line number
assert f':LINE: file:{orgfile}::2' in f'{entries[0]}'
assert f":LINE: file:{orgfile}::2" in f"{entries[0]}"
# ----------------------------------------------------------------------------------------------------
def test_source_link_to_entry_escaped_for_rendering(tmp_path):
"Test SOURCE link renders with square brackets in filename, heading escaped for org-mode rendering"
# Arrange
entry = f'''
entry = f"""
*** [#A] Heading[1] :tag1:
:PROPERTIES:
:ID: 123-456-789-4234-1231
:END:
Body Line 1'''
Body Line 1"""
orgfile = create_file(tmp_path, entry, filename="test[1].org")
# Act
@ -162,15 +162,15 @@ Body Line 1'''
# parsed heading from entry
assert entries[0].heading == "Heading[1]"
# ensure SOURCE link has square brackets in filename, heading escaped in rendered entries
escaped_orgfile = f'{orgfile}'.replace("[1]", "\\[1\\]")
assert f':SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]' in f'{entries[0]}'
escaped_orgfile = f"{orgfile}".replace("[1]", "\\[1\\]")
assert f":SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]" in f"{entries[0]}"
# ----------------------------------------------------------------------------------------------------
def test_parse_multiple_entries(tmp_path):
"Test parsing of multiple entries"
# Arrange
content = f'''
content = f"""
*** FAILED [#A] Heading1 :tag1:
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
:PROPERTIES:
@ -193,7 +193,7 @@ CLOCK: [1984-04-02 Mon 09:00]--[1984-04-02 Mon 12:00] => 3:00
:END:
Body 2
'''
"""
orgfile = create_file(tmp_path, content)
# Act
@ -208,18 +208,20 @@ Body 2
assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n"
assert entry.priority == "A"
assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}"
assert entry.closed == datetime.date(1984,4,index+1)
assert entry.scheduled == datetime.date(1984,4,index+1)
assert entry.deadline == datetime.date(1984,4,index+1)
assert entry.logbook == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))]
assert entry.closed == datetime.date(1984, 4, index + 1)
assert entry.scheduled == datetime.date(1984, 4, index + 1)
assert entry.deadline == datetime.date(1984, 4, index + 1)
assert entry.logbook == [
(datetime.datetime(1984, 4, index + 1, 9, 0, 0), datetime.datetime(1984, 4, index + 1, 12, 0, 0))
]
# ----------------------------------------------------------------------------------------------------
def test_parse_entry_with_empty_title(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''#+TITLE:
Body Line 1'''
entry = f"""#+TITLE:
Body Line 1"""
orgfile = create_file(tmp_path, entry)
# Act
@ -227,7 +229,7 @@ Body Line 1'''
# Assert
assert len(entries) == 1
assert entries[0].heading == f'{orgfile}'
assert entries[0].heading == f"{orgfile}"
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1"
assert entries[0].priority == ""
@ -241,8 +243,8 @@ Body Line 1'''
def test_parse_entry_with_title_and_no_headings(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''#+TITLE: test
Body Line 1'''
entry = f"""#+TITLE: test
Body Line 1"""
orgfile = create_file(tmp_path, entry)
# Act
@ -250,7 +252,7 @@ Body Line 1'''
# Assert
assert len(entries) == 1
assert entries[0].heading == 'test'
assert entries[0].heading == "test"
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1"
assert entries[0].priority == ""
@ -264,9 +266,9 @@ Body Line 1'''
def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path):
"Test parsing of entry with minimal fields"
# Arrange
entry = f'''#+TITLE: title1
entry = f"""#+TITLE: title1
Body Line 1
#+TITLE: title2 '''
#+TITLE: title2 """
orgfile = create_file(tmp_path, entry)
# Act
@ -274,7 +276,7 @@ Body Line 1
# Assert
assert len(entries) == 1
assert entries[0].heading == 'title1 title2'
assert entries[0].heading == "title1 title2"
assert entries[0].tags == list()
assert entries[0].body == "Body Line 1\n"
assert entries[0].priority == ""

View file

@ -14,7 +14,9 @@ from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
# Test
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
def test_asymmetric_setup_with_missing_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
):
# Arrange
# Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0])
@ -27,10 +29,12 @@ def test_asymmetric_setup_with_missing_file_raises_error(org_config_with_only_ne
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup_with_empty_file_raises_error(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
def test_asymmetric_setup_with_empty_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
):
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r'^No valid entries found*'):
with pytest.raises(ValueError, match=r"^No valid entries found*"):
text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
@ -52,15 +56,9 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
query = "How to git install application?"
# Act
hits, entries = text_search.query(
query,
model = model.notes_search,
rank_results=True)
hits, entries = text_search.query(query, model=model.notes_search, rank_results=True)
results = text_search.collate_results(
hits,
entries,
count=1)
results = text_search.collate_results(hits, entries, count=1)
# Assert
# Actual_data should contain "Khoj via Emacs" entry
@ -76,12 +74,14 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
with open(new_file_to_index, "w") as f:
f.write(f"* Entry more than {max_tokens} words\n")
for index in range(max_tokens+1):
for index in range(max_tokens + 1):
f.write(f"{index} ")
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False
)
# Assert
# verify newly added org-mode entry is split by max tokens
@ -92,18 +92,20 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
# Arrange
initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
# append org-mode entry to first org input file in config
content_config.org.input_files = [f'{new_org_file}']
content_config.org.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True
)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
@ -137,7 +139,7 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f'{new_org_file}']
content_config.org.input_files = [f"{new_org_file}"]
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
# Assert

View file

@ -7,7 +7,7 @@ def test_no_word_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_no_filter = 'head tail'
q_with_no_filter = "head tail"
# Act
can_filter = word_filter.can_filter(q_with_no_filter)
@ -15,7 +15,7 @@ def test_no_word_filter():
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
@ -31,7 +31,7 @@ def test_word_exclude_filter():
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {0, 2}
@ -47,7 +47,7 @@ def test_word_include_filter():
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {2, 3}
@ -63,16 +63,16 @@ def test_word_include_and_exclude_filter():
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert ret_query == "head tail"
assert entry_indices == {2}
def arrange_content():
entries = [
Entry(compiled='', raw='Minimal Entry'),
Entry(compiled='', raw='Entry with exclude_word'),
Entry(compiled='', raw='Entry with include_word'),
Entry(compiled='', raw='Entry with include_word and exclude_word')
Entry(compiled="", raw="Minimal Entry"),
Entry(compiled="", raw="Entry with exclude_word"),
Entry(compiled="", raw="Entry with include_word"),
Entry(compiled="", raw="Entry with include_word and exclude_word"),
]
return entries