添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
import json
import logging
import os
import time
from configparser import ConfigParser
from datetime import datetime , timezone
from enum import Enum
from functools import partial
from logging . config import fileConfig
# A nosec comment is appended to the following line in order to disable the B404 check.
# In this file the input of the module subprocess is trusted.
from subprocess import CalledProcessError # nosec B404
from typing import Dict , List
from botocore . config import Config
from common . schedulers . slurm_commands import (
PartitionNodelistMapping ,
get_nodes_info ,
get_partitions_info ,
reset_nodes ,
set_nodes_down ,
set_nodes_drain ,
set_nodes_power_down ,
update_all_partitions ,
update_partitions ,
)
from common . time_utils import seconds
from common . utils import check_command_output , read_json , sleep_remaining_loop_time , time_is_up , wait_remaining_time
from retrying import retry
from slurm_plugin . capacity_block_manager import CapacityBlockManager
from slurm_plugin . cluster_event_publisher import ClusterEventPublisher
from slurm_plugin . common import TIMESTAMP_FORMAT , ScalingStrategy , log_exception , print_with_count
from slurm_plugin . console_logger import ConsoleLogger
from slurm_plugin . instance_manager import InstanceManager
from slurm_plugin . slurm_resources import (
CONFIG_FILE_DIR ,
ComputeResourceFailureEvent ,
DynamicNode ,
EC2InstanceHealthState ,
PartitionStatus ,
SlurmNode ,
SlurmPartition ,
StaticNode ,
)
from slurm_plugin . task_executor import TaskExecutor
LOOP_TIME = 60
CONSOLE_OUTPUT_WAIT_TIME = 5 * 60
MAXIMUM_TASK_BACKLOG = 100
log = logging . getLogger ( __name__ )
compute_logger = log . getChild ( "console_output" )
event_logger = log . getChild ( "events" )
class ComputeFleetStatus ( Enum ):
"""Represents the status of the cluster compute fleet."""
STOPPED = "STOPPED" # Fleet is stopped, partitions are inactive.
RUNNING = "RUNNING" # Fleet is running, partitions are active.
STOPPING = "STOPPING" # clustermgtd is handling the stop request.
STARTING = "STARTING" # clustermgtd is handling the start request.
STOP_REQUESTED = "STOP_REQUESTED" # A request to stop the fleet has been submitted.
START_REQUESTED = "START_REQUESTED" # A request to start the fleet has been submitted.
# PROTECTED indicates that some partitions have consistent bootstrap failures. Affected partitions are inactive.
PROTECTED = "PROTECTED"
def __str__ ( self ):
return str ( self . value )
@ staticmethod
def is_start_requested ( status ):
return status == ComputeFleetStatus . START_REQUESTED
@ staticmethod
def is_stop_requested ( status ):
return status == ComputeFleetStatus . STOP_REQUESTED
@ staticmethod
def is_protected ( status ):
return status == ComputeFleetStatus . PROTECTED
class ComputeFleetStatusManager :
COMPUTE_FLEET_STATUS_ATTRIBUTE = "status"
COMPUTE_FLEET_LAST_UPDATED_TIME_ATTRIBUTE = "lastStatusUpdatedTime"
@ staticmethod
def get_status ( fallback = None ):
try :
compute_fleet_raw_data = check_command_output ( "get-compute-fleet-status.sh" )
log . debug ( "Retrieved compute fleet data: %s" , compute_fleet_raw_data )
return ComputeFleetStatus (
json . loads ( compute_fleet_raw_data ). get ( ComputeFleetStatusManager . COMPUTE_FLEET_STATUS_ATTRIBUTE )
)
except Exception as e :
if isinstance ( e , CalledProcessError ):
error = e . stdout . rstrip ()
else :
error = e
log . error (
"Failed when retrieving fleet status with error: %s, using fallback value %s" ,
error ,
fallback ,
)
return fallback
@ staticmethod
def update_status ( status ):
try :
check_command_output ( f"update-compute-fleet-status.sh --status { status } " )
except Exception as e :
log . error ( "Failed when updating fleet status to status %s, with error: %s" , status , e . stdout . rstrip ())
raise
class ClustermgtdConfig :
DEFAULTS = {
# Basic configs
"max_retry" : 1 ,
"loop_time" : LOOP_TIME ,
"proxy" : "NONE" ,
"logging_config" : os . path . join (
os . path . dirname ( __file__ ), "logging" , "parallelcluster_clustermgtd_logging.conf"
),
# Launch configs
"launch_max_batch_size" : 500 ,
"assign_node_max_batch_size" : 500 ,
"update_node_address" : True ,
"run_instances_overrides" : "/opt/slurm/etc/pcluster/run_instances_overrides.json" ,
"create_fleet_overrides" : "/opt/slurm/etc/pcluster/create_fleet_overrides.json" ,
"fleet_config_file" : "/etc/parallelcluster/slurm_plugin/fleet-config.json" ,
# Terminate configs
"terminate_max_batch_size" : 1000 ,
# Timeout to wait for node initialization, should be the same as ResumeTimeout
"node_replacement_timeout" : 1800 ,
"terminate_drain_nodes" : True ,
"terminate_down_nodes" : True ,
"orphaned_instance_timeout" : 300 ,
"ec2_instance_missing_max_count" : 0 ,
# Health check configs
"disable_ec2_health_check" : False ,
"disable_scheduled_event_health_check" : False ,
"disable_all_cluster_management" : False ,
"health_check_timeout" : 180 ,
"health_check_timeout_after_slurmdstarttime" : 180 ,
"disable_capacity_blocks_management" : False ,
# DNS domain configs
"hosted_zone" : None ,
"dns_domain" : None ,
"use_private_hostname" : False ,
"protected_failure_count" : 10 ,
"insufficient_capacity_timeout" : 600 ,
# Compute console logging configs
"compute_console_logging_enabled" : True ,
"compute_console_logging_max_sample_size" : 1 ,
"compute_console_wait_time" : CONSOLE_OUTPUT_WAIT_TIME ,
# Task executor configs
"worker_pool_size" : 5 ,
"worker_pool_max_backlog" : MAXIMUM_TASK_BACKLOG ,
}
def __init__ ( self , config_file_path ):
self . _get_config ( config_file_path )
def __repr__ ( self ):
attrs = ", " . join ([ "{key}={value}" . format ( key = key , value = repr ( value )) for key , value in self . __dict__ . items ()])
return "{class_name}({attrs})" . format ( class_name = self . __class__ . __name__ , attrs = attrs )
def __eq__ ( self , other ):
if type ( other ) is type ( self ):
return (
self . _config == other . _config
and self . fleet_config == other . fleet_config
and self . run_instances_overrides == other . run_instances_overrides
and self . create_fleet_overrides == other . create_fleet_overrides
)
return False
def __ne__ ( self , other ):
return not self . __eq__ ( other )
def _get_basic_config ( self , config ):
"""Get basic config options."""
self . region = config . get ( "clustermgtd" , "region" )
self . cluster_name = config . get ( "clustermgtd" , "cluster_name" )
self . dynamodb_table = config . get ( "clustermgtd" , "dynamodb_table" )
self . head_node_private_ip = config . get ( "clustermgtd" , "head_node_private_ip" )
self . head_node_hostname = config . get ( "clustermgtd" , "head_node_hostname" )
self . head_node_instance_id = config . get ( "clustermgtd" , "instance_id" , fallback = "unknown" )
# Configure boto3 to retry 1 times by default
self . _boto3_retry = config . getint ( "clustermgtd" , "boto3_retry" , fallback = self . DEFAULTS . get ( "max_retry" ))
self . _boto3_config = { "retries" : { "max_attempts" : self . _boto3_retry , "mode" : "standard" }}
self . loop_time = config . getint ( "clustermgtd" , "loop_time" , fallback = self . DEFAULTS . get ( "loop_time" ))
self . disable_all_cluster_management = config . getboolean (
"clustermgtd" ,
"disable_all_cluster_management" ,
fallback = self . DEFAULTS . get ( "disable_all_cluster_management" ),
)
self . heartbeat_file_path = config . get ( "clustermgtd" , "heartbeat_file_path" )
proxy = config . get ( "clustermgtd" , "proxy" , fallback = self . DEFAULTS . get ( "proxy" ))
if proxy != "NONE" :
self . _boto3_config [ "proxies" ] = { "https" : proxy }
self . boto3_config = Config ( ** self . _boto3_config )
self . logging_config = config . get ( "clustermgtd" , "logging_config" , fallback = self . DEFAULTS . get ( "logging_config" ))
def _get_launch_config ( self , config ):
"""Get config options related to launching instances."""
self . launch_max_batch_size = config . getint (
"clustermgtd" , "launch_max_batch_size" , fallback = self . DEFAULTS . get ( "launch_max_batch_size" )
)
self . assign_node_max_batch_size = config . getint (
"clustermgtd" , "assign_node_max_batch_size" , fallback = self . DEFAULTS . get ( "assign_node_max_batch_size" )
)
self . update_node_address = config . getboolean (
"clustermgtd" , "update_node_address" , fallback = self . DEFAULTS . get ( "update_node_address" )
)
fleet_config_file = config . get (
"clustermgtd" , "fleet_config_file" , fallback = self . DEFAULTS . get ( "fleet_config_file" )
)
self . fleet_config = read_json ( fleet_config_file )
# run_instances_overrides_file and create_fleet_overrides_file contain a json with the following format:
# {
# "queue_name": {
# "compute_resource_name": {
# <arbitrary-json-with-boto3-api-params-to-override>
# },
# ...
# },
# ...
# }
run_instances_overrides_file = config . get (
"clustermgtd" , "run_instances_overrides" , fallback = self . DEFAULTS . get ( "run_instances_overrides" )
)
self . run_instances_overrides = read_json ( run_instances_overrides_file , default = {})
create_fleet_overrides_file = config . get (
"clustermgtd" , "create_fleet_overrides" , fallback = self . DEFAULTS . get ( "create_fleet_overrides" )
)
self . create_fleet_overrides = read_json ( create_fleet_overrides_file , default = {})
def _get_health_check_config ( self , config ):
self . disable_ec2_health_check = config . getboolean (
"clustermgtd" , "disable_ec2_health_check" , fallback = self . DEFAULTS . get ( "disable_ec2_health_check" )
)
self . disable_scheduled_event_health_check = config . getboolean (
"clustermgtd" ,
"disable_scheduled_event_health_check" ,
fallback = self . DEFAULTS . get ( "disable_scheduled_event_health_check" ),
)
self . health_check_timeout = config . getint (
"clustermgtd" , "health_check_timeout" , fallback = self . DEFAULTS . get ( "health_check_timeout" )
)
self . health_check_timeout_after_slurmdstarttime = config . getint (
"clustermgtd" ,
"health_check_timeout_after_slurmdstarttime" ,
fallback = self . DEFAULTS . get ( "health_check_timeout_after_slurmdstarttime" ),
)
self . disable_all_health_checks = config . getboolean (
"clustermgtd" ,
"disable_all_health_checks" ,
fallback = ( self . disable_ec2_health_check and self . disable_scheduled_event_health_check ),
)
self . disable_capacity_blocks_management = config . getboolean (
"clustermgtd" ,
"disable_capacity_blocks_management" ,
fallback = self . DEFAULTS . get ( "disable_capacity_block_management" ),
)
def _get_terminate_config ( self , config ):
"""Get config option related to instance termination and node replacement."""
self . terminate_max_batch_size = config . getint (
"clustermgtd" , "terminate_max_batch_size" , fallback = self . DEFAULTS . get ( "terminate_max_batch_size" )
)
self . node_replacement_timeout = config . getint (
"clustermgtd" , "node_replacement_timeout" , fallback = self . DEFAULTS . get ( "node_replacement_timeout" )
)
self . terminate_drain_nodes = config . getboolean (
"clustermgtd" , "terminate_drain_nodes" , fallback = self . DEFAULTS . get ( "terminate_drain_nodes" )
)
self . terminate_down_nodes = config . getboolean (
"clustermgtd" , "terminate_down_nodes" , fallback = self . DEFAULTS . get ( "terminate_down_nodes" )
)
self . orphaned_instance_timeout = config . getint (
"clustermgtd" , "orphaned_instance_timeout" , fallback = self . DEFAULTS . get ( "orphaned_instance_timeout" )
)
self . protected_failure_count = config . getint (
"clustermgtd" , "protected_failure_count" , fallback = self . DEFAULTS . get ( "protected_failure_count" )
)
self . insufficient_capacity_timeout = config . getfloat (
"clustermgtd" , "insufficient_capacity_timeout" , fallback = self . DEFAULTS . get ( "insufficient_capacity_timeout" )
)
self . ec2_instance_missing_max_count = config . getint (
"clustermgtd" ,
"ec2_instance_missing_max_count" ,
fallback = self . DEFAULTS . get ( "ec2_instance_missing_max_count" ),
)
self . disable_nodes_on_insufficient_capacity = self . insufficient_capacity_timeout > 0
def _get_dns_config ( self , config ):
"""Get config option related to Route53 DNS domain."""
self . hosted_zone = config . get ( "clustermgtd" , "hosted_zone" , fallback = self . DEFAULTS . get ( "hosted_zone" ))
self . dns_domain = config . get ( "clustermgtd" , "dns_domain" , fallback = self . DEFAULTS . get ( "dns_domain" ))
self . use_private_hostname = config . getboolean (
"clustermgtd" , "use_private_hostname" , fallback = self . DEFAULTS . get ( "use_private_hostname" )
)
def _get_compute_console_output_config ( self , config ):
"""Get config options related to logging console output from compute nodes."""
self . compute_console_logging_enabled = config . getboolean (
"clustermgtd" ,
"compute_console_logging_enabled" ,
fallback = self . DEFAULTS . get ( "compute_console_logging_enabled" ),
)
self . compute_console_logging_max_sample_size = config . getint (
"clustermgtd" ,
"compute_console_logging_max_sample_size" ,
fallback = self . DEFAULTS . get ( "compute_console_logging_max_sample_size" ),
)
self . compute_console_wait_time = config . getint (
"clustermgtd" , "compute_console_wait_time" , fallback = self . DEFAULTS . get ( "compute_console_wait_time" )
)
def _get_worker_pool_config ( self , config ):
self . worker_pool_size = config . getint (
"clustermgtd" , "worker_pool_size" , fallback = self . DEFAULTS . get ( "worker_pool_size" )
)
self . worker_pool_max_backlog = config . getint (
"clustermgtd" , "worker_pool_max_backlog" , fallback = self . DEFAULTS . get ( "worker_pool_max_backlog" )
)
@ log_exception ( log , "reading cluster manager configuration file" , catch_exception = IOError , raise_on_error = True )
def _get_config ( self , config_file_path ):
"""Get clustermgtd configuration."""
log . info ( "Reading %s" , config_file_path )
self . _config = ConfigParser ()
self . _config . read_file ( open ( config_file_path , "r" ))
# Get config settings
self . _get_basic_config ( self . _config )
self . _get_health_check_config ( self . _config )
self . _get_launch_config ( self . _config )
self . _get_terminate_config ( self . _config )
self . _get_dns_config ( self . _config )
self . _get_compute_console_output_config ( self . _config )
self . _get_worker_pool_config ( self . _config )
class ClusterManager :
"""Class for all cluster management related actions."""
_config : ClustermgtdConfig
class HealthCheckTypes ( Enum ):
"""Enum for health check types."""
scheduled_event = "scheduled_events_check"
ec2_health = "ec2_health_check"
def __str__ ( self ):
return self . value
class EC2InstancesInfoUnavailable ( Exception ):
"""Exception raised when unable to retrieve cluster instance info from EC2."""
def __init__ ( self , config ):
"""
Initialize ClusterManager.
self.static_nodes_in_replacement is persistent across multiple iteration of manage_cluster
This state is required because we need to ignore static nodes that might have long bootstrap time
"""
self . _insufficient_capacity_compute_resources = {}
self . _static_nodes_in_replacement = set ()
self . _partitions_protected_failure_count_map = {}
self . _nodes_without_backing_instance_count_map = {}
self . _compute_fleet_status = ComputeFleetStatus . RUNNING
self . _current_time = None
self . _config = None
self . _compute_fleet_status_manager = None
self . _instance_manager = None
self . _task_executor = None
self . _console_logger = None
self . _event_publisher = None
self . _partition_nodelist_mapping_instance = None
self . _capacity_block_manager = None
self . set_config ( config )
def set_config ( self , config : ClustermgtdConfig ):
if self . _config != config :
log . info ( "Applying new clustermgtd config: %s" , config )
# If a new task executor is needed, the old one will be shutdown.
# This should cause any pending tasks to be cancelled (Py 3.9) and
# any executing tasks to raise an exception as long as they test
# the executor for shutdown. If tracking of cancelled or failed
# tasks is needed, then the Future object returned from task_executor
# can be queried to determine how the task exited.
# The shutdown on the task_executor is by default, non-blocking, so
# it is possible for some tasks to continue executing even after the
# shutdown request has returned.
self . _task_executor = self . _initialize_executor ( config )
self . _config = config
self . _event_publisher = ClusterEventPublisher . create_with_default_publisher (
event_logger , config . cluster_name , "HeadNode" , "clustermgtd" , config . head_node_instance_id
)
self . _compute_fleet_status_manager = ComputeFleetStatusManager ()
self . _instance_manager = self . _initialize_instance_manager ( config )
self . _console_logger = self . _initialize_console_logger ( config )
self . _capacity_block_manager = self . _initialize_capacity_block_manager ( config )
def shutdown ( self ):
if self . _task_executor :
self . _task_executor . shutdown ()
self . _task_executor = None
@ staticmethod
def _initialize_instance_manager ( config ):
"""Initialize instance manager class that will be used to launch/terminate/describe instances."""
return InstanceManager (
config . region ,
config . cluster_name ,
config . boto3_config ,
table_name = config . dynamodb_table ,
hosted_zone = config . hosted_zone ,
dns_domain = config . dns_domain ,
use_private_hostname = config . use_private_hostname ,
head_node_private_ip = config . head_node_private_ip ,
head_node_hostname = config . head_node_hostname ,
run_instances_overrides = config . run_instances_overrides ,
create_fleet_overrides = config . create_fleet_overrides ,
fleet_config = config . fleet_config ,
)
def _initialize_executor ( self , config ):
if (
self . _config is None
or self . _config . worker_pool_size != config . worker_pool_size
or self . _config . worker_pool_max_backlog != config . worker_pool_max_backlog
):
if self . _task_executor :
self . _task_executor . shutdown ()
return TaskExecutor (
worker_pool_size = config . worker_pool_size ,
max_backlog = config . worker_pool_max_backlog ,
)
return self . _task_executor
@ staticmethod
def _initialize_console_logger ( config ):
return ConsoleLogger (
region = config . region ,
enabled = config . compute_console_logging_enabled ,
console_output_consumer = lambda name , instance_id , output : compute_logger . info (
"Console output for node %s (Instance Id %s): \r %s" , name , instance_id , output
),
)
@ staticmethod
def _initialize_capacity_block_manager ( config ):
return CapacityBlockManager (
region = config . region , fleet_config = config . fleet_config , boto3_config = config . boto3_config
)
def _update_compute_fleet_status ( self , status ):
log . info ( "Updating compute fleet status from %s to %s" , self . _compute_fleet_status , status )
self . _compute_fleet_status_manager . update_status ( status )
self . _compute_fleet_status = status
def _handle_successfully_launched_nodes ( self , partitions_name_map ):
"""
Handle nodes have been failed in bootstrap are launched successfully during this iteration.
Includes resetting partition bootstrap failure count for these nodes type and updating the
bootstrap_failure_nodes_types.
If a node has been failed successfully launched, the partition count of this node will be reset.
So the successfully launched nodes type will be removed from bootstrap_failure_nodes_types.
If there's node types failed in this partition later, it will keep count again.
"""
# Find nodes types which have been failed during bootstrap now become online.
partitions_protected_failure_count_map = self . _partitions_protected_failure_count_map . copy ()
for partition , failures_per_compute_resource in partitions_protected_failure_count_map . items ():
partition_online_compute_resources = partitions_name_map [ partition ]. get_online_node_by_type (
self . _config . terminate_drain_nodes ,
self . _config . terminate_down_nodes ,
self . _config . ec2_instance_missing_max_count ,
self . _nodes_without_backing_instance_count_map ,
)
for compute_resource in failures_per_compute_resource . keys ():
if compute_resource in partition_online_compute_resources :
self . _reset_partition_failure_count ( partition )
break
def manage_cluster ( self ):
"""Manage cluster by syncing scheduler states with EC2 states and performing node maintenance actions."""
# Initialization
log . info ( "Managing cluster..." )
self . _current_time = datetime . now ( tz = timezone . utc )
self . _compute_fleet_status = self . _compute_fleet_status_manager . get_status ( fallback = self . _compute_fleet_status )
log . info ( "Current compute fleet status: %s" , self . _compute_fleet_status )
if not self . _config . disable_all_cluster_management :
if self . _compute_fleet_status in {
None ,
ComputeFleetStatus . RUNNING ,
ComputeFleetStatus . PROTECTED ,
}:
# Get partition_nodelist_mapping between PC-managed Slurm partitions and PC-managed Slurm nodelists
# Initialize PartitionNodelistMapping singleton
self . _partition_nodelist_mapping_instance = PartitionNodelistMapping . instance ()
# Get node states for nodes in inactive and active partitions
# Initialize nodes
try :
log . info ( "Retrieving nodes info from the scheduler" )
nodes = self . _get_node_info_with_retry ()
log . debug ( "Nodes: %s" , nodes )
partitions_name_map , compute_resource_nodes_map = self . _parse_scheduler_nodes_data ( nodes )
except Exception as e :
log . error (
"Unable to get partition/node info from slurm, no other action can be performed. Sleeping... "
"Exception: %s" ,
e ,
)
return
# Get all non-terminating instances in EC2
try :
cluster_instances = self . _get_ec2_instances ()
except ClusterManager . EC2InstancesInfoUnavailable :
log . error ( "Unable to get instances info from EC2, no other action can be performed. Sleeping..." )
return
log . debug ( "Current cluster instances in EC2: %s" , cluster_instances )
partitions = list ( partitions_name_map . values ())
self . _update_slurm_nodes_with_ec2_info ( nodes , cluster_instances )
self . _event_publisher . publish_compute_node_events ( nodes , cluster_instances )
# Handle inactive partition and terminate backing instances
self . _clean_up_inactive_partition ( partitions )
# Perform health check actions
if not self . _config . disable_all_health_checks :
self . _perform_health_check_actions ( partitions )
# Maintain slurm nodes
self . _maintain_nodes ( partitions_name_map , compute_resource_nodes_map )
# Clean up orphaned instances
self . _terminate_orphaned_instances ( cluster_instances )
elif self . _compute_fleet_status in {
ComputeFleetStatus . STOPPED ,
}:
# Since Slurm partition status might have been manually modified, when STOPPED we want to keep checking
# partitions and EC2 instances to take into account changes that can be manually
# applied by the user by re-activating Slurm partitions.
# When partition are INACTIVE, always try to reset nodeaddr/nodehostname to avoid issue.
self . _maintain_nodes_down ()
# Write clustermgtd heartbeat to file
self . _write_timestamp_to_file ()
def _write_timestamp_to_file ( self ):
"""Write timestamp into shared file so compute nodes can determine if head node is online."""
# Make clustermgtd heartbeat readable to all users
with open ( os . open ( self . _config . heartbeat_file_path , os . O_WRONLY | os . O_CREAT , 0o644 ), "w" ) as timestamp_file :
# Note: heartbeat must be written with datetime.strftime to convert localized datetime into str
# datetime.strptime will not work with str(datetime)
timestamp_file . write ( datetime . now ( tz = timezone . utc ). strftime ( TIMESTAMP_FORMAT ))
@ staticmethod
@ retry ( stop_max_attempt_number = 2 , wait_fixed = 1000 )
def _get_node_info_with_retry ( nodes = "" ):
return get_nodes_info ( nodes )
@ staticmethod
@ retry ( stop_max_attempt_number = 2 , wait_fixed = 1000 )
def _get_partitions_info_with_retry () -> Dict [ str , SlurmPartition ]:
return { part . name : part for part in get_partitions_info ()}
def _clean_up_inactive_partition ( self , partitions ):
"""Terminate all other instances associated with nodes in INACTIVE partition directly through EC2."""
inactive_instance_ids , inactive_nodes = ClusterManager . _get_inactive_instances_and_nodes ( partitions )
if inactive_nodes :
try :
log . info ( "Cleaning up INACTIVE partitions." )
if inactive_instance_ids :
log . info (
"Clean up instances associated with nodes in INACTIVE partitions: %s" ,
print_with_count ( inactive_instance_ids ),
)
self . _instance_manager . delete_instances (
inactive_instance_ids , terminate_batch_size = self . _config . terminate_max_batch_size
)
self . _reset_nodes_in_inactive_partitions ( list ( inactive_nodes ))
except Exception as e :
log . error ( "Failed to clean up INACTIVE nodes %s with exception %s" , print_with_count ( inactive_nodes ), e )
@ staticmethod
def _reset_nodes_in_inactive_partitions ( inactive_nodes ):
# Try to reset nodeaddr if possible to avoid potential problems
nodes_to_reset = set ()
for node in inactive_nodes :
if node . needs_reset_when_inactive ():
nodes_to_reset . add ( node . name )
if nodes_to_reset :
# Setting to down and not power_down cause while inactive power_down doesn't seem to be applied
log . info (
"Resetting nodeaddr/nodehostname and setting to down the following nodes: %s" ,
print_with_count ( nodes_to_reset ),
)
try :
reset_nodes (
nodes_to_reset ,
raise_on_error = False ,
state = "down" ,
reason = "inactive partition" ,
)
except Exception as e :
log . error (
"Encountered exception when resetting nodeaddr for INACTIVE nodes %s: %s" ,
print_with_count ( nodes_to_reset ),
e ,
)
def _get_ec2_instances ( self ):
"""
Get EC2 instance by describe_instances API.
Call is made by filtering on tags and includes non-terminating instances only
Instances returned will not contain instances previously terminated in _clean_up_inactive_partition
"""
# After reading Slurm nodes wait for 5 seconds to let instances appear in EC2 describe_instances call
time . sleep ( 5 )
log . info ( "Retrieving list of EC2 instances associated with the cluster" )
try :
return self . _instance_manager . get_cluster_instances ( include_head_node = False , alive_states_only = True )
except Exception as e :
log . error ( "Failed when getting instance info from EC2 with exception %s" , e )
raise ClusterManager . EC2InstancesInfoUnavailable
@ log_exception ( log , "maintaining slurm nodes down" , catch_exception = Exception , raise_on_error = False )
def _maintain_nodes_down ( self ):
update_all_partitions ( PartitionStatus . INACTIVE , reset_node_addrs_hostname = True )
self . _instance_manager . terminate_all_compute_nodes ( self . _config . terminate_max_batch_size )
@ log_exception ( log , "performing health check action" , catch_exception = Exception , raise_on_error = False )
def _perform_health_check_actions ( self , partitions : List [ SlurmPartition ]):
"""Run health check actions."""
log . info ( "Performing instance health check actions" )
instance_id_to_active_node_map = ClusterManager . get_instance_id_to_active_node_map ( partitions )
if not instance_id_to_active_node_map :
return
# Get health states for instances that might be considered unhealthy
unhealthy_instances_status = self . _instance_manager . get_unhealthy_cluster_instance_status (
list ( instance_id_to_active_node_map . keys ())
)
log . debug ( "Cluster instances that might be considered unhealthy: %s" , unhealthy_instances_status )
if unhealthy_instances_status :
# Perform EC2 health check actions
if not self . _config . disable_ec2_health_check :
self . _handle_health_check (
unhealthy_instances_status ,
instance_id_to_active_node_map ,
health_check_type = ClusterManager . HealthCheckTypes . ec2_health ,
)
# Perform scheduled event actions
if not self . _config . disable_scheduled_event_health_check :
self . _handle_health_check (
unhealthy_instances_status ,
instance_id_to_active_node_map ,
health_check_type = ClusterManager . HealthCheckTypes . scheduled_event ,
)
def _get_nodes_failing_health_check (
self , unhealthy_instances_status , instance_id_to_active_node_map , health_check_type
):
"""Get nodes fail health check."""
log . info ( "Performing actions for health check type: %s" , health_check_type )
nodes_failing_health_check = []
for instance_status in unhealthy_instances_status :
unhealthy_node = instance_id_to_active_node_map . get ( instance_status . id )
if unhealthy_node and self . _is_instance_unhealthy ( instance_status , health_check_type ):
nodes_failing_health_check . append ( unhealthy_node )
unhealthy_node . is_failing_health_check = True
log . warning (
"Node %s(%s) is associated with instance %s that is failing %s. EC2 health state: %s" ,
unhealthy_node . name ,
unhealthy_node . nodeaddr ,
instance_status . id ,
health_check_type ,
[
instance_status . id ,
instance_status . state ,
instance_status . instance_status ,
instance_status . system_status ,
instance_status . scheduled_events ,
],
)
return nodes_failing_health_check
def _is_instance_unhealthy ( self , instance_status : EC2InstanceHealthState , health_check_type ):
"""Check if instance status is unhealthy based on picked corresponding health check function."""
is_instance_status_unhealthy = False
if health_check_type == ClusterManager . HealthCheckTypes . scheduled_event :
is_instance_status_unhealthy = instance_status . fail_scheduled_events_check ()
elif health_check_type == ClusterManager . HealthCheckTypes . ec2_health :
is_instance_status_unhealthy = instance_status . fail_ec2_health_check (
self . _current_time , self . _config . health_check_timeout
)
return is_instance_status_unhealthy
def _update_static_nodes_in_replacement ( self , slurm_nodes : List [ SlurmNode ]):
"""Remove from self.static_nodes_in_replacement nodes that are up or that are in maintenance."""
nodename_to_slurm_nodes_map = { node . name : node for node in slurm_nodes }
nodes_still_in_replacement = set ()
for nodename in self . _static_nodes_in_replacement :
node = nodename_to_slurm_nodes_map . get ( nodename )
# Consider nodename still in replacement if node is not up and not in maintenance
if node and not node . is_up () and not node . is_in_maintenance ():
nodes_still_in_replacement . add ( nodename )
# override self._static_nodes_in_replacement with updated list
self . _static_nodes_in_replacement = nodes_still_in_replacement
# update node attributes
for node in slurm_nodes :
node . is_static_nodes_in_replacement = node . name in self . _static_nodes_in_replacement
node . is_being_replaced = self . _is_node_being_replaced ( node )
node . _is_replacement_timeout = self . _is_node_replacement_timeout ( node )
def _find_unhealthy_slurm_nodes ( self , slurm_nodes ):
"""
Find unhealthy static slurm nodes and dynamic slurm nodes.
Check and return slurm nodes with unhealthy and healthy scheduler state, grouping unhealthy nodes
by node type (static/dynamic).
"""
unhealthy_static_nodes = []
unhealthy_dynamic_nodes = []
ice_compute_resources_and_nodes_map = {}
all_unhealthy_nodes = []
# Remove the nodes part of inactive Capacity Blocks from the list of unhealthy nodes.
# Nodes from active Capacity Blocks will be instead managed as unhealthy instances.
reserved_nodenames = []
if not self . _config . disable_capacity_blocks_management :
reserved_nodenames = self . _capacity_block_manager . get_reserved_nodenames ( slurm_nodes )
if reserved_nodenames :
log . info (
"The nodes associated with inactive Capacity Blocks and not considered as unhealthy nodes are: %s" ,
"," . join ( reserved_nodenames ),
)
else :
log . debug ( "No nodes found associated with inactive Capacity Blocks." )
for node in slurm_nodes :
if not node . is_healthy (
consider_drain_as_unhealthy = self . _config . terminate_drain_nodes ,
consider_down_as_unhealthy = self . _config . terminate_down_nodes ,
ec2_instance_missing_max_count = self . _config . ec2_instance_missing_max_count ,
nodes_without_backing_instance_count_map = self . _nodes_without_backing_instance_count_map ,
log_warn_if_unhealthy = node . name not in reserved_nodenames ,
):
if not self . _config . disable_capacity_blocks_management and node . name in reserved_nodenames :
# do not consider as unhealthy the nodes reserved for capacity blocks
continue
all_unhealthy_nodes . append ( node )
if isinstance ( node , StaticNode ):
unhealthy_static_nodes . append ( node )
elif self . _config . disable_nodes_on_insufficient_capacity and node . is_ice ():
ice_compute_resources_and_nodes_map . setdefault ( node . queue_name , {}). setdefault (
node . compute_resource_name , []
). append ( node )
else :
unhealthy_dynamic_nodes . append ( node )
self . _event_publisher . publish_unhealthy_node_events (
all_unhealthy_nodes ,
self . _config . ec2_instance_missing_max_count ,
self . _nodes_without_backing_instance_count_map ,
)
return (
unhealthy_dynamic_nodes ,
unhealthy_static_nodes ,
ice_compute_resources_and_nodes_map ,
)
def _increase_partitions_protected_failure_count ( self , bootstrap_failure_nodes ):
"""Keep count of boostrap failures."""
for node in bootstrap_failure_nodes :
compute_resource = node . compute_resource_name
queue_name = node . queue_name
if queue_name in self . _partitions_protected_failure_count_map :
self . _partitions_protected_failure_count_map [ queue_name ][ compute_resource ] = (
self . _partitions_protected_failure_count_map [ queue_name ]. get ( compute_resource , 0 ) + 1
)
else :
self . _partitions_protected_failure_count_map [ queue_name ] = {}
self . _partitions_protected_failure_count_map [ queue_name ][ compute_resource ] = 1
@ log_exception ( log , "maintaining unhealthy dynamic nodes" , raise_on_error = False )
def _handle_unhealthy_dynamic_nodes ( self , unhealthy_dynamic_nodes ):
"""
Maintain any unhealthy dynamic node.
Terminate instances backing dynamic nodes.
Setting node to down will let slurm requeue jobs allocated to node.
Setting node to power_down will terminate backing instance and reset dynamic node for future use.
"""
instances_to_terminate = [ node . instance . id for node in unhealthy_dynamic_nodes if node . instance ]
if instances_to_terminate :
log . info ( "Terminating instances that are backing unhealthy dynamic nodes" )
self . _instance_manager . delete_instances (
instances_to_terminate , terminate_batch_size = self . _config . terminate_max_batch_size
)
log . info ( "Setting unhealthy dynamic nodes to down and power_down." )
set_nodes_power_down ([ node . name for node in unhealthy_dynamic_nodes ], reason = "Scheduler health check failed" )
@ log_exception ( log , "maintaining powering down nodes" , raise_on_error = False )
def _handle_powering_down_nodes ( self , slurm_nodes ):
"""
Handle nodes that are powering down and not already being replaced.
Terminate instances backing the powering down node if any.
Reset the nodeaddr for the powering down node. Node state is not changed.
"""
powering_down_nodes = [
node for node in slurm_nodes if node . is_powering_down_with_nodeaddr () and not node . is_being_replaced
]
if powering_down_nodes :
log . info ( "Resetting powering down nodes: %s" , print_with_count ( powering_down_nodes ))
reset_nodes ( nodes = [ node . name for node in powering_down_nodes ])
instances_to_terminate = [ node . instance . id for node in powering_down_nodes if node . instance ]
log . info ( "Terminating instances that are backing powering down nodes" )
self . _instance_manager . delete_instances (
instances_to_terminate , terminate_batch_size = self . _config . terminate_max_batch_size
)
@ log_exception ( log , "maintaining unhealthy static nodes" , raise_on_error = False )
def _handle_unhealthy_static_nodes ( self , unhealthy_static_nodes ):
"""
Maintain any unhealthy static node.
Set node to down, terminate backing instance, and launch new instance for static node.
"""
try :
wait_function = partial (
wait_remaining_time ,
wait_start_time = datetime . now ( tz = timezone . utc ),
total_wait_time = self . _config . compute_console_wait_time ,
wait_function = self . _task_executor . wait_unless_shutdown ,
)
self . _console_logger . report_console_output_from_nodes (
compute_instances = self . _instance_manager . get_compute_node_instances (
unhealthy_static_nodes ,
self . _config . compute_console_logging_max_sample_size ,
),
task_controller = self . _task_executor ,
task_wait_function = wait_function ,
)
except Exception as e :
log . error ( "Encountered exception when retrieving console output from unhealthy static nodes: %s" , e )
node_list = [ node . name for node in unhealthy_static_nodes ]
# Set nodes into down state so jobs can be requeued immediately
try :
log . info ( "Setting unhealthy static nodes to DOWN" )
reset_nodes ( node_list , state = "down" , reason = "Static node maintenance: unhealthy node is being replaced" )
except Exception as e :
log . error ( "Encountered exception when setting unhealthy static nodes into down state: %s" , e )
instances_to_terminate = [ node . instance . id for node in unhealthy_static_nodes if node . instance ]
if instances_to_terminate :
log . info ( "Terminating instances backing unhealthy static nodes" )
self . _instance_manager . delete_instances (
instances_to_terminate , terminate_batch_size = self . _config . terminate_max_batch_size
)
log . info ( "Launching new instances for unhealthy static nodes" )
self . _instance_manager . add_instances (
node_list = node_list ,
launch_batch_size = self . _config . launch_max_batch_size ,
assign_node_batch_size = self . _config . assign_node_max_batch_size ,
update_node_address = self . _config . update_node_address ,
scaling_strategy = ScalingStrategy . BEST_EFFORT ,
)
# Add launched nodes to list of nodes being replaced, excluding any nodes that failed to launch
failed_nodes = set (). union ( * self . _instance_manager . failed_nodes . values ())
launched_nodes = set ( node_list ) - failed_nodes
self . _static_nodes_in_replacement |= launched_nodes
log . info (
"After node maintenance, following nodes are currently in replacement: %s" ,
print_with_count ( self . _static_nodes_in_replacement ),
)
self . _event_publisher . publish_unhealthy_static_node_events (
unhealthy_static_nodes ,
self . _static_nodes_in_replacement ,
launched_nodes ,
self . _instance_manager . failed_nodes ,
)
@ log_exception ( log , "maintaining slurm nodes" , catch_exception = Exception , raise_on_error = False )
def _maintain_nodes ( self , partitions_name_map , compute_resource_nodes_map ):
"""
Call functions to maintain unhealthy nodes.
This function needs to handle the case that 2 slurm nodes have the same IP/nodeaddr.
A list of slurm nodes is passed in and slurm node map with IP/nodeaddr as key should be avoided.
"""
log . info ( "Performing node maintenance actions" )
# Retrieve nodes from Slurm partitions in ACTIVE state
active_nodes = self . _find_active_nodes ( partitions_name_map )
# Update self.static_nodes_in_replacement by removing from the set any node that is up or in maintenance
self . _update_static_nodes_in_replacement ( active_nodes )
log . info (
"Following nodes are currently in replacement: %s" , print_with_count ( self . _static_nodes_in_replacement )
)
# terminate powering down instances
self . _handle_powering_down_nodes ( active_nodes )
# retrieve and manage unhealthy nodes
(
unhealthy_dynamic_nodes ,
unhealthy_static_nodes ,
ice_compute_resources_and_nodes_map ,
) = self . _find_unhealthy_slurm_nodes ( active_nodes )
if unhealthy_dynamic_nodes :
log . info ( "Found the following unhealthy dynamic nodes: %s" , print_with_count ( unhealthy_dynamic_nodes ))
self . _handle_unhealthy_dynamic_nodes ( unhealthy_dynamic_nodes )
if unhealthy_static_nodes :
log . info ( "Found the following unhealthy static nodes: %s" , print_with_count ( unhealthy_static_nodes ))
self . _handle_unhealthy_static_nodes ( unhealthy_static_nodes )
# evaluate partitions to put in protected mode and ICEs nodes to terminate
if self . _is_protected_mode_enabled ():
self . _handle_protected_mode_process ( active_nodes , partitions_name_map )
if self . _config . disable_nodes_on_insufficient_capacity :
self . _handle_ice_nodes ( ice_compute_resources_and_nodes_map , compute_resource_nodes_map )
self . _handle_failed_health_check_nodes_in_replacement ( active_nodes )
@ log_exception ( log , "terminating orphaned instances" , catch_exception = Exception , raise_on_error = False )
def _terminate_orphaned_instances ( self , cluster_instances ):
"""Terminate instance not associated with any node and running longer than orphaned_instance_timeout."""
log . info ( "Checking for orphaned instance" )
instances_to_terminate = []
for instance in cluster_instances :
if not instance . slurm_node and time_is_up (
instance . launch_time , self . _current_time , self . _config . orphaned_instance_timeout
):
instances_to_terminate . append ( instance . id )
if instances_to_terminate :
log . info ( "Terminating orphaned instances" )
self . _instance_manager . delete_instances (
instances_to_terminate , terminate_batch_size = self . _config . terminate_max_batch_size
)
def _enter_protected_mode ( self , partitions_to_disable ):
"""Entering protected mode if no active running job in queue."""
# Place partitions into inactive
log . info ( "Placing bootstrap failure partitions to INACTIVE: %s" , partitions_to_disable )
update_partitions ( partitions_to_disable , PartitionStatus . INACTIVE )
# Change compute fleet status to protected
if not ComputeFleetStatus . is_protected ( self . _compute_fleet_status ):
log . warning (
"Setting cluster into protected mode due to failures detected in node provisioning. "
"Please investigate the issue and then use 'pcluster update-compute-fleet --status START_REQUESTED' "
"command to re-enable the fleet."
)
self . _update_compute_fleet_status ( ComputeFleetStatus . PROTECTED )
def _reset_partition_failure_count ( self , partition ):
"""Reset bootstrap failure count for partition which has bootstrap failure nodes successfully launched."""
log . info ( "Find successfully launched node in partition %s, reset partition protected failure count" , partition )
self . _partitions_protected_failure_count_map . pop ( partition , None )
def _handle_bootstrap_failure_nodes ( self , active_nodes ):
"""
Find bootstrap failure nodes and increase partition failure count.
There are two kinds of bootstrap failure nodes:
Nodes fail in bootstrap during health check,