Support Indexing Images via OCR (#823)

- Added support for uploading .jpeg, .jpg, and .png files to Khoj from Web, Desktop app
- Updating indexer to generate raw text and entries using RapidOCR
- Details
  * added support for indexing images via ocr
  * fixed pyproject.toml
  * Update src/khoj/processor/content/images/image_to_entries.py
     Co-authored-by: Debanjum <debanjum@gmail.com>
  * Update src/khoj/processor/content/images/image_to_entries.py
     Co-authored-by: Debanjum <debanjum@gmail.com>
  * removed redudant try except blocks
  * updated desktop js file to support image formats
  * added tests for jpg and png
  * Fix processing for image to entries files
  * Update unit tests with working image indexer
  * Change png test from version verificaition to open-cv verification

---------

Co-authored-by: Debanjum <debanjum@gmail.com>
Co-authored-by: sabaimran <narmiabas@gmail.com>
This commit is contained in:
Raghav Tirumale 2024-07-01 09:00:00 -04:00 committed by GitHub
parent c83b8f2768
commit 8eccd8a5e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 180 additions and 7 deletions

View file

@ -19,7 +19,7 @@ const textFileTypes = [
'org', 'md', 'markdown', 'txt', 'html', 'xml', 'org', 'md', 'markdown', 'txt', 'html', 'xml',
// Other valid text file extensions from https://google.github.io/magika/model/config.json // Other valid text file extensions from https://google.github.io/magika/model/config.json
'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml'] 'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml']
const binaryFileTypes = ['pdf'] const binaryFileTypes = ['pdf', 'jpg', 'jpeg', 'png']
const validFileTypes = textFileTypes.concat(binaryFileTypes); const validFileTypes = textFileTypes.concat(binaryFileTypes);
const schema = { const schema = {

View file

@ -48,8 +48,8 @@ Get the Khoj [Desktop](https://khoj.dev/downloads), [Obsidian](https://docs.khoj
To get started, just start typing below. You can also type / to see a list of commands. To get started, just start typing below. You can also type / to see a list of commands.
`.trim() `.trim()
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document']; const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'image/jpeg', 'image/png', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'docx']; const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'jpg', 'jpeg', 'png', 'docx'];
let chatOptions = []; let chatOptions = [];
function createCopyParentText(message) { function createCopyParentText(message) {
return function(event) { return function(event) {
@ -974,7 +974,12 @@ To get started, just start typing below. You can also type / to see a list of co
fileType = "text/html"; fileType = "text/html";
} else if (fileExtension === "pdf") { } else if (fileExtension === "pdf") {
fileType = "application/pdf"; fileType = "application/pdf";
} else { } else if (fileExtension === "jpg" || fileExtension === "jpeg"){
fileType = "image/jpeg";
} else if (fileExtension === "png") {
fileType = "image/png";
}
else {
// Skip this file if its type is not supported // Skip this file if its type is not supported
resolve(); resolve();
return; return;

View file

@ -0,0 +1,118 @@
import base64
import logging
import os
from datetime import datetime
from typing import Dict, List, Tuple
from rapidocr_onnxruntime import RapidOCR
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
logger = logging.getLogger(__name__)
class ImageToEntries(TextToEntries):
def __init__(self):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified image files
with timer("Extract entries from specified Image files", logger):
file_to_text_map, current_entries = ImageToEntries.extract_image_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_image_entries(image_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified image files"""
file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for image_file in image_files:
try:
loader = RapidOCR()
bytes = image_files[image_file]
# write the image to a temporary file
timestamp_now = datetime.utcnow().timestamp()
# use either png or jpg
if image_file.endswith(".png"):
tmp_file = f"tmp_image_file_{timestamp_now}.png"
elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"):
tmp_file = f"tmp_image_file_{timestamp_now}.jpg"
with open(tmp_file, "wb") as f:
bytes = image_files[image_file]
f.write(bytes)
try:
image_entries_per_file = ""
result, _ = loader(tmp_file)
if result:
expanded_entries = [text[1] for text in result]
image_entries_per_file = " ".join(expanded_entries)
except ImportError:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
continue
entry_to_location_map.append((image_entries_per_file, image_file))
entries.extend([image_entries_per_file])
file_to_text_map[image_file] = image_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(tmp_file):
os.remove(tmp_file)
return file_to_text_map, ImageToEntries.convert_image_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
def convert_image_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each image entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entry_filename = entry_to_file_map[parsed_entry]
# Append base filename to compiled entry for context to model
heading = f"{entry_filename}\n"
compiled_entry = f"{heading}{parsed_entry}"
entries.append(
Entry(
compiled=compiled_entry,
raw=parsed_entry,
heading=heading,
file=f"{entry_filename}",
)
)
logger.debug(f"Converted {len(parsed_entries)} image entries to dictionaries")
return entries

View file

@ -9,6 +9,7 @@ from starlette.authentication import requires
from khoj.database.models import GithubConfig, KhojUser, NotionConfig from khoj.database.models import GithubConfig, KhojUser, NotionConfig
from khoj.processor.content.docx.docx_to_entries import DocxToEntries from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.notion.notion_to_entries import NotionToEntries from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
@ -41,6 +42,7 @@ class IndexerInput(BaseModel):
markdown: Optional[dict[str, str]] = None markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None docx: Optional[dict[str, bytes]] = None
@ -65,7 +67,14 @@ async def update(
), ),
): ):
user = request.user.object user = request.user.object
index_files: Dict[str, Dict[str, str]] = {"org": {}, "markdown": {}, "pdf": {}, "plaintext": {}, "docx": {}} index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try: try:
logger.info(f"📬 Updating content index via API call by {client} client") logger.info(f"📬 Updating content index via API call by {client} client")
for file in files: for file in files:
@ -81,6 +90,7 @@ async def update(
markdown=index_files["markdown"], markdown=index_files["markdown"],
pdf=index_files["pdf"], pdf=index_files["pdf"],
plaintext=index_files["plaintext"], plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"], docx=index_files["docx"],
) )
@ -133,6 +143,7 @@ async def update(
"num_markdown": len(index_files["markdown"]), "num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]), "num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]), "num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]), "num_docx": len(index_files["docx"]),
} }
@ -300,6 +311,23 @@ def configure_content(
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True) logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
success = False success = False
try:
# Initialize Image Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
"image"
]:
logger.info("🖼️ Setting up search for images")
# Extract Entries, Generate Image Embeddings
text_search.setup(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try: try:
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]: if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]:
logger.info("📄 Setting up search for docx") logger.info("📄 Setting up search for docx")

View file

@ -118,9 +118,9 @@ def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]: elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
return "docx", encoding return "docx", encoding
elif file_type in ["image/jpeg"]: elif file_type in ["image/jpeg"]:
return "jpeg", encoding return "image", encoding
elif file_type in ["image/png"]: elif file_type in ["image/png"]:
return "png", encoding return "image", encoding
elif content_group in ["code", "text"]: elif content_group in ["code", "text"]:
return "plaintext", encoding return "plaintext", encoding
else: else:

View file

@ -65,6 +65,7 @@ class ContentConfig(ConfigBase):
plaintext: Optional[TextContentConfig] = None plaintext: Optional[TextContentConfig] = None
github: Optional[GithubContentConfig] = None github: Optional[GithubContentConfig] = None
notion: Optional[NotionContentConfig] = None notion: Optional[NotionContentConfig] = None
image: Optional[TextContentConfig] = None
docx: Optional[TextContentConfig] = None docx: Optional[TextContentConfig] = None

BIN
tests/data/images/nasdaq.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
tests/data/images/testocr.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

View file

@ -0,0 +1,21 @@
import os
from khoj.processor.content.images.image_to_entries import ImageToEntries
def test_png_to_jsonl():
with open("tests/data/images/testocr.png", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/testocr.png": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "opencv-python" in entries[1][0].raw
def test_jpg_to_jsonl():
with open("tests/data/images/nasdaq.jpg", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/nasdaq.jpg": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "investments" in entries[1][0].raw