Source code for pynenc_redis.orchestrator.redis_orchestrator

import json
from collections.abc import Iterator
from datetime import UTC, datetime
from functools import cached_property
from time import time
from typing import TYPE_CHECKING

import redis
from pynenc.identifiers.invocation_id import InvocationId
from pynenc.invocation.dist_invocation import DistributedInvocation
from pynenc.invocation.status import (
    InvocationStatus,
    InvocationStatusRecord,
    status_record_transition,
)
from pynenc.orchestrator.atomic_service import ActiveRunnerInfo
from pynenc.orchestrator.base_orchestrator import (
    BaseBlockingControl,
    BaseOrchestrator,
)
from pynenc.types import Params, Result

from pynenc_redis.conf.config_orchestrator import ConfigOrchestratorRedis
from pynenc_redis.util.mongo_client import get_redis_client
from pynenc_redis.util.redis_keys import Key

if TYPE_CHECKING:
    from pynenc.app import Pynenc
    from pynenc.identifiers.call_id import CallId
    from pynenc.invocation.dist_invocation import DistributedInvocation
    from pynenc.task import Task, TaskId


[docs] class StatusNotFound(Exception): """Raised when a status is not found in Redis"""
[docs] class RedisBlockingControl(BaseBlockingControl): """ A Redis-based implementation of blocking control for task invocations. Manages invocation dependencies and blocking states in a Redis-backed environment, ensuring that invocations waiting for others are properly tracked and released. :param Pynenc app: The Pynenc application instance. :param redis.Redis client: The Redis client instance. """ def __init__(self, app: "Pynenc", client: redis.Redis) -> None: self.app = app self.key = Key(app.app_id, "blocking_control") self.client = client
[docs] def purge(self) -> None: """Purges all data related to blocking control from Redis.""" self.key.purge(self.client)
[docs] def waiting_for_results( self, caller_invocation_id: "InvocationId", result_invocation_ids: list["InvocationId"], ) -> None: """ Notifies the system that an invocation is waiting for the results of other invocations. :param caller_invocation_id: The ID of the invocation that is waiting. :param result_invocation_ids: The IDs of the invocations being waited on. """ for waited_invocation_id in result_invocation_ids: self.client.set( self.key.invocation(waited_invocation_id), waited_invocation_id ) self.client.sadd( self.key.waited_by(waited_invocation_id), caller_invocation_id ) self.client.zadd(self.key.all_waited(), {waited_invocation_id: time()}) if not self.client.exists(self.key.waiting_for(waited_invocation_id)): self.client.zadd(self.key.not_waiting(), {waited_invocation_id: time()}) if self.client.zscore(self.key.not_waiting(), caller_invocation_id) is not None: self.client.zrem(self.key.not_waiting(), caller_invocation_id) self.client.sadd( self.key.waiting_for(caller_invocation_id), *result_invocation_ids )
[docs] def release_waiters(self, waited_invocation_id: str) -> None: """ Releases any invocations that are waiting on the specified invocation. :param waited_invocation_id: The ID of the invocation that has finished and can release its waiters. """ for waiter_invocation_id in self.client.smembers( self.key.waited_by(waited_invocation_id) ): self.client.srem( self.key.waiting_for(waiter_invocation_id.decode()), waited_invocation_id, ) if not self.client.exists(self.key.waiting_for(waiter_invocation_id)): self.client.zadd(self.key.not_waiting(), {waiter_invocation_id: time()}) self.client.delete(self.key.invocation(waited_invocation_id)) self.client.delete(self.key.waiting_for(waited_invocation_id)) self.client.delete(self.key.waited_by(waited_invocation_id)) self.client.zrem(self.key.all_waited(), waited_invocation_id) self.client.zrem(self.key.not_waiting(), waited_invocation_id)
[docs] def get_blocking_invocations( self, max_num_invocations: int ) -> Iterator["InvocationId"]: """ Retrieves invocation IDs that are blocking others but are not blocked themselves. :param max_num_invocations: The maximum number of blocking invocation IDs to retrieve. :return: An iterator over unblocked, blocking invocation IDs, ordered by age (oldest first). """ index = 0 page_size = max(10, max_num_invocations) count = 0 while count < max_num_invocations: if not ( page := self.client.zrange( self.key.not_waiting(), index, index + page_size - 1 ) ): break index += page_size for waited_invocation_id in page: invocation_id = InvocationId(waited_invocation_id.decode()) val_inv_id = self.client.get(self.key.invocation(invocation_id)) if not val_inv_id: continue try: status_record = self.app.orchestrator.get_invocation_status_record( invocation_id ) if status_record.status.is_available_for_run(): yield invocation_id count += 1 if count == max_num_invocations: break except KeyError: self.app.logger.warning( f"Skipping invocation {invocation_id} in get_blocking_invocations: " "status not found in Redis" ) if max_num_invocations == 0: break
[docs] class RedisOrchestrator(BaseOrchestrator): """ Orchestrator implementation using Redis for distributed invocation management. Stores status records with ownership tracking by invocation_id, using atomic transactions for status changes. """ def __init__(self, app: "Pynenc") -> None: super().__init__(app) self._client: redis.Redis | None = None self._blocking_control: RedisBlockingControl | None = None self.key = Key(app.app_id, "orchestrator") @cached_property def conf(self) -> ConfigOrchestratorRedis: return ConfigOrchestratorRedis( config_values=self.app.config_values, config_filepath=self.app.config_filepath, ) @property def client(self) -> redis.Redis: if self._client is None: self._client = get_redis_client(self.conf) return self._client @property def blocking_control(self) -> "RedisBlockingControl": if not self._blocking_control: self._blocking_control = RedisBlockingControl(self.app, self.client) return self._blocking_control
[docs] def get_existing_invocations( self, task: "Task[Params, Result]", key_serialized_arguments: dict[str, str] | None = None, statuses: list[InvocationStatus] | None = None, ) -> Iterator["InvocationId"]: """ Retrieves existing invocation IDs based on task, arguments, and status. :param task: The task for which to retrieve invocations. :param key_serialized_arguments: Serialized arguments to filter invocations. :param statuses: The statuses to filter invocations. :return: An iterator over the matching invocation IDs. """ task_id_key: str = task.task_id.key invocation_ids: set[str] = set() for inv_id in self.client.smembers(self.key.task(task_id_key)): invocation_ids.add( InvocationId(inv_id.decode()) if isinstance(inv_id, bytes) else InvocationId(inv_id) ) if key_serialized_arguments: for arg, val in key_serialized_arguments.items(): arg_val_ids = { InvocationId(id.decode()) if isinstance(id, bytes) else InvocationId(id) for id in self.client.smembers(self.key.args(task_id_key, arg, val)) } invocation_ids &= arg_val_ids if statuses: status_ids: set[str] = set() for status in statuses: status_ids |= { InvocationId(id.decode()) if isinstance(id, bytes) else InvocationId(id) for id in self.client.smembers( self.key.status_to_invocations(status) ) } invocation_ids &= status_ids for inv_id in invocation_ids: yield InvocationId(inv_id)
[docs] def get_task_invocation_ids(self, task_id: "TaskId") -> Iterator["InvocationId"]: """ Retrieves all invocation IDs associated with a specific task ID. :param task_id: The task ID to filter invocations. :return: Iterator of invocation IDs for the specified task. """ for inv_id in self.client.smembers(self.key.task(task_id.key)): yield ( InvocationId(inv_id.decode()) if isinstance(inv_id, bytes) else InvocationId(inv_id) )
[docs] def get_invocation_ids_paginated( self, task_id: "TaskId | None" = None, statuses: list[InvocationStatus] | None = None, limit: int = 100, offset: int = 0, ) -> list["InvocationId"]: """ Retrieves invocation IDs with pagination support. Uses Redis sorted sets indexed by registration time for efficient pagination. Results are ordered by registration time (newest first). :param task_id: Optional task ID to filter by. :param statuses: Optional statuses to filter by. :param limit: Maximum number of results to return. :param offset: Number of results to skip. :return: List of matching invocation IDs. """ # Determine which sorted set to use based on task_id filter if task_id: source_key = self.key.task_invocations_by_time(task_id) else: source_key = self.key.all_invocations_by_time() # If no status filter, use direct range query (newest first = reverse order) if not statuses: raw_ids = self.client.zrevrange(source_key, offset, offset + limit - 1) return [ InvocationId(inv_id.decode()) if isinstance(inv_id, bytes) else InvocationId(inv_id) for inv_id in raw_ids ] # With status filter, need to intersect with status sets status_keys = [self.key.status_to_invocations(status) for status in statuses] # Get all invocation IDs from status sets status_inv_ids: set[str] = set() for status_key in status_keys: for inv_id in self.client.smembers(status_key): decoded_id = inv_id.decode() if isinstance(inv_id, bytes) else inv_id status_inv_ids.add(decoded_id) if not status_inv_ids: return [] # Get invocations from time-sorted set with their scores for ordering # Using zrevrange with scores to get all, then filter and paginate all_with_scores = self.client.zrevrange(source_key, 0, -1, withscores=True) # Filter by status and collect in order filtered_ids = [] for inv_id, _score in all_with_scores: decoded_id = inv_id.decode() if isinstance(inv_id, bytes) else inv_id if decoded_id in status_inv_ids: filtered_ids.append(decoded_id) # Apply pagination return filtered_ids[offset : offset + limit]
[docs] def count_invocations( self, task_id: "TaskId | None" = None, statuses: list[InvocationStatus] | None = None, ) -> int: """ Counts invocations matching the given filters. :param task_id: Optional task ID to filter by. :param statuses: Optional statuses to filter by. :return: The total count of matching invocations. """ # Determine source based on task_id filter if task_id: source_key = self.key.task_invocations_by_time(task_id.key) else: source_key = self.key.all_invocations_by_time() # If no status filter, return count from sorted set if not statuses: return self.client.zcard(source_key) # With status filter, need to count intersection # Get all invocation IDs from the source source_inv_ids: set["InvocationId"] = { InvocationId(inv_id.decode()) if isinstance(inv_id, bytes) else InvocationId(inv_id) for inv_id in self.client.zrange(source_key, 0, -1) } if not source_inv_ids: return 0 # Get all invocation IDs matching any of the statuses status_inv_ids: set[str] = set() for status in statuses: for inv_id in self.client.smembers(self.key.status_to_invocations(status)): decoded_id = inv_id.decode() if isinstance(inv_id, bytes) else inv_id status_inv_ids.add(decoded_id) # Return count of intersection return len(source_inv_ids & status_inv_ids)
[docs] def get_call_invocation_ids(self, call_id: "CallId") -> Iterator["InvocationId"]: """ Retrieves all invocation IDs associated with a specific call ID. :param call_id: The call ID to filter invocations. :return: Iterator of invocation IDs for the specified call. """ for inv_id in self.client.smembers(self.key.call_to_invocation(call_id.key)): yield ( InvocationId(inv_id.decode()) if isinstance(inv_id, bytes) else InvocationId(inv_id) )
[docs] def any_non_final_invocations(self, call_id: "CallId") -> bool: """ Checks if there are any non-final invocations for a specific call ID. :param call_id: The call ID to check for non-final invocations. :return: True if there are non-final invocations, False otherwise. """ for invocation_id in self.get_call_invocation_ids(call_id): status_record = self.get_invocation_status_record(invocation_id) if not status_record.status.is_final(): return True return False
[docs] def _register_new_invocations( self, invocations: list["DistributedInvocation[Params, Result]"], runner_id: str | None = None, ) -> InvocationStatusRecord: """ Register new invocations with status REGISTERED if they don't exist yet. Initializes the necessary Redis data structures for task-to-invocation, call-to-invocation mappings, status, and retry tracking. """ status_record = InvocationStatusRecord(InvocationStatus.REGISTERED, runner_id) for invocation in invocations: # Skip if already registered if self.client.exists( self.key.invocation_to_status(invocation.invocation_id) ): continue # Add to task's invocation set self.client.sadd( self.key.task(invocation.task.task_id.key), invocation.invocation_id ) # Add to call's invocation set self.client.sadd( self.key.call_to_invocation(invocation.call.call_id.key), invocation.invocation_id, ) # Store invocation_id -> call_id mapping self.client.set( self.key.invocation_to_call(invocation.invocation_id), invocation.call.call_id.key, ) # Set status to REGISTERED self._set_status_record(invocation.invocation_id, status_record) # Initialize retry count to 0 self.client.set(self.key.invocation_retries(invocation.invocation_id), 0) # Add to time-indexed sorted sets for pagination support registration_time = time() self.client.zadd( self.key.all_invocations_by_time(), {invocation.invocation_id: registration_time}, ) self.client.zadd( self.key.task_invocations_by_time(invocation.task.task_id.key), {invocation.invocation_id: registration_time}, ) return status_record
[docs] def _set_status_record( self, invocation_id: str, status_record: InvocationStatusRecord ) -> None: """Store a status record in Redis.""" pipeline = self.client.pipeline(transaction=True) pipeline.sadd( self.key.status_to_invocations(status_record.status), invocation_id ) pipeline.set( self.key.invocation_to_status(invocation_id), json.dumps(status_record.to_json()), ) pipeline.execute()
[docs] def _atomic_status_transition( self, invocation_id: str, status: InvocationStatus, runner_id: str | None = None ) -> InvocationStatusRecord: """ Perform atomic status transition with validation. Uses Redis transactions to ensure atomic updates with ownership validation. """ # Get current status record current_record = self.get_invocation_status_record(invocation_id) # Validate and compute new record new_record = status_record_transition(current_record, status, runner_id) # Use Redis transaction for atomic update pipeline = self.client.pipeline(transaction=True) # Remove from old status set pipeline.srem( self.key.status_to_invocations(current_record.status), invocation_id ) # Add to new status set and update status record pipeline.sadd(self.key.status_to_invocations(new_record.status), invocation_id) pipeline.set( self.key.invocation_to_status(invocation_id), json.dumps(new_record.to_json()), ) pipeline.execute() self.app.logger.debug( f"Transitioned invocation {invocation_id} from {current_record.status} to {status}" ) return new_record
[docs] def get_invocation_status_record( self, invocation_id: "InvocationId" ) -> InvocationStatusRecord: """ Retrieves the status record of a specific invocation. :param invocation_id: The id of the invocation whose status is to be retrieved. :return: The current status record of the invocation. :raises KeyError: If invocation not found. """ if encoded_status := self.client.get( self.key.invocation_to_status(invocation_id) ): status_dict = json.loads(encoded_status.decode()) return InvocationStatusRecord.from_json(status_dict) raise KeyError(f"Invocation status {invocation_id} not found in Redis")
[docs] def index_arguments_for_concurrency_control( self, invocation: "DistributedInvocation[Params, Result]", ) -> None: """ Caches the required data to implement concurrency control. :param invocation: The invocation to be cached. """ for key, value in invocation.call.serialized_arguments.items(): self.client.sadd( self.key.args(invocation.task.task_id.key, key, value), invocation.invocation_id, )
[docs] def set_up_invocation_auto_purge(self, invocation_id: str) -> None: """ Sets up automatic purging for an invocation after a defined period. :param invocation_id: The invocation to set up for auto purge. """ self.client.zadd( self.key.invocation_auto_purge(), {invocation_id: time()}, )
[docs] def auto_purge(self) -> None: """ Automatically purges invocations that have been in their final state beyond a specified duration. """ end_time = ( time() - self.app.orchestrator.conf.auto_final_invocation_purge_hours * 3600 ) for _invocation_id in self.client.zrangebyscore( self.key.invocation_auto_purge(), 0, end_time ): invocation_id = _invocation_id.decode() try: invocation = self.app.state_backend.get_invocation(invocation_id) task_id = invocation.task.task_id # clean up task keys self.client.srem(self.key.task(task_id.key), invocation_id) if not self.client.smembers(self.key.task(task_id.key)): self.client.delete(self.key.task(task_id.key)) # clean up task-status keys status_record = self.get_invocation_status_record(invocation_id) self.client.srem( self.key.status_to_invocations(status_record.status), invocation_id ) if not self.client.smembers( self.key.status_to_invocations(status_record.status) ): self.client.delete( self.key.status_to_invocations(status_record.status) ) except KeyError: self.app.logger.warning(f"{invocation_id=} not found during auto purge") self.client.delete(self.key.invocation_to_status(invocation_id)) self.client.zrem(self.key.invocation_auto_purge(), invocation_id)
[docs] def increment_invocation_retries(self, invocation_id: str) -> None: """ Increments the retry count of a specific invocation. :param invocation_id: The id of the invocation for which to increment retries. """ self.client.incr(self.key.invocation_retries(invocation_id))
[docs] def get_invocation_retries(self, invocation_id: str) -> int: """ Retrieves the number of retries for a specific invocation. :param invocation_id: The id of the invocation whose retry count is to be retrieved. :return: The number of retries for the invocation. """ if encoded_retries := self.client.get( self.key.invocation_retries(invocation_id) ): return int(encoded_retries.decode()) return 0
[docs] def filter_by_status( self, invocation_ids: list["InvocationId"], status_filter: frozenset["InvocationStatus"], ) -> list["InvocationId"]: """ Filters a list of invocation ids by their status in an optimized way. :param invocation_ids: The invocation ids to filter :param status_filter: The statuses to filter by. :return: List of invocation ids matching the status filter """ if not invocation_ids or not status_filter: return [] with self.client.pipeline(transaction=False) as pipe: for inv_id in invocation_ids: pipe.get(self.key.invocation_to_status(inv_id)) status_results = pipe.execute() filtered = [] for i, inv_id in enumerate(invocation_ids): status_json = status_results[i] if not status_json: continue status_dict = json.loads(status_json.decode()) status_record = InvocationStatusRecord.from_json(status_dict) if status_record.status in status_filter: filtered.append(inv_id) return filtered
[docs] def register_runner_heartbeats( self, runner_ids: list[str], can_run_atomic_service: bool = False ) -> None: """ Register or update runners' heartbeat timestamp and atomic service eligibility. :param runner_ids: List of runner IDs :param can_run_atomic_service: Whether runners are eligible for atomic service """ current_time = time() pipeline = self.client.pipeline(transaction=True) for runner_id in runner_ids: runner_key = self.key.runner_heartbeat(runner_id) # Add to sorted set only if doesn't exist (NX flag preserves creation order) pipeline.zadd( self.key.runner_heartbeats(), {runner_id: current_time}, nx=True ) # Always update the hash - use HSET with multiple fields # If key doesn't exist, this creates it with all fields # If key exists, this only updates the fields we specify pipeline.hsetnx(runner_key, "creation_timestamp", current_time) pipeline.hset( runner_key, mapping={ "last_heartbeat": current_time, "can_run_atomic_service": int(can_run_atomic_service), }, ) pipeline.execute()
[docs] def record_atomic_service_execution( self, runner_id: str, start_time: datetime, end_time: datetime ) -> None: """ Record the latest atomic service execution window for a runner. Replaces any previous execution record for this runner with the current one. Used for diagnostics and detecting potential collisions. :param str runner_id: The runner that executed the service :param datetime start_time: When execution started (UTC timezone-aware) :param datetime end_time: When execution ended (UTC timezone-aware) """ runner_key = self.key.runner_heartbeat(runner_id) self.client.hset( runner_key, mapping={ "last_service_start": start_time.timestamp(), "last_service_end": end_time.timestamp(), }, )
[docs] def _get_runner_heartbeat_data(self) -> list[tuple[str, dict[bytes, bytes]]]: """ Fetch all runner IDs with their heartbeat hash data. :return: List of (runner_id, hash_data) tuples for all registered runners. """ all_runner_ids = self.client.zrange(self.key.runner_heartbeats(), 0, -1) if not all_runner_ids: return [] pipeline = self.client.pipeline(transaction=False) for runner_id_bytes in all_runner_ids: runner_id = runner_id_bytes.decode() pipeline.hgetall(self.key.runner_heartbeat(runner_id)) return [ (runner_id_bytes.decode(), data) for runner_id_bytes, data in zip( all_runner_ids, pipeline.execute(), strict=True ) ]
[docs] def _is_runner_active(self, runner_data: dict[bytes, bytes], cutoff: float) -> bool: """Check if a runner is active based on its last heartbeat.""" if not runner_data or b"last_heartbeat" not in runner_data: return False try: return float(runner_data[b"last_heartbeat"].decode()) >= cutoff except (ValueError, TypeError): return False
[docs] def _get_active_runners( self, timeout_seconds: float, can_run_atomic_service: bool | None ) -> list["ActiveRunnerInfo"]: """ Retrieve runners that are considered active based on heartbeat activity. A runner is considered "active" if it has sent a heartbeat within the timeout period. This is used for atomic service scheduling to determine which runners are eligible to participate in time slot distribution. :param float timeout_seconds: Heartbeat timeout in seconds (typically from atomic_service_runner_considered_dead_after_minutes config) :param bool | None can_run_atomic_service: If specified, filters runners based on their eligibility to run atomic services :return: List of active runners ordered by creation time (oldest first) :rtype: list["ActiveRunnerInfo"] """ cutoff_time = time() - timeout_seconds active_runners = [] for runner_id, runner_data in self._get_runner_heartbeat_data(): if not self._is_runner_active(runner_data, cutoff_time): continue can_run_atomic = bool(int(runner_data[b"can_run_atomic_service"].decode())) if ( can_run_atomic_service is not None and can_run_atomic != can_run_atomic_service ): continue try: last_heartbeat = float(runner_data[b"last_heartbeat"].decode()) # Optional service execution timestamps last_service_start = ( datetime.fromtimestamp( float(runner_data[b"last_service_start"].decode()), tz=UTC ) if b"last_service_start" in runner_data else None ) last_service_end = ( datetime.fromtimestamp( float(runner_data[b"last_service_end"].decode()), tz=UTC ) if b"last_service_end" in runner_data else None ) active_runners.append( ActiveRunnerInfo( runner_id=runner_id, creation_time=datetime.fromtimestamp( float(runner_data[b"creation_timestamp"].decode()), tz=UTC ), allow_to_run_atomic_service=can_run_atomic, last_heartbeat=datetime.fromtimestamp(last_heartbeat, tz=UTC), last_service_start=last_service_start, last_service_end=last_service_end, ) ) except (ValueError, KeyError): continue return active_runners
[docs] def get_pending_invocations_for_recovery(self) -> Iterator["InvocationId"]: """Retrieve invocation IDs stuck in PENDING status beyond the allowed time.""" max_pending_seconds = self.app.conf.max_pending_seconds current_time = datetime.now(UTC) cutoff_timestamp = current_time.timestamp() - max_pending_seconds # Get all PENDING invocations pending_invocations = self.client.smembers( self.key.status_to_invocations(InvocationStatus.PENDING) ) for invocation_id_bytes in pending_invocations: invocation_id = invocation_id_bytes.decode() try: status_record = self.get_invocation_status_record(invocation_id) if status_record.timestamp.timestamp() <= cutoff_timestamp: yield invocation_id except KeyError: # Invocation no longer exists continue
[docs] def _get_running_invocations_for_recovery( self, timeout_seconds: float ) -> Iterator["InvocationId"]: """ Retrieve invocation IDs in RUNNING status owned by inactive runners. An inactive runner is one that hasn't sent a heartbeat within the configured timeout period. Invocations owned by such runners are considered stuck and need recovery. :param float timeout_seconds: Heartbeat timeout in seconds :return: Iterator of invocation IDs that need recovery. :rtype: Iterator[str] """ cutoff_time = time() - timeout_seconds # Build set of active runner IDs active_runner_ids = { runner_id for runner_id, data in self._get_runner_heartbeat_data() if self._is_runner_active(data, cutoff_time) } # Check RUNNING invocations for orphaned owners for inv_id_bytes in self.client.smembers( self.key.status_to_invocations(InvocationStatus.RUNNING) ): invocation_id = InvocationId(inv_id_bytes.decode()) try: status_record = self.get_invocation_status_record(invocation_id) if ( status_record.runner_id and status_record.runner_id not in active_runner_ids ): yield invocation_id except KeyError: continue
[docs] def purge(self) -> None: """Remove all invocations from the orchestrator""" self.key.purge(self.client) self.blocking_control.purge()