Auto-update: Wed Jul 24 23:49:59 PDT 2024

This commit is contained in:
sanj 2024-07-24 23:49:59 -07:00
parent fb0dd4ece8
commit f4010c8a3f
2 changed files with 191 additions and 93 deletions

View file

@ -125,7 +125,7 @@ class Configuration(BaseModel):
elif len(parts) == 2 and parts[0] == 'ENV': elif len(parts) == 2 and parts[0] == 'ENV':
replacement = os.getenv(parts[1], '') replacement = os.getenv(parts[1], '')
else: else:
replacement = value # Keep original if not recognized replacement = value
value = value.replace('{{' + match + '}}', str(replacement)) value = value.replace('{{' + match + '}}', str(replacement))
@ -154,6 +154,7 @@ class Configuration(BaseModel):
extra = "allow" extra = "allow"
arbitrary_types_allowed = True arbitrary_types_allowed = True
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int
@ -161,9 +162,9 @@ class APIConfig(BaseModel):
URL: str URL: str
PUBLIC: List[str] PUBLIC: List[str]
TRUSTED_SUBNETS: List[str] TRUSTED_SUBNETS: List[str]
MODULES: Any # This will be replaced with a dynamic model MODULES: Any
POOL: List[Dict[str, Any]] POOL: List[Dict[str, Any]]
EXTENSIONS: Any # This will be replaced with a dynamic model EXTENSIONS: Any
TZ: str TZ: str
KEYS: List[str] KEYS: List[str]
GARBAGE: Dict[str, Any] GARBAGE: Dict[str, Any]
@ -173,11 +174,10 @@ class APIConfig(BaseModel):
config_path = cls._resolve_path(config_path, 'config') config_path = cls._resolve_path(config_path, 'config')
secrets_path = cls._resolve_path(secrets_path, 'config') secrets_path = cls._resolve_path(secrets_path, 'config')
# Load main configuration
with open(config_path, 'r') as file: with open(config_path, 'r') as file:
config_data = yaml.safe_load(file) config_data = yaml.safe_load(file)
print(f"Loaded main config: {config_data}") print(f"Loaded main config: {config_data}")
try: try:
with open(secrets_path, 'r') as file: with open(secrets_path, 'r') as file:
@ -191,7 +191,7 @@ class APIConfig(BaseModel):
secrets_data = {} secrets_data = {}
config_data = cls.resolve_placeholders(config_data) config_data = cls.resolve_placeholders(config_data)
print(f"Resolved config: {config_data}") print(f"Resolved config: {config_data}")
if isinstance(config_data.get('KEYS'), list) and len(config_data['KEYS']) == 1: if isinstance(config_data.get('KEYS'), list) and len(config_data['KEYS']) == 1:
placeholder = config_data['KEYS'][0] placeholder = config_data['KEYS'][0]
if placeholder.startswith('{{') and placeholder.endswith('}}'): if placeholder.startswith('{{') and placeholder.endswith('}}'):
@ -227,7 +227,7 @@ class APIConfig(BaseModel):
@classmethod @classmethod
def _resolve_path(cls, path: Union[str, Path], default_dir: str) -> Path: def _resolve_path(cls, path: Union[str, Path], default_dir: str) -> Path:
base_path = Path(__file__).parent.parent # This will be two levels up from this file base_path = Path(__file__).parent.parent
path = Path(path) path = Path(path)
if not path.suffix: if not path.suffix:
path = base_path / "sijapi" / default_dir / f"{path.name}.yaml" path = base_path / "sijapi" / default_dir / f"{path.name}.yaml"
@ -255,7 +255,6 @@ class APIConfig(BaseModel):
else: else:
resolved_data[key] = resolve_value(value) resolved_data[key] = resolve_value(value)
# Resolve BIND separately to ensure HOST and PORT are used
if 'BIND' in resolved_data: if 'BIND' in resolved_data:
resolved_data['BIND'] = resolved_data['BIND'].replace('{{ HOST }}', str(resolved_data['HOST'])) resolved_data['BIND'] = resolved_data['BIND'].replace('{{ HOST }}', str(resolved_data['HOST']))
resolved_data['BIND'] = resolved_data['BIND'].replace('{{ PORT }}', str(resolved_data['PORT'])) resolved_data['BIND'] = resolved_data['BIND'].replace('{{ PORT }}', str(resolved_data['PORT']))
@ -265,7 +264,9 @@ class APIConfig(BaseModel):
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
if name in ['MODULES', 'EXTENSIONS']: if name in ['MODULES', 'EXTENSIONS']:
return self.__dict__[name] return self.__dict__[name]
return super().__getattr__(name) if name in self.__dict__:
return self.__dict__[name]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
@property @property
def active_modules(self) -> List[str]: def active_modules(self) -> List[str]:
@ -303,7 +304,6 @@ class APIConfig(BaseModel):
crit(f"Error: {str(e)}") crit(f"Error: {str(e)}")
raise raise
async def push_changes(self, query: str, *args): async def push_changes(self, query: str, *args):
connections = [] connections = []
try: try:
@ -337,41 +337,72 @@ class APIConfig(BaseModel):
continue continue
return None return None
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
if source_pool_entry is None: async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
source_pool_entry = await self.get_default_source() if source_pool_entry is None:
source_pool_entry = await self.get_default_source()
if source_pool_entry is None:
err("No available source for pulling changes") if source_pool_entry is None:
return err("No available source for pulling changes")
return
async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection() as dest_conn: async with self.get_connection(source_pool_entry) as source_conn:
# This is a simplistic approach. You might need a more sophisticated async with self.get_connection() as dest_conn:
# method to determine what data needs to be synced. tables = await source_conn.fetch(
tables = await source_conn.fetch( "SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'" )
) for table in tables:
for table in tables: table_name = table['tablename']
table_name = table['tablename'] info(f"Processing table: {table_name}")
await dest_conn.execute(f"TRUNCATE TABLE {table_name}")
rows = await source_conn.fetch(f"SELECT * FROM {table_name}") # Get primary key column(s)
if rows: pk_columns = await source_conn.fetch("""
columns = rows[0].keys() SELECT a.attname
await dest_conn.copy_records_to_table( FROM pg_index i
table_name, records=rows, columns=columns JOIN pg_attribute a ON a.attrelid = i.indrelid
) AND a.attnum = ANY(i.indkey)
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}") WHERE i.indrelid = $1::regclass
AND i.indisprimary;
""", table_name)
pk_cols = [col['attname'] for col in pk_columns]
if not pk_cols:
warn(f"No primary key found for table {table_name}. Skipping.")
continue
# Fetch all rows from the source table
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
if rows:
columns = rows[0].keys()
# Upsert records to the destination table
await dest_conn.executemany(f"""
INSERT INTO {table_name} ({', '.join(columns)})
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
ON CONFLICT ({', '.join(pk_cols)}) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in pk_cols)}
""", [tuple(row[col] for col in columns) for row in rows])
info(f"Completed processing table: {table_name}")
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
async def sync_schema(self): async def sync_schema(self):
source_entry = self.POOL[0] # Use the local database as the source for pool_entry in self.POOL:
async with self.get_connection(pool_entry) as conn:
await conn.execute('CREATE EXTENSION IF NOT EXISTS postgis')
source_entry = self.local_db
source_schema = await self.get_schema(source_entry) source_schema = await self.get_schema(source_entry)
for pool_entry in self.POOL[1:]: for pool_entry in self.POOL[1:]:
target_schema = await self.get_schema(pool_entry) try:
await self.apply_schema_changes(pool_entry, source_schema, target_schema) target_schema = await self.get_schema(pool_entry)
info(f"Synced schema to {pool_entry['ts_ip']}") await self.apply_schema_changes(pool_entry, source_schema, target_schema)
info(f"Synced schema to {pool_entry['ts_ip']}")
except Exception as e:
err(f"Failed to sync schema to {pool_entry['ts_ip']}: {str(e)}")
async def get_schema(self, pool_entry: Dict[str, Any]): async def get_schema(self, pool_entry: Dict[str, Any]):
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
@ -402,61 +433,126 @@ class APIConfig(BaseModel):
'constraints': constraints 'constraints': constraints
} }
async def create_sequence_if_not_exists(self, conn, sequence_name):
await conn.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
CREATE SEQUENCE {sequence_name};
END IF;
END $$;
""")
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema): async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
source_tables = {t['table_name']: t for t in source_schema['tables']} source_tables = {t['table_name']: t for t in source_schema['tables']}
target_tables = {t['table_name']: t for t in target_schema['tables']} target_tables = {t['table_name']: t for t in target_schema['tables']}
for table_name, source_table in source_tables.items(): def get_column_type(data_type):
if table_name not in target_tables: if data_type == 'ARRAY':
columns = [f"\"{t['column_name']}\" {t['data_type']}" + return 'text[]' # or another appropriate type
(f"({t['character_maximum_length']})" if t['character_maximum_length'] else "") + elif data_type == 'USER-DEFINED':
(" NOT NULL" if t['is_nullable'] == 'NO' else "") + return 'geometry'
(f" DEFAULT {t['column_default']}" if t['column_default'] else "")
for t in source_schema['tables'] if t['table_name'] == table_name]
await conn.execute(f'CREATE TABLE "{table_name}" ({", ".join(columns)})')
else: else:
target_table = target_tables[table_name] return data_type
source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name}
target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name}
for col_name, source_col in source_columns.items(): for table_name, source_table in source_tables.items():
if col_name not in target_columns: try:
col_def = f"\"{col_name}\" {source_col['data_type']}" + \ if table_name not in target_tables:
(f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \ columns = []
(" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \ for t in source_schema['tables']:
(f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "") if t['table_name'] == table_name:
await conn.execute(f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}') col_type = get_column_type(t['data_type'])
else: col_def = f"\"{t['column_name']}\" {col_type}"
target_col = target_columns[col_name] if t['character_maximum_length']:
if source_col != target_col: col_def += f"({t['character_maximum_length']})"
await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {source_col["data_type"]}') if t['is_nullable'] == 'NO':
if source_col['is_nullable'] != target_col['is_nullable']: col_def += " NOT NULL"
null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL" if t['column_default']:
await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}') if 'nextval' in t['column_default']:
if source_col['column_default'] != target_col['column_default']: sequence_name = t['column_default'].split("'")[1]
default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT" await self.create_sequence_if_not_exists(conn, sequence_name)
await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}') col_def += f" DEFAULT {t['column_default']}"
columns.append(col_def)
sql = f'CREATE TABLE "{table_name}" ({", ".join(columns)})'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
else:
target_table = target_tables[table_name]
source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name}
target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name}
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']} for col_name, source_col in source_columns.items():
target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']} if col_name not in target_columns:
col_type = get_column_type(source_col['data_type'])
col_def = f"\"{col_name}\" {col_type}" + \
(f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \
(" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \
(f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "")
sql = f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
else:
target_col = target_columns[col_name]
if source_col != target_col:
col_type = get_column_type(source_col['data_type'])
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {col_type}'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
if source_col['is_nullable'] != target_col['is_nullable']:
null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL"
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
if source_col['column_default'] != target_col['column_default']:
default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT"
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
except Exception as e:
print(f"Error processing table {table_name}: {str(e)}")
# Optionally, you might want to raise this exception if you want to stop the entire process
# raise
for idx_name, idx_def in source_indexes.items(): try:
if idx_name not in target_indexes: source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
await conn.execute(idx_def) target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']}
elif idx_def != target_indexes[idx_name]:
await conn.execute(f'DROP INDEX "{idx_name}"')
await conn.execute(idx_def)
source_constraints = {con['conname']: con for con in source_schema['constraints']} for idx_name, idx_def in source_indexes.items():
target_constraints = {con['conname']: con for con in target_schema['constraints']} if idx_name not in target_indexes:
print(f"Executing SQL: {idx_def}")
await conn.execute(idx_def)
elif idx_def != target_indexes[idx_name]:
sql = f'DROP INDEX IF EXISTS "{idx_name}"'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
print(f"Executing SQL: {idx_def}")
await conn.execute(idx_def)
except Exception as e:
print(f"Error processing indexes: {str(e)}")
try:
source_constraints = {con['conname']: con for con in source_schema['constraints']}
target_constraints = {con['conname']: con for con in target_schema['constraints']}
for con_name, source_con in source_constraints.items():
if con_name not in target_constraints:
sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
elif source_con != target_constraints[con_name]:
sql = f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT IF EXISTS "{con_name}"'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
print(f"Executing SQL: {sql}")
await conn.execute(sql)
except Exception as e:
print(f"Error processing constraints: {str(e)}")
print(f"Schema synchronization completed for {pool_entry['ts_ip']}")
for con_name, source_con in source_constraints.items():
if con_name not in target_constraints:
await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}')
elif source_con != target_constraints[con_name]:
await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT "{con_name}"')
await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}')
class Location(BaseModel): class Location(BaseModel):
latitude: float latitude: float

View file

@ -24,6 +24,9 @@ def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text) def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
# Global dictionary to store transcription results
transcription_results = {}
class TranscribeParams(BaseModel): class TranscribeParams(BaseModel):
model: str = Field(default="small") model: str = Field(default="small")
output_srt: Optional[bool] = Field(default=False) output_srt: Optional[bool] = Field(default=False)
@ -40,8 +43,6 @@ class TranscribeParams(BaseModel):
dtw: Optional[str] = Field(None) dtw: Optional[str] = Field(None)
threads: Optional[int] = Field(None) threads: Optional[int] = Field(None)
# Global dictionary to store transcription results
transcription_results = {}
@asr.post("/asr") @asr.post("/asr")
@asr.post("/transcribe") @asr.post("/transcribe")
@ -63,12 +64,11 @@ async def transcribe_endpoint(
temp_file.write(await file.read()) temp_file.write(await file.read())
temp_file_path = temp_file.name temp_file_path = temp_file.name
transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters) job_id = await transcribe_audio(file_path=temp_file_path, params=parameters)
job_id = transcription_job["job_id"]
# Poll for completion # Poll for completion
max_wait_time = 600 # 10 minutes max_wait_time = 3600 # 60 minutes
poll_interval = 2 # 2 seconds poll_interval = 10 # 2 seconds
elapsed_time = 0 elapsed_time = 0
while elapsed_time < max_wait_time: while elapsed_time < max_wait_time:
@ -85,6 +85,7 @@ async def transcribe_endpoint(
# If we've reached this point, the transcription has taken too long # If we've reached this point, the transcription has taken too long
return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202) return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202)
async def transcribe_audio(file_path, params: TranscribeParams): async def transcribe_audio(file_path, params: TranscribeParams):
debug(f"Transcribing audio file from {file_path}...") debug(f"Transcribing audio file from {file_path}...")
file_path = await convert_to_wav(file_path) file_path = await convert_to_wav(file_path)
@ -136,8 +137,8 @@ async def transcribe_audio(file_path, params: TranscribeParams):
# Start the transcription process immediately # Start the transcription process immediately
transcription_task = asyncio.create_task(process_transcription(command, file_path, job_id)) transcription_task = asyncio.create_task(process_transcription(command, file_path, job_id))
max_wait_time = 300 # 5 minutes max_wait_time = 3600 # 1 hour
poll_interval = 1 # 1 second poll_interval = 10 # 10 seconds
start_time = asyncio.get_event_loop().time() start_time = asyncio.get_event_loop().time()
debug(f"Starting to poll for job {job_id}") debug(f"Starting to poll for job {job_id}")
@ -147,7 +148,7 @@ async def transcribe_audio(file_path, params: TranscribeParams):
debug(f"Current status for job {job_id}: {job_status['status']}") debug(f"Current status for job {job_id}: {job_status['status']}")
if job_status["status"] == "completed": if job_status["status"] == "completed":
info(f"Transcription completed for job {job_id}") info(f"Transcription completed for job {job_id}")
return job_status["result"] return job_id # This is the only change
elif job_status["status"] == "failed": elif job_status["status"] == "failed":
err(f"Transcription failed for job {job_id}: {job_status.get('error', 'Unknown error')}") err(f"Transcription failed for job {job_id}: {job_status.get('error', 'Unknown error')}")
raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}") raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}")
@ -162,6 +163,7 @@ async def transcribe_audio(file_path, params: TranscribeParams):
# This line should never be reached, but just in case: # This line should never be reached, but just in case:
raise Exception("Unexpected exit from transcription function") raise Exception("Unexpected exit from transcription function")
async def process_transcription(command, file_path, job_id): async def process_transcription(command, file_path, job_id):
try: try:
debug(f"Starting transcription process for job {job_id}") debug(f"Starting transcription process for job {job_id}")