diff --git a/pyproject.toml b/pyproject.toml index c816f4d2..6d1ef0a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ test = [ "factory-boy >= 3.2.1", "trio >= 0.22.0", "pytest-xdist", + "psutil >= 5.8.0", ] dev = [ "khoj-assistant[test]", diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 7fbc5287..5669414d 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,4 +1,3 @@ -import secrets from typing import Type, TypeVar, List from datetime import date import secrets @@ -36,9 +35,6 @@ from database.models import ( OfflineChatProcessorConversationConfig, ) from khoj.utils.helpers import generate_random_name -from khoj.utils.rawconfig import ( - ConversationProcessorConfig as UserConversationProcessorConfig, -) from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.date_filter import DateFilter diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index fcd88d80..1e92f27d 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -8,10 +8,10 @@ from khoj.utils.rawconfig import SearchResponse class EmbeddingsModel: def __init__(self): - self.model_name = "thenlper/gte-small" self.encode_kwargs = {"normalize_embeddings": True} - model_kwargs = {"device": get_device()} - self.embeddings_model = SentenceTransformer(self.model_name, **model_kwargs) + self.model_kwargs = {"device": get_device()} + self.model_name = "thenlper/gte-small" + self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) def embed_query(self, query): return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 622592b1..30499049 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,3 +1,14 @@ +# Standard Packages +import numpy as np +import psutil +from scipy.stats import linregress +import secrets + +# External Packages +import pytest + +# Internal Packages +from khoj.processor.embeddings import EmbeddingsModel from khoj.utils import helpers @@ -44,3 +55,29 @@ def test_lru_cache(): cache["b"] # accessing 'b' makes it the most recently used item cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b' assert cache == {"b": 2, "d": 4} + + +@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices") +def test_encode_docs_memory_leak(): + # Arrange + iterations = 50 + batch_size = 20 + embeddings_model = EmbeddingsModel() + memory_usage_trend = [] + + # Act + # Encode random strings repeatedly and record memory usage trend + for iteration in range(iterations): + random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)] + a = [embeddings_model.embed_documents(random_docs)] + memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)] + print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True) + + # Calculate slope of line fitting memory usage history + memory_usage_trend = np.array(memory_usage_trend) + slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend) + + # Assert + # If slope is positive memory utilization is increasing + # Positive threshold of 2, from observing memory usage trend on MPS vs CPU device + assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"