From e7c299870ed9f5ebdfe251c8b97a0465159a31a6 Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Wed, 24 Jul 2024 23:49:59 -0700
Subject: [PATCH] Auto-update: Wed Jul 24 23:49:59 PDT 2024

---
 sijapi/classes.py     | 264 ++++++++++++++++++++++++++++--------------
 sijapi/routers/asr.py |  20 ++--
 2 files changed, 191 insertions(+), 93 deletions(-)

diff --git a/sijapi/classes.py b/sijapi/classes.py
index 525eb42..4629570 100644
--- a/sijapi/classes.py
+++ b/sijapi/classes.py
@@ -125,7 +125,7 @@ class Configuration(BaseModel):
             elif len(parts) == 2 and parts[0] == 'ENV':
                 replacement = os.getenv(parts[1], '')
             else:
-                replacement = value  # Keep original if not recognized
+                replacement = value 
 
             value = value.replace('{{' + match + '}}', str(replacement))
 
@@ -154,6 +154,7 @@ class Configuration(BaseModel):
         extra = "allow"
         arbitrary_types_allowed = True
 
+
 class APIConfig(BaseModel):
     HOST: str
     PORT: int
@@ -161,9 +162,9 @@ class APIConfig(BaseModel):
     URL: str
     PUBLIC: List[str]
     TRUSTED_SUBNETS: List[str]
-    MODULES: Any  # This will be replaced with a dynamic model
+    MODULES: Any
     POOL: List[Dict[str, Any]]
-    EXTENSIONS: Any  # This will be replaced with a dynamic model
+    EXTENSIONS: Any
     TZ: str
     KEYS: List[str]
     GARBAGE: Dict[str, Any]
@@ -173,11 +174,10 @@ class APIConfig(BaseModel):
         config_path = cls._resolve_path(config_path, 'config')
         secrets_path = cls._resolve_path(secrets_path, 'config')
 
-        # Load main configuration
         with open(config_path, 'r') as file:
             config_data = yaml.safe_load(file)
 
-        print(f"Loaded main config: {config_data}") 
+        print(f"Loaded main config: {config_data}")
 
         try:
             with open(secrets_path, 'r') as file:
@@ -191,7 +191,7 @@ class APIConfig(BaseModel):
             secrets_data = {}
 
         config_data = cls.resolve_placeholders(config_data)
-        print(f"Resolved config: {config_data}") 
+        print(f"Resolved config: {config_data}")
         if isinstance(config_data.get('KEYS'), list) and len(config_data['KEYS']) == 1:
             placeholder = config_data['KEYS'][0]
             if placeholder.startswith('{{') and placeholder.endswith('}}'):
@@ -227,7 +227,7 @@ class APIConfig(BaseModel):
 
     @classmethod
     def _resolve_path(cls, path: Union[str, Path], default_dir: str) -> Path:
-        base_path = Path(__file__).parent.parent  # This will be two levels up from this file
+        base_path = Path(__file__).parent.parent
         path = Path(path)
         if not path.suffix:
             path = base_path / "sijapi" / default_dir / f"{path.name}.yaml"
@@ -255,7 +255,6 @@ class APIConfig(BaseModel):
             else:
                 resolved_data[key] = resolve_value(value)
 
-        # Resolve BIND separately to ensure HOST and PORT are used
         if 'BIND' in resolved_data:
             resolved_data['BIND'] = resolved_data['BIND'].replace('{{ HOST }}', str(resolved_data['HOST']))
             resolved_data['BIND'] = resolved_data['BIND'].replace('{{ PORT }}', str(resolved_data['PORT']))
@@ -265,7 +264,9 @@ class APIConfig(BaseModel):
     def __getattr__(self, name: str) -> Any:
         if name in ['MODULES', 'EXTENSIONS']:
             return self.__dict__[name]
-        return super().__getattr__(name)
+        if name in self.__dict__:
+            return self.__dict__[name]
+        raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
 
     @property
     def active_modules(self) -> List[str]:
@@ -303,7 +304,6 @@ class APIConfig(BaseModel):
             crit(f"Error: {str(e)}")
             raise
 
-
     async def push_changes(self, query: str, *args):
         connections = []
         try:
@@ -337,41 +337,72 @@ class APIConfig(BaseModel):
                     continue
         return None
 
