import asyncio
import json
import time
from collections.abc import Callable
from expiringdict import ExpiringDict
from pynostr.encrypted_dm import EncryptedDirectMessage
from pynostr.event import Event
from pynostr.filters import Filters
from pynostr.utils import get_public_key
from agentstr.logger import get_logger
from agentstr.relay import DecryptedMessage, EventRelay
logger = get_logger(__name__)
[docs]
class RelayManager:
"""Manages connections to multiple Nostr relays and handles message passing.
Args:
relays: List of relay URLs to connect to.
private_key: Optional private key for signing events.
"""
[docs]
def __init__(self, relays: list[str], private_key: str | None = None):
logger.debug(f"Initializing RelayManager with {len(relays)} relays")
self._relays = relays
self.private_key = private_key
self.public_key = self.private_key.public_key if self.private_key else None
@property
def relays(self) -> list[EventRelay]:
"""Get a list of connected EventRelay instances.
Returns:
A list of EventRelay instances, one for each relay URL.
"""
return [EventRelay(relay, self.private_key, self.public_key) for relay in self._relays]
[docs]
async def get_events(self, filters: Filters, limit: int = 10, timeout: int = 30, close_on_eose: bool = True) -> list[Event]:
"""Fetch events matching the given filters from connected relays.
Args:
filters: The filters to apply when fetching events.
limit: Maximum number of events to return. Defaults to 10.
timeout: Maximum time to wait for events in seconds. Defaults to 30.
close_on_eose: Whether to close the subscription after EOSE. Defaults to True.
Returns:
A list of up to `limit` unique events that match the filters.
Note:
Stops early if enough events are found before the timeout.
"""
limit = filters.limit if filters.limit else limit
event_id_map = {}
result = None
t0 = time.time()
tasks = []
for relay in self.relays:
tasks.append(asyncio.create_task(relay.get_events(filters, limit, timeout, close_on_eose)))
for done in asyncio.as_completed(tasks):
result = await done
if result and len(result) >= limit:
break
for event in result:
if event.id in event_id_map:
continue
event_id_map[event.id] = event
if len(event_id_map) >= limit:
result = list(event_id_map.values())
break
if timeout < time.time() - t0:
break
if not result:
result = list(event_id_map.values())
return result
[docs]
async def get_event(self, filters: Filters, timeout: int = 30, close_on_eose: bool = True) -> Event:
"""Get a single event matching the filters or None if not found."""
result = await self.get_events(filters, limit=1, timeout=timeout, close_on_eose=close_on_eose)
if result and len(result) > 0:
return result[0]
return None
[docs]
async def send_event(self, event: Event) -> Event:
"""Send an event to all connected relays."""
tasks = []
event.created_at = int(time.time())
event.compute_id()
event.sign(self.private_key.hex())
for relay in self.relays:
tasks.append(asyncio.create_task(relay.send_event(event)))
await asyncio.gather(*tasks)
[docs]
def encrypt_message(self, message: str | dict, recipient_pubkey: str, event_ref: str | None = None) -> Event:
"""Encrypt a message for the recipient and prepare it as a Nostr event."""
recipient = get_public_key(recipient_pubkey)
dm = EncryptedDirectMessage(reference_event_id=event_ref)
if isinstance(message, dict):
message = json.dumps(message)
dm.encrypt(self.private_key.hex(), cleartext_content=message, recipient_pubkey=recipient.hex())
event = dm.to_event()
event.created_at = int(time.time())
event.compute_id()
event.sign(self.private_key.hex())
return event
[docs]
async def send_message(self, message: str | dict, recipient_pubkey: str, event_ref: str | None = None) -> Event:
"""Send an encrypted message to a recipient through all connected relays."""
logger.info(f"Sending message to {recipient_pubkey[:10]}: {message}")
try:
event = self.encrypt_message(message, recipient_pubkey, event_ref)
logger.debug(f"Encrypted message event: {event.id}")
tasks = []
for relay in self.relays:
logger.debug(f"Queueing message for relay: {relay.relay}")
tasks.append(asyncio.create_task(relay.send_event(event)))
logger.debug(f"Dispatching message to {len(tasks)} relays")
await asyncio.gather(*tasks)
logger.info(f"Successfully sent message to {recipient_pubkey[:10]} with event id: {event.id[:10]}")
return event
except Exception as e:
logger.error(f"Failed to send message to {recipient_pubkey[:10]}: {e!s}", exc_info=True)
raise
[docs]
async def receive_message(self, author_pubkey: str, timestamp: int | None = None, timeout: int = 30) -> DecryptedMessage | None:
"""Wait for and return the next message from the specified author."""
logger.info(f"Waiting for message from {author_pubkey[:10]}...")
logger.debug(f"Timeout: {timeout}s, Timestamp: {timestamp}")
t0 = time.time()
tasks = []
try:
# Start receive tasks for all relays
for relay in self.relays:
logger.debug(f"Starting receive task for relay: {relay.relay}")
task = asyncio.create_task(relay.receive_message(author_pubkey, timestamp, timeout))
tasks.append(task)
# Wait for the first successful response
for task in asyncio.as_completed(tasks):
try:
result = await task
if result:
logger.info(f"Received message from {author_pubkey[:10]} with id {result.event.id[:10]}: {result.message}")
return result
# Check timeout
if time.time() - t0 > timeout:
logger.warning(f"Receive operation timed out after {timeout} seconds")
break
except Exception as e:
logger.warning(f"Error in receive task: {e!s}")
continue
logger.warning("No messages received before timeout")
return None
except Exception as e:
logger.error(f"Error in receive_message: {e!s}", exc_info=True)
raise
[docs]
async def send_receive_message(self, message: str | dict, recipient_pubkey: str, timeout: int = 3, event_ref: str | None = None) -> DecryptedMessage | None:
"""Send a message and wait for a response from the recipient.
Returns the first response received within the timeout period.
"""
dm_event = await self.send_message(message, recipient_pubkey, event_ref)
timestamp = dm_event.created_at
logger.debug(f"Sent receive DM event: {dm_event.to_dict()}")
return await self.receive_message(recipient_pubkey, timestamp, timeout)
[docs]
async def event_listener(self, filters: Filters, callback: Callable[[Event], None]):
"""Start listening for events matching the given filters.
The callback will be called for each matching event.
"""
event_cache = ExpiringDict(max_len=1000, max_age_seconds=300)
lock = asyncio.Lock()
tasks = []
for relay in self.relays:
tasks.append(asyncio.create_task(relay.event_listener(filters, callback, event_cache, lock)))
await asyncio.gather(*tasks)
[docs]
async def direct_message_listener(self, filters: Filters, callback: Callable[[Event, str], None]):
"""Start listening for direct messages.
The callback will be called with each received message and its decrypted content.
"""
event_cache = ExpiringDict(max_len=1000, max_age_seconds=300)
lock = asyncio.Lock()
tasks = []
for relay in self.relays:
tasks.append(asyncio.create_task(relay.direct_message_listener(filters, callback, event_cache, lock)))
await asyncio.gather(*tasks)
[docs]
async def get_following(self, pubkey: str | None = None) -> list[str]:
"""Get the list of public keys that the specified user follows."""
pubkey = get_public_key(pubkey).hex() if pubkey else self.public_key.hex()
filters = Filters(authors=[pubkey], kinds=[3], limit=1)
event = await self.get_event(filters)
if event:
return [tag[1] for tag in event.tags if tag[0] == "p"]
return []