mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Test memory leak on MPS device when generating vector embeddings
Slope threshold of 2.0 determined qualitatively on local Mac device Minor unused import and clean-up
This commit is contained in:
parent
ef24485ada
commit
a4f407f595
4 changed files with 41 additions and 7 deletions
|
@ -92,6 +92,7 @@ test = [
|
||||||
"factory-boy >= 3.2.1",
|
"factory-boy >= 3.2.1",
|
||||||
"trio >= 0.22.0",
|
"trio >= 0.22.0",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
|
"psutil >= 5.8.0",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"khoj-assistant[test]",
|
"khoj-assistant[test]",
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import secrets
|
|
||||||
from typing import Type, TypeVar, List
|
from typing import Type, TypeVar, List
|
||||||
from datetime import date
|
from datetime import date
|
||||||
import secrets
|
import secrets
|
||||||
|
@ -36,9 +35,6 @@ from database.models import (
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import generate_random_name
|
from khoj.utils.helpers import generate_random_name
|
||||||
from khoj.utils.rawconfig import (
|
|
||||||
ConversationProcessorConfig as UserConversationProcessorConfig,
|
|
||||||
)
|
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
|
|
@ -8,10 +8,10 @@ from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
class EmbeddingsModel:
|
class EmbeddingsModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_name = "thenlper/gte-small"
|
|
||||||
self.encode_kwargs = {"normalize_embeddings": True}
|
self.encode_kwargs = {"normalize_embeddings": True}
|
||||||
model_kwargs = {"device": get_device()}
|
self.model_kwargs = {"device": get_device()}
|
||||||
self.embeddings_model = SentenceTransformer(self.model_name, **model_kwargs)
|
self.model_name = "thenlper/gte-small"
|
||||||
|
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||||
|
|
|
@ -1,3 +1,14 @@
|
||||||
|
# Standard Packages
|
||||||
|
import numpy as np
|
||||||
|
import psutil
|
||||||
|
from scipy.stats import linregress
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
from khoj.utils import helpers
|
from khoj.utils import helpers
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,3 +55,29 @@ def test_lru_cache():
|
||||||
cache["b"] # accessing 'b' makes it the most recently used item
|
cache["b"] # accessing 'b' makes it the most recently used item
|
||||||
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
||||||
assert cache == {"b": 2, "d": 4}
|
assert cache == {"b": 2, "d": 4}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
|
||||||
|
def test_encode_docs_memory_leak():
|
||||||
|
# Arrange
|
||||||
|
iterations = 50
|
||||||
|
batch_size = 20
|
||||||
|
embeddings_model = EmbeddingsModel()
|
||||||
|
memory_usage_trend = []
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Encode random strings repeatedly and record memory usage trend
|
||||||
|
for iteration in range(iterations):
|
||||||
|
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
|
||||||
|
a = [embeddings_model.embed_documents(random_docs)]
|
||||||
|
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
|
||||||
|
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
|
||||||
|
|
||||||
|
# Calculate slope of line fitting memory usage history
|
||||||
|
memory_usage_trend = np.array(memory_usage_trend)
|
||||||
|
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# If slope is positive memory utilization is increasing
|
||||||
|
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
|
||||||
|
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"
|
||||||
|
|
Loading…
Reference in a new issue