diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 3bbb1e31..bfabc17a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1348,6 +1348,17 @@ class ConversationAdapters: conversation.save() return conversation.file_filters + @staticmethod + def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): + conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) + if not conversation: + return False + conversation_log = conversation.conversation_log + updated_log = [msg for msg in conversation_log if msg.get("turnId") != turn_id] + conversation.conversation_log = updated_log + conversation.save() + return True + class FileObjectAdapters: @staticmethod diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b4db71d9..dd420881 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -4,6 +4,7 @@ import math import mimetypes import os import queue +import uuid from dataclasses import dataclass from datetime import datetime from enum import Enum @@ -232,12 +233,14 @@ def save_to_conversation_log( train_of_thought: List[Any] = [], ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") + turn_id = tracer.get("mid") or str(uuid.uuid4()) updated_conversation = message_to_log( user_message=q, chat_response=chat_response, user_message_metadata={ "created": user_message_time, "images": query_images, + "turnId": turn_id, }, khoj_message_metadata={ "context": compiled_references, @@ -246,6 +249,7 @@ def save_to_conversation_log( "codeContext": code_results, "automationId": automation_id, "trainOfThought": train_of_thought, + "turnId": turn_id, }, conversation_log=meta_log.get("chat", []), train_of_thought=train_of_thought, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e9d60a1b..92c024e9 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -38,6 +38,7 @@ from khoj.routers.helpers import ( ChatRequestBody, CommonQueryParams, ConversationCommandRateLimiter, + DeleteMessageRequestBody, agenerate_chat_response, aget_relevant_information_sources, aget_relevant_output_modes, @@ -534,6 +535,19 @@ async def set_conversation_title( ) +@api_chat.delete("/conversation/message", response_class=Response) +@requires(["authenticated"]) +def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response: + user = request.user.object + success = ConversationAdapters.delete_message_by_turn_id( + user, delete_request.conversation_id, delete_request.turn_id + ) + if success: + return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200) + else: + return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404) + + @api_chat.post("") @requires(["authenticated"]) async def chat( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 1cb322b0..6aa25c5e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1264,6 +1264,11 @@ class ChatRequestBody(BaseModel): create_new: Optional[bool] = False +class DeleteMessageRequestBody(BaseModel): + conversation_id: str + turn_id: str + + class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): self.requests = requests