diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js index 6ee8ff75..12d486e9 100644 --- a/src/interface/desktop/main.js +++ b/src/interface/desktop/main.js @@ -19,7 +19,7 @@ const textFileTypes = [ 'org', 'md', 'markdown', 'txt', 'html', 'xml', // Other valid text file extensions from https://google.github.io/magika/model/config.json 'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml'] -const binaryFileTypes = ['pdf'] +const binaryFileTypes = ['pdf', 'jpg', 'jpeg', 'png'] const validFileTypes = textFileTypes.concat(binaryFileTypes); const schema = { diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 1ae4397f..cf7d1598 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -48,8 +48,8 @@ Get the Khoj [Desktop](https://khoj.dev/downloads), [Obsidian](https://docs.khoj To get started, just start typing below. You can also type / to see a list of commands. `.trim() - const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document']; - const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'docx']; + const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'image/jpeg', 'image/png', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document']; + const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'jpg', 'jpeg', 'png', 'docx']; let chatOptions = []; function createCopyParentText(message) { return function(event) { @@ -974,7 +974,12 @@ To get started, just start typing below. You can also type / to see a list of co fileType = "text/html"; } else if (fileExtension === "pdf") { fileType = "application/pdf"; - } else { + } else if (fileExtension === "jpg" || fileExtension === "jpeg"){ + fileType = "image/jpeg"; + } else if (fileExtension === "png") { + fileType = "image/png"; + } + else { // Skip this file if its type is not supported resolve(); return; diff --git a/src/khoj/processor/content/images/__init__.py b/src/khoj/processor/content/images/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/processor/content/images/image_to_entries.py b/src/khoj/processor/content/images/image_to_entries.py new file mode 100644 index 00000000..20705a0f --- /dev/null +++ b/src/khoj/processor/content/images/image_to_entries.py @@ -0,0 +1,118 @@ +import base64 +import logging +import os +from datetime import datetime +from typing import Dict, List, Tuple + +from rapidocr_onnxruntime import RapidOCR + +from khoj.database.models import Entry as DbEntry +from khoj.database.models import KhojUser +from khoj.processor.content.text_to_entries import TextToEntries +from khoj.utils.helpers import timer +from khoj.utils.rawconfig import Entry + +logger = logging.getLogger(__name__) + + +class ImageToEntries(TextToEntries): + def __init__(self): + super().__init__() + + # Define Functions + def process( + self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False + ) -> Tuple[int, int]: + # Extract required fields from config + if not full_corpus: + deletion_file_names = set([file for file in files if files[file] == b""]) + files_to_process = set(files) - deletion_file_names + files = {file: files[file] for file in files_to_process} + else: + deletion_file_names = None + + # Extract Entries from specified image files + with timer("Extract entries from specified Image files", logger): + file_to_text_map, current_entries = ImageToEntries.extract_image_entries(files) + + # Split entries by max tokens supported by model + with timer("Split entries by max token size supported by model", logger): + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) + + # Identify, mark and merge any new entries with previous entries + with timer("Identify new or updated entries", logger): + num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + current_entries, + DbEntry.EntryType.IMAGE, + DbEntry.EntrySource.COMPUTER, + "compiled", + logger, + deletion_file_names, + user, + regenerate=regenerate, + file_to_text_map=file_to_text_map, + ) + + return num_new_embeddings, num_deleted_embeddings + + @staticmethod + def extract_image_entries(image_files) -> Tuple[Dict, List[Entry]]: # important function + """Extract entries by page from specified image files""" + file_to_text_map = dict() + entries: List[str] = [] + entry_to_location_map: List[Tuple[str, str]] = [] + for image_file in image_files: + try: + loader = RapidOCR() + bytes = image_files[image_file] + # write the image to a temporary file + timestamp_now = datetime.utcnow().timestamp() + # use either png or jpg + if image_file.endswith(".png"): + tmp_file = f"tmp_image_file_{timestamp_now}.png" + elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"): + tmp_file = f"tmp_image_file_{timestamp_now}.jpg" + with open(tmp_file, "wb") as f: + bytes = image_files[image_file] + f.write(bytes) + try: + image_entries_per_file = "" + result, _ = loader(tmp_file) + if result: + expanded_entries = [text[1] for text in result] + image_entries_per_file = " ".join(expanded_entries) + except ImportError: + logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.") + continue + entry_to_location_map.append((image_entries_per_file, image_file)) + entries.extend([image_entries_per_file]) + file_to_text_map[image_file] = image_entries_per_file + except Exception as e: + logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.") + logger.warning(e, exc_info=True) + finally: + if os.path.exists(tmp_file): + os.remove(tmp_file) + return file_to_text_map, ImageToEntries.convert_image_entries_to_maps(entries, dict(entry_to_location_map)) + + @staticmethod + def convert_image_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]: + "Convert each image entries into a dictionary" + entries = [] + for parsed_entry in parsed_entries: + entry_filename = entry_to_file_map[parsed_entry] + # Append base filename to compiled entry for context to model + heading = f"{entry_filename}\n" + compiled_entry = f"{heading}{parsed_entry}" + entries.append( + Entry( + compiled=compiled_entry, + raw=parsed_entry, + heading=heading, + file=f"{entry_filename}", + ) + ) + + logger.debug(f"Converted {len(parsed_entries)} image entries to dictionaries") + + return entries diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index 4391b1ec..2046ad41 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -9,6 +9,7 @@ from starlette.authentication import requires from khoj.database.models import GithubConfig, KhojUser, NotionConfig from khoj.processor.content.docx.docx_to_entries import DocxToEntries from khoj.processor.content.github.github_to_entries import GithubToEntries +from khoj.processor.content.images.image_to_entries import ImageToEntries from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.notion.notion_to_entries import NotionToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries @@ -41,6 +42,7 @@ class IndexerInput(BaseModel): markdown: Optional[dict[str, str]] = None pdf: Optional[dict[str, bytes]] = None plaintext: Optional[dict[str, str]] = None + image: Optional[dict[str, bytes]] = None docx: Optional[dict[str, bytes]] = None @@ -65,7 +67,14 @@ async def update( ), ): user = request.user.object - index_files: Dict[str, Dict[str, str]] = {"org": {}, "markdown": {}, "pdf": {}, "plaintext": {}, "docx": {}} + index_files: Dict[str, Dict[str, str]] = { + "org": {}, + "markdown": {}, + "pdf": {}, + "plaintext": {}, + "image": {}, + "docx": {}, + } try: logger.info(f"📬 Updating content index via API call by {client} client") for file in files: @@ -81,6 +90,7 @@ async def update( markdown=index_files["markdown"], pdf=index_files["pdf"], plaintext=index_files["plaintext"], + image=index_files["image"], docx=index_files["docx"], ) @@ -133,6 +143,7 @@ async def update( "num_markdown": len(index_files["markdown"]), "num_pdf": len(index_files["pdf"]), "num_plaintext": len(index_files["plaintext"]), + "num_image": len(index_files["image"]), "num_docx": len(index_files["docx"]), } @@ -300,6 +311,23 @@ def configure_content( logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True) success = False + try: + # Initialize Image Search + if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[ + "image" + ]: + logger.info("🖼️ Setting up search for images") + # Extract Entries, Generate Image Embeddings + text_search.setup( + ImageToEntries, + files.get("image"), + regenerate=regenerate, + full_corpus=full_corpus, + user=user, + ) + except Exception as e: + logger.error(f"🚨 Failed to setup images: {e}", exc_info=True) + success = False try: if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]: logger.info("📄 Setting up search for docx") diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 12eaae91..a98b715c 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -118,9 +118,9 @@ def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]: elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]: return "docx", encoding elif file_type in ["image/jpeg"]: - return "jpeg", encoding + return "image", encoding elif file_type in ["image/png"]: - return "png", encoding + return "image", encoding elif content_group in ["code", "text"]: return "plaintext", encoding else: diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index b4bdcbea..8ed635e6 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -65,6 +65,7 @@ class ContentConfig(ConfigBase): plaintext: Optional[TextContentConfig] = None github: Optional[GithubContentConfig] = None notion: Optional[NotionContentConfig] = None + image: Optional[TextContentConfig] = None docx: Optional[TextContentConfig] = None diff --git a/tests/data/images/nasdaq.jpg b/tests/data/images/nasdaq.jpg new file mode 100644 index 00000000..8b46ed06 Binary files /dev/null and b/tests/data/images/nasdaq.jpg differ diff --git a/tests/data/images/testocr.png b/tests/data/images/testocr.png new file mode 100644 index 00000000..d4b66ead Binary files /dev/null and b/tests/data/images/testocr.png differ diff --git a/tests/test_image_to_entries.py b/tests/test_image_to_entries.py new file mode 100644 index 00000000..77254c0a --- /dev/null +++ b/tests/test_image_to_entries.py @@ -0,0 +1,21 @@ +import os + +from khoj.processor.content.images.image_to_entries import ImageToEntries + + +def test_png_to_jsonl(): + with open("tests/data/images/testocr.png", "rb") as f: + image_bytes = f.read() + data = {"tests/data/images/testocr.png": image_bytes} + entries = ImageToEntries.extract_image_entries(image_files=data) + assert len(entries) == 2 + assert "opencv-python" in entries[1][0].raw + + +def test_jpg_to_jsonl(): + with open("tests/data/images/nasdaq.jpg", "rb") as f: + image_bytes = f.read() + data = {"tests/data/images/nasdaq.jpg": image_bytes} + entries = ImageToEntries.extract_image_entries(image_files=data) + assert len(entries) == 2 + assert "investments" in entries[1][0].raw