2023-11-05 12:32:29 +01:00
|
|
|
# Standard Packages
|
|
|
|
import numpy as np
|
|
|
|
import psutil
|
|
|
|
from scipy.stats import linregress
|
|
|
|
import secrets
|
|
|
|
|
|
|
|
# External Packages
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
# Internal Packages
|
|
|
|
from khoj.processor.embeddings import EmbeddingsModel
|
2023-02-14 21:50:51 +01:00
|
|
|
from khoj.utils import helpers
|
2021-08-22 00:32:23 +02:00
|
|
|
|
2023-02-17 17:04:26 +01:00
|
|
|
|
2021-08-22 00:32:23 +02:00
|
|
|
def test_get_from_null_dict():
|
|
|
|
# null handling
|
|
|
|
assert helpers.get_from_dict(dict()) == dict()
|
|
|
|
assert helpers.get_from_dict(dict(), None) == None
|
|
|
|
|
|
|
|
# key present in nested dictionary
|
|
|
|
# 1-level dictionary
|
2023-02-17 17:04:26 +01:00
|
|
|
assert helpers.get_from_dict({"a": 1, "b": 2}, "a") == 1
|
|
|
|
assert helpers.get_from_dict({"a": 1, "b": 2}, "c") == None
|
2021-08-22 00:32:23 +02:00
|
|
|
|
|
|
|
# 2-level dictionary
|
2023-02-17 17:04:26 +01:00
|
|
|
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a") == {"a_a": 1}
|
|
|
|
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a", "a_a") == 1
|
2021-08-22 00:32:23 +02:00
|
|
|
|
|
|
|
# key not present in nested dictionary
|
|
|
|
# 2-level_dictionary
|
2023-02-17 17:04:26 +01:00
|
|
|
assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "b", "b_a") == None
|
2021-08-22 00:32:23 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_merge_dicts():
|
|
|
|
# basic merge of dicts with non-overlapping keys
|
2023-02-17 17:04:26 +01:00
|
|
|
assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"b": 2}) == {"a": 1, "b": 2}
|
2021-08-22 00:32:23 +02:00
|
|
|
|
|
|
|
# use default dict items when not present in priority dict
|
2023-02-17 17:04:26 +01:00
|
|
|
assert helpers.merge_dicts(priority_dict={}, default_dict={"b": 2}) == {"b": 2}
|
2021-08-22 00:32:23 +02:00
|
|
|
|
|
|
|
# do not override existing key in priority_dict with default dict
|
2023-02-17 17:04:26 +01:00
|
|
|
assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"a": 2}) == {"a": 1}
|
2022-09-04 15:31:46 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_lru_cache():
|
|
|
|
# Test initializing cache
|
2023-02-17 17:04:26 +01:00
|
|
|
cache = helpers.LRU({"a": 1, "b": 2}, capacity=2)
|
|
|
|
assert cache == {"a": 1, "b": 2}
|
2022-09-04 15:31:46 +02:00
|
|
|
|
|
|
|
# Test capacity overflow
|
2023-02-17 17:04:26 +01:00
|
|
|
cache["c"] = 3
|
|
|
|
assert cache == {"b": 2, "c": 3}
|
2022-09-04 15:31:46 +02:00
|
|
|
|
|
|
|
# Test delete least recently used item from LRU cache on capacity overflow
|
2023-02-17 17:04:26 +01:00
|
|
|
cache["b"] # accessing 'b' makes it the most recently used item
|
|
|
|
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
|
|
|
assert cache == {"b": 2, "d": 4}
|
2023-11-05 12:32:29 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
|
|
|
|
def test_encode_docs_memory_leak():
|
|
|
|
# Arrange
|
|
|
|
iterations = 50
|
|
|
|
batch_size = 20
|
|
|
|
embeddings_model = EmbeddingsModel()
|
|
|
|
memory_usage_trend = []
|
2023-11-07 04:26:54 +01:00
|
|
|
device = f"{helpers.get_device()}".upper()
|
2023-11-05 12:32:29 +01:00
|
|
|
|
|
|
|
# Act
|
|
|
|
# Encode random strings repeatedly and record memory usage trend
|
|
|
|
for iteration in range(iterations):
|
|
|
|
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
|
|
|
|
a = [embeddings_model.embed_documents(random_docs)]
|
|
|
|
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
|
|
|
|
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
|
|
|
|
|
|
|
|
# Calculate slope of line fitting memory usage history
|
|
|
|
memory_usage_trend = np.array(memory_usage_trend)
|
|
|
|
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
|
2023-11-07 04:26:54 +01:00
|
|
|
print(f"Memory usage increased at ~{slope:.2f} MB per iteration on {device}")
|
2023-11-05 12:32:29 +01:00
|
|
|
|
|
|
|
# Assert
|
|
|
|
# If slope is positive memory utilization is increasing
|
|
|
|
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
|
2023-11-07 04:26:54 +01:00
|
|
|
assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration"
|