-    async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
-        if source_pool_entry is None:
-            source_pool_entry = await self.get_default_source()
-        
-        if source_pool_entry is None:
-            err("No available source for pulling changes")
-            return
-        
-        async with self.get_connection(source_pool_entry) as source_conn:
-            async with self.get_connection() as dest_conn:
-                # This is a simplistic approach. You might need a more sophisticated
-                # method to determine what data needs to be synced.
-                tables = await source_conn.fetch(
-                    "SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
-                )
-                for table in tables:
-                    table_name = table['tablename']
-                    await dest_conn.execute(f"TRUNCATE TABLE {table_name}")
-                    rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
-                    if rows:
-                        columns = rows[0].keys()
-                        await dest_conn.copy_records_to_table(
-                            table_name, records=rows, columns=columns
-                        )
-                info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
+
+async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
+    if source_pool_entry is None:
+        source_pool_entry = await self.get_default_source()
+    
+    if source_pool_entry is None:
+        err("No available source for pulling changes")
+        return
+    
+    async with self.get_connection(source_pool_entry) as source_conn:
+        async with self.get_connection() as dest_conn:
+            tables = await source_conn.fetch(
+                "SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
+            )
+            for table in tables:
+                table_name = table['tablename']
+                info(f"Processing table: {table_name}")
+                
+                # Get primary key column(s)
+                pk_columns = await source_conn.fetch("""
+                    SELECT a.attname
+                    FROM   pg_index i
+                    JOIN   pg_attribute a ON a.attrelid = i.indrelid
+                                        AND a.attnum = ANY(i.indkey)
+                    WHERE  i.indrelid = $1::regclass
+                    AND    i.indisprimary;
+                """, table_name)
+                
+                pk_cols = [col['attname'] for col in pk_columns]
+                if not pk_cols:
+                    warn(f"No primary key found for table {table_name}. Skipping.")
+                    continue
+
+                # Fetch all rows from the source table
+                rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
+                if rows:
+                    columns = rows[0].keys()
+                    # Upsert records to the destination table
+                    await dest_conn.executemany(f"""
+                        INSERT INTO {table_name} ({', '.join(columns)})
+                        VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
+                        ON CONFLICT ({', '.join(pk_cols)}) DO UPDATE SET
+                        {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in pk_cols)}
+                    """, [tuple(row[col] for col in columns) for row in rows])
+                
+                info(f"Completed processing table: {table_name}")
+
+            info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
+
+
 
     async def sync_schema(self):
-        source_entry = self.POOL[0]  # Use the local database as the source
+        for pool_entry in self.POOL:
+            async with self.get_connection(pool_entry) as conn:
+                await conn.execute('CREATE EXTENSION IF NOT EXISTS postgis')
+
+        source_entry = self.local_db
         source_schema = await self.get_schema(source_entry)
         
         for pool_entry in self.POOL[1:]:
-            target_schema = await self.get_schema(pool_entry)
-            await self.apply_schema_changes(pool_entry, source_schema, target_schema)
-            info(f"Synced schema to {pool_entry['ts_ip']}")
-
+            try:
+                target_schema = await self.get_schema(pool_entry)
+                await self.apply_schema_changes(pool_entry, source_schema, target_schema)
+                info(f"Synced schema to {pool_entry['ts_ip']}")
+            except Exception as e:
+                err(f"Failed to sync schema to {pool_entry['ts_ip']}: {str(e)}")
 
     async def get_schema(self, pool_entry: Dict[str, Any]):
         async with self.get_connection(pool_entry) as conn:
@@ -402,61 +433,126 @@ class APIConfig(BaseModel):
                 'constraints': constraints
             }
 
+    async def create_sequence_if_not_exists(self, conn, sequence_name):
+        await conn.execute(f"""
+        DO $$
+        BEGIN
+            IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
+                CREATE SEQUENCE {sequence_name};
+            END IF;
+        END $$;
+        """)
+
     async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
         async with self.get_connection(pool_entry) as conn:
             source_tables = {t['table_name']: t for t in source_schema['tables']}
             target_tables = {t['table_name']: t for t in target_schema['tables']}
 
-            for table_name, source_table in source_tables.items():
-                if table_name not in target_tables:
-                    columns = [f"\"{t['column_name']}\" {t['data_type']}" +
-                            (f"({t['character_maximum_length']})" if t['character_maximum_length'] else "") +
-                            (" NOT NULL" if t['is_nullable'] == 'NO' else "") +
-                            (f" DEFAULT {t['column_default']}" if t['column_default'] else "")
-                            for t in source_schema['tables'] if t['table_name'] == table_name]
-                    await conn.execute(f'CREATE TABLE "{table_name}" ({", ".join(columns)})')
+            def get_column_type(data_type):
+                if data_type == 'ARRAY':
+                    return 'text[]'  # or another appropriate type
+                elif data_type == 'USER-DEFINED':
+                    return 'geometry'
                 else:
