Skip to content
Merged
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ click >= 8.1.7
colorama >= 0.4.6, < 0.4.7
cryptography >= 48.0.1
fastapi[all] >= 0.94.0
filelock >= 3.19.1
paramiko >= 3.3.1
prettytable >= 3.16.0
psutil >= 5.9.0
Expand Down
12 changes: 9 additions & 3 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .rp_handler import is_generator
from .rp_job import run_job, run_job_generator
from .rp_ping import Heartbeat
from .worker_state import JobsProgress
from .worker_state import JobsProgress, PingJobMirror

RUNPOD_ENDPOINT_ID = os.environ.get("RUNPOD_ENDPOINT_ID", None)

Expand Down Expand Up @@ -184,8 +184,14 @@ def __init__(self, config: Dict[str, Any]):
2. Initializes the FastAPI web server.
3. Sets the handler for processing jobs.
"""
# Start the heartbeat thread.
heartbeat.start_ping()
# One per-worker mirror so the separate ping process reports the
# in-progress job ids tracked here. Attaching to job_list means every
# add/remove syncs it automatically.
mirror = PingJobMirror()
Comment thread
deanq marked this conversation as resolved.
job_list.set_mirror(mirror)

# Start the heartbeat process.
heartbeat.start_ping(mirror)

self.config = config

Expand Down
15 changes: 9 additions & 6 deletions runpod/serverless/modules/rp_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from runpod.http_client import SyncClientSession
from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.worker_state import WORKER_ID, JobsProgress
from runpod.serverless.modules.worker_state import WORKER_ID
from runpod.version import __version__ as runpod_version

log = RunPodLogger()
Expand All @@ -30,6 +30,9 @@ def __init__(self, pool_connections=10, retries=3) -> None:
self.PING_URL = self.PING_URL.replace("$RUNPOD_POD_ID", WORKER_ID)
self.PING_INTERVAL = int(os.environ.get("RUNPOD_PING_INTERVAL", 10000)) // 1000

# In-progress job-id snapshot, injected by the main process at start.
self._mirror = None

# Create a new HTTP session
self._session = SyncClientSession()
self._session.headers.update(
Expand All @@ -52,15 +55,16 @@ def __init__(self, pool_connections=10, retries=3) -> None:
self._session.mount("https://", adapter)

@staticmethod
def process_loop(test=False):
def process_loop(mirror=None, test=False):
"""
Static helper to run the ping loop in a separate process.
Creates a new Heartbeat instance to avoid pickling issues.
"""
hb = Heartbeat()
hb._mirror = mirror
hb.ping_loop(test)

def start_ping(self, test=False):
def start_ping(self, mirror=None, test=False):
"""
Sends heartbeat pings to the Runpod server in a separate process.
"""
Comment thread
deanq marked this conversation as resolved.
Expand All @@ -77,7 +81,7 @@ def start_ping(self, test=False):
return

if not Heartbeat._process_started:
process = Process(target=Heartbeat.process_loop, args=(test,))
process = Process(target=Heartbeat.process_loop, args=(mirror, test))
process.daemon = True
process.start()
Heartbeat._process_started = True
Expand All @@ -96,8 +100,7 @@ def _send_ping(self):
"""
Sends a heartbeat to the Runpod server.
"""
jobs = JobsProgress() # Get the singleton instance
job_ids = jobs.get_job_list()
job_ids = self._mirror.get() if self._mirror is not None else None
Comment thread
deanq marked this conversation as resolved.
ping_params = {"job_id": job_ids, "runpod_version": runpod_version}

try:
Expand Down
147 changes: 78 additions & 69 deletions runpod/serverless/modules/worker_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
Handles getting stuff from environment variables and updating the global state like job id.
"""

import multiprocessing
import os
import time
import uuid
import pickle
import tempfile
from typing import Any, Dict, Optional, Set

from filelock import FileLock

from .rp_logger import RunPodLogger


Expand All @@ -20,6 +17,8 @@

