147 lines
4.5 KiB
Python
147 lines
4.5 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import Any, Dict
|
|
|
|
import aio_pika
|
|
from aio_pika.exceptions import AMQPConnectionError
|
|
import psycopg
|
|
from psycopg.rows import dict_row
|
|
|
|
from services.anilist_importer import AniListImporter
|
|
|
|
PG_DSN = os.getenv("NYANIMEDB_PG_DSN") or os.getenv("DATABASE_URL")
|
|
RMQ_URL = os.getenv("NYANIMEDB_RMQ_URL") or os.getenv("RABBITMQ_URL") or "amqp://guest:guest@rabbitmq:5672/"
|
|
RPC_QUEUE_NAME = os.getenv("NYANIMEDB_IMPORT_RPC_QUEUE", "anime_import_rpc")
|
|
|
|
|
|
def rmq_request_to_filters(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
filters: Dict[str, Any] = {}
|
|
|
|
name = payload.get("name")
|
|
if isinstance(name, str) and name.strip():
|
|
filters["query"] = name.strip()
|
|
|
|
year = payload.get("year")
|
|
if isinstance(year, int) and year > 0:
|
|
filters["year"] = year
|
|
|
|
season = payload.get("season")
|
|
if isinstance(season, str) and season:
|
|
filters["season"] = season.lower()
|
|
|
|
filters.setdefault("limit", 10)
|
|
return filters
|
|
|
|
|
|
def create_handler(channel: aio_pika.Channel):
|
|
async def handle_message(message: aio_pika.IncomingMessage) -> None:
|
|
async with message.process():
|
|
try:
|
|
payload = json.loads(message.body.decode("utf-8"))
|
|
except json.JSONDecodeError:
|
|
return
|
|
|
|
if not isinstance(payload, dict):
|
|
return
|
|
|
|
filters = rmq_request_to_filters(payload)
|
|
timestamp = payload.get("timestamp")
|
|
|
|
try:
|
|
async with await psycopg.AsyncConnection.connect(
|
|
PG_DSN,
|
|
row_factory=dict_row,
|
|
) as conn:
|
|
importer = AniListImporter()
|
|
titles = await importer.import_by_filters_in_tx(conn, filters)
|
|
|
|
response: dict[str, Any] = {
|
|
"timestamp": timestamp,
|
|
"ok": True,
|
|
"titles": titles,
|
|
"error": None,
|
|
}
|
|
|
|
except Exception as e:
|
|
response = {
|
|
"timestamp": timestamp,
|
|
"ok": False,
|
|
"titles": [],
|
|
"error": {
|
|
"code": "import_failed",
|
|
"message": str(e),
|
|
},
|
|
}
|
|
|
|
body = json.dumps(response).encode("utf-8")
|
|
|
|
if message.reply_to:
|
|
await channel.default_exchange.publish(
|
|
aio_pika.Message(
|
|
body=body,
|
|
content_type="application/json",
|
|
correlation_id=message.correlation_id,
|
|
),
|
|
routing_key=message.reply_to,
|
|
)
|
|
|
|
return handle_message
|
|
|
|
|
|
async def connect_rmq_with_retry(
|
|
url: str,
|
|
retries: int = 20,
|
|
delay: float = 3.0,
|
|
) -> aio_pika.RobustConnection:
|
|
last_exc: Exception | None = None
|
|
|
|
for attempt in range(1, retries + 1):
|
|
try:
|
|
print(f"[worker] Connecting to RabbitMQ ({attempt}/{retries}) {url}", flush=True)
|
|
conn = await aio_pika.connect_robust(url)
|
|
print("[worker] Connected to RabbitMQ", flush=True)
|
|
return conn
|
|
except AMQPConnectionError as e:
|
|
last_exc = e
|
|
print(f"[worker] RabbitMQ connection failed: {e!r}, retry in {delay}s", flush=True)
|
|
await asyncio.sleep(delay)
|
|
|
|
print("[worker] Failed to connect to RabbitMQ after retries", file=sys.stderr, flush=True)
|
|
if last_exc:
|
|
raise last_exc
|
|
raise RuntimeError("Failed to connect to RabbitMQ")
|
|
|
|
|
|
async def main() -> None:
|
|
if not PG_DSN:
|
|
raise RuntimeError("PG_DSN is not set (NYANIMEDB_PG_DSN / DATABASE_URL)")
|
|
|
|
print(f"[worker] Starting. PG_DSN={PG_DSN!r}, RMQ_URL={RMQ_URL!r}, queue={RPC_QUEUE_NAME!r}", flush=True)
|
|
|
|
connection = await connect_rmq_with_retry(RMQ_URL)
|
|
channel = await connection.channel()
|
|
|
|
queue = await channel.declare_queue(
|
|
RPC_QUEUE_NAME,
|
|
durable=True,
|
|
)
|
|
|
|
handler = create_handler(channel)
|
|
await queue.consume(handler)
|
|
|
|
print(f"[*] Waiting for messages in '{RPC_QUEUE_NAME}'. Ctrl+C to exit.", flush=True)
|
|
|
|
try:
|
|
await asyncio.Future() # run forever
|
|
finally:
|
|
await connection.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if sys.platform.startswith("win"):
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
asyncio.run(main())
|