Support Indexing Images via OCR (#823)

- Added support for uploading .jpeg, .jpg, and .png files to Khoj from Web, Desktop app
- Updating indexer to generate raw text and entries using RapidOCR
- Details
  * added support for indexing images via ocr
  * fixed pyproject.toml
  * Update src/khoj/processor/content/images/image_to_entries.py
     Co-authored-by: Debanjum <debanjum@gmail.com>
  * Update src/khoj/processor/content/images/image_to_entries.py
     Co-authored-by: Debanjum <debanjum@gmail.com>
  * removed redudant try except blocks
  * updated desktop js file to support image formats
  * added tests for jpg and png
  * Fix processing for image to entries files
  * Update unit tests with working image indexer
  * Change png test from version verificaition to open-cv verification

---------

Co-authored-by: Debanjum <debanjum@gmail.com>
Co-authored-by: sabaimran <narmiabas@gmail.com>
This commit is contained in:
Raghav Tirumale 2024-07-01 09:00:00 -04:00 committed by GitHub
parent c83b8f2768
commit 8eccd8a5e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 180 additions and 7 deletions

View file

@ -19,7 +19,7 @@ const textFileTypes = [
'org', 'md', 'markdown', 'txt', 'html', 'xml',
// Other valid text file extensions from https://google.github.io/magika/model/config.json
'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml']
const binaryFileTypes = ['pdf']
const binaryFileTypes = ['pdf', 'jpg', 'jpeg', 'png']
const validFileTypes = textFileTypes.concat(binaryFileTypes);
const schema = {

View file

@ -48,8 +48,8 @@ Get the Khoj [Desktop](https://khoj.dev/downloads), [Obsidian](https://docs.khoj
To get started, just start typing below. You can also type / to see a list of commands.
`.trim()
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'docx'];
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'image/jpeg', 'image/png', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'jpg', 'jpeg', 'png', 'docx'];
let chatOptions = [];
function createCopyParentText(message) {
return function(event) {
@ -974,7 +974,12 @@ To get started, just start typing below. You can also type / to see a list of co
fileType = "text/html";
} else if (fileExtension === "pdf") {
fileType = "application/pdf";
} else {
} else if (fileExtension === "jpg" || fileExtension === "jpeg"){
fileType = "image/jpeg";
} else if (fileExtension === "png") {
fileType = "image/png";
}
else {
// Skip this file if its type is not supported
resolve();
return;

View file

@ -0,0 +1,118 @@
import base64
import logging
import os
from datetime import datetime
from typing import Dict, List, Tuple
from rapidocr_onnxruntime import RapidOCR
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
logger = logging.getLogger(__name__)
class ImageToEntries(TextToEntries):
def __init__(self):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified image files
with timer("Extract entries from specified Image files", logger):
file_to_text_map, current_entries = ImageToEntries.extract_image_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_image_entries(image_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified image files"""
file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for image_file in image_files:
try:
loader = RapidOCR()
bytes = image_files[image_file]
# write the image to a temporary file
timestamp_now = datetime.utcnow().timestamp()
# use either png or jpg
if image_file.endswith(".png"):
tmp_file = f"tmp_image_file_{timestamp_now}.png"
elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"):
tmp_file = f"tmp_image_file_{timestamp_now}.jpg"
with open(tmp_file, "wb") as f:
bytes = image_files[image_file]
f.write(bytes)
try:
image_entries_per_file = ""
result, _ = loader(tmp_file)
if result:
expanded_entries = [text[1] for text in result]
image_entries_per_file = " ".join(expanded_entries)
except ImportError:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
continue
entry_to_location_map.append((image_entries_per_file, image_file))
entries.extend([image_entries_per_file])
file_to_text_map[image_file] = image_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(tmp_file):
os.remove(tmp_file)
return file_to_text_map, ImageToEntries.convert_image_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
def convert_image_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each image entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entry_filename = entry_to_file_map[parsed_entry]
# Append base filename to compiled entry for context to model
heading = f"{entry_filename}\n"
compiled_entry = f"{heading}{parsed_entry}"
entries.append(
Entry(
compiled=compiled_entry,
raw=parsed_entry,
heading=heading,
file=f"{entry_filename}",
)
)
logger.debug(f"Converted {len(parsed_entries)} image entries to dictionaries")
return entries

View file

@ -9,6 +9,7 @@ from starlette.authentication import requires
from khoj.database.models import GithubConfig, KhojUser, NotionConfig
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
@ -41,6 +42,7 @@ class IndexerInput(BaseModel):
markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None
@ -65,7 +67,14 @@ async def update(
),
):
user = request.user.object
index_files: Dict[str, Dict[str, str]] = {"org": {}, "markdown": {}, "pdf": {}, "plaintext": {}, "docx": {}}
index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
@ -81,6 +90,7 @@ async def update(
markdown=index_files["markdown"],
pdf=index_files["pdf"],
plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"],
)
@ -133,6 +143,7 @@ async def update(
"num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]),
}
@ -300,6 +311,23 @@ def configure_content(
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
success = False
try:
# Initialize Image Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
"image"
]:
logger.info("🖼️ Setting up search for images")
# Extract Entries, Generate Image Embeddings
text_search.setup(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try:
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]:
logger.info("📄 Setting up search for docx")

View file

@ -118,9 +118,9 @@ def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
return "docx", encoding
elif file_type in ["image/jpeg"]:
return "jpeg", encoding
return "image", encoding
elif file_type in ["image/png"]:
return "png", encoding
return "image", encoding
elif content_group in ["code", "text"]:
return "plaintext", encoding
else:

View file

@ -65,6 +65,7 @@ class ContentConfig(ConfigBase):
plaintext: Optional[TextContentConfig] = None
github: Optional[GithubContentConfig] = None
notion: Optional[NotionContentConfig] = None
image: Optional[TextContentConfig] = None
docx: Optional[TextContentConfig] = None

BIN
tests/data/images/nasdaq.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
tests/data/images/testocr.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

View file

@ -0,0 +1,21 @@
import os
from khoj.processor.content.images.image_to_entries import ImageToEntries
def test_png_to_jsonl():
with open("tests/data/images/testocr.png", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/testocr.png": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "opencv-python" in entries[1][0].raw
def test_jpg_to_jsonl():
with open("tests/data/images/nasdaq.jpg", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/nasdaq.jpg": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "investments" in entries[1][0].raw