TEMP commit just to show a proof of concept of how sqlite can be supported with the khoj backend

- note that vector search / lookup is not yet working, but probably should be feasiable with an abstraction layer on top of the APIs which are doing searches. If vendor.sqlite, search this way, else search this way
This commit is contained in:
sabaimran 2024-11-18 16:05:12 -08:00
parent 3f70d2f685
commit 3806cce2e6
7 changed files with 114 additions and 100 deletions

View file

@ -115,16 +115,23 @@ CLOSE_CONNECTIONS_AFTER_REQUEST = True
# Database # Database
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases # https://docs.djangoproject.com/en/4.2/ref/settings/#databases
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20000 DATA_UPLOAD_MAX_NUMBER_FIELDS = 20000
# DATABASES = {
# "default": {
# "ENGINE": "django.db.backends.postgresql",
# "HOST": os.getenv("POSTGRES_HOST", "localhost"),
# "PORT": os.getenv("POSTGRES_PORT", "5432"),
# "USER": os.getenv("POSTGRES_USER", "postgres"),
# "NAME": os.getenv("POSTGRES_DB", "khoj"),
# "PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"),
# "CONN_MAX_AGE": 0,
# "CONN_HEALTH_CHECKS": True,
# }
# }
DATABASES = { DATABASES = {
"default": { "default": {
"ENGINE": "django.db.backends.postgresql", "ENGINE": "django.db.backends.sqlite3",
"HOST": os.getenv("POSTGRES_HOST", "localhost"), "NAME": os.path.join(BASE_DIR, "db.sqlite3"),
"PORT": os.getenv("POSTGRES_PORT", "5432"),
"USER": os.getenv("POSTGRES_USER", "postgres"),
"NAME": os.getenv("POSTGRES_DB", "khoj"),
"PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"),
"CONN_MAX_AGE": 0,
"CONN_HEALTH_CHECKS": True,
} }
} }

View file