-                    target_table = target_tables[table_name]
-                    source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name}
-                    target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name}
+                    return data_type
 
-                    for col_name, source_col in source_columns.items():
-                        if col_name not in target_columns:
-                            col_def = f"\"{col_name}\" {source_col['data_type']}" + \
-                                    (f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \
-                                    (" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \
-                                    (f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "")
-                            await conn.execute(f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}')
-                        else:
-                            target_col = target_columns[col_name]
-                            if source_col != target_col:
-                                await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {source_col["data_type"]}')
-                                if source_col['is_nullable'] != target_col['is_nullable']:
-                                    null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL"
-                                    await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}')
-                                if source_col['column_default'] != target_col['column_default']:
-                                    default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT"
-                                    await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}')
+            for table_name, source_table in source_tables.items():
+                try:
+                    if table_name not in target_tables:
+                        columns = []
+                        for t in source_schema['tables']:
+                            if t['table_name'] == table_name:
+                                col_type = get_column_type(t['data_type'])
+                                col_def = f"\"{t['column_name']}\" {col_type}"
+                                if t['character_maximum_length']:
+                                    col_def += f"({t['character_maximum_length']})"
+                                if t['is_nullable'] == 'NO':
+                                    col_def += " NOT NULL"
+                                if t['column_default']:
+                                    if 'nextval' in t['column_default']:
+                                        sequence_name = t['column_default'].split("'")[1]
+                                        await self.create_sequence_if_not_exists(conn, sequence_name)
+                                    col_def += f" DEFAULT {t['column_default']}"
+                                columns.append(col_def)
+                        
+                        sql = f'CREATE TABLE "{table_name}" ({", ".join(columns)})'
+                        print(f"Executing SQL: {sql}")
+                        await conn.execute(sql)
+                    else:
+                        target_table = target_tables[table_name]
+                        source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name}
+                        target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name}
 
-            source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
-            target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']}
+                        for col_name, source_col in source_columns.items():
+                            if col_name not in target_columns:
+                                col_type = get_column_type(source_col['data_type'])
+                                col_def = f"\"{col_name}\" {col_type}" + \
+                                        (f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \
+                                        (" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \
+                                        (f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "")
+                                sql = f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}'
+                                print(f"Executing SQL: {sql}")
+                                await conn.execute(sql)
+                            else:
+                                target_col = target_columns[col_name]
+                                if source_col != target_col:
+                                    col_type = get_column_type(source_col['data_type'])
+                                    sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {col_type}'
+                                    print(f"Executing SQL: {sql}")
+                                    await conn.execute(sql)
+                                    if source_col['is_nullable'] != target_col['is_nullable']:
+                                        null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL"
+                                        sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}'
+                                        print(f"Executing SQL: {sql}")
+                                        await conn.execute(sql)
+                                    if source_col['column_default'] != target_col['column_default']:
+                                        default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT"
+                                        sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}'
+                                        print(f"Executing SQL: {sql}")
+                                        await conn.execute(sql)
+                except Exception as e:
+                    print(f"Error processing table {table_name}: {str(e)}")
+                    # Optionally, you might want to raise this exception if you want to stop the entire process
+                    # raise
 
-            for idx_name, idx_def in source_indexes.items():
-                if idx_name not in target_indexes:
-                    await conn.execute(idx_def)
-                elif idx_def != target_indexes[idx_name]:
-                    await conn.execute(f'DROP INDEX "{idx_name}"')
-                    await conn.execute(idx_def)
+            try:
+                source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
+                target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']}
 
