Auto-update: Wed Jul 24 23:49:59 PDT 2024
This commit is contained in:
parent
fb0dd4ece8
commit
f4010c8a3f
2 changed files with 191 additions and 93 deletions
|
@ -125,7 +125,7 @@ class Configuration(BaseModel):
|
||||||
elif len(parts) == 2 and parts[0] == 'ENV':
|
elif len(parts) == 2 and parts[0] == 'ENV':
|
||||||
replacement = os.getenv(parts[1], '')
|
replacement = os.getenv(parts[1], '')
|
||||||
else:
|
else:
|
||||||
replacement = value # Keep original if not recognized
|
replacement = value
|
||||||
|
|
||||||
value = value.replace('{{' + match + '}}', str(replacement))
|
value = value.replace('{{' + match + '}}', str(replacement))
|
||||||
|
|
||||||
|
@ -154,6 +154,7 @@ class Configuration(BaseModel):
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class APIConfig(BaseModel):
|
class APIConfig(BaseModel):
|
||||||
HOST: str
|
HOST: str
|
||||||
PORT: int
|
PORT: int
|
||||||
|
@ -161,9 +162,9 @@ class APIConfig(BaseModel):
|
||||||
URL: str
|
URL: str
|
||||||
PUBLIC: List[str]
|
PUBLIC: List[str]
|
||||||
TRUSTED_SUBNETS: List[str]
|
TRUSTED_SUBNETS: List[str]
|
||||||
MODULES: Any # This will be replaced with a dynamic model
|
MODULES: Any
|
||||||
POOL: List[Dict[str, Any]]
|
POOL: List[Dict[str, Any]]
|
||||||
EXTENSIONS: Any # This will be replaced with a dynamic model
|
EXTENSIONS: Any
|
||||||
TZ: str
|
TZ: str
|
||||||
KEYS: List[str]
|
KEYS: List[str]
|
||||||
GARBAGE: Dict[str, Any]
|
GARBAGE: Dict[str, Any]
|
||||||
|
@ -173,11 +174,10 @@ class APIConfig(BaseModel):
|
||||||
config_path = cls._resolve_path(config_path, 'config')
|
config_path = cls._resolve_path(config_path, 'config')
|
||||||
secrets_path = cls._resolve_path(secrets_path, 'config')
|
secrets_path = cls._resolve_path(secrets_path, 'config')
|
||||||
|
|
||||||
# Load main configuration
|
|
||||||
with open(config_path, 'r') as file:
|
with open(config_path, 'r') as file:
|
||||||
config_data = yaml.safe_load(file)
|
config_data = yaml.safe_load(file)
|
||||||
|
|
||||||
print(f"Loaded main config: {config_data}")
|
print(f"Loaded main config: {config_data}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(secrets_path, 'r') as file:
|
with open(secrets_path, 'r') as file:
|
||||||
|
@ -191,7 +191,7 @@ class APIConfig(BaseModel):
|
||||||
secrets_data = {}
|
secrets_data = {}
|
||||||
|
|
||||||
config_data = cls.resolve_placeholders(config_data)
|
config_data = cls.resolve_placeholders(config_data)
|
||||||
print(f"Resolved config: {config_data}")
|
print(f"Resolved config: {config_data}")
|
||||||
if isinstance(config_data.get('KEYS'), list) and len(config_data['KEYS']) == 1:
|
if isinstance(config_data.get('KEYS'), list) and len(config_data['KEYS']) == 1:
|
||||||
placeholder = config_data['KEYS'][0]
|
placeholder = config_data['KEYS'][0]
|
||||||
if placeholder.startswith('{{') and placeholder.endswith('}}'):
|
if placeholder.startswith('{{') and placeholder.endswith('}}'):
|
||||||
|
@ -227,7 +227,7 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_path(cls, path: Union[str, Path], default_dir: str) -> Path:
|
def _resolve_path(cls, path: Union[str, Path], default_dir: str) -> Path:
|
||||||
base_path = Path(__file__).parent.parent # This will be two levels up from this file
|
base_path = Path(__file__).parent.parent
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if not path.suffix:
|
if not path.suffix:
|
||||||
path = base_path / "sijapi" / default_dir / f"{path.name}.yaml"
|
path = base_path / "sijapi" / default_dir / f"{path.name}.yaml"
|
||||||
|
@ -255,7 +255,6 @@ class APIConfig(BaseModel):
|
||||||
else:
|
else:
|
||||||
resolved_data[key] = resolve_value(value)
|
resolved_data[key] = resolve_value(value)
|
||||||
|
|
||||||
# Resolve BIND separately to ensure HOST and PORT are used
|
|
||||||
if 'BIND' in resolved_data:
|
if 'BIND' in resolved_data:
|
||||||
resolved_data['BIND'] = resolved_data['BIND'].replace('{{ HOST }}', str(resolved_data['HOST']))
|
resolved_data['BIND'] = resolved_data['BIND'].replace('{{ HOST }}', str(resolved_data['HOST']))
|
||||||
resolved_data['BIND'] = resolved_data['BIND'].replace('{{ PORT }}', str(resolved_data['PORT']))
|
resolved_data['BIND'] = resolved_data['BIND'].replace('{{ PORT }}', str(resolved_data['PORT']))
|
||||||
|
@ -265,7 +264,9 @@ class APIConfig(BaseModel):
|
||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> Any:
|
||||||
if name in ['MODULES', 'EXTENSIONS']:
|
if name in ['MODULES', 'EXTENSIONS']:
|
||||||
return self.__dict__[name]
|
return self.__dict__[name]
|
||||||
return super().__getattr__(name)
|
if name in self.__dict__:
|
||||||
|
return self.__dict__[name]
|
||||||
|
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_modules(self) -> List[str]:
|
def active_modules(self) -> List[str]:
|
||||||
|
@ -303,7 +304,6 @@ class APIConfig(BaseModel):
|
||||||
crit(f"Error: {str(e)}")
|
crit(f"Error: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def push_changes(self, query: str, *args):
|
async def push_changes(self, query: str, *args):
|
||||||
connections = []
|
connections = []
|
||||||
try:
|
try:
|
||||||
|
@ -337,41 +337,72 @@ class APIConfig(BaseModel):
|
||||||
continue
|
continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
|
|
||||||
if source_pool_entry is None:
|
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
|
||||||
source_pool_entry = await self.get_default_source()
|
if source_pool_entry is None:
|
||||||
|
source_pool_entry = await self.get_default_source()
|
||||||
if source_pool_entry is None:
|
|
||||||
err("No available source for pulling changes")
|
if source_pool_entry is None:
|
||||||
return
|
err("No available source for pulling changes")
|
||||||
|
return
|
||||||
async with self.get_connection(source_pool_entry) as source_conn:
|
|
||||||
async with self.get_connection() as dest_conn:
|
async with self.get_connection(source_pool_entry) as source_conn:
|
||||||
# This is a simplistic approach. You might need a more sophisticated
|
async with self.get_connection() as dest_conn:
|
||||||
# method to determine what data needs to be synced.
|
tables = await source_conn.fetch(
|
||||||
tables = await source_conn.fetch(
|
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
|
||||||
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
|
)
|
||||||
)
|
for table in tables:
|
||||||
for table in tables:
|
table_name = table['tablename']
|
||||||
table_name = table['tablename']
|
info(f"Processing table: {table_name}")
|
||||||
await dest_conn.execute(f"TRUNCATE TABLE {table_name}")
|
|
||||||
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
|
# Get primary key column(s)
|
||||||
if rows:
|
pk_columns = await source_conn.fetch("""
|
||||||
columns = rows[0].keys()
|
SELECT a.attname
|
||||||
await dest_conn.copy_records_to_table(
|
FROM pg_index i
|
||||||
table_name, records=rows, columns=columns
|
JOIN pg_attribute a ON a.attrelid = i.indrelid
|
||||||
)
|
AND a.attnum = ANY(i.indkey)
|
||||||
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
|
WHERE i.indrelid = $1::regclass
|
||||||
|
AND i.indisprimary;
|
||||||
|
""", table_name)
|
||||||
|
|
||||||
|
pk_cols = [col['attname'] for col in pk_columns]
|
||||||
|
if not pk_cols:
|
||||||
|
warn(f"No primary key found for table {table_name}. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fetch all rows from the source table
|
||||||
|
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
|
||||||
|
if rows:
|
||||||
|
columns = rows[0].keys()
|
||||||
|
# Upsert records to the destination table
|
||||||
|
await dest_conn.executemany(f"""
|
||||||
|
INSERT INTO {table_name} ({', '.join(columns)})
|
||||||
|
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
|
||||||
|
ON CONFLICT ({', '.join(pk_cols)}) DO UPDATE SET
|
||||||
|
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in pk_cols)}
|
||||||
|
""", [tuple(row[col] for col in columns) for row in rows])
|
||||||
|
|
||||||
|
info(f"Completed processing table: {table_name}")
|
||||||
|
|
||||||
|
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def sync_schema(self):
|
async def sync_schema(self):
|
||||||
source_entry = self.POOL[0] # Use the local database as the source
|
for pool_entry in self.POOL:
|
||||||
|
async with self.get_connection(pool_entry) as conn:
|
||||||
|
await conn.execute('CREATE EXTENSION IF NOT EXISTS postgis')
|
||||||
|
|
||||||
|
source_entry = self.local_db
|
||||||
source_schema = await self.get_schema(source_entry)
|
source_schema = await self.get_schema(source_entry)
|
||||||
|
|
||||||
for pool_entry in self.POOL[1:]:
|
for pool_entry in self.POOL[1:]:
|
||||||
target_schema = await self.get_schema(pool_entry)
|
try:
|
||||||
await self.apply_schema_changes(pool_entry, source_schema, target_schema)
|
target_schema = await self.get_schema(pool_entry)
|
||||||
info(f"Synced schema to {pool_entry['ts_ip']}")
|
await self.apply_schema_changes(pool_entry, source_schema, target_schema)
|
||||||
|
info(f"Synced schema to {pool_entry['ts_ip']}")
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Failed to sync schema to {pool_entry['ts_ip']}: {str(e)}")
|
||||||
|
|
||||||
async def get_schema(self, pool_entry: Dict[str, Any]):
|
async def get_schema(self, pool_entry: Dict[str, Any]):
|
||||||
async with self.get_connection(pool_entry) as conn:
|
async with self.get_connection(pool_entry) as conn:
|
||||||
|
@ -402,61 +433,126 @@ class APIConfig(BaseModel):
|
||||||
'constraints': constraints
|
'constraints': constraints
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def create_sequence_if_not_exists(self, conn, sequence_name):
|
||||||
|
await conn.execute(f"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
|
||||||
|
CREATE SEQUENCE {sequence_name};
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
|
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
|
||||||
async with self.get_connection(pool_entry) as conn:
|
async with self.get_connection(pool_entry) as conn:
|
||||||
source_tables = {t['table_name']: t for t in source_schema['tables']}
|
source_tables = {t['table_name']: t for t in source_schema['tables']}
|
||||||
target_tables = {t['table_name']: t for t in target_schema['tables']}
|
target_tables = {t['table_name']: t for t in target_schema['tables']}
|
||||||
|
|
||||||
for table_name, source_table in source_tables.items():
|
def get_column_type(data_type):
|
||||||
if table_name not in target_tables:
|
if data_type == 'ARRAY':
|
||||||
columns = [f"\"{t['column_name']}\" {t['data_type']}" +
|
return 'text[]' # or another appropriate type
|
||||||
(f"({t['character_maximum_length']})" if t['character_maximum_length'] else "") +
|
elif data_type == 'USER-DEFINED':
|
||||||
(" NOT NULL" if t['is_nullable'] == 'NO' else "") +
|
return 'geometry'
|
||||||
(f" DEFAULT {t['column_default']}" if t['column_default'] else "")
|
|
||||||
for t in source_schema['tables'] if t['table_name'] == table_name]
|
|
||||||
await conn.execute(f'CREATE TABLE "{table_name}" ({", ".join(columns)})')
|
|
||||||
else:
|
else:
|
||||||
target_table = target_tables[table_name]
|
return data_type
|
||||||
source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name}
|
|
||||||
target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name}
|
|
||||||
|
|
||||||
for col_name, source_col in source_columns.items():
|
for table_name, source_table in source_tables.items():
|
||||||
if col_name not in target_columns:
|
try:
|
||||||
col_def = f"\"{col_name}\" {source_col['data_type']}" + \
|
if table_name not in target_tables:
|
||||||
(f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \
|
columns = []
|
||||||
(" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \
|
for t in source_schema['tables']:
|
||||||
(f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "")
|
if t['table_name'] == table_name:
|
||||||
await conn.execute(f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}')
|
col_type = get_column_type(t['data_type'])
|
||||||
else:
|
col_def = f"\"{t['column_name']}\" {col_type}"
|
||||||
target_col = target_columns[col_name]
|
if t['character_maximum_length']:
|
||||||
if source_col != target_col:
|
col_def += f"({t['character_maximum_length']})"
|
||||||
await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {source_col["data_type"]}')
|
if t['is_nullable'] == 'NO':
|
||||||
if source_col['is_nullable'] != target_col['is_nullable']:
|
col_def += " NOT NULL"
|
||||||
null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL"
|
if t['column_default']:
|
||||||
await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}')
|
if 'nextval' in t['column_default']:
|
||||||
if source_col['column_default'] != target_col['column_default']:
|
sequence_name = t['column_default'].split("'")[1]
|
||||||
default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT"
|
await self.create_sequence_if_not_exists(conn, sequence_name)
|
||||||
await conn.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}')
|
col_def += f" DEFAULT {t['column_default']}"
|
||||||
|
columns.append(col_def)
|
||||||
|
|
||||||
|
sql = f'CREATE TABLE "{table_name}" ({", ".join(columns)})'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
else:
|
||||||
|
target_table = target_tables[table_name]
|
||||||
|
source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name}
|
||||||
|
target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name}
|
||||||
|
|
||||||
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
|
for col_name, source_col in source_columns.items():
|
||||||
target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']}
|
if col_name not in target_columns:
|
||||||
|
col_type = get_column_type(source_col['data_type'])
|
||||||
|
col_def = f"\"{col_name}\" {col_type}" + \
|
||||||
|
(f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \
|
||||||
|
(" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \
|
||||||
|
(f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "")
|
||||||
|
sql = f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
else:
|
||||||
|
target_col = target_columns[col_name]
|
||||||
|
if source_col != target_col:
|
||||||
|
col_type = get_column_type(source_col['data_type'])
|
||||||
|
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {col_type}'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
if source_col['is_nullable'] != target_col['is_nullable']:
|
||||||
|
null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL"
|
||||||
|
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
if source_col['column_default'] != target_col['column_default']:
|
||||||
|
default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT"
|
||||||
|
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing table {table_name}: {str(e)}")
|
||||||
|
# Optionally, you might want to raise this exception if you want to stop the entire process
|
||||||
|
# raise
|
||||||
|
|
||||||
for idx_name, idx_def in source_indexes.items():
|
try:
|
||||||
if idx_name not in target_indexes:
|
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
|
||||||
await conn.execute(idx_def)
|
target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']}
|
||||||
elif idx_def != target_indexes[idx_name]:
|
|
||||||
await conn.execute(f'DROP INDEX "{idx_name}"')
|
|
||||||
await conn.execute(idx_def)
|
|
||||||
|
|
||||||
source_constraints = {con['conname']: con for con in source_schema['constraints']}
|
for idx_name, idx_def in source_indexes.items():
|
||||||
target_constraints = {con['conname']: con for con in target_schema['constraints']}
|
if idx_name not in target_indexes:
|
||||||
|
print(f"Executing SQL: {idx_def}")
|
||||||
|
await conn.execute(idx_def)
|
||||||
|
elif idx_def != target_indexes[idx_name]:
|
||||||
|
sql = f'DROP INDEX IF EXISTS "{idx_name}"'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
print(f"Executing SQL: {idx_def}")
|
||||||
|
await conn.execute(idx_def)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing indexes: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
source_constraints = {con['conname']: con for con in source_schema['constraints']}
|
||||||
|
target_constraints = {con['conname']: con for con in target_schema['constraints']}
|
||||||
|
|
||||||
|
for con_name, source_con in source_constraints.items():
|
||||||
|
if con_name not in target_constraints:
|
||||||
|
sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
elif source_con != target_constraints[con_name]:
|
||||||
|
sql = f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT IF EXISTS "{con_name}"'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
|
||||||
|
print(f"Executing SQL: {sql}")
|
||||||
|
await conn.execute(sql)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing constraints: {str(e)}")
|
||||||
|
|
||||||
|
print(f"Schema synchronization completed for {pool_entry['ts_ip']}")
|
||||||
|
|
||||||
for con_name, source_con in source_constraints.items():
|
|
||||||
if con_name not in target_constraints:
|
|
||||||
await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}')
|
|
||||||
elif source_con != target_constraints[con_name]:
|
|
||||||
await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT "{con_name}"')
|
|
||||||
await conn.execute(f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}')
|
|
||||||
|
|
||||||
class Location(BaseModel):
|
class Location(BaseModel):
|
||||||
latitude: float
|
latitude: float
|
||||||
|
|
|
@ -24,6 +24,9 @@ def warn(text: str): logger.warning(text)
|
||||||
def err(text: str): logger.error(text)
|
def err(text: str): logger.error(text)
|
||||||
def crit(text: str): logger.critical(text)
|
def crit(text: str): logger.critical(text)
|
||||||
|
|
||||||
|
# Global dictionary to store transcription results
|
||||||
|
transcription_results = {}
|
||||||
|
|
||||||
class TranscribeParams(BaseModel):
|
class TranscribeParams(BaseModel):
|
||||||
model: str = Field(default="small")
|
model: str = Field(default="small")
|
||||||
output_srt: Optional[bool] = Field(default=False)
|
output_srt: Optional[bool] = Field(default=False)
|
||||||
|
@ -40,8 +43,6 @@ class TranscribeParams(BaseModel):
|
||||||
dtw: Optional[str] = Field(None)
|
dtw: Optional[str] = Field(None)
|
||||||
threads: Optional[int] = Field(None)
|
threads: Optional[int] = Field(None)
|
||||||
|
|
||||||
# Global dictionary to store transcription results
|
|
||||||
transcription_results = {}
|
|
||||||
|
|
||||||
@asr.post("/asr")
|
@asr.post("/asr")
|
||||||
@asr.post("/transcribe")
|
@asr.post("/transcribe")
|
||||||
|
@ -63,12 +64,11 @@ async def transcribe_endpoint(
|
||||||
temp_file.write(await file.read())
|
temp_file.write(await file.read())
|
||||||
temp_file_path = temp_file.name
|
temp_file_path = temp_file.name
|
||||||
|
|
||||||
transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters)
|
job_id = await transcribe_audio(file_path=temp_file_path, params=parameters)
|
||||||
job_id = transcription_job["job_id"]
|
|
||||||
|
|
||||||
# Poll for completion
|
# Poll for completion
|
||||||
max_wait_time = 600 # 10 minutes
|
max_wait_time = 3600 # 60 minutes
|
||||||
poll_interval = 2 # 2 seconds
|
poll_interval = 10 # 2 seconds
|
||||||
elapsed_time = 0
|
elapsed_time = 0
|
||||||
|
|
||||||
while elapsed_time < max_wait_time:
|
while elapsed_time < max_wait_time:
|
||||||
|
@ -85,6 +85,7 @@ async def transcribe_endpoint(
|
||||||
# If we've reached this point, the transcription has taken too long
|
# If we've reached this point, the transcription has taken too long
|
||||||
return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202)
|
return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202)
|
||||||
|
|
||||||
|
|
||||||
async def transcribe_audio(file_path, params: TranscribeParams):
|
async def transcribe_audio(file_path, params: TranscribeParams):
|
||||||
debug(f"Transcribing audio file from {file_path}...")
|
debug(f"Transcribing audio file from {file_path}...")
|
||||||
file_path = await convert_to_wav(file_path)
|
file_path = await convert_to_wav(file_path)
|
||||||
|
@ -136,8 +137,8 @@ async def transcribe_audio(file_path, params: TranscribeParams):
|
||||||
# Start the transcription process immediately
|
# Start the transcription process immediately
|
||||||
transcription_task = asyncio.create_task(process_transcription(command, file_path, job_id))
|
transcription_task = asyncio.create_task(process_transcription(command, file_path, job_id))
|
||||||
|
|
||||||
max_wait_time = 300 # 5 minutes
|
max_wait_time = 3600 # 1 hour
|
||||||
poll_interval = 1 # 1 second
|
poll_interval = 10 # 10 seconds
|
||||||
start_time = asyncio.get_event_loop().time()
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
debug(f"Starting to poll for job {job_id}")
|
debug(f"Starting to poll for job {job_id}")
|
||||||
|
@ -147,7 +148,7 @@ async def transcribe_audio(file_path, params: TranscribeParams):
|
||||||
debug(f"Current status for job {job_id}: {job_status['status']}")
|
debug(f"Current status for job {job_id}: {job_status['status']}")
|
||||||
if job_status["status"] == "completed":
|
if job_status["status"] == "completed":
|
||||||
info(f"Transcription completed for job {job_id}")
|
info(f"Transcription completed for job {job_id}")
|
||||||
return job_status["result"]
|
return job_id # This is the only change
|
||||||
elif job_status["status"] == "failed":
|
elif job_status["status"] == "failed":
|
||||||
err(f"Transcription failed for job {job_id}: {job_status.get('error', 'Unknown error')}")
|
err(f"Transcription failed for job {job_id}: {job_status.get('error', 'Unknown error')}")
|
||||||
raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}")
|
raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}")
|
||||||
|
@ -162,6 +163,7 @@ async def transcribe_audio(file_path, params: TranscribeParams):
|
||||||
# This line should never be reached, but just in case:
|
# This line should never be reached, but just in case:
|
||||||
raise Exception("Unexpected exit from transcription function")
|
raise Exception("Unexpected exit from transcription function")
|
||||||
|
|
||||||
|
|
||||||
async def process_transcription(command, file_path, job_id):
|
async def process_transcription(command, file_path, job_id):
|
||||||
try:
|
try:
|
||||||
debug(f"Starting transcription process for job {job_id}")
|
debug(f"Starting transcription process for job {job_id}")
|
||||||
|
|
Loading…
Reference in a new issue