WORKER_ID = os.environ.get("RUNPOD_POD_ID", str(uuid.uuid4()))

PING_MIRROR_CAPACITY = 65536 # bytes; ample headroom for a job-id snapshot


# ----------------------------------- Flags ---------------------------------- #
IS_LOCAL_TEST = os.environ.get("RUNPOD_WEBHOOK_GET_JOB", None) is None
Expand Down Expand Up @@ -63,87 +62,50 @@ def __str__(self) -> str:


# ---------------------------------------------------------------------------- #
# Tracker #
# Tracker #
# ---------------------------------------------------------------------------- #
class JobsProgress(Set[Job]):
"""Track the state of current jobs in progress with persistent state."""
"""Track the state of current jobs in progress (in-memory, per process)."""

_instance = None
_STATE_DIR = os.getcwd()
_STATE_FILE = os.path.join(_STATE_DIR, ".runpod_jobs.pkl")

def __new__(cls):
if JobsProgress._instance is None:
os.makedirs(cls._STATE_DIR, exist_ok=True)
JobsProgress._instance = set.__new__(cls)
# Initialize as empty set before loading state
set.__init__(JobsProgress._instance)
JobsProgress._instance._load_state()
# One-way snapshot to the ping process; attached in the main
# process via set_mirror(). Stays None off-Runpod and in tests.
JobsProgress._instance._mirror = None
return JobsProgress._instance

def __init__(self):
# This should never clear data in a singleton
# Don't call parent __init__ as it would clear the set
# Singleton: never re-initialize, it would clear the set.
pass

def __repr__(self) -> str:
return f"<{self.__class__.__name__}>: {self.get_job_list()}"

def _load_state(self):
"""Load jobs state from pickle file with file locking."""
try:
if (
os.path.exists(self._STATE_FILE)
and os.path.getsize(self._STATE_FILE) > 0
):
with FileLock(self._STATE_FILE + '.lock'):
with open(self._STATE_FILE, "rb") as f:
try:
loaded_jobs = pickle.load(f)
# Clear current state and add loaded jobs
super().clear()
for job in loaded_jobs:
set.add(
self, job
) # Use set.add to avoid triggering _save_state

except (EOFError, pickle.UnpicklingError):
# Handle empty or corrupted file
log.debug(
"JobsProgress: Failed to load state file, starting with empty state"
)
pass

except FileNotFoundError:
log.debug("JobsProgress: No state file found, starting with empty state")
pass

def _save_state(self):
"""Save jobs state to pickle file with atomic write and file locking."""
try:
# Use temporary file for atomic write
with FileLock(self._STATE_FILE + '.lock'):
with tempfile.NamedTemporaryFile(
dir=self._STATE_DIR, delete=False, mode="wb"
) as temp_f:
pickle.dump(set(self), temp_f)

# Atomically replace the state file
os.replace(temp_f.name, self._STATE_FILE)
except Exception as e:
log.error(f"Failed to save job state: {e}")
def set_mirror(self, mirror) -> None:
"""Attach a PingJobMirror that mirrors the in-progress job ids to the
ping process. Every add/remove/clear then pushes the snapshot to it."""
self._mirror = mirror
self._notify_mirror()

def _notify_mirror(self) -> None:
"""Push the current job-id snapshot to the attached mirror, if any."""
if self._mirror is not None:
self._mirror.set(self.get_job_list())

def clear(self) -> None:
super().clear()
self._save_state()
self._notify_mirror()

def add(self, element: Any):
"""
Adds a Job object to the set.

If the added element is a string, then `Job(id=element)` is added

If the added element is a dict, that `Job(**element)` is added
If the added element is a string, then `Job(id=element)` is added.
If the added element is a dict, then `Job(**element)` is added.
"""
if isinstance(element, str):
element = Job(id=element)
Expand All @@ -155,16 +117,15 @@ def add(self, element: Any):
raise TypeError("Only Job objects can be added to JobsProgress.")

