Move truncate code context func for reusability across modules

It needs to be used across routers and processors. It being in
run_code tool makes it hard to be used in other chat provider contexts
due to circular dependency issues created by
send_message_to_model_wrapper func
This commit is contained in:
Debanjum 2024-11-21 14:27:39 -08:00
parent f434c3fab2
commit 5475a262d4
3 changed files with 28 additions and 27 deletions

View file

@ -1,5 +1,4 @@
import base64 import base64
import copy
import datetime import datetime
import json import json
import logging import logging
@ -20,7 +19,7 @@ from khoj.processor.conversation.utils import (
construct_chat_history, construct_chat_history,
) )
from khoj.routers.helpers import send_message_to_model_wrapper from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import is_none_or_empty, timer from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -180,26 +179,3 @@ async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_ur
"std_err": f"Failed to execute code with {response.status}", "std_err": f"Failed to execute code with {response.status}",
"output_files": [], "output_files": [],
} }
def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]:
"""
Truncate large output files and drop image file data from code results.
"""
# Create a deep copy of the code results to avoid modifying the original data
code_results = copy.deepcopy(original_code_results)
for code_result in code_results.values():
for idx, output_file in enumerate(code_result["results"]["output_files"]):
# Drop image files from code results
if Path(output_file["filename"]).suffix in {".png", ".jpg", ".jpeg", ".webp"}:
code_result["results"]["output_files"][idx] = {
"filename": output_file["filename"],
"b64_data": "[placeholder for generated image data for brevity]",
}
# Truncate large output files
elif len(output_file["b64_data"]) > max_chars:
code_result["results"]["output_files"][idx] = {
"filename": output_file["filename"],
"b64_data": output_file["b64_data"][:max_chars] + "...",
}
return code_results

View file

@ -16,7 +16,7 @@ from khoj.processor.conversation.utils import (
construct_tool_chat_history, construct_tool_chat_history,
) )
from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code, truncate_code_context from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
@ -28,6 +28,7 @@ from khoj.utils.helpers import (
function_calling_description_for_llm, function_calling_description_for_llm,
is_none_or_empty, is_none_or_empty,
timer, timer,
truncate_code_context,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData

View file

@ -1,5 +1,6 @@
from __future__ import annotations # to avoid quoting type hints from __future__ import annotations # to avoid quoting type hints
import copy
import datetime import datetime
import io import io
import ipaddress import ipaddress
@ -18,7 +19,7 @@ from itertools import islice
from os import path from os import path
from pathlib import Path from pathlib import Path
from time import perf_counter from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import psutil import psutil
@ -527,6 +528,29 @@ def convert_image_to_webp(image_bytes):
return webp_image_bytes return webp_image_bytes
def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]:
"""
Truncate large output files and drop image file data from code results.
"""
# Create a deep copy of the code results to avoid modifying the original data
code_results = copy.deepcopy(original_code_results)
for code_result in code_results.values():
for idx, output_file in enumerate(code_result["results"]["output_files"]):
# Drop image files from code results
if Path(output_file["filename"]).suffix in {".png", ".jpg", ".jpeg", ".webp"}:
code_result["results"]["output_files"][idx] = {
"filename": output_file["filename"],
"b64_data": "[placeholder for generated image data for brevity]",
}
# Truncate large output files
elif len(output_file["b64_data"]) > max_chars:
code_result["results"]["output_files"][idx] = {
"filename": output_file["filename"],
"b64_data": output_file["b64_data"][:max_chars] + "...",
}
return code_results
@lru_cache @lru_cache
def tz_to_cc_map() -> dict[str, str]: def tz_to_cc_map() -> dict[str, str]:
"""Create a mapping of timezone to country code""" """Create a mapping of timezone to country code"""