Auto-update: Sun Jul 28 14:32:21 PDT 2024

This commit is contained in:
sanj 2024-07-28 14:32:21 -07:00
parent d338b5bab3
commit 5792657c6e
2 changed files with 34 additions and 23 deletions

View file

@ -5,7 +5,6 @@ import yaml
import math import math
import os import os
import re import re
import traceback
import aiofiles import aiofiles
import aiohttp import aiohttp
import asyncpg import asyncpg
@ -31,7 +30,6 @@ def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text) def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
T = TypeVar('T', bound='Configuration') T = TypeVar('T', bound='Configuration')
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
@ -398,11 +396,12 @@ class APIConfig(BaseModel):
""", last_synced_version, source_pool_entry['ts_id']) """, last_synced_version, source_pool_entry['ts_id'])
for change in changes: for change in changes:
# Convert change.keys() to a list
columns = list(change.keys()) columns = list(change.keys())
values = [change[col] for col in columns] values = [change[col] for col in columns]
# Construct the SQL query # Log the target database and table name
debug(f"Attempting to insert data into table: {table_name} in database: {dest_conn._params['database']} (host: {dest_conn._params['host']})")
insert_query = f""" insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(columns)}) INSERT INTO "{table_name}" ({', '.join(columns)})
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))}) VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
@ -410,12 +409,16 @@ class APIConfig(BaseModel):
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')} {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
""" """
# Execute the query try:
await dest_conn.execute(insert_query, *values) await dest_conn.execute(insert_query, *values)
except asyncpg.exceptions.UndefinedColumnError as e:
err(f"UndefinedColumnError in table: {table_name} in database: {dest_conn._params['database']} (host: {dest_conn._params['host']})")
raise e
if changes: if changes:
await self.update_sync_status(table_name, source_pool_entry['ts_id'], changes[-1]['version']) await self.update_sync_status(table_name, source_pool_entry['ts_id'], changes[-1]['version'])
async def push_changes_to_all(self): async def push_changes_to_all(self):
async with self.get_connection() as local_conn: async with self.get_connection() as local_conn:
tables = await local_conn.fetch(""" tables = await local_conn.fetch("""
@ -473,10 +476,6 @@ class APIConfig(BaseModel):
""", table_name, server_id, version) """, table_name, server_id, version)
async def sync_schema(self): async def sync_schema(self):
for pool_entry in self.POOL:
async with self.get_connection(pool_entry) as conn:
await conn.execute('CREATE EXTENSION IF NOT EXISTS postgis')
source_entry = self.local_db source_entry = self.local_db
source_schema = await self.get_schema(source_entry) source_schema = await self.get_schema(source_entry)
@ -549,7 +548,16 @@ class APIConfig(BaseModel):
col_def += f" DEFAULT {t['column_default']}" col_def += f" DEFAULT {t['column_default']}"
columns.append(col_def) columns.append(col_def)
sql = f'CREATE TABLE "{table_name}" ({", ".join(columns)})' primary_key_constraint = next(
(con['definition'] for con in source_schema['constraints'] if con['table_name'] == table_name and con['contype'] == 'p'),
None
)
sql = f'CREATE TABLE "{table_name}" ({", ".join(columns)}'
if primary_key_constraint:
sql += f', {primary_key_constraint}'
sql += ')'
info(f"Executing SQL: {sql}") info(f"Executing SQL: {sql}")
await conn.execute(sql) await conn.execute(sql)
else: else:
@ -584,6 +592,16 @@ class APIConfig(BaseModel):
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}' sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}'
debug(f"Executing SQL: {sql}") debug(f"Executing SQL: {sql}")
await conn.execute(sql) await conn.execute(sql)
# Ensure primary key constraint exists
primary_key_constraint = next(
(con['definition'] for con in source_schema['constraints'] if con['table_name'] == table_name and con['contype'] == 'p'),
None
)
if primary_key_constraint and primary_key_constraint not in target_schema['constraints']:
sql = f'ALTER TABLE "{table_name}" ADD CONSTRAINT {primary_key_constraint}'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
except Exception as e: except Exception as e:
err(f"Error processing table {table_name}: {str(e)}") err(f"Error processing table {table_name}: {str(e)}")
@ -636,6 +654,7 @@ class APIConfig(BaseModel):
""") """)
class Location(BaseModel): class Location(BaseModel):
latitude: float latitude: float
longitude: float longitude: float

View file

@ -3,18 +3,16 @@ Uses whisper_cpp to create an OpenAI-compatible Whisper web service.
''' '''
# routers/asr.py # routers/asr.py
import os import os
import sys
import uuid import uuid
import json import json
import asyncio import asyncio
import tempfile import tempfile
import subprocess
from urllib.parse import unquote from urllib.parse import unquote
from fastapi import APIRouter, HTTPException, Form, UploadFile, File, BackgroundTasks from fastapi import APIRouter, HTTPException, Form, UploadFile, File, BackgroundTasks
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from sijapi import L, ASR_DIR, WHISPER_CPP_MODELS, GARBAGE_COLLECTION_INTERVAL, GARBAGE_TTL, WHISPER_CPP_DIR, MAX_CPU_CORES from sijapi import L, ASR_DIR, WHISPER_CPP_MODELS, WHISPER_CPP_DIR, MAX_CPU_CORES
asr = APIRouter() asr = APIRouter()
logger = L.get_module_logger("asr") logger = L.get_module_logger("asr")
@ -24,9 +22,7 @@ def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text) def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
# Global dictionary to store transcription results
transcription_results = {} transcription_results = {}
class TranscribeParams(BaseModel): class TranscribeParams(BaseModel):
model: str = Field(default="small") model: str = Field(default="small")
output_srt: Optional[bool] = Field(default=False) output_srt: Optional[bool] = Field(default=False)
@ -67,8 +63,8 @@ async def transcribe_endpoint(
job_id = await transcribe_audio(file_path=temp_file_path, params=parameters) job_id = await transcribe_audio(file_path=temp_file_path, params=parameters)
# Poll for completion # Poll for completion
max_wait_time = 3600 # 60 minutes max_wait_time = 3600
poll_interval = 10 # 2 seconds poll_interval = 10
elapsed_time = 0 elapsed_time = 0
while elapsed_time < max_wait_time: while elapsed_time < max_wait_time:
@ -85,7 +81,6 @@ async def transcribe_endpoint(
# If we've reached this point, the transcription has taken too long # If we've reached this point, the transcription has taken too long
return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202) return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202)
async def transcribe_audio(file_path, params: TranscribeParams): async def transcribe_audio(file_path, params: TranscribeParams):
debug(f"Transcribing audio file from {file_path}...") debug(f"Transcribing audio file from {file_path}...")
file_path = await convert_to_wav(file_path) file_path = await convert_to_wav(file_path)
@ -94,8 +89,7 @@ async def transcribe_audio(file_path, params: TranscribeParams):
command = [str(WHISPER_CPP_DIR / 'build' / 'bin' / 'main')] command = [str(WHISPER_CPP_DIR / 'build' / 'bin' / 'main')]
command.extend(['-m', str(model_path)]) command.extend(['-m', str(model_path)])
command.extend(['-t', str(max(1, min(params.threads or MAX_CPU_CORES, MAX_CPU_CORES)))]) command.extend(['-t', str(max(1, min(params.threads or MAX_CPU_CORES, MAX_CPU_CORES)))])
command.extend(['-np']) # Always enable no-prints command.extend(['-np'])
if params.split_on_word: if params.split_on_word:
command.append('-sow') command.append('-sow')
@ -159,8 +153,6 @@ async def transcribe_audio(file_path, params: TranscribeParams):
finally: finally:
# Ensure the task is cancelled if we exit the loop # Ensure the task is cancelled if we exit the loop
transcription_task.cancel() transcription_task.cancel()
# This line should never be reached, but just in case:
raise Exception("Unexpected exit from transcription function") raise Exception("Unexpected exit from transcription function")