result = super().add(element)
self._save_state()
self._notify_mirror()
return result

def remove(self, element: Any):
"""
Removes a Job object from the set.

If the element is a string, then `Job(id=element)` is removed

If the element is a dict, then `Job(**element)` is removed
If the element is a string, then `Job(id=element)` is removed.
If the element is a dict, then `Job(**element)` is removed.
"""
if isinstance(element, str):
element = Job(id=element)
Expand All @@ -176,7 +137,7 @@ def remove(self, element: Any):
raise TypeError("Only Job objects can be removed from JobsProgress.")

result = super().discard(element)
self._save_state()
self._notify_mirror()
return result

def get(self, element: Any) -> Optional[Job]:
Expand All @@ -193,10 +154,8 @@ def get(self, element: Any) -> Optional[Job]:

def get_job_list(self) -> Optional[str]:
"""
Returns the list of job IDs as comma-separated string.
Returns the list of job IDs as a comma-separated string, or None if empty.
"""
self._load_state()

if not len(self):
return None

Expand All @@ -207,3 +166,53 @@ def get_job_count(self) -> int:
Returns the number of jobs.
"""
return len(self)


# ---------------------------------------------------------------------------- #
# Ping Job Mirror #
# ---------------------------------------------------------------------------- #
class PingJobMirror:
"""
One-way snapshot of in-progress job ids from the worker (main) process to
the separate ping process.

Backed by a fixed-size shared-memory buffer created in the main process and
passed to the ping process via ``Process(args=...)``. It lives only in this
worker's own process tree, so it cannot be shared across workers and never
touches the filesystem. All operations are best-effort and never raise into
the caller (a failure here must not break job processing or kill the ping).
"""

def __init__(self, capacity: int = PING_MIRROR_CAPACITY, ctx=None):
ctx = ctx or multiprocessing
self._capacity = capacity
self._buffer = ctx.Array("c", capacity) # SynchronizedString with .get_lock()

def set(self, job_ids: Optional[str]) -> None:
"""Write the current job-id snapshot. Best-effort; never raises."""
try:
data = (job_ids or "").encode("utf-8")
limit = self._capacity - 1 # reserve a byte for the NUL terminator
if len(data) > limit:
data = data[:limit]
cut = data.rfind(b",")
if cut != -1:
data = data[:cut]
log.warn(
f"PingJobMirror: job-id snapshot exceeded {limit} bytes; truncated"
)
with self._buffer.get_lock():
self._buffer.value = data
except Exception as err: # never break job processing
log.error(f"PingJobMirror.set failed: {err}")

def get(self) -> Optional[str]:
"""Read the current job-id snapshot. Best-effort; never raises."""
try:
with self._buffer.get_lock():
data = self._buffer.value
text = data.decode("utf-8")
return text or None
except Exception as err: # never kill the ping loop
log.debug(f"PingJobMirror.get failed: {err}")
return None
8 changes: 7 additions & 1 deletion runpod/serverless/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ def run_worker(config: Dict[str, Any]) -> None:
# Run fitness checks before starting worker (production only)
asyncio.run(run_fitness_checks())

# One per-worker mirror: the job tracker writes it, the ping process reads
# it. Attaching to JobsProgress means every add/remove syncs automatically.
from runpod.serverless.modules.worker_state import JobsProgress, PingJobMirror
mirror = PingJobMirror()
JobsProgress().set_mirror(mirror)

# Start pinging Runpod to show that the worker is alive.
heartbeat.start_ping()
heartbeat.start_ping(mirror)

# Create a JobScaler responsible for adjusting the concurrency
job_scaler = rp_scale.JobScaler(config)
Expand Down
1 change: 0 additions & 1 deletion tests/test_serverless/local_sim/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ worker:
python worker.py

clean:
find . -type f -name ".runpod_jobs.pkl" -delete
find . -type f -name "*.pyc" -delete
find . -type d -name "__pycache__" -delete
Loading