@ -689,11 +689,12 @@ class AgentAdapters:
# TODO Update this to allow any public agent that's officially approved once that experience is launched # TODO Update this to allow any public agent that's officially approved once that experience is launched
public_query &= Q(managed_by_admin=True) public_query &= Q(managed_by_admin=True)
if user: if user:
return ( return list(
Agent.objects.filter(public_query | Q(creator=user)) set(
.distinct() Agent.objects.filter(public_query | Q(creator=user))
.order_by("created_at") .order_by("created_at")
.prefetch_related("creator", "chat_model", "fileobject_set") .prefetch_related("creator", "chat_model", "fileobject_set")
)
) )
return ( return (
Agent.objects.filter(public_query) Agent.objects.filter(public_query)
@ -1584,18 +1585,12 @@ class EntryAdapters:
async def aget_agent_entry_filepaths(agent: Agent): async def aget_agent_entry_filepaths(agent: Agent):
if agent is None: if agent is None:
return [] return []
return await sync_to_async(set)( return await sync_to_async(set)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True))
Entry.objects.filter(agent=agent).distinct("file_path").values_list("file_path", flat=True)
)
@staticmethod @staticmethod
@require_valid_user @require_valid_user
def get_all_filenames_by_source(user: KhojUser, file_source: str): def get_all_filenames_by_source(user: KhojUser, file_source: str):
return ( return Entry.objects.filter(user=user, file_source=file_source).values_list("file_path", flat=True)
Entry.objects.filter(user=user, file_source=file_source)
.distinct("file_path")
.values_list("file_path", flat=True)
)
@staticmethod @staticmethod
@require_valid_user @require_valid_user
@ -1698,12 +1693,12 @@ class EntryAdapters:
@staticmethod @staticmethod
@require_valid_user @require_valid_user
def get_unique_file_types(user: KhojUser): def get_unique_file_types(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() return list(set(Entry.objects.filter(user=user).values_list("file_type", flat=True)))
@staticmethod @staticmethod
@require_valid_user @require_valid_user
def get_unique_file_sources(user: KhojUser): def get_unique_file_sources(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() return list(set(Entry.objects.filter(user=user).values_list("file_source", flat=True)))
class AutomationAdapters: class AutomationAdapters:

View file

@ -57,6 +57,22 @@ def enable_triggers(apps, schema_editor):
schema_editor.execute('ALTER TABLE "database_conversation" ENABLE TRIGGER ALL;') schema_editor.execute('ALTER TABLE "database_conversation" ENABLE TRIGGER ALL;')
def rename_conversation_id_to_temp_id(apps, schema_editor):
# if sqlite
if schema_editor.connection.vendor == "sqlite":
pass
else:
schema_editor.execute('ALTER TABLE "database_conversation" RENAME COLUMN "id" TO "temp_id";')
def rename_temp_id_to_id(apps, schema_editor):
# if sqlite
if schema_editor.connection.vendor == "sqlite":
pass
else:
schema_editor.execute('ALTER TABLE "database_conversation" RENAME COLUMN "temp_id" TO "id";')
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
("database", "0063_conversation_temp_id"), ("database", "0063_conversation_temp_id"),

View file

@ -14,33 +14,33 @@ class Migration(migrations.Migration):
model_name="agent", model_name="agent",
name="tools", name="tools",
), ),
migrations.AddField( # migrations.AddField(
model_name="agent", # model_name="agent",
name="input_tools", # name="input_tools",
field=django.contrib.postgres.fields.ArrayField( # field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField( # base_field=models.CharField(
choices=[ # choices=[
("general", "General"), # ("general", "General"),
("online", "Online"), # ("online", "Online"),
("notes", "Notes"), # ("notes", "Notes"),
("summarize", "Summarize"), # ("summarize", "Summarize"),
("webpage", "Webpage"), # ("webpage", "Webpage"),
], # ],
max_length=200, # max_length=200,
), # ),
default=list, # default=list,
size=None, # size=None,
), # ),
), # ),
migrations.AddField( # migrations.AddField(
model_name="agent", # model_name="agent",
name="output_modes", # name="output_modes",
field=django.contrib.postgres.fields.ArrayField( # field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField(choices=[("text", "Text"), ("image", "Image")], max_length=200), # base_field=models.CharField(choices=[("text", "Text"), ("image", "Image")], max_length=200),
default=list, # default=list,
size=None, # size=None,
), # ),
), # ),
migrations.AlterField( migrations.AlterField(
model_name="agent", model_name="agent",
name="style_icon", name="style_icon",

View file

@ -10,15 +10,15 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.AlterField( # migrations.AlterField(
model_name="agent", # model_name="agent",
name="output_modes", # name="output_modes",
field=django.contrib.postgres.fields.ArrayField( # field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField( # base_field=models.CharField(
choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200 # choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200
), # ),
default=list, # default=list,
size=None, # size=None,
), # ),
), # ),
] ]

View file

@ -10,37 +10,37 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.AlterField( # migrations.AlterField(
model_name="agent", # model_name="agent",
name="input_tools", # name="input_tools",
field=django.contrib.postgres.fields.ArrayField( # field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField( # base_field=models.CharField(
choices=[ # choices=[
("general", "General"), # ("general", "General"),
("online", "Online"), # ("online", "Online"),
("notes", "Notes"), # ("notes", "Notes"),
("summarize", "Summarize"), # ("summarize", "Summarize"),
("webpage", "Webpage"), # ("webpage", "Webpage"),
], # ],
max_length=200, # max_length=200,
), # ),
blank=True, # blank=True,
default=list, # default=list,
null=True, # null=True,
size=None, # size=None,
), # ),
), # ),
migrations.AlterField( # migrations.AlterField(
model_name="agent", # model_name="agent",
name="output_modes", # name="output_modes",
field=django.contrib.postgres.fields.ArrayField( # field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField( # base_field=models.CharField(
choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200 # choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200
), # ),
blank=True, # blank=True,
default=list, # default=list,
null=True, # null=True,
size=None, # size=None,
), # ),
), # ),
] ]

View file

@ -181,12 +181,8 @@ class Agent(BaseModel):
) # Creator will only be null when the agents are managed by admin ) # Creator will only be null when the agents are managed by admin
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
personality = models.TextField() personality = models.TextField()
input_tools = ArrayField( input_tools = models.JSONField(default=list, null=True)
models.CharField(max_length=200, choices=InputToolOptions.choices), default=list, null=True, blank=True output_modes = models.JSONField(default=list, null=True)
)
output_modes = ArrayField(
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
)
managed_by_admin = models.BooleanField(default=False) managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE) chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
slug = models.CharField(max_length=200, unique=True) slug = models.CharField(max_length=200, unique=True)