-            source_constraints = {con['conname']: con for con in source_schema['constraints']}
-            target_constraints = {con['conname']: con for con in target_schema['constraints']}
+                for idx_name, idx_def in source_indexes.items():
+                    if idx_name not in target_indexes:
+                        print(f"Executing SQL: {idx_def}")
+                        await conn.execute(idx_def)
+                    elif idx_def != target_indexes[idx_name]:
+                        sql = f'DROP INDEX IF EXISTS "{idx_name}"'
+                        print(f"Executing SQL: {sql}")
+                        await conn.execute(sql)
+                        print(f"Executing SQL: {idx_def}")
+                        await conn.execute(idx_def)
+            except Exception as e:
+                print(f"Error processing indexes: {str(e)}")
+
+            try:
+                source_constraints = {con['conname']: con for con in source_schema['constraints']}
+                target_constraints = {con['conname']: con for con in target_schema['constraints']}
+
+                for con_name, source_con in source_constraints.items():
+                    if con_name not in target_constraints:
+                        sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
+                        print(f"Executing SQL: {sql}")
+                        await conn.execute(sql)
+                    elif source_con != target_constraints[con_name]:
+                        sql = f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT IF EXISTS "{con_name}"'
+                        print(f"Executing SQL: {sql}")
+                        await conn.execute(sql)
+                        sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
+                        print(f"Executing SQL: {sql}")
+                        await conn.execute(sql)
+            except Exception as e:
+                print(f"Error processing constraints: {str(e)}")
+
+        print(f"Schema synchronization completed for {pool_entry['ts_ip']}")
 
-            for con_name, source_con in source_constraints.items():
-                if con_name not in target_constraints:
-                    await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}')
-                elif source_con != target_constraints[con_name]:
-                    await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT "{con_name}"')
-                    await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}')
 
 class Location(BaseModel):
     latitude: float
diff --git a/sijapi/routers/asr.py b/sijapi/routers/asr.py
index 1981825..7278537 100644
--- a/sijapi/routers/asr.py
+++ b/sijapi/routers/asr.py
@@ -24,6 +24,9 @@ def warn(text: str): logger.warning(text)
 def err(text: str): logger.error(text)
 def crit(text: str): logger.critical(text)
 
+# Global dictionary to store transcription results
+transcription_results = {}
+
 class TranscribeParams(BaseModel):
     model: str = Field(default="small")
     output_srt: Optional[bool] = Field(default=False)
@@ -40,8 +43,6 @@ class TranscribeParams(BaseModel):
     dtw: Optional[str] = Field(None)
     threads: Optional[int] = Field(None)
 
-# Global dictionary to store transcription results
-transcription_results = {}
 
 @asr.post("/asr")
 @asr.post("/transcribe")
@@ -63,12 +64,11 @@ async def transcribe_endpoint(
         temp_file.write(await file.read())
         temp_file_path = temp_file.name
     
-    transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters)
-    job_id = transcription_job["job_id"]
+    job_id = await transcribe_audio(file_path=temp_file_path, params=parameters)
 
     # Poll for completion
-    max_wait_time = 600  # 10 minutes
-    poll_interval = 2  # 2 seconds
+    max_wait_time = 3600  # 60 minutes
+    poll_interval = 10  # 2 seconds
     elapsed_time = 0
 
     while elapsed_time < max_wait_time:
@@ -85,6 +85,7 @@ async def transcribe_endpoint(
     # If we've reached this point, the transcription has taken too long
     return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202)
 
+
 async def transcribe_audio(file_path, params: TranscribeParams):
     debug(f"Transcribing audio file from {file_path}...")
     file_path = await convert_to_wav(file_path)
@@ -136,8 +137,8 @@ async def transcribe_audio(file_path, params: TranscribeParams):
     # Start the transcription process immediately
     transcription_task = asyncio.create_task(process_transcription(command, file_path, job_id))
 
-    max_wait_time = 300  # 5 minutes
-    poll_interval = 1  # 1 second
+    max_wait_time = 3600  # 1 hour
+    poll_interval = 10  # 10 seconds
     start_time = asyncio.get_event_loop().time()
 
     debug(f"Starting to poll for job {job_id}")
@@ -147,7 +148,7 @@ async def transcribe_audio(file_path, params: TranscribeParams):
             debug(f"Current status for job {job_id}: {job_status['status']}")
             if job_status["status"] == "completed":
                 info(f"Transcription completed for job {job_id}")
-                return job_status["result"]
+                return job_id  # This is the only change
             elif job_status["status"] == "failed":
                 err(f"Transcription failed for job {job_id}: {job_status.get('error', 'Unknown error')}")
                 raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}")
@@ -162,6 +163,7 @@ async def transcribe_audio(file_path, params: TranscribeParams):
     # This line should never be reached, but just in case:
     raise Exception("Unexpected exit from transcription function")
 
+
 async def process_transcription(command, file_path, job_id):
     try:
         debug(f"Starting transcription process for job {job_id}")